Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 19 additions & 4 deletions src/omop_emb/backends/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@ class EmbeddingBackend(ABC, Generic[T]):
filters.
"""
DEFAULT_BASE_STORAGE_DIR = Path.home() / ".omop_emb"
DEFAULT_K_NEAREST = 10
def __init__(
self,
storage_base_dir: Optional[str | Path] = None,
Expand Down Expand Up @@ -383,7 +384,6 @@ def get_nearest_concepts(
query_embedding: ndarray,
metric_type: MetricType,
concept_filter: Optional[EmbeddingConceptFilter] = None,
k: int = 10,
_model_record: EmbeddingModelRecord,
) -> Tuple[Tuple[NearestConceptMatch, ...], ...]:
"""
Expand All @@ -402,9 +402,7 @@ def get_nearest_concepts(
metric_type : MetricType
Similarity or distance metric for nearest-neighbor search.
concept_filter : Optional[EmbeddingConceptFilter], optional
Optional filter restricting which OMOP concepts are considered.
k : int, optional
Number of nearest matches to return per query.
Optional filter restricting which OMOP concepts are considered. The "limit" field of the filter determines how many nearest neighbors are returned per query vector. If not set, defaults to the global DEFAULT_K_NEAREST.
_model_record : EmbeddingModelRecord
Internal registered-model record injected by ``@require_registered_model``.

Expand All @@ -413,6 +411,23 @@ def get_nearest_concepts(
Tuple[Tuple[NearestConceptMatch, ...], ...]
A tuple of tuples containing nearest concept matches for each query vector. The outer tuple corresponds to the query vectors in order, and each inner tuple contains the nearest matches for that query vector, sorted by similarity. Returned shape is (q, k) where q is the number of query vectors and k is the number of nearest neighbors returned per query.
"""

def validate_nearest_concepts_output(
self,
nearest_concepts: Tuple[Tuple[NearestConceptMatch, ...], ...],
k: int,
query_embeddings: ndarray,
) -> None:
assert all(len(d) <= k for d in nearest_concepts), (
f"Expected at most {k} nearest neighbors per query embedding, but found a dictionary with {max(len(d) for d in nearest_concepts)} entries."
)

assert len(nearest_concepts) == query_embeddings.shape[0], (
f"Expected nearest concepts for {query_embeddings.shape[0]} query embeddings, "
f"but got {len(nearest_concepts)}."
)



def has_any_embeddings(
self,
Expand Down
20 changes: 16 additions & 4 deletions src/omop_emb/backends/faiss/faiss_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -214,7 +214,6 @@ def get_nearest_concepts(
metric_type: MetricType,
*,
concept_filter: Optional[EmbeddingConceptFilter] = None,
k: int = 10,
_model_record: EmbeddingModelRecord,
) -> Tuple[Tuple[NearestConceptMatch, ...], ...]:

Expand All @@ -237,8 +236,15 @@ def get_nearest_concepts(
if concept_filter is None:
# Easier to not do any filter if all are allowed
permitted_concept_ids = None
logger.debug(f"No concept filter provided. Setting number of returned nearest concepts (k) to default: {self.DEFAULT_K_NEAREST}")
k = self.DEFAULT_K_NEAREST
else:
permitted_concept_ids = np.array(list(permitted_concept_ids_storage.keys()), dtype=np.int64)
if concept_filter.limit is not None:
k = concept_filter.limit
else:
logger.debug(f"Concept filter provided without limit. Setting number of returned nearest concepts (k) to default: {self.DEFAULT_K_NEAREST}")
k = self.DEFAULT_K_NEAREST

distances, concept_ids = storage_manager.search(
query_vector=query_embeddings,
Expand All @@ -258,16 +264,22 @@ def get_nearest_concepts(
if row is None:
logger.warning(f"Concept ID {concept_id} returned by FAISS search but not found in permitted concept IDs. This indicates a mismatch between the FAISS index and the database registry. Skipping this result.")
continue

similarity = get_similarity_from_distance(distance.item(), metric_type)
assert isinstance(similarity, float), f"Expected similarity to be a float, got {type(similarity)}"
matches_per_query.append(NearestConceptMatch(
concept_id=int(concept_id),
concept_name=row.concept_name,
similarity=get_similarity_from_distance(distance, metric_type),
similarity=similarity,
is_standard=bool(row.is_standard),
is_active=bool(row.is_active),
))
matches.append(tuple(matches_per_query))
return tuple(matches)


matches_tuple = tuple(matches)
self.validate_nearest_concepts_output(matches_tuple, k, query_embeddings=query_embeddings)
return matches_tuple

@require_registered_model
def get_embeddings_by_concept_ids(
self,
Expand Down
27 changes: 24 additions & 3 deletions src/omop_emb/backends/pgvector/pgvector_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from typing import Mapping, Optional, Sequence, Type, Tuple

from numpy import ndarray
import logging
from sqlalchemy import Engine, text
from sqlalchemy.orm import Session
from sqlalchemy.exc import IntegrityError
Expand All @@ -22,6 +23,8 @@
)
from omop_emb.model_registry import EmbeddingModelRecord

logger = logging.getLogger(__name__)

class PGVectorEmbeddingBackend(EmbeddingBackend[PGVectorConceptIDEmbeddingTable]):
"""
pgvector-backed embedding backend for postgresql databases.
Expand Down Expand Up @@ -166,7 +169,6 @@ def get_nearest_concepts(
query_embeddings: ndarray,
metric_type: MetricType,
concept_filter: Optional[EmbeddingConceptFilter] = None,
k: int = 10,
_model_record: EmbeddingModelRecord,
) -> Tuple[Tuple[NearestConceptMatch, ...], ...]:
"""
Expand Down Expand Up @@ -197,13 +199,27 @@ def get_nearest_concepts(
)
self.validate_embeddings(embeddings=query_embeddings, dimensions=_model_record.dimensions)

# Guarantee that concept_filter has a limit set for K nearest neighbors
if concept_filter is None or concept_filter.limit is None:
logger.debug(f"No concept filter or concept filter limit provided. Setting number of returned nearest concepts (k) to default: {self.DEFAULT_K_NEAREST}")

if concept_filter is None:
concept_filter = EmbeddingConceptFilter(limit=self.DEFAULT_K_NEAREST)
elif concept_filter.limit is None:
concept_filter = EmbeddingConceptFilter(
concept_ids=concept_filter.concept_ids,
domains=concept_filter.domains,
vocabularies=concept_filter.vocabularies,
require_standard=concept_filter.require_standard,
limit=self.DEFAULT_K_NEAREST,
)

query_list = query_embeddings.tolist()
query = q_embedding_nearest_concepts(
embedding_table=embedding_table,
query_embeddings=query_list,
metric_type=metric_type,
concept_filter=concept_filter,
limit=k
)

rows = session.execute(query).all()
Expand All @@ -220,5 +236,10 @@ def get_nearest_concepts(
)
)

return tuple(tuple(matches) for matches in results)
matches_tuple = tuple(tuple(matches) for matches in results)

k = concept_filter.limit
assert k is not None, "Internal error: concept_filter.limit should have been set to a non-None value by this point."
self.validate_nearest_concepts_output(matches_tuple, k, query_embeddings=query_embeddings)
return matches_tuple

5 changes: 2 additions & 3 deletions src/omop_emb/backends/pgvector/pgvector_sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,6 @@ def q_embedding_nearest_concepts(
query_embeddings: List[List[float]],
metric_type: MetricType,
concept_filter: Optional[EmbeddingConceptFilter] = None,
limit: int = 10,
) -> Select:
"""Constructs a SQL query to retrieve the nearest concepts for the given query embeddings, applying the specified metric and filters. The query uses a LATERAL join to compute distances/similarities for each query vector against all candidate concept embeddings, and then applies the necessary OMOP filters before returning the top K results per query vector.

Expand All @@ -148,7 +147,7 @@ def q_embedding_nearest_concepts(
metric_type : MetricType
The distance metric to use for nearest neighbor search (e.g., COSINE, L2, etc.). This will determine which pgvector distance function is used in the query.
concept_filter : Optional[EmbeddingConceptFilter], optional
An optional filter object containing criteria to filter the concepts (e.g., by concept_id, domain, vocabulary, standard_concept flag).
An optional filter object containing criteria to filter the concepts (e.g., by concept_id, domain, vocabulary, standard_concept flag). Also is used to limit the number of nearest neighbors (K) returned per query vector.
limit : int, optional
The number of nearest neighbors (K) to return for each query embedding, by default 10.
"""
Expand Down Expand Up @@ -186,7 +185,7 @@ def q_embedding_nearest_concepts(
if concept_filter:
inner_stmt = concept_filter.apply(inner_stmt)

lateral_subq = inner_stmt.limit(limit).lateral("top_k")
lateral_subq = inner_stmt.lateral("top_k")

# Joins the Q vectors to their K nearest neighbors
stmt = (
Expand Down
28 changes: 8 additions & 20 deletions src/omop_emb/interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,9 +156,7 @@ def get_nearest_concepts(
*,
metric_type: MetricType,
concept_filter: Optional[EmbeddingConceptFilter] = None,
k: int = 10,
) -> Tuple[Mapping[int, float], ...]:

"""
Return nearest stored concepts for the query embedding.

Expand All @@ -176,18 +174,12 @@ def get_nearest_concepts(
metric_type : MetricType
The similarity or distance metric to use for nearest neighbor search. This must be compatible with the index type used by the database.
concept_filter : Optional[EmbeddingConceptFilter], optional
A filter to specify which concepts to consider as potential nearest neighbors.
vocabularies : Optional[Tuple[str, ...]], optional
If provided, only consider concepts from these vocabularies as potential nearest neighbors.
require_standard : bool, optional
If True, only consider standard concepts as potential nearest neighbors. By default False.
k : int, optional
K nearest neighbors to return for each query vector. Default is 10.
A filter to specify which concepts to consider as potential nearest neighbors. The `limit` field of this filter determines the number of neighbors returned.

Returns
-------
Tuple[Mapping[int, float], ...]
A tuple of dictionaries containing nearest concept matches for each query vector. The outer tuple corresponds to the query vectors in order, and each inner dictionary contains the nearest matches for that query vector, sorted by similarity. Returned shape is (q, k) where q is the number of query vectors and k is the number of nearest neighbors returned per query.
A tuple of dictionaries containing nearest concept matches for each query vector. The outer tuple corresponds to the query vectors in order, and each inner dictionary contains the nearest matches for that query vector, sorted by similarity. Returned shape is (q, limit) where q is the number of query vectors and limit is the number of nearest neighbors returned per query as determined by the `concept_filter` argument or backend default.
"""
if not isinstance(metric_type, MetricType):
raise TypeError(
Expand All @@ -199,10 +191,9 @@ def get_nearest_concepts(
index_type=index_type,
query_embeddings=query_embedding,
concept_filter=concept_filter,
metric_type=metric_type,
k=k
metric_type=metric_type
)
return tuple({match_per_query.concept_id: match_per_query.similarity for match_per_query in match} for match in nearest_concepts)
return tuple({match.concept_id: match.similarity for match in matches_per_query} for matches_per_query in nearest_concepts)

def get_nearest_concepts_by_texts(
self,
Expand All @@ -213,13 +204,13 @@ def get_nearest_concepts_by_texts(
*,
metric_type: MetricType,
concept_filter: Optional[EmbeddingConceptFilter] = None,
k: int = 10,
batch_size: Optional[int] = None
) -> Tuple[Mapping[int, float], ...]:

"""
Return nearest stored concepts for the query embedding. Convenience wrapper that embeds the query texts before performing the nearest neighbor search.

The number of neighbors returned is determined by the `limit` field of the `concept_filter` argument. If `limit` is not set, a backend default may be used.

Parameters
----------
session : Session
Expand All @@ -233,16 +224,14 @@ def get_nearest_concepts_by_texts(
metric_type : MetricType
The similarity or distance metric to use for nearest neighbor search. This should be compatible with the index type used by the model.
concept_filter : Optional[EmbeddingConceptFilter], optional
A filter to specify which concepts to consider as potential nearest neighbors.
k : int, optional
K nearest neighbors to return for each query vector. Default is 10.
A filter to specify which concepts to consider as potential nearest neighbors. The `limit` field of this filter determines the number of neighbors returned.
batch_size : Optional[int], optional
If provided, this batch size will be used when embedding the query texts. If not provided, the default batch size of the embedding client will be used.

Returns
-------
Tuple[Mapping[int, float], ...]
A tuple of dictionaries containing nearest concept matches for each query vector. The outer tuple corresponds to the query vectors in order, and each inner dictionary contains the nearest matches for that query vector, sorted by similarity. Returned shape is (q, k) where q is the number of query vectors and k is the number of nearest neighbors returned per query.
A tuple of dictionaries containing nearest concept matches for each query vector. The outer tuple corresponds to the query vectors in order, and each inner dictionary contains the nearest matches for that query vector, sorted by similarity. Returned shape is (q, limit) where q is the number of query vectors and limit is the number of nearest neighbors returned per query.
"""
if isinstance(query_texts, str):
query_texts = (query_texts,)
Expand All @@ -258,7 +247,6 @@ def get_nearest_concepts_by_texts(
query_embedding=query_embeddings,
metric_type=metric_type,
concept_filter=concept_filter,
k=k,
)

def get_embeddings_by_concept_ids(
Expand Down
5 changes: 4 additions & 1 deletion src/omop_emb/utils/embedding_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,15 @@ class EmbeddingConceptFilter:

This mirrors the current OMOP grounding needs without importing
``omop_graph`` or its search-constraint objects into ``omop_emb``.

The `limit` field determines the number of nearest neighbors returned by embedding search operations. If not set, a backend default may be used.
"""

concept_ids: Optional[tuple[int, ...]] = None
domains: Optional[tuple[str, ...]] = None
vocabularies: Optional[tuple[str, ...]] = None
require_standard: bool = False
limit: Optional[int] = None

def apply(self, query: Select) -> Select:
if self.concept_ids is not None:
Expand All @@ -35,7 +38,7 @@ def apply(self, query: Select) -> Select:
if self.require_standard:
query = query.where(Concept.standard_concept.in_(["S", "C"]))

return query
return query.limit(self.limit)


@dataclass(frozen=True)
Expand Down
13 changes: 5 additions & 8 deletions tests/shared_backend_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,7 +155,7 @@ def test_nearest_neighbor_search(self, session, backend, mock_llm_client, index_
index_type=index_type,
query_embeddings=query_embeddings,
metric_type=MetricType.L2,
k=1,
concept_filter=EmbeddingConceptFilter(limit=1),
)

assert len(results) == 1
Expand All @@ -182,15 +182,13 @@ def test_nearest_neighbor_with_domain_filter(self, session, backend, mock_llm_cl
concept_ids=concept_ids,
embeddings=embeddings,
)

results = backend.get_nearest_concepts(
session=session,
model_name=MODEL_NAME,
index_type=index_type,
query_embeddings=TEST_CONCEPT_EMB,
concept_filter=EmbeddingConceptFilter(domains=("Condition",)),
concept_filter=EmbeddingConceptFilter(domains=("Condition",), limit=10),
metric_type=MetricType.L2,
k=10,
)

expected_ids = {CONCEPTS["Hypertension"].concept_id, CONCEPTS["Diabetes"].concept_id}
Expand Down Expand Up @@ -223,9 +221,8 @@ def test_nearest_neighbor_with_vocabulary_filter(self, session, backend, mock_ll
model_name=MODEL_NAME,
index_type=index_type,
query_embeddings=TEST_CONCEPT_EMB,
concept_filter=EmbeddingConceptFilter(vocabularies=("RxNorm",)),
concept_filter=EmbeddingConceptFilter(vocabularies=("RxNorm",), limit=10),
metric_type=MetricType.L2,
k=10,
)

assert len(results) == 1
Expand Down Expand Up @@ -297,7 +294,7 @@ def test_l2_similarity_exact_values(self, session, backend, mock_llm_client, ind
index_type=index_type,
query_embeddings=TEST_CONCEPT_EMB,
metric_type=MetricType.L2,
k=10,
concept_filter=EmbeddingConceptFilter(limit=10),
)

expected_similarities = {
Expand Down Expand Up @@ -351,7 +348,7 @@ def test_cosine_similarity_exact_values(self, session, backend, mock_llm_client,
index_type=index_type,
query_embeddings=TEST_CONCEPT_EMB,
metric_type=MetricType.COSINE,
k=10,
concept_filter=EmbeddingConceptFilter(limit=10),
)

# For 1D unit vectors, cos_sim = (norm_a) * (norm_b)
Expand Down
4 changes: 2 additions & 2 deletions tests/test_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,9 @@
import pytest
import numpy as np
from unittest.mock import Mock
from sqlalchemy.orm import Session

from omop_emb.interface import EmbeddingInterface
from omop_emb.utils.embedding_utils import EmbeddingConceptFilter
from omop_emb.config import IndexType, MetricType
from omop_emb.backends.base import NearestConceptMatch
from .conftest import CONCEPTS, MODEL_NAME, EMBEDDING_DIM
Expand Down Expand Up @@ -162,7 +162,7 @@ def test_search_return_structure_for_backends(self, session, mock_llm_client, ba
index_type=index_type,
query_embedding=query_embedding,
metric_type=MetricType.COSINE,
k=2,
concept_filter=EmbeddingConceptFilter(limit=2),
)

assert isinstance(result, tuple)
Expand Down
Loading