diff --git a/openmetadata-mcp/src/main/java/org/openmetadata/mcp/tools/SemanticSearchTool.java b/openmetadata-mcp/src/main/java/org/openmetadata/mcp/tools/SemanticSearchTool.java index 6f9e12def868..4ec10b65d2d6 100644 --- a/openmetadata-mcp/src/main/java/org/openmetadata/mcp/tools/SemanticSearchTool.java +++ b/openmetadata-mcp/src/main/java/org/openmetadata/mcp/tools/SemanticSearchTool.java @@ -10,7 +10,7 @@ import org.openmetadata.schema.utils.JsonUtils; import org.openmetadata.service.Entity; import org.openmetadata.service.limits.Limits; -import org.openmetadata.service.search.vector.OpenSearchVectorService; +import org.openmetadata.service.search.vector.VectorIndexService; import org.openmetadata.service.search.vector.utils.DTOs.VectorSearchResponse; import org.openmetadata.service.security.Authorizer; import org.openmetadata.service.security.auth.CatalogSecurityContext; @@ -42,7 +42,7 @@ public Map execute( "Semantic search is not enabled. Configure vector embeddings in the OpenMetadata server settings."); } - OpenSearchVectorService vectorService = OpenSearchVectorService.getInstance(); + VectorIndexService vectorService = Entity.getSearchRepository().getVectorIndexService(); if (vectorService == null) { return errorResponse("Vector search service is not initialized"); } diff --git a/openmetadata-service/src/main/java/org/openmetadata/service/apps/bundles/searchIndex/ElasticSearchBulkSink.java b/openmetadata-service/src/main/java/org/openmetadata/service/apps/bundles/searchIndex/ElasticSearchBulkSink.java index f1e7c990b690..3377f0ec1057 100644 --- a/openmetadata-service/src/main/java/org/openmetadata/service/apps/bundles/searchIndex/ElasticSearchBulkSink.java +++ b/openmetadata-service/src/main/java/org/openmetadata/service/apps/bundles/searchIndex/ElasticSearchBulkSink.java @@ -15,6 +15,7 @@ import java.io.StringWriter; import java.nio.charset.StandardCharsets; import java.util.ArrayList; +import java.util.Collections; import java.util.List; import java.util.Map; import java.util.Optional; @@ -51,6 +52,9 @@ import org.openmetadata.service.search.elasticsearch.ElasticSearchClient; import org.openmetadata.service.search.elasticsearch.EsUtils; import org.openmetadata.service.search.indexes.ColumnSearchIndex; +import org.openmetadata.service.search.vector.ElasticSearchVectorService; +import org.openmetadata.service.search.vector.VectorDocBuilder; +import org.openmetadata.service.search.vector.utils.AvailableEntityTypes; /** * Elasticsearch implementation using new Java API client with custom bulk handler @@ -125,6 +129,10 @@ public static synchronized void resetDocBuildPoolSize() { private final ConcurrentLinkedDeque> pendingColumnFutures = new ConcurrentLinkedDeque<>(); + // Vector embedding stats (incremented inline during addEntity) + private final AtomicLong vectorSuccess = new AtomicLong(0); + private final AtomicLong vectorFailed = new AtomicLong(0); + public ElasticSearchBulkSink( SearchRepository searchRepository, int batchSize, @@ -243,13 +251,28 @@ public void write(List entities, Map contextData) throws Exce } else { List entityInterfaces = (List) entities; - // Add entities to search index in parallel + boolean embeddingsEnabled = isVectorEmbeddingEnabledForEntity(entityType); + + Map existingFingerprints = Collections.emptyMap(); + if (embeddingsEnabled && !recreateIndex) { + existingFingerprints = fetchExistingFingerprints(entityInterfaces, indexName); + } + + Map finalFingerprints = existingFingerprints; List> futures = entityInterfaces.stream() .map( entity -> CompletableFuture.runAsync( - () -> addEntity(entity, indexName, recreateIndex, tracker), + () -> + addEntity( + entity, + indexName, + recreateIndex, + reindexContext, + tracker, + embeddingsEnabled, + finalFingerprints), DOC_BUILD_EXECUTOR)) .toList(); CompletableFuture.allOf(futures.toArray(CompletableFuture[]::new)).join(); @@ -300,11 +323,22 @@ protected StageStatsTracker extractTracker(Map contextData) { private static final int BULK_OPERATION_METADATA_OVERHEAD = 150; private void addEntity( - EntityInterface entity, String indexName, boolean recreateIndex, StageStatsTracker tracker) { + EntityInterface entity, + String indexName, + boolean recreateIndex, + ReindexContext reindexContext, + StageStatsTracker tracker, + boolean embeddingsEnabled, + Map existingFingerprints) { try { String entityType = Entity.getEntityTypeFromObject(entity); Object searchIndexDoc = Entity.buildSearchIndex(entityType, entity).buildSearchIndexDoc(); String json = JsonUtils.pojoToJson(searchIndexDoc); + + if (embeddingsEnabled) { + json = enrichWithEmbedding(entity, json, recreateIndex, existingFingerprints, tracker); + } + String docId = entity.getId().toString(); long rawDocSize = (long) json.getBytes(StandardCharsets.UTF_8).length; long estimatedSize = rawDocSize + BULK_OPERATION_METADATA_OVERHEAD; @@ -749,6 +783,83 @@ public void updateConcurrentRequests(int concurrentRequests) { LOG.info("Concurrent requests updated to: {}", concurrentRequests); } + boolean isVectorEmbeddingEnabledForEntity(String entityType) { + return searchRepository.isVectorEmbeddingEnabled() + && ElasticSearchVectorService.getInstance() != null + && AvailableEntityTypes.isVectorIndexable(entityType); + } + + @SuppressWarnings("unchecked") + private String enrichWithEmbedding( + EntityInterface entity, + String json, + boolean recreateIndex, + Map existingFingerprints, + StageStatsTracker tracker) { + try { + ElasticSearchVectorService vectorService = ElasticSearchVectorService.getInstance(); + if (vectorService == null) { + return json; + } + + if (!recreateIndex) { + String currentFp = VectorDocBuilder.computeFingerprintForEntity(entity); + String existingFp = existingFingerprints.get(entity.getId().toString()); + if (existingFp != null && existingFp.equals(currentFp)) { + vectorSuccess.incrementAndGet(); + if (tracker != null) { + tracker.recordVector(StatsResult.SUCCESS); + } + return json; + } + } + + Map embeddingFields = vectorService.generateEmbeddingFields(entity); + Map docMap = OBJECT_MAPPER.readValue(json, Map.class); + docMap.putAll(embeddingFields); + + vectorSuccess.incrementAndGet(); + if (tracker != null) { + tracker.recordVector(StatsResult.SUCCESS); + } + return OBJECT_MAPPER.writeValueAsString(docMap); + } catch (Exception e) { + LOG.warn( + "Failed to generate embeddings for entity {}: {}", entity.getId(), e.getMessage(), e); + vectorFailed.incrementAndGet(); + if (tracker != null) { + tracker.recordVector(StatsResult.FAILED); + } + return json; + } + } + + private Map fetchExistingFingerprints( + List entities, String indexName) { + try { + ElasticSearchVectorService vectorService = ElasticSearchVectorService.getInstance(); + if (vectorService == null) { + return Collections.emptyMap(); + } + List entityIds = new ArrayList<>(entities.size()); + for (EntityInterface entity : entities) { + entityIds.add(entity.getId().toString()); + } + return vectorService.getExistingFingerprintsBatch(indexName, entityIds); + } catch (Exception e) { + LOG.warn("Failed to fetch existing fingerprints: {}", e.getMessage()); + return Collections.emptyMap(); + } + } + + @Override + public StepStats getVectorStats() { + return new StepStats() + .withTotalRecords((int) (vectorSuccess.get() + vectorFailed.get())) + .withSuccessRecords((int) vectorSuccess.get()) + .withFailedRecords((int) vectorFailed.get()); + } + public static class CustomBulkProcessor { private final ElasticsearchAsyncClient asyncClient; private final List buffer = new ArrayList<>(); diff --git a/openmetadata-service/src/main/java/org/openmetadata/service/resources/search/VectorSearchResource.java b/openmetadata-service/src/main/java/org/openmetadata/service/resources/search/VectorSearchResource.java index 06678c6662e3..ed7018dccabf 100644 --- a/openmetadata-service/src/main/java/org/openmetadata/service/resources/search/VectorSearchResource.java +++ b/openmetadata-service/src/main/java/org/openmetadata/service/resources/search/VectorSearchResource.java @@ -1,23 +1,28 @@ package org.openmetadata.service.resources.search; import io.swagger.v3.oas.annotations.Operation; +import io.swagger.v3.oas.annotations.Parameter; import io.swagger.v3.oas.annotations.media.Content; import io.swagger.v3.oas.annotations.media.Schema; import io.swagger.v3.oas.annotations.responses.ApiResponse; import io.swagger.v3.oas.annotations.tags.Tag; import jakarta.ws.rs.Consumes; +import jakarta.ws.rs.GET; import jakarta.ws.rs.POST; import jakarta.ws.rs.Path; import jakarta.ws.rs.Produces; +import jakarta.ws.rs.QueryParam; import jakarta.ws.rs.core.Context; import jakarta.ws.rs.core.MediaType; import jakarta.ws.rs.core.Response; import jakarta.ws.rs.core.SecurityContext; import java.util.Collections; +import java.util.UUID; import lombok.extern.slf4j.Slf4j; import org.openmetadata.service.Entity; import org.openmetadata.service.resources.Collection; -import org.openmetadata.service.search.vector.OpenSearchVectorService; +import org.openmetadata.service.search.vector.VectorIndexService; +import org.openmetadata.service.search.vector.utils.DTOs.FingerprintResponse; import org.openmetadata.service.search.vector.utils.DTOs.VectorSearchRequest; import org.openmetadata.service.search.vector.utils.DTOs.VectorSearchResponse; import org.openmetadata.service.security.Authorizer; @@ -70,7 +75,7 @@ public Response vectorSearchPost( .build(); } - OpenSearchVectorService vectorService = OpenSearchVectorService.getInstance(); + VectorIndexService vectorService = Entity.getSearchRepository().getVectorIndexService(); if (vectorService == null) { return Response.status(Response.Status.SERVICE_UNAVAILABLE) .entity("{\"error\":\"Vector search service is not initialized\"}") @@ -97,4 +102,57 @@ public Response vectorSearchPost( .build(); } } + + @GET + @Path("/fingerprint") + @Operation( + operationId = "getFingerprint", + summary = "Get vector fingerprint", + description = "Returns the existing fingerprint for a given entity.") + public Response getFingerprint( + @Context SecurityContext securityContext, + @Parameter(description = "Parent entity ID", required = true) @QueryParam("parentId") + String parentId) { + authorizer.authorizeAdmin(securityContext); + + if (!Entity.getSearchRepository().isVectorEmbeddingEnabled()) { + return Response.status(Response.Status.SERVICE_UNAVAILABLE) + .entity("{\"error\":\"Vector search is not enabled\"}") + .build(); + } + + VectorIndexService vectorService = Entity.getSearchRepository().getVectorIndexService(); + if (vectorService == null) { + return Response.status(Response.Status.SERVICE_UNAVAILABLE) + .entity("{\"error\":\"Vector search service is not initialized\"}") + .build(); + } + + if (parentId == null || parentId.isBlank()) { + return Response.status(Response.Status.BAD_REQUEST) + .entity("{\"error\":\"parentId is required\"}") + .build(); + } + try { + UUID.fromString(parentId); + } catch (IllegalArgumentException e) { + return Response.status(Response.Status.BAD_REQUEST) + .entity("{\"error\":\"Invalid parentId format\"}") + .build(); + } + + try { + String indexName = vectorService.getIndexAlias(); + String fingerprint = vectorService.getExistingFingerprint(indexName, parentId); + FingerprintResponse response = + new FingerprintResponse( + parentId, indexName, fingerprint, fingerprint != null ? "Found" : "Not found"); + return Response.ok(response).build(); + } catch (Exception e) { + LOG.error("Failed to get fingerprint for {}: {}", parentId, e.getMessage(), e); + return Response.status(Response.Status.INTERNAL_SERVER_ERROR) + .entity("{\"error\":\"An internal error occurred\"}") + .build(); + } + } } diff --git a/openmetadata-service/src/main/java/org/openmetadata/service/search/RecreateWithEmbeddings.java b/openmetadata-service/src/main/java/org/openmetadata/service/search/RecreateWithEmbeddings.java index 98781b51a272..8c041f6e3d2c 100644 --- a/openmetadata-service/src/main/java/org/openmetadata/service/search/RecreateWithEmbeddings.java +++ b/openmetadata-service/src/main/java/org/openmetadata/service/search/RecreateWithEmbeddings.java @@ -9,8 +9,7 @@ public class RecreateWithEmbeddings extends DefaultRecreateHandler { @Override public ReindexContext reCreateIndexes(Set entities) { - SearchRepository searchRepository = Entity.getSearchRepository(); - searchRepository.initializeVectorSearchService(); + Entity.getSearchRepository().initializeVectorSearchService(); return super.reCreateIndexes(entities); } @@ -18,13 +17,10 @@ public ReindexContext reCreateIndexes(Set entities) { public void finalizeReindex(EntityReindexContext context, boolean reindexSuccess) { super.finalizeReindex(context, reindexSuccess); - if (reindexSuccess) { - SearchRepository searchRepository = Entity.getSearchRepository(); - if (searchRepository.isVectorEmbeddingEnabled()) { - LOG.info( - "Reindex finalized for entity type '{}' with vector embeddings enabled", - context.getEntityType()); - } + if (reindexSuccess && Entity.getSearchRepository().isVectorEmbeddingEnabled()) { + LOG.info( + "Reindex finalized for entity type '{}' with vector embeddings enabled", + context.getEntityType()); } } } diff --git a/openmetadata-service/src/main/java/org/openmetadata/service/search/SearchRepository.java b/openmetadata-service/src/main/java/org/openmetadata/service/search/SearchRepository.java index 9028199b2b36..9585dfb3963a 100644 --- a/openmetadata-service/src/main/java/org/openmetadata/service/search/SearchRepository.java +++ b/openmetadata-service/src/main/java/org/openmetadata/service/search/SearchRepository.java @@ -135,7 +135,9 @@ import org.openmetadata.service.search.nlq.NLQService; import org.openmetadata.service.search.nlq.NLQServiceFactory; import org.openmetadata.service.search.opensearch.OpenSearchClient; +import org.openmetadata.service.search.vector.ElasticSearchVectorService; import org.openmetadata.service.search.vector.OpenSearchVectorService; +import org.openmetadata.service.search.vector.VectorSearchQueryBuilder; import org.openmetadata.service.search.vector.VectorEmbeddingHandler; import org.openmetadata.service.search.vector.VectorIndexService; import org.openmetadata.service.search.vector.client.BedrockEmbeddingClient; @@ -416,9 +418,11 @@ public synchronized void initializeVectorSearchService() { OpenSearchVectorService.init(osClient, embeddingClient); this.vectorIndexService = OpenSearchVectorService.getInstance(); } else { - LOG.warn( - "Vector embedding is only supported with OpenSearch. Elasticsearch support is planned."); - return; + es.co.elastic.clients.elasticsearch.ElasticsearchClient esClient = + ((ElasticSearchClient) getSearchClient()).getNewClient(); + int knnMultiplier = resolveKnnNumCandidatesMultiplier(cfg); + ElasticSearchVectorService.init(esClient, embeddingClient, language, knnMultiplier); + this.vectorIndexService = ElasticSearchVectorService.getInstance(); } this.vectorEmbeddingHandler = new VectorEmbeddingHandler(vectorIndexService); @@ -438,6 +442,15 @@ public synchronized void initializeVectorSearchService() { } } + private static int resolveKnnNumCandidatesMultiplier(ElasticSearchConfiguration cfg) { + NaturalLanguageSearchConfiguration nlCfg = cfg.getNaturalLanguageSearch(); + if (nlCfg != null && nlCfg.getKnnNumCandidatesMultiplier() != null + && nlCfg.getKnnNumCandidatesMultiplier() >= 1) { + return nlCfg.getKnnNumCandidatesMultiplier(); + } + return VectorSearchQueryBuilder.DEFAULT_KNN_NUM_CANDIDATES_MULTIPLIER; + } + public void ensureHybridSearchPipeline() { if (!isVectorEmbeddingEnabled() || !vectorServiceInitialized) { return; @@ -588,10 +601,9 @@ public void deleteIndex(IndexMapping indexMapping) { } private String getIndexMapping(IndexMapping indexMapping) { + String mappingFile = indexMapping.getIndexMappingFile(); try (InputStream in = - getClass() - .getResourceAsStream( - String.format(indexMapping.getIndexMappingFile(), language.toLowerCase()))) { + getClass().getResourceAsStream(String.format(mappingFile, language.toLowerCase()))) { assert in != null; return new String(in.readAllBytes()); } catch (Exception e) { @@ -601,11 +613,7 @@ private String getIndexMapping(IndexMapping indexMapping) { } public String readIndexMapping(IndexMapping indexMapping) { - String mapping = getIndexMapping(indexMapping); - if (isVectorEmbeddingEnabled() && embeddingClient != null && mapping != null) { - mapping = reformatVectorIndexWithDimension(mapping, embeddingClient.getDimension()); - } - return mapping; + return getIndexMapping(indexMapping); } /** @@ -3035,27 +3043,6 @@ private static List copyWithInheritedFlag(List return inheritedReferences; } - private String reformatVectorIndexWithDimension(String mapping, int dimension) { - try { - com.fasterxml.jackson.databind.ObjectMapper mapper = - new com.fasterxml.jackson.databind.ObjectMapper(); - JsonNode root = mapper.readTree(mapping); - if (root.has("mappings")) { - JsonNode mappings = root.get("mappings"); - com.fasterxml.jackson.databind.node.ObjectNode meta = - ((com.fasterxml.jackson.databind.node.ObjectNode) mappings).putObject("_meta"); - meta.put( - "embedding_model", - embeddingClient != null ? embeddingClient.getModelId() : "unknown") - .put("embedding_dimension", dimension); - } - return mapper.writeValueAsString(root); - } catch (Exception e) { - LOG.warn("Failed to set embedding _meta in mapping JSON", e); - return mapping; - } - } - protected EmbeddingClient createEmbeddingClient(ElasticSearchConfiguration esConfig) { NaturalLanguageSearchConfiguration config = esConfig.getNaturalLanguageSearch(); String provider = diff --git a/openmetadata-service/src/main/java/org/openmetadata/service/search/elasticsearch/ElasticSearchIndexManager.java b/openmetadata-service/src/main/java/org/openmetadata/service/search/elasticsearch/ElasticSearchIndexManager.java index dcd56dd53d5b..dfa42955bb9a 100644 --- a/openmetadata-service/src/main/java/org/openmetadata/service/search/elasticsearch/ElasticSearchIndexManager.java +++ b/openmetadata-service/src/main/java/org/openmetadata/service/search/elasticsearch/ElasticSearchIndexManager.java @@ -1,5 +1,7 @@ package org.openmetadata.service.search.elasticsearch; +import com.fasterxml.jackson.databind.JsonNode; +import com.fasterxml.jackson.databind.ObjectMapper; import es.co.elastic.clients.elasticsearch.ElasticsearchClient; import es.co.elastic.clients.elasticsearch._types.ElasticsearchException; import es.co.elastic.clients.elasticsearch.indices.CreateIndexRequest; @@ -31,6 +33,7 @@ */ @Slf4j public class ElasticSearchIndexManager implements IndexManagementClient { + private static final ObjectMapper MAPPER = new ObjectMapper(); private final ElasticsearchClient client; private final String clusterAlias; private final boolean isClientAvailable; @@ -83,12 +86,17 @@ public void updateIndex(IndexMapping indexMapping, String indexMappingContent) { try { String indexName = indexMapping.getIndexName(clusterAlias); + String transformedContent = + (indexMappingContent != null && !indexMappingContent.isEmpty()) + ? EsUtils.enrichIndexMappingForElasticsearch(indexMappingContent) + : indexMappingContent; + String mappingsJson = extractMappingsJson(transformedContent); PutMappingRequest request = PutMappingRequest.of( builder -> { builder.index(indexName); - if (indexMappingContent != null) { - builder.withJson(new StringReader(indexMappingContent)); + if (mappingsJson != null) { + builder.withJson(new StringReader(mappingsJson)); } return builder; }); @@ -141,14 +149,36 @@ public void createIndex(String indexName, String indexMappingContent) { } } + private String extractMappingsJson(String indexMappingContent) { + if (indexMappingContent == null) { + return null; + } + try { + JsonNode root = MAPPER.readTree(indexMappingContent); + JsonNode mappings = root.get("mappings"); + if (mappings != null) { + return MAPPER.writeValueAsString(mappings); + } + return indexMappingContent; + } catch (IOException e) { + LOG.warn( + "Failed to extract mappings from index content, using full content: {}", e.getMessage()); + return indexMappingContent; + } + } + private void createIndexInternal(String indexName, String indexMappingContent) throws IOException { + String enrichedContent = + (indexMappingContent != null && !indexMappingContent.isEmpty()) + ? EsUtils.enrichIndexMappingForElasticsearch(indexMappingContent) + : indexMappingContent; CreateIndexRequest request = CreateIndexRequest.of( builder -> { builder.index(indexName); - if (indexMappingContent != null) { - builder.withJson(new StringReader(indexMappingContent)); + if (enrichedContent != null) { + builder.withJson(new StringReader(enrichedContent)); } return builder; }); diff --git a/openmetadata-service/src/main/java/org/openmetadata/service/search/elasticsearch/EsUtils.java b/openmetadata-service/src/main/java/org/openmetadata/service/search/elasticsearch/EsUtils.java index 8cdb92dbef79..6098fdad1c9e 100644 --- a/openmetadata-service/src/main/java/org/openmetadata/service/search/elasticsearch/EsUtils.java +++ b/openmetadata-service/src/main/java/org/openmetadata/service/search/elasticsearch/EsUtils.java @@ -524,4 +524,56 @@ private static void buildSearchSourceFilter( } } } + + /** + * Enriches an Elasticsearch index mapping with vector search support. When the mapping contains + * a {@code fingerprint} field (the signal that this index stores embedded entity docs), injects a + * {@code dense_vector} embedding field and records {@code _meta} with the model ID and dimension. + * + *

The embedding dimension is resolved from the active {@link + * org.openmetadata.service.search.vector.client.EmbeddingClient}. If embeddings are disabled or + * the client is unavailable the mapping is returned unchanged. + */ + public static String enrichIndexMappingForElasticsearch(String indexMappingContent) { + if (nullOrEmpty(indexMappingContent)) { + throw new IllegalArgumentException("Empty Index Mapping Content."); + } + JsonNode rootNode = JsonUtils.readTree(indexMappingContent); + addDenseVectorSettings(rootNode); + return rootNode.toString(); + } + + static void addDenseVectorSettings(JsonNode rootNode) { + JsonNode properties = rootNode.path("mappings").path("properties"); + if (properties.isMissingNode() || !properties.has("fingerprint")) { + return; + } + + org.openmetadata.service.search.SearchRepository searchRepository = + org.openmetadata.service.Entity.getSearchRepository(); + if (searchRepository == null + || !searchRepository.isVectorEmbeddingEnabled() + || searchRepository.getEmbeddingClient() == null) { + return; + } + + int dimension = searchRepository.getEmbeddingClient().getDimension(); + + com.fasterxml.jackson.databind.node.ObjectNode embeddingNode = mapper.createObjectNode(); + embeddingNode.put("type", "dense_vector"); + embeddingNode.put("dims", dimension); + embeddingNode.put("index", true); + embeddingNode.put("similarity", "cosine"); + ((com.fasterxml.jackson.databind.node.ObjectNode) properties).set("embedding", embeddingNode); + + JsonNode mappings = rootNode.path("mappings"); + if (!mappings.isMissingNode()) { + com.fasterxml.jackson.databind.node.ObjectNode meta = + ((com.fasterxml.jackson.databind.node.ObjectNode) mappings).putObject("_meta"); + meta.put( + "embedding_model", + searchRepository.getEmbeddingClient().getModelId()) + .put("embedding_dimension", dimension); + } + } } diff --git a/openmetadata-service/src/main/java/org/openmetadata/service/search/vector/ElasticSearchVectorService.java b/openmetadata-service/src/main/java/org/openmetadata/service/search/vector/ElasticSearchVectorService.java new file mode 100644 index 000000000000..22a7c963c5f1 --- /dev/null +++ b/openmetadata-service/src/main/java/org/openmetadata/service/search/vector/ElasticSearchVectorService.java @@ -0,0 +1,345 @@ +package org.openmetadata.service.search.vector; + +import com.fasterxml.jackson.databind.JsonNode; +import com.fasterxml.jackson.databind.ObjectMapper; +import es.co.elastic.clients.elasticsearch.ElasticsearchClient; +import es.co.elastic.clients.transport.rest5_client.Rest5ClientTransport; +import es.co.elastic.clients.transport.rest5_client.low_level.Request; +import es.co.elastic.clients.transport.rest5_client.low_level.Response; +import es.co.elastic.clients.transport.rest5_client.low_level.Rest5Client; +import java.io.IOException; +import java.io.InputStream; +import java.nio.charset.StandardCharsets; +import java.util.ArrayList; +import java.util.Collections; +import java.util.HashMap; +import java.util.LinkedHashMap; +import java.util.List; +import java.util.Map; +import lombok.Getter; +import lombok.extern.slf4j.Slf4j; +import org.openmetadata.schema.EntityInterface; +import org.openmetadata.service.events.lifecycle.EntityLifecycleEventDispatcher; +import org.openmetadata.service.search.vector.client.EmbeddingClient; +import org.openmetadata.service.search.vector.utils.DTOs.VectorSearchResponse; + +@Slf4j +public class ElasticSearchVectorService implements VectorIndexService { + private static final ObjectMapper MAPPER = new ObjectMapper(); + private static final int OVER_FETCH_MULTIPLIER = 2; + + private static volatile ElasticSearchVectorService instance; + + private final ElasticsearchClient client; + private final Rest5Client restClient; + @Getter private final EmbeddingClient embeddingClient; + private final String language; + private final int knnNumCandidatesMultiplier; + + public ElasticSearchVectorService( + ElasticsearchClient client, + EmbeddingClient embeddingClient, + String language, + int knnNumCandidatesMultiplier) { + this.client = client; + this.restClient = extractRestClient(client); + this.embeddingClient = embeddingClient; + this.language = language != null ? language.toLowerCase(java.util.Locale.ROOT) : "en"; + this.knnNumCandidatesMultiplier = + knnNumCandidatesMultiplier > 0 + ? knnNumCandidatesMultiplier + : VectorSearchQueryBuilder.DEFAULT_KNN_NUM_CANDIDATES_MULTIPLIER; + } + + public ElasticSearchVectorService( + ElasticsearchClient client, EmbeddingClient embeddingClient, String language) { + this( + client, + embeddingClient, + language, + VectorSearchQueryBuilder.DEFAULT_KNN_NUM_CANDIDATES_MULTIPLIER); + } + + public ElasticSearchVectorService(ElasticsearchClient client, EmbeddingClient embeddingClient) { + this(client, embeddingClient, "en"); + } + + private static Rest5Client extractRestClient(ElasticsearchClient client) { + Rest5ClientTransport transport = (Rest5ClientTransport) client._transport(); + return transport.restClient(); + } + + public static synchronized void init( + ElasticsearchClient client, + EmbeddingClient embeddingClient, + String language, + int knnNumCandidatesMultiplier) { + if (instance != null) { + LOG.warn("ElasticSearchVectorService already initialized, reinitializing"); + EntityLifecycleEventDispatcher.getInstance().unregisterHandler("VectorEmbeddingHandler"); + } + ElasticSearchVectorService svc = + new ElasticSearchVectorService(client, embeddingClient, language, knnNumCandidatesMultiplier); + svc.registerVectorEmbeddingHandler(); + instance = svc; + LOG.info( + "ElasticSearchVectorService initialized with model={}, dimension={}", + embeddingClient.getModelId(), + embeddingClient.getDimension()); + } + + public static synchronized void init( + ElasticsearchClient client, EmbeddingClient embeddingClient, String language) { + init(client, embeddingClient, language, VectorSearchQueryBuilder.DEFAULT_KNN_NUM_CANDIDATES_MULTIPLIER); + } + + public static ElasticSearchVectorService getInstance() { + return instance; + } + + private void registerVectorEmbeddingHandler() { + try { + VectorEmbeddingHandler handler = new VectorEmbeddingHandler(this); + EntityLifecycleEventDispatcher.getInstance().registerHandler(handler); + LOG.info("Registered VectorEmbeddingHandler for entity lifecycle events"); + } catch (Exception e) { + LOG.error("Failed to register VectorEmbeddingHandler", e); + } + } + + @Override + @SuppressWarnings("unchecked") + public VectorSearchResponse search( + String query, + Map> filters, + int size, + int from, + int k, + double threshold) { + long start = System.currentTimeMillis(); + try { + float[] queryVector = embeddingClient.embed(query); + LinkedHashMap>> byParent = new LinkedHashMap<>(); + int rawOffset = 0; + long totalHits = -1L; + boolean exhausted = false; + int requestedParents = from + size + 1; + int overFetchSize = Math.max(requestedParents * OVER_FETCH_MULTIPLIER, OVER_FETCH_MULTIPLIER); + if (threshold <= 0.0) { + overFetchSize = Math.min(overFetchSize, k); + } + + String indexName = getIndexAlias(); + while (!exhausted && byParent.size() < requestedParents) { + String queryJson = + VectorSearchQueryBuilder.buildNativeESQuery( + queryVector, overFetchSize, rawOffset, k, filters, knnNumCandidatesMultiplier); + String responseBody = executeGenericRequest("POST", "/" + indexName + "/_search", queryJson); + + JsonNode root = MAPPER.readTree(responseBody); + JsonNode hitsNode = root.path("hits").path("hits"); + totalHits = extractTotalHits(root); + + int pageHitCount = collectSearchHits(hitsNode, threshold, byParent); + if (pageHitCount == 0) { + exhausted = true; + break; + } + + rawOffset += pageHitCount; + exhausted = totalHits >= 0 ? rawOffset >= totalHits : pageHitCount < overFetchSize; + } + + List> results = new ArrayList<>(); + int parentCount = 0; + int skipped = 0; + for (List> chunks : byParent.values()) { + if (skipped < from) { + skipped++; + continue; + } + if (parentCount >= size) { + break; + } + results.addAll(chunks); + parentCount++; + } + + boolean hasMore = byParent.size() > (from + parentCount); + long tookMillis = System.currentTimeMillis() - start; + return new VectorSearchResponse( + tookMillis, results, totalHits >= 0 ? totalHits : null, hasMore); + } catch (Exception e) { + LOG.error("Vector search failed: {}", e.getMessage(), e); + throw new RuntimeException("Vector search failed", e); + } + } + + private static int collectSearchHits( + JsonNode hitsNode, + double threshold, + LinkedHashMap>> byParent) { + int pageHitCount = 0; + for (JsonNode hit : hitsNode) { + pageHitCount++; + double score = hit.path("_score").asDouble(0.0); + if (score < threshold) { + continue; + } + Map hitMap = MAPPER.convertValue(hit.path("_source"), Map.class); + hitMap.put("_score", score); + String parentId = (String) hitMap.getOrDefault("parentId", hit.path("_id").asText()); + byParent.computeIfAbsent(parentId, ignored -> new ArrayList<>()).add(hitMap); + } + return pageHitCount; + } + + private static long extractTotalHits(JsonNode root) { + JsonNode totalNode = root.path("hits").path("total"); + if (totalNode.isIntegralNumber()) { + return totalNode.asLong(-1L); + } + if (totalNode.isObject()) { + return totalNode.path("value").asLong(-1L); + } + return -1L; + } + + @Override + public String executeGenericRequest(String method, String endpoint, String body) { + try { + Request request = new Request(method, endpoint); + if (body != null) { + request.setJsonEntity(body); + } + Response response = restClient.performRequest(request); + int statusCode = response.getStatusCode(); + try (InputStream is = response.getEntity().getContent()) { + String responseBody = new String(is.readAllBytes(), StandardCharsets.UTF_8); + if (statusCode >= 400) { + throw new IOException( + "Elasticsearch request failed with status " + statusCode + ": " + responseBody); + } + return responseBody; + } + } catch (Exception e) { + LOG.error("Generic request failed: {} {}", method, endpoint, e); + throw new RuntimeException("Elasticsearch generic request failed", e); + } + } + + @Override + public Map generateEmbeddingFields(EntityInterface entity) { + return VectorDocBuilder.buildEmbeddingFields(entity, embeddingClient); + } + + @Override + public void updateEntityEmbedding(EntityInterface entity, String entityIndexName) { + try { + String entityId = entity.getId().toString(); + String existingFingerprint = getExistingFingerprint(entityIndexName, entityId); + String currentFingerprint = VectorDocBuilder.computeFingerprintForEntity(entity); + + if (currentFingerprint.equals(existingFingerprint)) { + LOG.debug("Skipping entity {} - fingerprint unchanged", entityId); + return; + } + + Map embeddingFields = generateEmbeddingFields(entity); + partialUpdateEntity(entityIndexName, entityId, embeddingFields); + } catch (Exception e) { + LOG.error("Failed to update embedding for entity {}: {}", entity.getId(), e.getMessage(), e); + } + } + + @Override + public String getExistingFingerprint(String indexName, String entityId) { + try { + String query = + "{\"size\":1,\"_source\":[\"fingerprint\"]," + + "\"query\":{\"term\":{\"_id\":\"" + + VectorSearchQueryBuilder.escape(entityId) + + "\"}}}"; + String response = executeGenericRequest("POST", "/" + indexName + "/_search", query); + JsonNode root = MAPPER.readTree(response); + JsonNode hits = root.path("hits").path("hits"); + if (hits.isArray() && !hits.isEmpty()) { + return hits.get(0).path("_source").path("fingerprint").asText(null); + } + } catch (Exception e) { + LOG.debug( + "Failed to get fingerprint for entityId={} in index={}: {}", + entityId, + indexName, + e.getMessage()); + } + return null; + } + + @Override + public Map getExistingFingerprintsBatch( + String indexName, List entityIds) { + if (entityIds == null || entityIds.isEmpty()) { + return Collections.emptyMap(); + } + try { + StringBuilder idsArray = new StringBuilder("["); + for (int i = 0; i < entityIds.size(); i++) { + if (i > 0) idsArray.append(','); + idsArray + .append("\"") + .append(VectorSearchQueryBuilder.escape(entityIds.get(i))) + .append("\""); + } + idsArray.append("]"); + + String query = + "{\"size\":" + + entityIds.size() + + ",\"_source\":[\"fingerprint\"]" + + ",\"query\":{\"ids\":{\"values\":" + + idsArray + + "}}}"; + + String response = executeGenericRequest("POST", "/" + indexName + "/_search", query); + JsonNode root = MAPPER.readTree(response); + JsonNode hits = root.path("hits").path("hits"); + + Map result = new HashMap<>(); + for (JsonNode hit : hits) { + String id = hit.path("_id").asText(); + String fp = hit.path("_source").path("fingerprint").asText(null); + if (id != null && fp != null) { + result.put(id, fp); + } + } + return result; + } catch (Exception e) { + LOG.error("Failed to batch get fingerprints in index={}: {}", indexName, e.getMessage(), e); + return Collections.emptyMap(); + } + } + + public void partialUpdateEntity( + String indexName, String entityId, Map embeddingFields) { + try { + String docJson = MAPPER.writeValueAsString(embeddingFields); + String updateBody = "{\"doc\":" + docJson + "}"; + executeGenericRequest( + "POST", "/" + indexName + "/_update/" + entityId + "?retry_on_conflict=3", updateBody); + } catch (Exception e) { + LOG.error( + "Failed to partial update entity {} in {}: {}", entityId, indexName, e.getMessage(), e); + } + } + + public void close() { + try { + if (client != null && client._transport() != null) { + client._transport().close(); + } + } catch (Exception e) { + LOG.warn("Error closing Elasticsearch transport: {}", e.getMessage()); + } + } +} diff --git a/openmetadata-service/src/main/java/org/openmetadata/service/search/vector/OpenSearchVectorService.java b/openmetadata-service/src/main/java/org/openmetadata/service/search/vector/OpenSearchVectorService.java index da311770d44f..1861c9293270 100644 --- a/openmetadata-service/src/main/java/org/openmetadata/service/search/vector/OpenSearchVectorService.java +++ b/openmetadata-service/src/main/java/org/openmetadata/service/search/vector/OpenSearchVectorService.java @@ -42,9 +42,11 @@ public OpenSearchVectorService(OpenSearchClient client, EmbeddingClient embeddin public static synchronized void init(OpenSearchClient client, EmbeddingClient embeddingClient) { if (instance != null) { LOG.warn("OpenSearchVectorService already initialized, reinitializing"); + EntityLifecycleEventDispatcher.getInstance().unregisterHandler("VectorEmbeddingHandler"); } - instance = new OpenSearchVectorService(client, embeddingClient); - instance.registerVectorEmbeddingHandler(); + OpenSearchVectorService svc = new OpenSearchVectorService(client, embeddingClient); + svc.registerVectorEmbeddingHandler(); + instance = svc; LOG.info( "OpenSearchVectorService initialized with model={}, dimension={}", embeddingClient.getModelId(), @@ -198,7 +200,7 @@ public VectorSearchResponse search( overFetchSize = Math.min(overFetchSize, k); } - String aliasName = getSearchAlias(); + String aliasName = getIndexAlias(); while (!exhausted && byParent.size() < requestedParents) { String queryJson = VectorSearchQueryBuilder.build( @@ -280,6 +282,7 @@ private static long extractTotalHits(JsonNode root) { return -1L; } + @Override public String getExistingFingerprint(String indexName, String entityId) { try { String query = @@ -303,6 +306,7 @@ public String getExistingFingerprint(String indexName, String entityId) { return null; } + @Override public Map getExistingFingerprintsBatch( String indexName, List entityIds) { if (entityIds == null || entityIds.isEmpty()) { @@ -359,7 +363,8 @@ public void partialUpdateEntity( } } - String executeGenericRequest(String method, String endpoint, String body) { + @Override + public String executeGenericRequest(String method, String endpoint, String body) { try { OpenSearchGenericClient genericClient = client.generic(); var request = Requests.builder().endpoint(endpoint).method(method).json(body).build(); @@ -387,15 +392,4 @@ String executeGenericRequest(String method, String endpoint, String body) { } } - private String getSearchAlias() { - try { - String clusterAlias = Entity.getSearchRepository().getClusterAlias(); - if (clusterAlias == null || clusterAlias.isEmpty()) { - return VECTOR_EMBEDDING_ALIAS; - } - return clusterAlias + "_" + VECTOR_EMBEDDING_ALIAS; - } catch (Exception ex) { - return VECTOR_EMBEDDING_ALIAS; - } - } } diff --git a/openmetadata-service/src/main/java/org/openmetadata/service/search/vector/VectorIndexService.java b/openmetadata-service/src/main/java/org/openmetadata/service/search/vector/VectorIndexService.java index b0b2c6e72625..eb3e870fffe1 100644 --- a/openmetadata-service/src/main/java/org/openmetadata/service/search/vector/VectorIndexService.java +++ b/openmetadata-service/src/main/java/org/openmetadata/service/search/vector/VectorIndexService.java @@ -3,11 +3,13 @@ import java.util.List; import java.util.Map; import org.openmetadata.schema.EntityInterface; +import org.openmetadata.service.Entity; import org.openmetadata.service.search.vector.utils.DTOs.VectorSearchResponse; public interface VectorIndexService { String VECTOR_EMBEDDING_ALIAS = "dataAssetEmbeddings"; + String VECTOR_INDEX_KEY = "vectorEmbeddings"; Map generateEmbeddingFields(EntityInterface entity); @@ -15,4 +17,22 @@ public interface VectorIndexService { VectorSearchResponse search( String query, Map> filters, int size, int from, int k, double threshold); + + String getExistingFingerprint(String indexName, String entityId); + + Map getExistingFingerprintsBatch(String indexName, List entityIds); + + String executeGenericRequest(String method, String endpoint, String body); + + default String getIndexAlias() { + try { + String clusterAlias = Entity.getSearchRepository().getClusterAlias(); + if (clusterAlias == null || clusterAlias.isEmpty()) { + return VECTOR_EMBEDDING_ALIAS; + } + return clusterAlias + "_" + VECTOR_EMBEDDING_ALIAS; + } catch (Exception ex) { + return VECTOR_EMBEDDING_ALIAS; + } + } } diff --git a/openmetadata-service/src/main/java/org/openmetadata/service/search/vector/VectorSearchQueryBuilder.java b/openmetadata-service/src/main/java/org/openmetadata/service/search/vector/VectorSearchQueryBuilder.java index b731d7062eee..99f11afe7d86 100644 --- a/openmetadata-service/src/main/java/org/openmetadata/service/search/vector/VectorSearchQueryBuilder.java +++ b/openmetadata-service/src/main/java/org/openmetadata/service/search/vector/VectorSearchQueryBuilder.java @@ -14,6 +14,7 @@ public class VectorSearchQueryBuilder { private static final Logger LOG = LoggerFactory.getLogger(VectorSearchQueryBuilder.class); private static final String ANY = "__ANY__"; private static final String NONE = "__NONE__"; + public static final int DEFAULT_KNN_NUM_CANDIDATES_MULTIPLIER = 2; /** Build a full search request body (size + _source + query) for standalone vector search. */ public static String build( @@ -66,17 +67,56 @@ private static void appendKnnQuery( // Build filter inside knn for efficient k-NN filtering sb.append(",\"filter\":{\"bool\":{\"must\":["); + appendFilterMustClauses(sb, filters); + sb.append("]}}"); // close must array and bool - // Only include documents where deleted=false - sb.append("{\"term\":{\"deleted\":false}}"); + sb.append("}}}}"); // close embedding, knn, query + } + + public static String buildNativeESQuery( + float[] vector, int size, int from, int k, Map> filters) { + return buildNativeESQuery(vector, size, from, k, filters, DEFAULT_KNN_NUM_CANDIDATES_MULTIPLIER); + } - // Then add user-specified filters + public static String buildNativeESQuery( + float[] vector, + int size, + int from, + int k, + Map> filters, + int numCandidatesMultiplier) { + int numCandidates = Math.max(k * numCandidatesMultiplier, 100); + StringBuilder sb = + new StringBuilder(512) + .append("{\"size\":") + .append(size) + .append(",\"from\":") + .append(from) + .append(",\"_source\":{\"excludes\":[\"embedding\"]}") + .append(",\"knn\":{") + .append("\"field\":\"embedding\"") + .append(",\"query_vector\":") + .append(Arrays.toString(vector)) + .append(",\"k\":") + .append(k) + .append(",\"num_candidates\":") + .append(numCandidates); + + sb.append(",\"filter\":{\"bool\":{\"must\":["); + appendFilterMustClauses(sb, filters); + sb.append("]}}"); // close must array and bool + + sb.append("}}"); // close knn object + return sb.toString(); + } + + private static void appendFilterMustClauses( + StringBuilder sb, Map> filters) { + sb.append("{\"term\":{\"deleted\":false}}"); for (var e : filters.entrySet()) { String field = e.getKey(); List values = e.getValue(); if (values == null || values.isEmpty()) continue; - - // Handle custom properties that will come with "customProperties." if (field.startsWith("customProperties.")) { sb.append(','); appendCustomPropertiesFilter(sb, field, values); @@ -126,10 +166,6 @@ private static void appendKnnQuery( } } } - - sb.append("]}}"); // close must array and bool - - sb.append("}}}"); // close embedding, knn, wrapper } private static void appendFlat(StringBuilder sb, String field, List vals) { diff --git a/openmetadata-service/src/test/java/org/openmetadata/service/apps/bundles/searchIndex/ElasticSearchBulkSinkSimpleTest.java b/openmetadata-service/src/test/java/org/openmetadata/service/apps/bundles/searchIndex/ElasticSearchBulkSinkSimpleTest.java index 110dac14b6da..ba9d0d4d27e3 100644 --- a/openmetadata-service/src/test/java/org/openmetadata/service/apps/bundles/searchIndex/ElasticSearchBulkSinkSimpleTest.java +++ b/openmetadata-service/src/test/java/org/openmetadata/service/apps/bundles/searchIndex/ElasticSearchBulkSinkSimpleTest.java @@ -80,4 +80,12 @@ void testContextDataHandling() { recreateIndex = (Boolean) contextData.getOrDefault("recreateIndex", false); assertEquals(false, recreateIndex); } + + @Test + void testIsVectorEmbeddingEnabledForEntity() { + assertEquals(false, elasticSearchBulkSink.isVectorEmbeddingEnabledForEntity("table")); + assertEquals(false, elasticSearchBulkSink.isVectorEmbeddingEnabledForEntity("user")); + assertEquals(false, elasticSearchBulkSink.isVectorEmbeddingEnabledForEntity("dashboard")); + } + } diff --git a/openmetadata-service/src/test/java/org/openmetadata/service/resources/search/VectorSearchResourceTest.java b/openmetadata-service/src/test/java/org/openmetadata/service/resources/search/VectorSearchResourceTest.java new file mode 100644 index 000000000000..759ebe443411 --- /dev/null +++ b/openmetadata-service/src/test/java/org/openmetadata/service/resources/search/VectorSearchResourceTest.java @@ -0,0 +1,146 @@ +package org.openmetadata.service.resources.search; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.doNothing; +import static org.mockito.Mockito.doThrow; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.mockStatic; +import static org.mockito.Mockito.never; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; + +import jakarta.ws.rs.core.Response; +import jakarta.ws.rs.core.SecurityContext; +import java.util.UUID; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.mockito.MockedStatic; +import org.openmetadata.service.Entity; +import org.openmetadata.service.search.SearchRepository; +import org.openmetadata.service.search.vector.VectorIndexService; +import org.openmetadata.service.security.Authorizer; +import org.openmetadata.service.security.AuthorizationException; + +class VectorSearchResourceTest { + + private Authorizer mockAuthorizer; + private SecurityContext mockSecurityContext; + private SearchRepository mockSearchRepository; + private VectorIndexService mockVectorService; + private VectorSearchResource resource; + + @BeforeEach + void setUp() { + mockAuthorizer = mock(Authorizer.class); + mockSecurityContext = mock(SecurityContext.class); + mockSearchRepository = mock(SearchRepository.class); + mockVectorService = mock(VectorIndexService.class); + resource = new VectorSearchResource(mockAuthorizer); + } + + @Test + void testGetFingerprintRequiresAdmin() { + doThrow(new AuthorizationException("Forbidden")) + .when(mockAuthorizer) + .authorizeAdmin(mockSecurityContext); + + try (MockedStatic entityMock = mockStatic(Entity.class)) { + entityMock.when(Entity::getSearchRepository).thenReturn(mockSearchRepository); + when(mockSearchRepository.isVectorEmbeddingEnabled()).thenReturn(true); + + try { + resource.getFingerprint(mockSecurityContext, UUID.randomUUID().toString()); + } catch (AuthorizationException e) { + verify(mockVectorService, never()).getExistingFingerprint(any(), any()); + return; + } + throw new AssertionError("Expected AuthorizationException"); + } + } + + @Test + void testGetFingerprintReturnsFoundWhenFingerprintExists() { + doNothing().when(mockAuthorizer).authorizeAdmin(mockSecurityContext); + + try (MockedStatic entityMock = mockStatic(Entity.class)) { + entityMock.when(Entity::getSearchRepository).thenReturn(mockSearchRepository); + when(mockSearchRepository.isVectorEmbeddingEnabled()).thenReturn(true); + when(mockSearchRepository.getVectorIndexService()).thenReturn(mockVectorService); + when(mockVectorService.getIndexAlias()).thenReturn("table_search_index"); + + String entityId = UUID.randomUUID().toString(); + when(mockVectorService.getExistingFingerprint("table_search_index", entityId)) + .thenReturn("abc123"); + + Response response = resource.getFingerprint(mockSecurityContext, entityId); + + assertEquals(Response.Status.OK.getStatusCode(), response.getStatus()); + } + } + + @Test + void testGetFingerprintReturnsNotFoundWhenFingerprintMissing() { + doNothing().when(mockAuthorizer).authorizeAdmin(mockSecurityContext); + + try (MockedStatic entityMock = mockStatic(Entity.class)) { + entityMock.when(Entity::getSearchRepository).thenReturn(mockSearchRepository); + when(mockSearchRepository.isVectorEmbeddingEnabled()).thenReturn(true); + when(mockSearchRepository.getVectorIndexService()).thenReturn(mockVectorService); + when(mockVectorService.getIndexAlias()).thenReturn("table_search_index"); + + String entityId = UUID.randomUUID().toString(); + when(mockVectorService.getExistingFingerprint("table_search_index", entityId)) + .thenReturn(null); + + Response response = resource.getFingerprint(mockSecurityContext, entityId); + + assertEquals(Response.Status.OK.getStatusCode(), response.getStatus()); + } + } + + @Test + void testGetFingerprintReturnsBadRequestForInvalidUuid() { + doNothing().when(mockAuthorizer).authorizeAdmin(mockSecurityContext); + + try (MockedStatic entityMock = mockStatic(Entity.class)) { + entityMock.when(Entity::getSearchRepository).thenReturn(mockSearchRepository); + when(mockSearchRepository.isVectorEmbeddingEnabled()).thenReturn(true); + when(mockSearchRepository.getVectorIndexService()).thenReturn(mockVectorService); + + Response response = resource.getFingerprint(mockSecurityContext, "not-a-uuid"); + + assertEquals(Response.Status.BAD_REQUEST.getStatusCode(), response.getStatus()); + } + } + + @Test + void testGetFingerprintReturnsBadRequestForMissingParentId() { + doNothing().when(mockAuthorizer).authorizeAdmin(mockSecurityContext); + + try (MockedStatic entityMock = mockStatic(Entity.class)) { + entityMock.when(Entity::getSearchRepository).thenReturn(mockSearchRepository); + when(mockSearchRepository.isVectorEmbeddingEnabled()).thenReturn(true); + when(mockSearchRepository.getVectorIndexService()).thenReturn(mockVectorService); + + Response response = resource.getFingerprint(mockSecurityContext, null); + + assertEquals(Response.Status.BAD_REQUEST.getStatusCode(), response.getStatus()); + } + } + + @Test + void testGetFingerprintReturnsServiceUnavailableWhenDisabled() { + doNothing().when(mockAuthorizer).authorizeAdmin(mockSecurityContext); + + try (MockedStatic entityMock = mockStatic(Entity.class)) { + entityMock.when(Entity::getSearchRepository).thenReturn(mockSearchRepository); + when(mockSearchRepository.isVectorEmbeddingEnabled()).thenReturn(false); + + Response response = + resource.getFingerprint(mockSecurityContext, UUID.randomUUID().toString()); + + assertEquals(Response.Status.SERVICE_UNAVAILABLE.getStatusCode(), response.getStatus()); + } + } +} diff --git a/openmetadata-service/src/test/java/org/openmetadata/service/search/SearchRepositoryBehaviorTest.java b/openmetadata-service/src/test/java/org/openmetadata/service/search/SearchRepositoryBehaviorTest.java index a4cf0e850cfd..a7c8c68efdad 100644 --- a/openmetadata-service/src/test/java/org/openmetadata/service/search/SearchRepositoryBehaviorTest.java +++ b/openmetadata-service/src/test/java/org/openmetadata/service/search/SearchRepositoryBehaviorTest.java @@ -2136,30 +2136,10 @@ void initializeLineageComponentsDelegatesWhenSearchClientExists() throws Excepti } @Test - void reformatVectorIndexWithDimensionAddsMetaAndPreservesInvalidJson() throws Exception { - EmbeddingClient embeddingClient = mock(EmbeddingClient.class); - when(embeddingClient.getModelId()).thenReturn("test-model"); - setPrivateField(repository, "embeddingClient", embeddingClient); - - String updated = - (String) - invokePrivateMethod( - repository, - "reformatVectorIndexWithDimension", - new Class[] {String.class, int.class}, - "{\"mappings\":{}}", - 768); - - assertTrue(updated.contains("\"embedding_model\":\"test-model\"")); - assertTrue(updated.contains("\"embedding_dimension\":768")); - assertEquals( - "not-json", - invokePrivateMethod( - repository, - "reformatVectorIndexWithDimension", - new Class[] {String.class, int.class}, - "not-json", - 384)); + void readIndexMappingReturnsMappingForKnownIndex() { + String mapping = repository.readIndexMapping(TABLE_MAPPING); + assertNotNull(mapping); + assertFalse(mapping.isBlank()); } @Test diff --git a/openmetadata-service/src/test/java/org/openmetadata/service/search/elasticsearch/ElasticSearchIndexManagerTest.java b/openmetadata-service/src/test/java/org/openmetadata/service/search/elasticsearch/ElasticSearchIndexManagerTest.java index c95512062059..8ee009fe20ed 100644 --- a/openmetadata-service/src/test/java/org/openmetadata/service/search/elasticsearch/ElasticSearchIndexManagerTest.java +++ b/openmetadata-service/src/test/java/org/openmetadata/service/search/elasticsearch/ElasticSearchIndexManagerTest.java @@ -327,6 +327,21 @@ void testUpdateIndex_HandlesInvalidJson() { verifyNoInteractions(indicesClient); } + @Test + void testUpdateIndex_ExtractsMappingsFromFullIndexJson() throws IOException { + // putMapping only accepts the mappings sub-object, not a full index JSON with settings/aliases + String fullIndexJson = + "{\"settings\":{\"number_of_shards\":1}," + + "\"mappings\":{\"properties\":{\"field1\":{\"type\":\"text\"}}}," + + "\"aliases\":{}}"; + when(indexMapping.getIndexName(CLUSTER_ALIAS)).thenReturn(TEST_INDEX); + when(indicesClient.putMapping(any(PutMappingRequest.class))).thenReturn(putMappingResponse); + + assertDoesNotThrow(() -> indexManager.updateIndex(indexMapping, fullIndexJson)); + + verify(indicesClient).putMapping(any(PutMappingRequest.class)); + } + @Test void testCreateIndex_ClientNotAvailable() { ElasticSearchIndexManager managerWithNullClient = diff --git a/openmetadata-service/src/test/java/org/openmetadata/service/search/elasticsearch/EsUtilsTest.java b/openmetadata-service/src/test/java/org/openmetadata/service/search/elasticsearch/EsUtilsTest.java index 0f9714a8e49d..bc2efc55119a 100644 --- a/openmetadata-service/src/test/java/org/openmetadata/service/search/elasticsearch/EsUtilsTest.java +++ b/openmetadata-service/src/test/java/org/openmetadata/service/search/elasticsearch/EsUtilsTest.java @@ -1,6 +1,7 @@ package org.openmetadata.service.search.elasticsearch; import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertFalse; import static org.junit.jupiter.api.Assertions.assertInstanceOf; import static org.junit.jupiter.api.Assertions.assertNotNull; import static org.junit.jupiter.api.Assertions.assertSame; @@ -503,6 +504,42 @@ void testSearchEntityByKeyThrowsWhenMultipleMatchesExist() throws Exception { } } + @Test + void enrichIndexMappingThrowsOnNullOrEmptyInput() { + assertThrows(IllegalArgumentException.class, () -> EsUtils.enrichIndexMappingForElasticsearch(null)); + assertThrows(IllegalArgumentException.class, () -> EsUtils.enrichIndexMappingForElasticsearch("")); + } + + @Test + void enrichIndexMappingSkipsEmbeddingWhenFingerprintFieldAbsent() { + String mapping = "{\"mappings\":{\"properties\":{\"name\":{\"type\":\"keyword\"}}}}"; + String result = EsUtils.enrichIndexMappingForElasticsearch(mapping); + assertFalse(result.contains("dense_vector"), "Should not add embedding when fingerprint field is absent"); + } + + @Test + void enrichIndexMappingInjectsEmbeddingWhenFingerprintPresentAndVectorEnabled() { + String mapping = "{\"mappings\":{\"properties\":{\"name\":{\"type\":\"keyword\"},\"fingerprint\":{\"type\":\"keyword\"}}}}"; + + org.openmetadata.service.search.vector.client.EmbeddingClient mockEmbeddingClient = + org.mockito.Mockito.mock(org.openmetadata.service.search.vector.client.EmbeddingClient.class); + org.mockito.Mockito.when(mockEmbeddingClient.getDimension()).thenReturn(768); + org.mockito.Mockito.when(mockEmbeddingClient.getModelId()).thenReturn("test-model"); + + try (MockedStatic entityMock = mockStatic(Entity.class)) { + entityMock.when(Entity::getSearchRepository).thenReturn(searchRepository); + org.mockito.Mockito.when(searchRepository.isVectorEmbeddingEnabled()).thenReturn(true); + org.mockito.Mockito.when(searchRepository.getEmbeddingClient()).thenReturn(mockEmbeddingClient); + + String result = EsUtils.enrichIndexMappingForElasticsearch(mapping); + + assertTrue(result.contains("\"dense_vector\""), "Should add dense_vector field"); + assertTrue(result.contains("\"dims\":768"), "Should set correct dimension"); + assertTrue(result.contains("\"embedding_model\":\"test-model\""), "Should add _meta.embedding_model"); + assertTrue(result.contains("\"embedding_dimension\":768"), "Should add _meta.embedding_dimension"); + } + } + @Test void testSearchEntitiesUsesResolvedAliasAndPostFilter() throws Exception { try (MockedStatic entity = mockStatic(Entity.class)) { diff --git a/openmetadata-service/src/test/java/org/openmetadata/service/search/vector/ElasticSearchVectorServiceTest.java b/openmetadata-service/src/test/java/org/openmetadata/service/search/vector/ElasticSearchVectorServiceTest.java new file mode 100644 index 000000000000..9ad740b8f61d --- /dev/null +++ b/openmetadata-service/src/test/java/org/openmetadata/service/search/vector/ElasticSearchVectorServiceTest.java @@ -0,0 +1,492 @@ +package org.openmetadata.service.search.vector; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertNotNull; +import static org.junit.jupiter.api.Assertions.assertNull; +import static org.junit.jupiter.api.Assertions.assertTrue; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; + +import java.util.List; + +import es.co.elastic.clients.elasticsearch.ElasticsearchClient; +import es.co.elastic.clients.transport.rest5_client.Rest5ClientTransport; +import es.co.elastic.clients.transport.rest5_client.low_level.Request; +import es.co.elastic.clients.transport.rest5_client.low_level.Response; +import es.co.elastic.clients.transport.rest5_client.low_level.Rest5Client; +import java.io.ByteArrayInputStream; +import java.nio.charset.StandardCharsets; +import java.util.Map; +import java.util.concurrent.atomic.AtomicInteger; +import org.apache.hc.core5.http.HttpEntity; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.openmetadata.service.search.vector.client.EmbeddingClient; +import org.openmetadata.service.search.vector.utils.DTOs; + +class ElasticSearchVectorServiceTest { + + private static final String EMPTY_HITS_RESPONSE = + "{\"hits\":{\"total\":{\"value\":0},\"hits\":[]}}"; + + private ElasticSearchVectorService vectorService; + private ElasticsearchClient mockEsClient; + private Rest5Client mockRestClient; + private EmbeddingClient mockEmbeddingClient; + + @BeforeEach + void setup() throws Exception { + mockEsClient = mock(ElasticsearchClient.class); + Rest5ClientTransport mockTransport = mock(Rest5ClientTransport.class); + mockRestClient = mock(Rest5Client.class); + + when(mockEsClient._transport()).thenReturn(mockTransport); + when(mockTransport.restClient()).thenReturn(mockRestClient); + + mockEmbeddingClient = mock(EmbeddingClient.class); + when(mockEmbeddingClient.embed(any(String.class))).thenReturn(new float[] {0.1f, 0.2f, 0.3f}); + + vectorService = new ElasticSearchVectorService(mockEsClient, mockEmbeddingClient); + } + + @Test + void testThresholdFilteringRemovesLowScoreResults() throws Exception { + String esResponse = + """ + { + "hits": { + "total": {"value": 4}, + "hits": [ + {"_score": 0.9, "_source": {"parentId": "parent1", "chunkIndex": 0, "text": "High score chunk"}}, + {"_score": 0.7, "_source": {"parentId": "parent2", "chunkIndex": 0, "text": "Medium score chunk"}}, + {"_score": 0.4, "_source": {"parentId": "parent3", "chunkIndex": 0, "text": "Low score chunk"}}, + {"_score": 0.2, "_source": {"parentId": "parent4", "chunkIndex": 0, "text": "Very low score chunk"}} + ] + } + } + """; + + mockRestClientResponse(esResponse); + + DTOs.VectorSearchResponse results = vectorService.search("test query", Map.of(), 10, 0, 100, 0.5); + + assertNotNull(results); + assertEquals(2, results.hits.size(), "Should return 2 results (scores 0.9 and 0.7)"); + for (Map result : results.hits) { + double score = (double) result.get("_score"); + assertTrue(score >= 0.5, "All results should have score >= 0.5, got: " + score); + } + } + + @Test + void testScoreFieldIncludedInResults() throws Exception { + String esResponse = + """ + { + "hits": { + "total": {"value": 1}, + "hits": [ + {"_score": 0.85, "_source": {"parentId": "parent1", "chunkIndex": 0, "text": "Test chunk"}} + ] + } + } + """; + + mockRestClientResponse(esResponse); + + DTOs.VectorSearchResponse results = vectorService.search("test query", Map.of(), 10, 0, 100, 0.0); + + assertEquals(1, results.hits.size()); + assertTrue(results.hits.get(0).containsKey("_score"), "Result should contain _score field"); + assertEquals(0.85, (double) results.hits.get(0).get("_score"), 0.001); + } + + @Test + void testParentGroupingLimitsDistinctParents() throws Exception { + // size=2 → requestedParents=3; 4 distinct parents in response causes loop to exit after 1 page + String esResponse = + """ + { + "hits": { + "total": {"value": 8}, + "hits": [ + {"_score": 0.9, "_source": {"parentId": "parent1", "chunkIndex": 0}}, + {"_score": 0.88, "_source": {"parentId": "parent1", "chunkIndex": 1}}, + {"_score": 0.85, "_source": {"parentId": "parent1", "chunkIndex": 2}}, + {"_score": 0.8, "_source": {"parentId": "parent2", "chunkIndex": 0}}, + {"_score": 0.78, "_source": {"parentId": "parent2", "chunkIndex": 1}}, + {"_score": 0.7, "_source": {"parentId": "parent3", "chunkIndex": 0}}, + {"_score": 0.68, "_source": {"parentId": "parent3", "chunkIndex": 1}}, + {"_score": 0.6, "_source": {"parentId": "parent4", "chunkIndex": 0}} + ] + } + } + """; + + mockRestClientResponse(esResponse); + + DTOs.VectorSearchResponse results = vectorService.search("test query", Map.of(), 2, 0, 100, 0.0); + + assertEquals(5, results.hits.size(), "Should return all chunks from first 2 parents (3+2=5)"); + long distinctParents = results.hits.stream().map(r -> r.get("parentId")).distinct().count(); + assertEquals(2, distinctParents, "Should have chunks from exactly 2 distinct parents"); + } + + @Test + void testZeroThresholdReturnsAllResults() throws Exception { + String esResponse = + """ + { + "hits": { + "total": {"value": 3}, + "hits": [ + {"_score": 0.9, "_source": {"parentId": "p1", "chunkIndex": 0}}, + {"_score": 0.5, "_source": {"parentId": "p2", "chunkIndex": 0}}, + {"_score": 0.1, "_source": {"parentId": "p3", "chunkIndex": 0}} + ] + } + } + """; + + mockRestClientResponse(esResponse); + + DTOs.VectorSearchResponse results = vectorService.search("test query", Map.of(), 10, 0, 100, 0.0); + + assertEquals(3, results.hits.size(), "With threshold 0.0, should return all 3 results"); + } + + @Test + void testHighThresholdFiltersAllResults() throws Exception { + String esResponse = + """ + { + "hits": { + "total": {"value": 3}, + "hits": [ + {"_score": 0.5, "_source": {"parentId": "p1", "chunkIndex": 0}}, + {"_score": 0.3, "_source": {"parentId": "p2", "chunkIndex": 0}}, + {"_score": 0.1, "_source": {"parentId": "p3", "chunkIndex": 0}} + ] + } + } + """; + + mockRestClientResponse(esResponse); + + DTOs.VectorSearchResponse results = vectorService.search("test query", Map.of(), 10, 0, 100, 0.9); + + assertEquals(0, results.hits.size(), "With threshold 0.9, all results should be filtered out"); + } + + @Test + void testChunksWithoutParentIdGroupedByDocumentId() throws Exception { + // Chunks without parentId in _source fall back to the document's _id field + String esResponse = + """ + { + "hits": { + "total": {"value": 3}, + "hits": [ + {"_score": 0.9, "_source": {"parentId": "p1", "chunkIndex": 0}}, + {"_id": "orphan-123", "_score": 0.8, "_source": {"chunkIndex": 0, "text": "orphan chunk"}}, + {"_score": 0.7, "_source": {"parentId": "p2", "chunkIndex": 0}} + ] + } + } + """; + + mockRestClientResponse(esResponse); + + DTOs.VectorSearchResponse results = vectorService.search("test query", Map.of(), 10, 0, 100, 0.0); + + assertEquals(3, results.hits.size(), "Orphan chunk should be included, grouped by document _id"); + } + + @Test + void testRequestedSizeLimitsDistinctParents() throws Exception { + // size=3 → requestedParents=4; 10 distinct parents in response exits loop immediately + String esResponse = + """ + { + "hits": { + "total": {"value": 10}, + "hits": [ + {"_score": 0.9, "_source": {"parentId": "p1", "chunkIndex": 0}}, + {"_score": 0.8, "_source": {"parentId": "p2", "chunkIndex": 0}}, + {"_score": 0.7, "_source": {"parentId": "p3", "chunkIndex": 0}}, + {"_score": 0.6, "_source": {"parentId": "p4", "chunkIndex": 0}}, + {"_score": 0.5, "_source": {"parentId": "p5", "chunkIndex": 0}}, + {"_score": 0.4, "_source": {"parentId": "p6", "chunkIndex": 0}}, + {"_score": 0.3, "_source": {"parentId": "p7", "chunkIndex": 0}}, + {"_score": 0.2, "_source": {"parentId": "p8", "chunkIndex": 0}}, + {"_score": 0.15, "_source": {"parentId": "p9", "chunkIndex": 0}}, + {"_score": 0.1, "_source": {"parentId": "p10", "chunkIndex": 0}} + ] + } + } + """; + + mockRestClientResponse(esResponse); + + DTOs.VectorSearchResponse results = vectorService.search("test query", Map.of(), 3, 0, 100, 0.0); + + assertEquals(3, results.hits.size(), "Should limit to 3 distinct parents"); + long distinctParents = results.hits.stream().map(r -> r.get("parentId")).distinct().count(); + assertEquals(3, distinctParents, "Should have exactly 3 distinct parents"); + } + + @Test + void testEmptyHitsResponseReturnsEmptyList() throws Exception { + String esResponse = + """ + { + "hits": { + "total": {"value": 0}, + "hits": [] + } + } + """; + + mockRestClientResponse(esResponse); + + DTOs.VectorSearchResponse results = vectorService.search("test query", Map.of(), 10, 0, 100, 0.0); + + assertNotNull(results); + assertTrue(results.hits.isEmpty(), "Empty hits should return empty list"); + } + + @Test + void testFromSkipsParentsNotChunks() throws Exception { + // from=1 should skip 1 parent (p1), not 1 raw chunk + // size=2, from=1 → requestedParents=4; 4 distinct parents in response exits loop after 1 page + String esResponse = + """ + { + "hits": { + "total": {"value": 4}, + "hits": [ + {"_score": 0.9, "_source": {"parentId": "p1", "chunkIndex": 0}}, + {"_score": 0.8, "_source": {"parentId": "p2", "chunkIndex": 0}}, + {"_score": 0.7, "_source": {"parentId": "p3", "chunkIndex": 0}}, + {"_score": 0.6, "_source": {"parentId": "p4", "chunkIndex": 0}} + ] + } + } + """; + + mockRestClientResponse(esResponse); + + DTOs.VectorSearchResponse results = vectorService.search("test query", Map.of(), 2, 1, 100, 0.0); + + assertEquals(2, results.hits.size(), "Should return 2 parents after skipping 1"); + assertEquals("p2", results.hits.get(0).get("parentId")); + assertEquals("p3", results.hits.get(1).get("parentId")); + } + + @Test + void testHasMoreTrueWhenExtraParentFetched() throws Exception { + // size=2, from=0 → requestedParents=3; 4 parents fetched → hasMore=true + String esResponse = + """ + { + "hits": { + "total": {"value": 4}, + "hits": [ + {"_score": 0.9, "_source": {"parentId": "p1", "chunkIndex": 0}}, + {"_score": 0.8, "_source": {"parentId": "p2", "chunkIndex": 0}}, + {"_score": 0.7, "_source": {"parentId": "p3", "chunkIndex": 0}}, + {"_score": 0.6, "_source": {"parentId": "p4", "chunkIndex": 0}} + ] + } + } + """; + + mockRestClientResponse(esResponse); + + DTOs.VectorSearchResponse results = vectorService.search("test query", Map.of(), 2, 0, 100, 0.0); + + assertEquals(2, results.hits.size()); + assertTrue(results.hasMore, "hasMore should be true when extra parent was fetched"); + } + + @Test + void testHasMoreFalseWhenNoExtraParent() throws Exception { + // size=10, from=0 → requestedParents=11; only 2 parents available → hasMore=false + String esResponse = + """ + { + "hits": { + "total": {"value": 2}, + "hits": [ + {"_score": 0.9, "_source": {"parentId": "p1", "chunkIndex": 0}}, + {"_score": 0.8, "_source": {"parentId": "p2", "chunkIndex": 0}} + ] + } + } + """; + + mockRestClientResponse(esResponse); + + DTOs.VectorSearchResponse results = vectorService.search("test query", Map.of(), 10, 0, 100, 0.0); + + assertEquals(2, results.hits.size()); + assertFalse(results.hasMore, "hasMore should be false when fewer parents than requested"); + } + + @Test + void testTotalHitsPopulatedFromResponse() throws Exception { + String esResponse = + """ + { + "hits": { + "total": {"value": 3}, + "hits": [ + {"_score": 0.9, "_source": {"parentId": "p1", "chunkIndex": 0}}, + {"_score": 0.8, "_source": {"parentId": "p2", "chunkIndex": 0}}, + {"_score": 0.7, "_source": {"parentId": "p3", "chunkIndex": 0}} + ] + } + } + """; + + mockRestClientResponse(esResponse); + + DTOs.VectorSearchResponse results = vectorService.search("test query", Map.of(), 10, 0, 100, 0.0); + + assertEquals(Long.valueOf(3), results.totalHits); + } + + @Test + void testTotalHitsNullWhenMissingFromResponse() throws Exception { + // No "total" field in first response; second call returns empty to terminate the loop + String esResponse = + """ + { + "hits": { + "hits": [ + {"_score": 0.9, "_source": {"parentId": "p1", "chunkIndex": 0}} + ] + } + } + """; + + mockRestClientResponseSequence(esResponse, EMPTY_HITS_RESPONSE); + + DTOs.VectorSearchResponse results = vectorService.search("test query", Map.of(), 10, 0, 100, 0.0); + + assertNull(results.totalHits, "totalHits should be null when not present in response"); + } + + @Test + void testGetExistingFingerprintReturnsNullWhenNotFound() throws Exception { + String esResponse = "{\"hits\":{\"total\":{\"value\":0},\"hits\":[]}}"; + + mockRestClientResponse(esResponse); + + String fingerprint = vectorService.getExistingFingerprint("vector_search_index", "unknown-id"); + + assertNull(fingerprint, "Should return null when no fingerprint found"); + } + + @Test + void testGetExistingFingerprintReturnsValueWhenFound() throws Exception { + String esResponse = + """ + { + "hits": { + "total": {"value": 1}, + "hits": [ + {"_source": {"fingerprint": "abc123"}} + ] + } + } + """; + + mockRestClientResponse(esResponse); + + String fingerprint = + vectorService.getExistingFingerprint("vector_search_index", "some-entity-id"); + + assertEquals("abc123", fingerprint); + } + + @Test + void testGetExistingFingerprintsBatchReturnsEmptyForNullInput() { + Map result = vectorService.getExistingFingerprintsBatch("index", null); + assertTrue(result.isEmpty()); + } + + @Test + void testGetExistingFingerprintsBatchReturnsEmptyForEmptyInput() { + Map result = + vectorService.getExistingFingerprintsBatch("index", java.util.List.of()); + assertTrue(result.isEmpty()); + } + + /** Returns a fresh stream on every call — safe for multi-iteration loops. */ + private void mockRestClientResponse(String responseJson) throws Exception { + Response mockResponse = mock(Response.class); + HttpEntity mockEntity = mock(HttpEntity.class); + when(mockRestClient.performRequest(any(Request.class))).thenReturn(mockResponse); + when(mockResponse.getEntity()).thenReturn(mockEntity); + when(mockEntity.getContent()) + .thenAnswer( + inv -> new ByteArrayInputStream(responseJson.getBytes(StandardCharsets.UTF_8))); + } + + @Test + void testNumCandidatesMultiplierFromConfigIsApplied() throws Exception { + int configuredMultiplier = 5; + int k = 50; + // num_candidates = max(50 * 5, 100) = 250 + int expectedNumCandidates = 250; + + ElasticSearchVectorService svc = + new ElasticSearchVectorService(mockEsClient, mockEmbeddingClient, "en", configuredMultiplier); + + List capturedBodies = new java.util.ArrayList<>(); + Response mockResponse = mock(Response.class); + HttpEntity mockEntity = mock(HttpEntity.class); + when(mockRestClient.performRequest(any(Request.class))) + .thenAnswer( + inv -> { + Request req = inv.getArgument(0); + org.apache.hc.core5.http.HttpEntity entity = req.getEntity(); + if (entity != null) { + try (java.io.InputStream is = entity.getContent()) { + capturedBodies.add(new String(is.readAllBytes(), StandardCharsets.UTF_8)); + } + } + return mockResponse; + }); + when(mockResponse.getEntity()).thenReturn(mockEntity); + when(mockEntity.getContent()) + .thenAnswer( + inv -> new ByteArrayInputStream(EMPTY_HITS_RESPONSE.getBytes(StandardCharsets.UTF_8))); + + svc.search("test query", Map.of(), 10, 0, k, 0.0); + + assertFalse(capturedBodies.isEmpty(), "Expected at least one request to be captured"); + String requestBody = capturedBodies.get(0); + assertTrue( + requestBody.contains("\"num_candidates\":" + expectedNumCandidates), + "Expected num_candidates=" + expectedNumCandidates + " in: " + requestBody); + } + + /** Returns each response in sequence; repeats the last one if more calls are made. */ + private void mockRestClientResponseSequence(String... responses) throws Exception { + Response mockResponse = mock(Response.class); + HttpEntity mockEntity = mock(HttpEntity.class); + AtomicInteger callCount = new AtomicInteger(0); + when(mockRestClient.performRequest(any(Request.class))).thenReturn(mockResponse); + when(mockResponse.getEntity()).thenReturn(mockEntity); + when(mockEntity.getContent()) + .thenAnswer( + inv -> { + int idx = Math.min(callCount.getAndIncrement(), responses.length - 1); + return new ByteArrayInputStream(responses[idx].getBytes(StandardCharsets.UTF_8)); + }); + } +} diff --git a/openmetadata-service/src/test/java/org/openmetadata/service/search/vector/VectorSearchQueryBuilderTest.java b/openmetadata-service/src/test/java/org/openmetadata/service/search/vector/VectorSearchQueryBuilderTest.java index 52d46d853d52..e5456ca9c541 100644 --- a/openmetadata-service/src/test/java/org/openmetadata/service/search/vector/VectorSearchQueryBuilderTest.java +++ b/openmetadata-service/src/test/java/org/openmetadata/service/search/vector/VectorSearchQueryBuilderTest.java @@ -698,9 +698,193 @@ void testIgnoresOnlyUnrecognizedFilterKeys() throws Exception { JsonNode root = MAPPER.readTree(query); JsonNode mustFilters = root.get("query").get("knn").get("embedding").get("filter").get("bool").get("must"); - + // Should have only 1 filter: deleted=false assertEquals(1, mustFilters.size()); assertFalse(mustFilters.get(0).get("term").get("deleted").asBoolean()); } + + // ------------------------------------------------------------------------- + // buildNativeESQuery tests (Elasticsearch 8.x/9.x top-level knn format) + // ------------------------------------------------------------------------- + + @Test + void testNativeESQueryTopLevelKnnStructure() throws Exception { + float[] vector = {0.1f, 0.2f, 0.3f}; + int size = 10; + int k = 100; + + + String query = VectorSearchQueryBuilder.buildNativeESQuery(vector, size, 0, k, Map.of()); + + JsonNode root = MAPPER.readTree(query); + assertEquals(size, root.get("size").asInt()); + + // Must have top-level "knn", NOT "query" + assertTrue(root.has("knn"), "ES native query must have top-level 'knn'"); + assertTrue(!root.has("query"), "ES native query must not have 'query' key"); + + JsonNode knn = root.get("knn"); + assertEquals("embedding", knn.get("field").asText()); + assertEquals(k, knn.get("k").asInt()); + assertNotNull(knn.get("query_vector")); + assertTrue(knn.get("query_vector").isArray()); + assertEquals(3, knn.get("query_vector").size()); + } + + @Test + void testNativeESQueryNumCandidates() throws Exception { + float[] vector = {0.1f}; + + // default multiplier (2): k * 2 < 100 → num_candidates should be 100 + String query1 = VectorSearchQueryBuilder.buildNativeESQuery(vector, 10, 0, 30, Map.of()); + JsonNode root1 = MAPPER.readTree(query1); + assertEquals(100, root1.get("knn").get("num_candidates").asInt()); + + // default multiplier (2): k * 2 > 100 → num_candidates should be k * 2 + String query2 = VectorSearchQueryBuilder.buildNativeESQuery(vector, 10, 0, 200, Map.of()); + JsonNode root2 = MAPPER.readTree(query2); + assertEquals(400, root2.get("knn").get("num_candidates").asInt()); + + // custom multiplier (5): num_candidates = max(k * 5, 100) + String query3 = VectorSearchQueryBuilder.buildNativeESQuery(vector, 10, 0, 100, Map.of(), 5); + JsonNode root3 = MAPPER.readTree(query3); + assertEquals(500, root3.get("knn").get("num_candidates").asInt()); + } + + @Test + void testNativeESQueryAlwaysHasDeletedFilter() throws Exception { + float[] vector = {0.1f, 0.2f}; + + String query = VectorSearchQueryBuilder.buildNativeESQuery(vector, 10, 0, 100, Map.of()); + + JsonNode root = MAPPER.readTree(query); + JsonNode mustFilters = root.get("knn").get("filter").get("bool").get("must"); + + assertNotNull(mustFilters); + assertTrue(mustFilters.isArray()); + assertTrue(mustFilters.size() >= 1); + assertEquals(false, mustFilters.get(0).get("term").get("deleted").asBoolean()); + } + + @Test + void testNativeESQueryWithEntityTypeFilter() throws Exception { + float[] vector = {0.5f}; + Map> filters = Map.of("entityType", List.of("table", "dashboard")); + + String query = VectorSearchQueryBuilder.buildNativeESQuery(vector, 5, 0, 50, filters); + + JsonNode root = MAPPER.readTree(query); + JsonNode mustFilters = root.get("knn").get("filter").get("bool").get("must"); + + assertEquals(2, mustFilters.size()); + JsonNode entityTypeFilter = mustFilters.get(1); + assertTrue(entityTypeFilter.has("terms")); + JsonNode entityTypes = entityTypeFilter.get("terms").get("entityType"); + assertEquals(2, entityTypes.size()); + assertEquals("table", entityTypes.get(0).asText()); + assertEquals("dashboard", entityTypes.get(1).asText()); + } + + @Test + void testNativeESQueryWithOwnersFilter() throws Exception { + float[] vector = {0.1f}; + Map> filters = Map.of("owners", List.of("user1", "team2")); + + String query = VectorSearchQueryBuilder.buildNativeESQuery(vector, 10, 0, 100, filters); + + JsonNode root = MAPPER.readTree(query); + JsonNode mustFilters = root.get("knn").get("filter").get("bool").get("must"); + + assertEquals(2, mustFilters.size()); + JsonNode ownersFilter = mustFilters.get(1); + assertTrue(ownersFilter.has("bool")); + JsonNode shouldClauses = ownersFilter.get("bool").get("should"); + assertNotNull(shouldClauses); + assertEquals(2, shouldClauses.size()); + + String ownersJson = shouldClauses.toString(); + assertTrue(ownersJson.contains("user1")); + assertTrue(ownersJson.contains("team2")); + } + + @Test + void testNativeESQueryWithTagsFilter() throws Exception { + float[] vector = {0.1f, 0.2f}; + Map> filters = Map.of("tags", List.of("PII.Sensitive")); + + String query = VectorSearchQueryBuilder.buildNativeESQuery(vector, 10, 0, 100, filters); + + JsonNode root = MAPPER.readTree(query); + JsonNode mustFilters = root.get("knn").get("filter").get("bool").get("must"); + + assertEquals(2, mustFilters.size()); + JsonNode tagsFilter = mustFilters.get(1); + assertTrue(tagsFilter.has("term")); + assertEquals("PII.Sensitive", tagsFilter.get("term").get("tags.tagFQN").asText()); + } + + @Test + void testNativeESQueryWithMultipleFilters() throws Exception { + float[] vector = {0.1f, 0.2f}; + Map> filters = + Map.of( + "entityType", List.of("table"), + "tier", List.of("Tier.Tier1"), + "serviceType", List.of("BigQuery")); + + String query = VectorSearchQueryBuilder.buildNativeESQuery(vector, 10, 0, 100, filters); + + JsonNode root = MAPPER.readTree(query); + JsonNode mustFilters = root.get("knn").get("filter").get("bool").get("must"); + + assertEquals(4, mustFilters.size(), "deleted=false + 3 user filters"); + String filtersJson = mustFilters.toString(); + assertTrue(filtersJson.contains("entityType")); + assertTrue(filtersJson.contains("tier")); + assertTrue(filtersJson.contains("serviceType")); + } + + @Test + void testNativeESQuerySourceExcludesEmbedding() throws Exception { + float[] vector = {0.1f}; + + String query = VectorSearchQueryBuilder.buildNativeESQuery(vector, 10, 0, 100, Map.of()); + + JsonNode root = MAPPER.readTree(query); + JsonNode excludes = root.get("_source").get("excludes"); + assertNotNull(excludes); + assertTrue(excludes.isArray()); + assertEquals("embedding", excludes.get(0).asText()); + } + + @Test + void testNativeESQueryAndOpenSearchQueryProduceSameFilters() throws Exception { + float[] vector = {0.1f, 0.2f}; + Map> filters = + Map.of( + "entityType", List.of("table"), + "owners", List.of("alice"), + "tier", List.of("Tier.Gold")); + + String osQuery = VectorSearchQueryBuilder.build(vector, 10, 0, 100, filters, 0.0); + String esQuery = VectorSearchQueryBuilder.buildNativeESQuery(vector, 10, 0, 100, filters); + + JsonNode osFilters = + MAPPER + .readTree(osQuery) + .get("query") + .get("knn") + .get("embedding") + .get("filter") + .get("bool") + .get("must"); + JsonNode esFilters = MAPPER.readTree(esQuery).get("knn").get("filter").get("bool").get("must"); + + assertEquals( + osFilters.size(), + esFilters.size(), + "Both queries should produce the same number of filter clauses"); + assertEquals(osFilters.toString(), esFilters.toString(), "Filter clauses should be identical"); + } } diff --git a/openmetadata-spec/src/main/resources/json/schema/configuration/elasticSearchConfiguration.json b/openmetadata-spec/src/main/resources/json/schema/configuration/elasticSearchConfiguration.json index 38ca4c3f5779..78775d4547ac 100644 --- a/openmetadata-spec/src/main/resources/json/schema/configuration/elasticSearchConfiguration.json +++ b/openmetadata-spec/src/main/resources/json/schema/configuration/elasticSearchConfiguration.json @@ -159,6 +159,12 @@ "default": 10, "minimum": 1 }, + "knnNumCandidatesMultiplier": { + "description": "Multiplier applied to k when computing num_candidates for Elasticsearch kNN vector search. num_candidates = max(k * multiplier, 100). Higher values improve recall at the cost of latency. Defaults to 2.", + "type": "integer", + "default": 2, + "minimum": 1 + }, "providerClass": { "description": "Fully qualified class name of the NLQService implementation to use", "type": "string", diff --git a/openmetadata-ui/src/main/resources/ui/src/generated/configuration/aiPlatformConfiguration.ts b/openmetadata-ui/src/main/resources/ui/src/generated/configuration/aiPlatformConfiguration.ts index ef75b021c9f5..f1f5cb02eca5 100644 --- a/openmetadata-ui/src/main/resources/ui/src/generated/configuration/aiPlatformConfiguration.ts +++ b/openmetadata-ui/src/main/resources/ui/src/generated/configuration/aiPlatformConfiguration.ts @@ -89,11 +89,7 @@ export interface GrpcConfiguration { */ port: number; /** - * Deadline (minutes) Collate enforces on an AI Platform streaming response. Carried on the - * gRPC call, so the AI Platform reads it from context and wraps up gracefully. The chat - * lock sweeper uses streamDeadlineMinutes + 2 as its default stale-lock ceiling (override - * via COLLATE_CHAT_LOCK_MAX_DURATION_MINUTES). Capped at 60 minutes; for longer tasks - * prefer async job + polling over a single long-lived stream. + * Deadline (minutes) enforced on a streaming response from the gRPC server. */ streamDeadlineMinutes?: number; [property: string]: any; diff --git a/openmetadata-ui/src/main/resources/ui/src/generated/configuration/elasticSearchConfiguration.ts b/openmetadata-ui/src/main/resources/ui/src/generated/configuration/elasticSearchConfiguration.ts index c3eba5bf8b23..79dee6cf735b 100644 --- a/openmetadata-ui/src/main/resources/ui/src/generated/configuration/elasticSearchConfiguration.ts +++ b/openmetadata-ui/src/main/resources/ui/src/generated/configuration/elasticSearchConfiguration.ts @@ -138,6 +138,12 @@ export interface NaturalLanguageSearch { * Weight for BM25 keyword search results in hybrid RRF pipeline (0.0-1.0) */ keywordWeight?: number; + /** + * Multiplier applied to k when computing num_candidates for Elasticsearch kNN vector + * search. num_candidates = max(k * multiplier, 100). Higher values improve recall at the + * cost of latency. Defaults to 2. + */ + knnNumCandidatesMultiplier?: number; /** * Maximum number of concurrent embedding API requests. Controls the semaphore used to * throttle calls to the embedding provider and prevent overwhelming HTTP/2 connection diff --git a/openmetadata-ui/src/main/resources/ui/src/generated/settings/settings.ts b/openmetadata-ui/src/main/resources/ui/src/generated/settings/settings.ts index a20a2d9106a0..6be6f6fb76c6 100644 --- a/openmetadata-ui/src/main/resources/ui/src/generated/settings/settings.ts +++ b/openmetadata-ui/src/main/resources/ui/src/generated/settings/settings.ts @@ -2184,6 +2184,12 @@ export interface NaturalLanguageSearch { * Weight for BM25 keyword search results in hybrid RRF pipeline (0.0-1.0) */ keywordWeight?: number; + /** + * Multiplier applied to k when computing num_candidates for Elasticsearch kNN vector + * search. num_candidates = max(k * multiplier, 100). Higher values improve recall at the + * cost of latency. Defaults to 2. + */ + knnNumCandidatesMultiplier?: number; /** * Maximum number of concurrent embedding API requests. Controls the semaphore used to * throttle calls to the embedding provider and prevent overwhelming HTTP/2 connection