Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
import org.openmetadata.schema.utils.JsonUtils;
import org.openmetadata.service.Entity;
import org.openmetadata.service.limits.Limits;
import org.openmetadata.service.search.vector.OpenSearchVectorService;
import org.openmetadata.service.search.vector.VectorIndexService;
import org.openmetadata.service.search.vector.utils.DTOs.VectorSearchResponse;
import org.openmetadata.service.security.Authorizer;
import org.openmetadata.service.security.auth.CatalogSecurityContext;
Expand Down Expand Up @@ -42,7 +42,7 @@ public Map<String, Object> execute(
"Semantic search is not enabled. Configure vector embeddings in the OpenMetadata server settings.");
}

OpenSearchVectorService vectorService = OpenSearchVectorService.getInstance();
VectorIndexService vectorService = Entity.getSearchRepository().getVectorIndexService();
if (vectorService == null) {
return errorResponse("Vector search service is not initialized");
}
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package org.openmetadata.service.apps.bundles.searchIndex;

import static org.openmetadata.service.workflows.searchIndex.ReindexingUtil.ENTITY_TYPE_KEY;
import static org.openmetadata.service.workflows.searchIndex.ReindexingUtil.RECREATE_CONTEXT;
import static org.openmetadata.service.workflows.searchIndex.ReindexingUtil.TARGET_INDEX_KEY;

import com.fasterxml.jackson.databind.ObjectMapper;
Expand All @@ -18,12 +19,16 @@
import java.util.Map;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.CopyOnWriteArrayList;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.LinkedBlockingQueue;
import java.util.concurrent.Phaser;
import java.util.concurrent.ScheduledExecutorService;
import java.util.concurrent.Semaphore;
import java.util.concurrent.ThreadPoolExecutor;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.TimeoutException;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.concurrent.atomic.AtomicLong;
import java.util.concurrent.locks.Lock;
Expand All @@ -40,15 +45,20 @@
import org.openmetadata.service.apps.bundles.searchIndex.stats.StatsResult;
import org.openmetadata.service.exception.EntityNotFoundException;
import org.openmetadata.service.exception.SearchIndexException;
import org.openmetadata.service.search.ReindexContext;
import org.openmetadata.service.search.SearchRepository;
import org.openmetadata.service.search.elasticsearch.ElasticSearchClient;
import org.openmetadata.service.search.elasticsearch.EsUtils;
import org.openmetadata.service.search.vector.VectorDocBuilder;
import org.openmetadata.service.search.vector.VectorIndexService;
import org.openmetadata.service.search.vector.utils.AvailableEntityTypes;

/**
* Elasticsearch implementation using new Java API client with custom bulk handler
*/
@Slf4j
public class ElasticSearchBulkSink implements BulkSink {
private static final int MAX_VECTOR_THREADS = 10;
private static final ObjectMapper OBJECT_MAPPER = new ObjectMapper();
private static final JacksonJsonpMapper JACKSON_JSONP_MAPPER =
new JacksonJsonpMapper(OBJECT_MAPPER);
Expand Down Expand Up @@ -107,6 +117,13 @@ public static synchronized void resetDocBuildPoolSize() {
// Failure callback
private volatile FailureCallback failureCallback;

// Vector embedding fields
private final ExecutorService vectorExecutor;
private final Phaser phaser;
private final CopyOnWriteArrayList<Thread> pendingThreads;
private final AtomicLong vectorSuccess = new AtomicLong(0);
private final AtomicLong vectorFailed = new AtomicLong(0);

public ElasticSearchBulkSink(
SearchRepository searchRepository,
int batchSize,
Expand All @@ -117,6 +134,10 @@ public ElasticSearchBulkSink(
this.searchClient = (ElasticSearchClient) searchRepository.getSearchClient();
this.batchSize = batchSize;
this.maxConcurrentRequests = maxConcurrentRequests;
this.vectorExecutor =
Executors.newFixedThreadPool(MAX_VECTOR_THREADS, Thread.ofVirtual().factory());
this.phaser = new Phaser(1);
this.pendingThreads = new CopyOnWriteArrayList<>();

// Initialize stats
stats.withTotalRecords(0).withSuccessRecords(0).withFailedRecords(0);
Expand Down Expand Up @@ -204,6 +225,10 @@ public void write(List<?> entities, Map<String, Object> contextData) throws Exce
CompletableFuture.allOf(futures.toArray(CompletableFuture[]::new)).join();
} else {
List<EntityInterface> entityInterfaces = (List<EntityInterface>) entities;
ReindexContext reindexContext =
contextData.containsKey(RECREATE_CONTEXT)
? (ReindexContext) contextData.get(RECREATE_CONTEXT)
: null;

// Add entities to search index in parallel
List<CompletableFuture<Void>> futures =
Expand All @@ -216,9 +241,9 @@ public void write(List<?> entities, Map<String, Object> contextData) throws Exce
.toList();
CompletableFuture.allOf(futures.toArray(CompletableFuture[]::new)).join();

// Process vector embeddings in batch (no-op in base class)
if (embeddingsEnabled) {
addEntitiesToVectorIndexBatch(bulkProcessor, entityInterfaces, recreateIndex);
addEntitiesToVectorIndexBatch(
bulkProcessor, entityInterfaces, recreateIndex, reindexContext, tracker);
}
}
} catch (Exception e) {
Expand Down Expand Up @@ -422,15 +447,17 @@ public StepStats getProcessStats() {
@Override
public void close() {
try {
// Flush any pending requests
awaitVectorCompletion(60);

bulkProcessor.flush();

// Wait for completion
boolean terminated = bulkProcessor.awaitClose(60, TimeUnit.SECONDS);
if (!terminated) {
LOG.warn("Bulk processor did not terminate within timeout");
}

vectorExecutor.shutdown();

// Final stats update to ensure all processed records are reflected
updateStats();

Expand Down Expand Up @@ -508,17 +535,161 @@ public void updateConcurrentRequests(int concurrentRequests) {
LOG.info("Concurrent requests updated to: {}", concurrentRequests);
}

/**
* Checks if vector embeddings are enabled for a specific entity type.
* This combines SearchRepository capability check with job configuration.
*/
boolean isVectorEmbeddingEnabledForEntity(String entityType) {
return false;
return searchRepository.isVectorEmbeddingEnabled()
&& searchRepository.getVectorIndexService() != null
&& AvailableEntityTypes.isVectorIndexable(entityType);
}

void addEntitiesToVectorIndexBatch(
CustomBulkProcessor bulkProcessor, List<EntityInterface> entities, boolean recreateIndex) {
// TODO: Implement Elasticsearch vector embedding support
CustomBulkProcessor bulkProcessor,
List<EntityInterface> entities,
boolean recreateIndex,
ReindexContext reindexContext,
StageStatsTracker tracker) {
if (entities.isEmpty()) {
return;
}

VectorIndexService vectorService = searchRepository.getVectorIndexService();
if (vectorService == null) {
return;
}

String entityType = entities.getFirst().getEntityReference().getType();
if (!AvailableEntityTypes.isVectorIndexable(entityType)) {
return;
}

String canonicalIndex = VectorIndexService.getClusteredIndexName();
String finalTargetIndex = canonicalIndex;
String finalSourceIndex = null;

if (reindexContext != null) {
String stagedIndex =
reindexContext.getStagedIndex(VectorIndexService.VECTOR_INDEX_KEY).orElse(null);
if (stagedIndex != null) {
finalSourceIndex = canonicalIndex;
finalTargetIndex = stagedIndex;
}
}

String srcIdx = finalSourceIndex;
String tgtIdx = finalTargetIndex;

Map<String, String> existingFingerprints = Map.of();
if (srcIdx != null) {
List<String> parentIds = new ArrayList<>(entities.size());
for (EntityInterface entity : entities) {
parentIds.add(entity.getId().toString());
}
existingFingerprints = vectorService.getExistingFingerprintsBatch(srcIdx, parentIds);
}

for (EntityInterface entity : entities) {
String parentId = entity.getId().toString();
String existingFp = existingFingerprints.get(parentId);
String currentFp = VectorDocBuilder.computeFingerprintForEntity(entity);

if (existingFp != null && existingFp.equals(currentFp) && srcIdx != null) {
submitVectorTask(
() ->
processMigration(
vectorService, srcIdx, tgtIdx, parentId, currentFp, entity, tracker));
} else {
submitVectorTask(() -> processEmbedding(vectorService, entity, tgtIdx, tracker));
}
}
}

private void processMigration(
VectorIndexService vectorService,
String sourceIndex,
String targetIndex,
String parentId,
String fingerprint,
EntityInterface entity,
StageStatsTracker tracker) {
try {
if (vectorService.copyExistingVectorDocuments(
sourceIndex, targetIndex, parentId, fingerprint)) {
vectorSuccess.incrementAndGet();
if (tracker != null) {
tracker.recordVector(StatsResult.SUCCESS);
}
} else {
processEmbedding(vectorService, entity, targetIndex, tracker);
}
} catch (Exception e) {
LOG.warn(
"Vector migration failed for parent_id={}, falling back to recomputation: {}",
parentId,
e.getMessage());
processEmbedding(vectorService, entity, targetIndex, tracker);
}
}

private void processEmbedding(
VectorIndexService vectorService,
EntityInterface entity,
String targetIndex,
StageStatsTracker tracker) {
try {
vectorService.updateVectorEmbeddings(entity, targetIndex);
vectorSuccess.incrementAndGet();
if (tracker != null) {
tracker.recordVector(StatsResult.SUCCESS);
}
} catch (Exception e) {
vectorFailed.incrementAndGet();
if (tracker != null) {
tracker.recordVector(StatsResult.FAILED);
}
LOG.error("Vector embedding failed for entity {}: {}", entity.getId(), e.getMessage(), e);
}
}

private void submitVectorTask(Runnable task) {
phaser.register();
vectorExecutor.submit(
() -> {
Thread current = Thread.currentThread();
pendingThreads.add(current);
try {
task.run();
} finally {
pendingThreads.remove(current);
phaser.arriveAndDeregister();
}
});
}

@Override
public boolean awaitVectorCompletion(int timeoutSeconds) {
try {
int phase = phaser.arrive();
phaser.awaitAdvanceInterruptibly(phase, timeoutSeconds, TimeUnit.SECONDS);
return true;
} catch (InterruptedException e) {
Thread.currentThread().interrupt();
return false;
} catch (TimeoutException e) {
LOG.warn("Timeout waiting for vector completion after {}s", timeoutSeconds);
return false;
}
}

@Override
public int getPendingVectorTaskCount() {
return Math.max(0, phaser.getUnarrivedParties() - 1);
}

@Override
public StepStats getVectorStats() {
return new StepStats()
.withTotalRecords((int) (vectorSuccess.get() + vectorFailed.get()))
.withSuccessRecords((int) vectorSuccess.get())
.withFailedRecords((int) vectorFailed.get());
}

public static class CustomBulkProcessor {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
import lombok.extern.slf4j.Slf4j;
import org.openmetadata.service.Entity;
import org.openmetadata.service.resources.Collection;
import org.openmetadata.service.search.vector.OpenSearchVectorService;
import org.openmetadata.service.search.vector.VectorIndexService;
import org.openmetadata.service.search.vector.utils.DTOs.FingerprintResponse;
import org.openmetadata.service.search.vector.utils.DTOs.VectorSearchRequest;
import org.openmetadata.service.search.vector.utils.DTOs.VectorSearchResponse;
Expand Down Expand Up @@ -75,7 +75,7 @@ public Response vectorSearchPost(
.build();
}

OpenSearchVectorService vectorService = OpenSearchVectorService.getInstance();
VectorIndexService vectorService = Entity.getSearchRepository().getVectorIndexService();
if (vectorService == null) {
return Response.status(Response.Status.SERVICE_UNAVAILABLE)
.entity("{\"error\":\"Vector search service is not initialized\"}")
Expand Down Expand Up @@ -119,7 +119,7 @@ public Response getFingerprint(
.build();
}

OpenSearchVectorService vectorService = OpenSearchVectorService.getInstance();
VectorIndexService vectorService = Entity.getSearchRepository().getVectorIndexService();
if (vectorService == null) {
return Response.status(Response.Status.SERVICE_UNAVAILABLE)
.entity("{\"error\":\"Vector search service is not initialized\"}")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
import lombok.extern.slf4j.Slf4j;
import org.openmetadata.search.IndexMapping;
import org.openmetadata.service.Entity;
import org.openmetadata.service.search.vector.OpenSearchVectorService;
import org.openmetadata.service.search.vector.VectorIndexService;

@Slf4j
Expand All @@ -17,7 +16,7 @@ public ReindexContext reCreateIndexes(Set<String> entities) {
searchRepository.initializeVectorSearchService();

Set<String> allEntities = new HashSet<>(entities);
if (OpenSearchVectorService.getInstance() != null) {
if (searchRepository.getVectorIndexService() != null) {
allEntities.add(VectorIndexService.VECTOR_INDEX_KEY);
}

Expand All @@ -28,7 +27,7 @@ public ReindexContext reCreateIndexes(Set<String> entities) {
protected void recreateIndexFromMapping(
ReindexContext context, IndexMapping indexMapping, String entityType) {
if (VectorIndexService.VECTOR_INDEX_KEY.equals(entityType)
&& OpenSearchVectorService.getInstance() == null) {
&& Entity.getSearchRepository().getVectorIndexService() == null) {
LOG.info("Skipping vector index recreation - vector service not initialized");
return;
}
Expand All @@ -38,7 +37,7 @@ protected void recreateIndexFromMapping(
@Override
public void promoteEntityIndex(EntityReindexContext context, boolean reindexSuccess) {
if (VectorIndexService.VECTOR_INDEX_KEY.equals(context.getEntityType())
&& OpenSearchVectorService.getInstance() == null) {
&& Entity.getSearchRepository().getVectorIndexService() == null) {
return;
}
super.promoteEntityIndex(context, reindexSuccess);
Expand Down
Loading
Loading