From c3ff93bc70c860a64046afb283ba8317bec65b6c Mon Sep 17 00:00:00 2001 From: "donghyuck, son" Date: Thu, 30 Apr 2026 19:32:06 +0900 Subject: [PATCH 1/2] [ai-assisted] feat(ai): support vector projection umap tsne Issue: - Closes #390 Why: - Vector Map needs server-side UMAP and t-SNE projection choices in addition to PCA. - The client should keep rendering stored x/y coordinates without embedding or projection calculation. What: - Add UMAP and TSNE ProjectionAlgorithm values. - Add dependency-free server-side UMAP/t-SNE-style projection generators using PCA initialization and embedding-neighbor refinement. - Register PCA/UMAP/TSNE default generator beans and preserve SPI replacement by preferring non-default same-algorithm generators. - Update vector projection docs and tests for algorithm parsing, generator output, service selection, and auto-configuration registration. Validation: - ./gradlew :studio-platform-ai:test :starter:studio-platform-starter-ai:test :starter:studio-platform-starter-ai-web:test (PASS) - ./gradlew test (PASS) - git diff --check (PASS) AI-Assisted: Yes Subagent used: Yes Delegated scope: code review for issue #390 before and after generator precedence fix --- .../studio-platform-starter-ai-web/README.md | 8 +- .../autoconfigure/AiWebAutoConfiguration.java | 21 ++- ...jectionGeneratorAutoConfigurationTest.java | 47 +++++ ...VectorVisualizationMgmtControllerTest.java | 35 ++++ .../DefaultVectorProjectionJobService.java | 15 +- .../DefaultVectorProjectionServiceTest.java | 87 +++++++++ studio-platform-ai/README.md | 7 +- .../NeighborVectorProjectionGenerator.java | 143 +++++++++++++++ .../PcaVectorProjectionGenerator.java | 148 +-------------- .../visualization/ProjectionAlgorithm.java | 4 +- .../ProjectionCoordinateSupport.java | 171 ++++++++++++++++++ .../TsneVectorProjectionGenerator.java | 21 +++ .../UmapVectorProjectionGenerator.java | 20 ++ ...NeighborVectorProjectionGeneratorTest.java | 62 +++++++ 14 files changed, 636 insertions(+), 153 deletions(-) create mode 100644 starter/studio-platform-starter-ai-web/src/test/java/studio/one/platform/ai/autoconfigure/VectorProjectionGeneratorAutoConfigurationTest.java create mode 100644 studio-platform-ai/src/main/java/studio/one/platform/ai/core/vector/visualization/NeighborVectorProjectionGenerator.java create mode 100644 studio-platform-ai/src/main/java/studio/one/platform/ai/core/vector/visualization/ProjectionCoordinateSupport.java create mode 100644 studio-platform-ai/src/main/java/studio/one/platform/ai/core/vector/visualization/TsneVectorProjectionGenerator.java create mode 100644 studio-platform-ai/src/main/java/studio/one/platform/ai/core/vector/visualization/UmapVectorProjectionGenerator.java create mode 100644 studio-platform-ai/src/test/java/studio/one/platform/ai/core/vector/visualization/NeighborVectorProjectionGeneratorTest.java diff --git a/starter/studio-platform-starter-ai-web/README.md b/starter/studio-platform-starter-ai-web/README.md index 4d5b8c6f..9da3c285 100644 --- a/starter/studio-platform-starter-ai-web/README.md +++ b/starter/studio-platform-starter-ai-web/README.md @@ -113,8 +113,10 @@ studio: 원본 embedding 테이블은 변경하지 않고, `tb_ai_vector_projection`과 `tb_ai_vector_projection_point`에 projection job 상태와 미리 계산된 좌표만 저장한다. 화면 요청 시마다 고차원 벡터를 다시 projection하지 않는다. -기본 알고리즘은 `PCA`다. v1 구현은 Java 내장 연산으로 PCA 좌표를 계산하고, 후속 UMAP/t-SNE는 -`VectorProjectionGenerator` 구현을 추가해 확장한다. `targetTypes`는 UI 문서 분류가 아니라 +지원 알고리즘은 `PCA`, `UMAP`, `TSNE`다. 기본값은 `PCA`이며, UMAP/t-SNE는 서버가 PCA 초기 좌표를 +embedding cosine distance 기반 이웃 보존 방식으로 보정해 저장하는 deterministic 구현이다. 클라이언트는 알고리즘별 계산을 하지 않고 +points API가 내려준 `x`, `y` 좌표를 그대로 렌더링한다. 더 높은 품질의 projection이 필요하면 +`VectorProjectionGenerator` 구현을 교체해 확장한다. `targetTypes`는 UI 문서 분류가 아니라 `tb_ai_document_chunk.object_type`에 저장된 RAG index objectType 기준이다. 예를 들어 `attachment`, `forums-post-attachment`, 정책 object type 값처럼 색인 job이 사용한 objectType을 지정한다. `targetTypes`가 비어 있으면 전체 vector item을 대상으로 한다. @@ -133,7 +135,7 @@ Content-Type: application/json { "name": "NCS-과정-청크 벡터맵", "targetTypes": ["NCS_UNIT", "COURSE", "COURSE_CHUNK"], - "algorithm": "PCA", + "algorithm": "UMAP", "filters": { "useYn": "Y" } diff --git a/starter/studio-platform-starter-ai-web/src/main/java/studio/one/platform/ai/autoconfigure/AiWebAutoConfiguration.java b/starter/studio-platform-starter-ai-web/src/main/java/studio/one/platform/ai/autoconfigure/AiWebAutoConfiguration.java index db0d8682..5575d5b7 100644 --- a/starter/studio-platform-starter-ai-web/src/main/java/studio/one/platform/ai/autoconfigure/AiWebAutoConfiguration.java +++ b/starter/studio-platform-starter-ai-web/src/main/java/studio/one/platform/ai/autoconfigure/AiWebAutoConfiguration.java @@ -14,6 +14,8 @@ import org.springframework.context.annotation.Conditional; import org.springframework.context.annotation.Bean; import org.springframework.context.annotation.Configuration; +import org.springframework.core.Ordered; +import org.springframework.core.annotation.Order; import org.springframework.core.env.Environment; import org.springframework.http.converter.json.Jackson2ObjectMapperBuilder; import org.springframework.jdbc.core.namedparam.NamedParameterJdbcTemplate; @@ -31,6 +33,8 @@ import studio.one.platform.ai.core.vector.VectorStorePort; import studio.one.platform.ai.core.vector.visualization.ExistingVectorItemRepository; import studio.one.platform.ai.core.vector.visualization.PcaVectorProjectionGenerator; +import studio.one.platform.ai.core.vector.visualization.TsneVectorProjectionGenerator; +import studio.one.platform.ai.core.vector.visualization.UmapVectorProjectionGenerator; import studio.one.platform.ai.core.vector.visualization.VectorProjectionGenerator; import studio.one.platform.ai.core.vector.visualization.VectorProjectionPointRepository; import studio.one.platform.ai.core.vector.visualization.VectorProjectionRepository; @@ -154,11 +158,26 @@ VectorController vectorController( } @Bean + @Order(Ordered.LOWEST_PRECEDENCE) @ConditionalOnMissingBean - VectorProjectionGenerator vectorProjectionGenerator() { + PcaVectorProjectionGenerator pcaVectorProjectionGenerator() { return new PcaVectorProjectionGenerator(); } + @Bean + @Order(Ordered.LOWEST_PRECEDENCE) + @ConditionalOnMissingBean + UmapVectorProjectionGenerator umapVectorProjectionGenerator() { + return new UmapVectorProjectionGenerator(); + } + + @Bean + @Order(Ordered.LOWEST_PRECEDENCE) + @ConditionalOnMissingBean + TsneVectorProjectionGenerator tsneVectorProjectionGenerator() { + return new TsneVectorProjectionGenerator(); + } + @Bean(name = "vectorProjectionExecutor") @ConditionalOnMissingBean(name = "vectorProjectionExecutor") Executor vectorProjectionExecutor() { diff --git a/starter/studio-platform-starter-ai-web/src/test/java/studio/one/platform/ai/autoconfigure/VectorProjectionGeneratorAutoConfigurationTest.java b/starter/studio-platform-starter-ai-web/src/test/java/studio/one/platform/ai/autoconfigure/VectorProjectionGeneratorAutoConfigurationTest.java new file mode 100644 index 00000000..70cd2b68 --- /dev/null +++ b/starter/studio-platform-starter-ai-web/src/test/java/studio/one/platform/ai/autoconfigure/VectorProjectionGeneratorAutoConfigurationTest.java @@ -0,0 +1,47 @@ +package studio.one.platform.ai.autoconfigure; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.mockito.Mockito.mock; + +import java.util.Map; + +import org.junit.jupiter.api.Test; +import org.springframework.boot.autoconfigure.AutoConfigurations; +import org.springframework.boot.autoconfigure.context.ConfigurationPropertiesAutoConfiguration; +import org.springframework.boot.test.context.runner.ApplicationContextRunner; + +import studio.one.platform.ai.core.chat.ChatPort; +import studio.one.platform.ai.autoconfigure.config.AiAdapterProperties; +import studio.one.platform.ai.core.embedding.EmbeddingPort; +import studio.one.platform.ai.core.registry.AiProviderRegistry; +import studio.one.platform.ai.core.vector.visualization.ProjectionAlgorithm; +import studio.one.platform.ai.core.vector.visualization.VectorProjectionGenerator; +import studio.one.platform.ai.service.pipeline.RagPipelineService; +import studio.one.platform.ai.service.prompt.PromptRenderer; + +class VectorProjectionGeneratorAutoConfigurationTest { + + private final ApplicationContextRunner contextRunner = new ApplicationContextRunner() + .withConfiguration(AutoConfigurations.of(ConfigurationPropertiesAutoConfiguration.class)) + .withUserConfiguration(AiWebAutoConfiguration.class) + .withBean(AiProviderRegistry.class, () -> mock(AiProviderRegistry.class)) + .withBean(RagPipelineService.class, () -> mock(RagPipelineService.class)) + .withBean(EmbeddingPort.class, () -> mock(EmbeddingPort.class)) + .withBean(ChatPort.class, () -> mock(ChatPort.class)) + .withBean(PromptRenderer.class, () -> mock(PromptRenderer.class)) + .withBean(AiAdapterProperties.class, AiAdapterProperties::new) + .withPropertyValues( + "studio.features.ai.enabled=true", + "studio.ai.endpoints.enabled=true"); + + @Test + void registersDefaultProjectionGeneratorsForSupportedAlgorithms() { + contextRunner.run(context -> { + assertThat(context).hasNotFailed(); + Map generators = context.getBeansOfType(VectorProjectionGenerator.class); + assertThat(generators.values()) + .extracting(VectorProjectionGenerator::algorithm) + .contains(ProjectionAlgorithm.PCA, ProjectionAlgorithm.UMAP, ProjectionAlgorithm.TSNE); + }); + } +} diff --git a/starter/studio-platform-starter-ai-web/src/test/java/studio/one/platform/ai/web/controller/VectorVisualizationMgmtControllerTest.java b/starter/studio-platform-starter-ai-web/src/test/java/studio/one/platform/ai/web/controller/VectorVisualizationMgmtControllerTest.java index f43fe9d1..86301edf 100644 --- a/starter/studio-platform-starter-ai-web/src/test/java/studio/one/platform/ai/web/controller/VectorVisualizationMgmtControllerTest.java +++ b/starter/studio-platform-starter-ai-web/src/test/java/studio/one/platform/ai/web/controller/VectorVisualizationMgmtControllerTest.java @@ -12,6 +12,7 @@ import java.util.Map; import org.junit.jupiter.api.Test; +import org.mockito.ArgumentCaptor; import org.springframework.http.HttpStatus; import org.springframework.web.server.ResponseStatusException; @@ -21,6 +22,7 @@ import studio.one.platform.ai.core.vector.visualization.ProjectionStatus; import studio.one.platform.ai.core.vector.visualization.VectorItem; import studio.one.platform.ai.core.vector.visualization.VectorProjection; +import studio.one.platform.ai.service.visualization.VectorProjectionCreateCommand; import studio.one.platform.ai.service.visualization.VectorProjectionService; import studio.one.platform.ai.service.visualization.VectorSearchVisualizationService; import studio.one.platform.ai.web.dto.visualization.ProjectionCreateRequest; @@ -48,6 +50,39 @@ void createProjectionDefaultsToPcaAndReturnsRequestedStatus() { verify(projectionService).create(any()); } + @Test + void createProjectionAcceptsUmapAndTsneAlgorithms() { + VectorProjectionService projectionService = mock(VectorProjectionService.class); + when(projectionService.create(any())).thenReturn(projection(ProjectionStatus.REQUESTED)); + VectorVisualizationMgmtController controller = new VectorVisualizationMgmtController( + projectionService, + mock(VectorSearchVisualizationService.class)); + ArgumentCaptor captor = ArgumentCaptor.forClass(VectorProjectionCreateCommand.class); + + controller.createProjection(new ProjectionCreateRequest("UMAP map", List.of(), "UMAP", Map.of())); + controller.createProjection(new ProjectionCreateRequest("TSNE map", List.of(), "tsne", Map.of())); + + verify(projectionService, org.mockito.Mockito.times(2)).create(captor.capture()); + assertThat(captor.getAllValues()) + .extracting(VectorProjectionCreateCommand::algorithm) + .containsExactly(ProjectionAlgorithm.UMAP, ProjectionAlgorithm.TSNE); + } + + @Test + void createProjectionRejectsUnsupportedAlgorithm() { + VectorVisualizationMgmtController controller = new VectorVisualizationMgmtController( + mock(VectorProjectionService.class), + mock(VectorSearchVisualizationService.class)); + + assertThatThrownBy(() -> controller.createProjection(new ProjectionCreateRequest( + "bad map", + List.of(), + "MDS", + Map.of()))) + .isInstanceOf(ResponseStatusException.class) + .hasMessageContaining("UNSUPPORTED_PROJECTION_ALGORITHM"); + } + @Test void pointsReturnsClientOrientedShape() { VectorProjectionService projectionService = mock(VectorProjectionService.class); diff --git a/starter/studio-platform-starter-ai/src/main/java/studio/one/platform/ai/service/visualization/DefaultVectorProjectionJobService.java b/starter/studio-platform-starter-ai/src/main/java/studio/one/platform/ai/service/visualization/DefaultVectorProjectionJobService.java index 154d8f3f..b4ff05c6 100644 --- a/starter/studio-platform-starter-ai/src/main/java/studio/one/platform/ai/service/visualization/DefaultVectorProjectionJobService.java +++ b/starter/studio-platform-starter-ai/src/main/java/studio/one/platform/ai/service/visualization/DefaultVectorProjectionJobService.java @@ -6,7 +6,10 @@ import lombok.extern.slf4j.Slf4j; import studio.one.platform.ai.core.vector.visualization.ExistingVectorItemRepository; +import studio.one.platform.ai.core.vector.visualization.PcaVectorProjectionGenerator; import studio.one.platform.ai.core.vector.visualization.ProjectionStatus; +import studio.one.platform.ai.core.vector.visualization.TsneVectorProjectionGenerator; +import studio.one.platform.ai.core.vector.visualization.UmapVectorProjectionGenerator; import studio.one.platform.ai.core.vector.visualization.VectorItem; import studio.one.platform.ai.core.vector.visualization.VectorProjection; import studio.one.platform.ai.core.vector.visualization.VectorProjectionGenerator; @@ -67,12 +70,22 @@ public void run(String projectionId) { } private VectorProjectionGenerator generatorFor(VectorProjection projection) { - return generators.stream() + List matching = generators.stream() .filter(generator -> generator.algorithm() == projection.algorithm()) + .toList(); + return matching.stream() + .filter(generator -> !isDefaultGenerator(generator)) .findFirst() + .or(() -> matching.stream().findFirst()) .orElseThrow(() -> new IllegalArgumentException("UNSUPPORTED_PROJECTION_ALGORITHM")); } + private boolean isDefaultGenerator(VectorProjectionGenerator generator) { + return generator instanceof PcaVectorProjectionGenerator + || generator instanceof UmapVectorProjectionGenerator + || generator instanceof TsneVectorProjectionGenerator; + } + private List actualTargetTypes(List items) { return items.stream() .map(VectorItem::targetType) diff --git a/starter/studio-platform-starter-ai/src/test/java/studio/one/platform/ai/service/visualization/DefaultVectorProjectionServiceTest.java b/starter/studio-platform-starter-ai/src/test/java/studio/one/platform/ai/service/visualization/DefaultVectorProjectionServiceTest.java index 1c8acf93..93cd973b 100644 --- a/starter/studio-platform-starter-ai/src/test/java/studio/one/platform/ai/service/visualization/DefaultVectorProjectionServiceTest.java +++ b/starter/studio-platform-starter-ai/src/test/java/studio/one/platform/ai/service/visualization/DefaultVectorProjectionServiceTest.java @@ -16,6 +16,7 @@ import org.springframework.web.server.ResponseStatusException; import studio.one.platform.ai.core.vector.visualization.ExistingVectorItemRepository; +import studio.one.platform.ai.core.vector.visualization.PcaVectorProjectionGenerator; import studio.one.platform.ai.core.vector.visualization.ProjectionAlgorithm; import studio.one.platform.ai.core.vector.visualization.ProjectionPointPage; import studio.one.platform.ai.core.vector.visualization.ProjectionPointView; @@ -73,6 +74,72 @@ public List generate(String projectionId, List { + assertThat(point.x()).isEqualTo(0.7); + assertThat(point.y()).isEqualTo(0.8); + }); + } + @Test void pointsRejectsProjectionThatIsNotCompleted() { FakeProjectionRepository projections = new FakeProjectionRepository(); @@ -148,6 +215,26 @@ private static VectorItem item(String id) { return new VectorItem(id, "COURSE_CHUNK", "course-1", "label", "text", List.of(0.1, 0.2), "model", 2, Map.of(), Instant.now()); } + private static VectorProjectionGenerator generator(ProjectionAlgorithm algorithm) { + return generator(algorithm, 0.1, 0.2); + } + + private static VectorProjectionGenerator generator(ProjectionAlgorithm algorithm, double x, double y) { + return new VectorProjectionGenerator() { + @Override + public ProjectionAlgorithm algorithm() { + return algorithm; + } + + @Override + public List generate(String projectionId, List sourceItems, Instant createdAt) { + return sourceItems.stream() + .map(source -> new VectorProjectionPoint(projectionId, source.vectorItemId(), x, y, null, 0, createdAt)) + .toList(); + } + }; + } + private static final class FakeProjectionRepository implements VectorProjectionRepository { private final Map projections = new LinkedHashMap<>(); diff --git a/studio-platform-ai/README.md b/studio-platform-ai/README.md index 6a274a29..189d4da0 100644 --- a/studio-platform-ai/README.md +++ b/studio-platform-ai/README.md @@ -177,8 +177,11 @@ Core 계약은 `studio.one.platform.ai.core.vector.visualization` 패키지에 | `ExistingVectorItemRepository` | 기존 벡터 테이블을 읽는 adapter 포트 | | `VectorProjectionRepository` / `VectorProjectionPointRepository` | projection 상태와 좌표 저장소 포트 | -기본 `PcaVectorProjectionGenerator`는 추가 의존성 없이 PCA 좌표를 계산한다. 화면 API는 원본 embedding 값을 -반환하지 않으며, metadata는 표시용 allowlist로 제한한다. 원문 text는 벡터 항목 상세 조회에서만 제공한다. +기본 `PcaVectorProjectionGenerator`는 추가 의존성 없이 PCA 좌표를 계산한다. `UmapVectorProjectionGenerator`와 +`TsneVectorProjectionGenerator`는 PCA 초기 좌표를 embedding cosine distance 기반 이웃 보존 방식으로 보정하는 서버 측 deterministic 구현이다. +운영 화면은 `ProjectionAlgorithm` 값과 무관하게 저장된 `x`, `y` 좌표만 사용한다. 더 높은 품질의 알고리즘이 +필요하면 동일한 `VectorProjectionGenerator` SPI로 교체한다. 화면 API는 원본 embedding 값을 반환하지 않으며, +metadata는 표시용 allowlist로 제한한다. 원문 text는 벡터 항목 상세 조회에서만 제공한다. - 기존 `VectorDocument`와 `VectorSearchResult`는 기존 호출자 호환성을 위해 유지한다. - 새 context assembly 계약은 아직 만들지 않는다. web context 조립은 `starter-ai-web`의 `RagContextBuilder`, chunk 주변 문맥 확장은 `studio-platform-chunking`의 `ChunkContextExpander`를 우선 사용한다. diff --git a/studio-platform-ai/src/main/java/studio/one/platform/ai/core/vector/visualization/NeighborVectorProjectionGenerator.java b/studio-platform-ai/src/main/java/studio/one/platform/ai/core/vector/visualization/NeighborVectorProjectionGenerator.java new file mode 100644 index 00000000..783639e0 --- /dev/null +++ b/studio-platform-ai/src/main/java/studio/one/platform/ai/core/vector/visualization/NeighborVectorProjectionGenerator.java @@ -0,0 +1,143 @@ +package studio.one.platform.ai.core.vector.visualization; + +import java.time.Instant; +import java.util.ArrayList; +import java.util.Comparator; +import java.util.List; + +abstract class NeighborVectorProjectionGenerator implements VectorProjectionGenerator { + + private static final double EPSILON = 1.0e-9d; + private static final int MAX_DISTANCE_DIMENSIONS = 256; + + private final int neighborCount; + private final int iterations; + private final double attraction; + private final double repulsion; + private final double targetDistance; + + NeighborVectorProjectionGenerator( + int neighborCount, + int iterations, + double attraction, + double repulsion, + double targetDistance) { + this.neighborCount = neighborCount; + this.iterations = iterations; + this.attraction = attraction; + this.repulsion = repulsion; + this.targetDistance = targetDistance; + } + + @Override + public List generate(String projectionId, List items, Instant createdAt) { + List usable = ProjectionCoordinateSupport.usableItems(items); + if (usable.isEmpty()) { + return List.of(); + } + List coordinates = ProjectionCoordinateSupport.pcaCoordinates(usable); + if (coordinates.isEmpty()) { + return List.of(); + } + List> neighbors = neighbors( + usable, + Math.min(neighborCount, Math.max(1, usable.size() - 1))); + refine(coordinates, neighbors); + ProjectionCoordinateSupport.normalizeCoordinates(coordinates); + return ProjectionCoordinateSupport.points(projectionId, usable, coordinates, createdAt); + } + + private List> neighbors(List items, int limit) { + List> neighbors = new ArrayList<>(items.size()); + for (int i = 0; i < items.size(); i++) { + List candidates = new ArrayList<>(Math.max(0, items.size() - 1)); + for (int j = 0; j < items.size(); j++) { + if (i == j) { + continue; + } + double distance = embeddingDistance(items.get(i).embedding(), items.get(j).embedding()); + candidates.add(new Neighbor(j, distance, similarity(distance))); + } + candidates.sort(Comparator.comparingDouble(Neighbor::distance)); + neighbors.add(candidates.stream().limit(limit).toList()); + } + return neighbors; + } + + private void refine(List coordinates, List> neighbors) { + if (coordinates.size() <= 1) { + return; + } + for (int iteration = 0; iteration < iterations; iteration++) { + double[][] deltas = new double[coordinates.size()][2]; + for (int i = 0; i < coordinates.size(); i++) { + double[] current = coordinates.get(i); + for (Neighbor neighbor : neighbors.get(i)) { + double[] other = coordinates.get(neighbor.index()); + double dx = other[0] - current[0]; + double dy = other[1] - current[1]; + double distance = Math.sqrt(dx * dx + dy * dy) + EPSILON; + double force = attraction * neighbor.similarity() * (distance - targetDistance); + deltas[i][0] += force * dx / distance; + deltas[i][1] += force * dy / distance; + } + applySampledRepulsion(i, coordinates, deltas[i]); + } + for (int i = 0; i < coordinates.size(); i++) { + coordinates.get(i)[0] += deltas[i][0]; + coordinates.get(i)[1] += deltas[i][1]; + } + } + } + + private void applySampledRepulsion(int index, List coordinates, double[] delta) { + int samples = Math.min(16, coordinates.size() - 1); + if (samples <= 0) { + return; + } + double[] current = coordinates.get(index); + for (int sample = 1; sample <= samples; sample++) { + int otherIndex = Math.floorMod(index + sample * 37, coordinates.size()); + if (otherIndex == index) { + otherIndex = (otherIndex + 1) % coordinates.size(); + } + double[] other = coordinates.get(otherIndex); + double dx = current[0] - other[0]; + double dy = current[1] - other[1]; + double squared = dx * dx + dy * dy + EPSILON; + double force = repulsion / squared; + delta[0] += force * dx; + delta[1] += force * dy; + } + } + + private double similarity(double distance) { + return 1.0d / (1.0d + Math.max(0.0d, distance)); + } + + private double embeddingDistance(List left, List right) { + int dimensions = Math.min(left.size(), right.size()); + if (dimensions <= 0) { + return 1.0d; + } + int step = Math.max(1, dimensions / MAX_DISTANCE_DIMENSIONS); + double dot = 0.0d; + double leftNorm = 0.0d; + double rightNorm = 0.0d; + for (int i = 0; i < dimensions; i += step) { + double leftValue = left.get(i); + double rightValue = right.get(i); + dot += leftValue * rightValue; + leftNorm += leftValue * leftValue; + rightNorm += rightValue * rightValue; + } + if (leftNorm <= 0.0d || rightNorm <= 0.0d) { + return 1.0d; + } + double similarity = dot / (Math.sqrt(leftNorm) * Math.sqrt(rightNorm)); + return 1.0d - Math.max(-1.0d, Math.min(1.0d, similarity)); + } + + private record Neighbor(int index, double distance, double similarity) { + } +} diff --git a/studio-platform-ai/src/main/java/studio/one/platform/ai/core/vector/visualization/PcaVectorProjectionGenerator.java b/studio-platform-ai/src/main/java/studio/one/platform/ai/core/vector/visualization/PcaVectorProjectionGenerator.java index 4b344074..da72091d 100644 --- a/studio-platform-ai/src/main/java/studio/one/platform/ai/core/vector/visualization/PcaVectorProjectionGenerator.java +++ b/studio-platform-ai/src/main/java/studio/one/platform/ai/core/vector/visualization/PcaVectorProjectionGenerator.java @@ -1,8 +1,6 @@ package studio.one.platform.ai.core.vector.visualization; import java.time.Instant; -import java.util.ArrayList; -import java.util.Comparator; import java.util.List; /** @@ -10,8 +8,6 @@ */ public class PcaVectorProjectionGenerator implements VectorProjectionGenerator { - private static final int POWER_ITERATIONS = 60; - @Override public ProjectionAlgorithm algorithm() { return ProjectionAlgorithm.PCA; @@ -19,149 +15,11 @@ public ProjectionAlgorithm algorithm() { @Override public List generate(String projectionId, List items, Instant createdAt) { - List usable = items.stream() - .filter(item -> item.embedding() != null && !item.embedding().isEmpty()) - .toList(); + List usable = ProjectionCoordinateSupport.usableItems(items); if (usable.isEmpty()) { return List.of(); } - int dimensions = usable.stream() - .map(VectorItem::embedding) - .mapToInt(List::size) - .min() - .orElse(0); - if (dimensions <= 0) { - return List.of(); - } - double[][] centered = centeredMatrix(usable, dimensions); - double[][] covariance = covariance(centered, dimensions); - double[] first = principalComponent(covariance, null); - double[] second = dimensions == 1 ? new double[] {0.0d} : principalComponent(covariance, first); - List coordinates = new ArrayList<>(usable.size()); - for (double[] vector : centered) { - coordinates.add(new double[] {dot(vector, first), dot(vector, second)}); - } - normalizeCoordinates(coordinates); - List points = new ArrayList<>(usable.size()); - for (int i = 0; i < usable.size(); i++) { - VectorItem item = usable.get(i); - double[] coordinate = coordinates.get(i); - points.add(new VectorProjectionPoint( - projectionId, - item.vectorItemId(), - coordinate[0], - coordinate[1], - null, - i, - createdAt)); - } - return points; - } - - private double[][] centeredMatrix(List items, int dimensions) { - double[] means = new double[dimensions]; - for (VectorItem item : items) { - for (int i = 0; i < dimensions; i++) { - means[i] += item.embedding().get(i); - } - } - for (int i = 0; i < dimensions; i++) { - means[i] /= items.size(); - } - double[][] centered = new double[items.size()][dimensions]; - for (int row = 0; row < items.size(); row++) { - List embedding = items.get(row).embedding(); - for (int col = 0; col < dimensions; col++) { - centered[row][col] = embedding.get(col) - means[col]; - } - } - return centered; - } - - private double[][] covariance(double[][] centered, int dimensions) { - double[][] covariance = new double[dimensions][dimensions]; - int divisor = Math.max(1, centered.length - 1); - for (double[] row : centered) { - for (int i = 0; i < dimensions; i++) { - for (int j = i; j < dimensions; j++) { - covariance[i][j] += row[i] * row[j] / divisor; - } - } - } - for (int i = 0; i < dimensions; i++) { - for (int j = 0; j < i; j++) { - covariance[i][j] = covariance[j][i]; - } - } - return covariance; - } - - private double[] principalComponent(double[][] matrix, double[] orthogonalTo) { - int dimensions = matrix.length; - double[] vector = new double[dimensions]; - for (int i = 0; i < dimensions; i++) { - vector[i] = 1.0d / Math.sqrt(dimensions); - } - for (int iteration = 0; iteration < POWER_ITERATIONS; iteration++) { - double[] next = multiply(matrix, vector); - if (orthogonalTo != null) { - subtractProjection(next, orthogonalTo); - } - normalize(next); - vector = next; - } - return vector; - } - - private double[] multiply(double[][] matrix, double[] vector) { - double[] result = new double[vector.length]; - for (int row = 0; row < matrix.length; row++) { - for (int col = 0; col < vector.length; col++) { - result[row] += matrix[row][col] * vector[col]; - } - } - return result; - } - - private void subtractProjection(double[] vector, double[] basis) { - double scale = dot(vector, basis); - for (int i = 0; i < vector.length; i++) { - vector[i] -= scale * basis[i]; - } - } - - private void normalize(double[] vector) { - double norm = Math.sqrt(dot(vector, vector)); - if (norm == 0.0d || Double.isNaN(norm)) { - for (int i = 0; i < vector.length; i++) { - vector[i] = i == 0 ? 1.0d : 0.0d; - } - return; - } - for (int i = 0; i < vector.length; i++) { - vector[i] /= norm; - } - } - - private double dot(double[] left, double[] right) { - double result = 0.0d; - for (int i = 0; i < Math.min(left.length, right.length); i++) { - result += left[i] * right[i]; - } - return result; - } - - private void normalizeCoordinates(List coordinates) { - double maxAbs = coordinates.stream() - .flatMap(values -> List.of(Math.abs(values[0]), Math.abs(values[1])).stream()) - .max(Comparator.naturalOrder()) - .orElse(0.0d); - if (maxAbs <= 0.0d || Double.isNaN(maxAbs)) { - return; - } - for (double[] coordinate : coordinates) { - coordinate[0] /= maxAbs; - coordinate[1] /= maxAbs; - } + List coordinates = ProjectionCoordinateSupport.pcaCoordinates(usable); + return ProjectionCoordinateSupport.points(projectionId, usable, coordinates, createdAt); } } diff --git a/studio-platform-ai/src/main/java/studio/one/platform/ai/core/vector/visualization/ProjectionAlgorithm.java b/studio-platform-ai/src/main/java/studio/one/platform/ai/core/vector/visualization/ProjectionAlgorithm.java index bb7054b0..c9400a53 100644 --- a/studio-platform-ai/src/main/java/studio/one/platform/ai/core/vector/visualization/ProjectionAlgorithm.java +++ b/studio-platform-ai/src/main/java/studio/one/platform/ai/core/vector/visualization/ProjectionAlgorithm.java @@ -1,5 +1,7 @@ package studio.one.platform.ai.core.vector.visualization; public enum ProjectionAlgorithm { - PCA + PCA, + UMAP, + TSNE } diff --git a/studio-platform-ai/src/main/java/studio/one/platform/ai/core/vector/visualization/ProjectionCoordinateSupport.java b/studio-platform-ai/src/main/java/studio/one/platform/ai/core/vector/visualization/ProjectionCoordinateSupport.java new file mode 100644 index 00000000..c743288c --- /dev/null +++ b/studio-platform-ai/src/main/java/studio/one/platform/ai/core/vector/visualization/ProjectionCoordinateSupport.java @@ -0,0 +1,171 @@ +package studio.one.platform.ai.core.vector.visualization; + +import java.util.ArrayList; +import java.util.Comparator; +import java.util.List; + +final class ProjectionCoordinateSupport { + + private static final int POWER_ITERATIONS = 60; + + private ProjectionCoordinateSupport() { + } + + static List usableItems(List items) { + return items.stream() + .filter(item -> item.embedding() != null && !item.embedding().isEmpty()) + .toList(); + } + + static List pcaCoordinates(List usable) { + if (usable.isEmpty()) { + return List.of(); + } + int dimensions = usable.stream() + .map(VectorItem::embedding) + .mapToInt(List::size) + .min() + .orElse(0); + if (dimensions <= 0) { + return List.of(); + } + double[][] centered = centeredMatrix(usable, dimensions); + double[][] covariance = covariance(centered, dimensions); + double[] first = principalComponent(covariance, null); + double[] second = dimensions == 1 ? new double[] {0.0d} : principalComponent(covariance, first); + List coordinates = new ArrayList<>(usable.size()); + for (double[] vector : centered) { + coordinates.add(new double[] {dot(vector, first), dot(vector, second)}); + } + normalizeCoordinates(coordinates); + return coordinates; + } + + static List points( + String projectionId, + List usable, + List coordinates, + java.time.Instant createdAt) { + List points = new ArrayList<>(Math.min(usable.size(), coordinates.size())); + for (int i = 0; i < usable.size() && i < coordinates.size(); i++) { + VectorItem item = usable.get(i); + double[] coordinate = coordinates.get(i); + points.add(new VectorProjectionPoint( + projectionId, + item.vectorItemId(), + coordinate[0], + coordinate[1], + null, + i, + createdAt)); + } + return points; + } + + static void normalizeCoordinates(List coordinates) { + double maxAbs = coordinates.stream() + .flatMap(values -> List.of(Math.abs(values[0]), Math.abs(values[1])).stream()) + .max(Comparator.naturalOrder()) + .orElse(0.0d); + if (maxAbs <= 0.0d || Double.isNaN(maxAbs)) { + return; + } + for (double[] coordinate : coordinates) { + coordinate[0] /= maxAbs; + coordinate[1] /= maxAbs; + } + } + + private static double[][] centeredMatrix(List items, int dimensions) { + double[] means = new double[dimensions]; + for (VectorItem item : items) { + for (int i = 0; i < dimensions; i++) { + means[i] += item.embedding().get(i); + } + } + for (int i = 0; i < dimensions; i++) { + means[i] /= items.size(); + } + double[][] centered = new double[items.size()][dimensions]; + for (int row = 0; row < items.size(); row++) { + List embedding = items.get(row).embedding(); + for (int col = 0; col < dimensions; col++) { + centered[row][col] = embedding.get(col) - means[col]; + } + } + return centered; + } + + private static double[][] covariance(double[][] centered, int dimensions) { + double[][] covariance = new double[dimensions][dimensions]; + int divisor = Math.max(1, centered.length - 1); + for (double[] row : centered) { + for (int i = 0; i < dimensions; i++) { + for (int j = i; j < dimensions; j++) { + covariance[i][j] += row[i] * row[j] / divisor; + } + } + } + for (int i = 0; i < dimensions; i++) { + for (int j = 0; j < i; j++) { + covariance[i][j] = covariance[j][i]; + } + } + return covariance; + } + + private static double[] principalComponent(double[][] matrix, double[] orthogonalTo) { + int dimensions = matrix.length; + double[] vector = new double[dimensions]; + for (int i = 0; i < dimensions; i++) { + vector[i] = 1.0d / Math.sqrt(dimensions); + } + for (int iteration = 0; iteration < POWER_ITERATIONS; iteration++) { + double[] next = multiply(matrix, vector); + if (orthogonalTo != null) { + subtractProjection(next, orthogonalTo); + } + normalize(next); + vector = next; + } + return vector; + } + + private static double[] multiply(double[][] matrix, double[] vector) { + double[] result = new double[vector.length]; + for (int row = 0; row < matrix.length; row++) { + for (int col = 0; col < vector.length; col++) { + result[row] += matrix[row][col] * vector[col]; + } + } + return result; + } + + private static void subtractProjection(double[] vector, double[] basis) { + double scale = dot(vector, basis); + for (int i = 0; i < vector.length; i++) { + vector[i] -= scale * basis[i]; + } + } + + private static void normalize(double[] vector) { + double norm = Math.sqrt(dot(vector, vector)); + if (norm == 0.0d || Double.isNaN(norm)) { + for (int i = 0; i < vector.length; i++) { + vector[i] = i == 0 ? 1.0d : 0.0d; + } + return; + } + for (int i = 0; i < vector.length; i++) { + vector[i] /= norm; + } + } + + private static double dot(double[] left, double[] right) { + double result = 0.0d; + for (int i = 0; i < Math.min(left.length, right.length); i++) { + result += left[i] * right[i]; + } + return result; + } +} diff --git a/studio-platform-ai/src/main/java/studio/one/platform/ai/core/vector/visualization/TsneVectorProjectionGenerator.java b/studio-platform-ai/src/main/java/studio/one/platform/ai/core/vector/visualization/TsneVectorProjectionGenerator.java new file mode 100644 index 00000000..e5295679 --- /dev/null +++ b/studio-platform-ai/src/main/java/studio/one/platform/ai/core/vector/visualization/TsneVectorProjectionGenerator.java @@ -0,0 +1,21 @@ +package studio.one.platform.ai.core.vector.visualization; + +/** + * Dependency-free t-SNE-style projection for management visualization. + *

+ * This generator uses a deterministic neighbor-preserving refinement over PCA + * initialization. It is intentionally bounded for management scatter plots and + * can be replaced by a library-backed implementation through + * {@link VectorProjectionGenerator}. + */ +public class TsneVectorProjectionGenerator extends NeighborVectorProjectionGenerator { + + public TsneVectorProjectionGenerator() { + super(30, 120, 0.025d, 0.0015d, 0.08d); + } + + @Override + public ProjectionAlgorithm algorithm() { + return ProjectionAlgorithm.TSNE; + } +} diff --git a/studio-platform-ai/src/main/java/studio/one/platform/ai/core/vector/visualization/UmapVectorProjectionGenerator.java b/studio-platform-ai/src/main/java/studio/one/platform/ai/core/vector/visualization/UmapVectorProjectionGenerator.java new file mode 100644 index 00000000..523fe8e5 --- /dev/null +++ b/studio-platform-ai/src/main/java/studio/one/platform/ai/core/vector/visualization/UmapVectorProjectionGenerator.java @@ -0,0 +1,20 @@ +package studio.one.platform.ai.core.vector.visualization; + +/** + * Dependency-free UMAP-style projection for management visualization. + *

+ * The implementation starts from PCA coordinates and preserves local neighbors + * with deterministic attraction/repulsion steps so projection jobs remain + * server-side and repeatable without a native numerical dependency. + */ +public class UmapVectorProjectionGenerator extends NeighborVectorProjectionGenerator { + + public UmapVectorProjectionGenerator() { + super(12, 80, 0.035d, 0.0008d, 0.12d); + } + + @Override + public ProjectionAlgorithm algorithm() { + return ProjectionAlgorithm.UMAP; + } +} diff --git a/studio-platform-ai/src/test/java/studio/one/platform/ai/core/vector/visualization/NeighborVectorProjectionGeneratorTest.java b/studio-platform-ai/src/test/java/studio/one/platform/ai/core/vector/visualization/NeighborVectorProjectionGeneratorTest.java new file mode 100644 index 00000000..5ba861d9 --- /dev/null +++ b/studio-platform-ai/src/test/java/studio/one/platform/ai/core/vector/visualization/NeighborVectorProjectionGeneratorTest.java @@ -0,0 +1,62 @@ +package studio.one.platform.ai.core.vector.visualization; + +import static org.assertj.core.api.Assertions.assertThat; + +import java.time.Instant; +import java.util.List; +import java.util.Map; + +import org.junit.jupiter.api.Test; + +class NeighborVectorProjectionGeneratorTest { + + @Test + void umapGeneratesNormalizedPoints() { + UmapVectorProjectionGenerator generator = new UmapVectorProjectionGenerator(); + + List points = generator.generate("proj-umap", items(), Instant.parse("2026-04-30T00:00:00Z")); + + assertThat(generator.algorithm()).isEqualTo(ProjectionAlgorithm.UMAP); + assertThat(points).hasSize(4); + assertThat(points).allSatisfy(point -> { + assertThat(point.projectionId()).isEqualTo("proj-umap"); + assertThat(point.x()).isBetween(-1.0, 1.0); + assertThat(point.y()).isBetween(-1.0, 1.0); + }); + } + + @Test + void tsneGeneratesNormalizedPoints() { + TsneVectorProjectionGenerator generator = new TsneVectorProjectionGenerator(); + + List points = generator.generate("proj-tsne", items(), Instant.parse("2026-04-30T00:00:00Z")); + + assertThat(generator.algorithm()).isEqualTo(ProjectionAlgorithm.TSNE); + assertThat(points).hasSize(4); + assertThat(points).allSatisfy(point -> { + assertThat(point.projectionId()).isEqualTo("proj-tsne"); + assertThat(point.x()).isBetween(-1.0, 1.0); + assertThat(point.y()).isBetween(-1.0, 1.0); + }); + } + + @Test + void generatorsReturnEmptyWhenNoEmbeddingExists() { + assertThat(new UmapVectorProjectionGenerator().generate("proj-1", List.of(item("a", List.of())), Instant.now())) + .isEmpty(); + assertThat(new TsneVectorProjectionGenerator().generate("proj-1", List.of(item("a", List.of())), Instant.now())) + .isEmpty(); + } + + private List items() { + return List.of( + item("a", List.of(1.0, 0.0, 0.0, 0.2)), + item("b", List.of(0.9, 0.1, 0.0, 0.3)), + item("c", List.of(0.0, 1.0, 0.1, 0.1)), + item("d", List.of(0.0, 0.8, 0.2, 0.0))); + } + + private static VectorItem item(String id, List embedding) { + return new VectorItem(id, "TYPE", "source", id, "text", embedding, "model", embedding.size(), Map.of(), Instant.now()); + } +} From f8e24814173f7ef794c086ac8f4bd21fcc5f4e9c Mon Sep 17 00:00:00 2001 From: "donghyuck, son" Date: Thu, 30 Apr 2026 19:59:52 +0900 Subject: [PATCH 2/2] [ai-assisted] fix(ai): harden projection generator wiring Issue: - #390 Why: - Review found that multiple default VectorProjectionGenerator beans could break single-type injection and that fixed-stride repulsion sampling collapsed for projection sizes divisible by 37. What: - Mark the default PCA projection generator as primary to preserve single VectorProjectionGenerator injection compatibility. - Add auto-configuration coverage for interface bean resolution and custom generator inclusion in the job service. - Use a projection-size coprime stride for repulsion sampling and add regression coverage for 37/74 item sizes. Validation: - ./gradlew :studio-platform-ai:test :starter:studio-platform-starter-ai:test :starter:studio-platform-starter-ai-web:test (PASS) - ./gradlew test (PASS) - git diff --check (PASS) AI-Assisted: Yes Subagent used: Yes Delegated scope: PR #391 blocker review and re-review --- .../autoconfigure/AiWebAutoConfiguration.java | 2 + ...jectionGeneratorAutoConfigurationTest.java | 54 ++++++++++++++++++- .../NeighborVectorProjectionGenerator.java | 45 +++++++++++++--- ...NeighborVectorProjectionGeneratorTest.java | 10 ++++ 4 files changed, 103 insertions(+), 8 deletions(-) diff --git a/starter/studio-platform-starter-ai-web/src/main/java/studio/one/platform/ai/autoconfigure/AiWebAutoConfiguration.java b/starter/studio-platform-starter-ai-web/src/main/java/studio/one/platform/ai/autoconfigure/AiWebAutoConfiguration.java index 5575d5b7..c544b811 100644 --- a/starter/studio-platform-starter-ai-web/src/main/java/studio/one/platform/ai/autoconfigure/AiWebAutoConfiguration.java +++ b/starter/studio-platform-starter-ai-web/src/main/java/studio/one/platform/ai/autoconfigure/AiWebAutoConfiguration.java @@ -14,6 +14,7 @@ import org.springframework.context.annotation.Conditional; import org.springframework.context.annotation.Bean; import org.springframework.context.annotation.Configuration; +import org.springframework.context.annotation.Primary; import org.springframework.core.Ordered; import org.springframework.core.annotation.Order; import org.springframework.core.env.Environment; @@ -158,6 +159,7 @@ VectorController vectorController( } @Bean + @Primary @Order(Ordered.LOWEST_PRECEDENCE) @ConditionalOnMissingBean PcaVectorProjectionGenerator pcaVectorProjectionGenerator() { diff --git a/starter/studio-platform-starter-ai-web/src/test/java/studio/one/platform/ai/autoconfigure/VectorProjectionGeneratorAutoConfigurationTest.java b/starter/studio-platform-starter-ai-web/src/test/java/studio/one/platform/ai/autoconfigure/VectorProjectionGeneratorAutoConfigurationTest.java index 70cd2b68..3c44be1a 100644 --- a/starter/studio-platform-starter-ai-web/src/test/java/studio/one/platform/ai/autoconfigure/VectorProjectionGeneratorAutoConfigurationTest.java +++ b/starter/studio-platform-starter-ai-web/src/test/java/studio/one/platform/ai/autoconfigure/VectorProjectionGeneratorAutoConfigurationTest.java @@ -9,13 +9,23 @@ import org.springframework.boot.autoconfigure.AutoConfigurations; import org.springframework.boot.autoconfigure.context.ConfigurationPropertiesAutoConfiguration; import org.springframework.boot.test.context.runner.ApplicationContextRunner; +import org.springframework.test.util.ReflectionTestUtils; -import studio.one.platform.ai.core.chat.ChatPort; import studio.one.platform.ai.autoconfigure.config.AiAdapterProperties; +import studio.one.platform.ai.core.chat.ChatPort; import studio.one.platform.ai.core.embedding.EmbeddingPort; import studio.one.platform.ai.core.registry.AiProviderRegistry; +import studio.one.platform.ai.core.vector.visualization.ExistingVectorItemRepository; import studio.one.platform.ai.core.vector.visualization.ProjectionAlgorithm; +import studio.one.platform.ai.core.vector.visualization.VectorItem; import studio.one.platform.ai.core.vector.visualization.VectorProjectionGenerator; +import studio.one.platform.ai.core.vector.visualization.VectorProjectionPoint; +import studio.one.platform.ai.core.vector.visualization.VectorProjectionPointRepository; +import studio.one.platform.ai.core.vector.visualization.VectorProjectionRepository; +import studio.one.platform.ai.service.visualization.DefaultVectorProjectionJobService; +import studio.one.platform.ai.service.visualization.DefaultVectorProjectionService; +import studio.one.platform.ai.service.visualization.VectorProjectionJobService; +import studio.one.platform.ai.service.visualization.VectorProjectionService; import studio.one.platform.ai.service.pipeline.RagPipelineService; import studio.one.platform.ai.service.prompt.PromptRenderer; @@ -42,6 +52,48 @@ void registersDefaultProjectionGeneratorsForSupportedAlgorithms() { assertThat(generators.values()) .extracting(VectorProjectionGenerator::algorithm) .contains(ProjectionAlgorithm.PCA, ProjectionAlgorithm.UMAP, ProjectionAlgorithm.TSNE); + assertThat(context.getBean(VectorProjectionGenerator.class).algorithm()).isEqualTo(ProjectionAlgorithm.PCA); }); } + + @Test + void customGeneratorIsIncludedInAutoConfiguredJobServiceWithDefaults() { + VectorProjectionGenerator customPca = new VectorProjectionGenerator() { + @Override + public ProjectionAlgorithm algorithm() { + return ProjectionAlgorithm.PCA; + } + + @Override + public java.util.List generate( + String projectionId, + java.util.List items, + java.time.Instant createdAt) { + return java.util.List.of(); + } + }; + + contextRunner + .withBean(VectorProjectionGenerator.class, () -> customPca) + .withBean(VectorProjectionRepository.class, () -> mock(VectorProjectionRepository.class)) + .withBean(VectorProjectionPointRepository.class, () -> mock(VectorProjectionPointRepository.class)) + .withBean(ExistingVectorItemRepository.class, () -> mock(ExistingVectorItemRepository.class)) + .run(context -> { + assertThat(context).hasNotFailed(); + assertThat(context).hasSingleBean(VectorProjectionJobService.class); + assertThat(context).hasSingleBean(VectorProjectionService.class); + DefaultVectorProjectionJobService jobService = + (DefaultVectorProjectionJobService) context.getBean(VectorProjectionJobService.class); + @SuppressWarnings("unchecked") + java.util.List generators = + (java.util.List) ReflectionTestUtils.getField(jobService, "generators"); + assertThat(generators) + .contains(customPca) + .extracting(VectorProjectionGenerator::algorithm) + .contains(ProjectionAlgorithm.PCA, ProjectionAlgorithm.UMAP, ProjectionAlgorithm.TSNE); + DefaultVectorProjectionService projectionService = + (DefaultVectorProjectionService) context.getBean(VectorProjectionService.class); + assertThat(ReflectionTestUtils.getField(projectionService, "jobService")).isSameAs(jobService); + }); + } } diff --git a/studio-platform-ai/src/main/java/studio/one/platform/ai/core/vector/visualization/NeighborVectorProjectionGenerator.java b/studio-platform-ai/src/main/java/studio/one/platform/ai/core/vector/visualization/NeighborVectorProjectionGenerator.java index 783639e0..e6616700 100644 --- a/studio-platform-ai/src/main/java/studio/one/platform/ai/core/vector/visualization/NeighborVectorProjectionGenerator.java +++ b/studio-platform-ai/src/main/java/studio/one/platform/ai/core/vector/visualization/NeighborVectorProjectionGenerator.java @@ -91,16 +91,12 @@ private void refine(List coordinates, List> neighbors) } private void applySampledRepulsion(int index, List coordinates, double[] delta) { - int samples = Math.min(16, coordinates.size() - 1); - if (samples <= 0) { + List peers = sampledPeerIndexes(index, coordinates.size(), 16); + if (peers.isEmpty()) { return; } double[] current = coordinates.get(index); - for (int sample = 1; sample <= samples; sample++) { - int otherIndex = Math.floorMod(index + sample * 37, coordinates.size()); - if (otherIndex == index) { - otherIndex = (otherIndex + 1) % coordinates.size(); - } + for (int otherIndex : peers) { double[] other = coordinates.get(otherIndex); double dx = current[0] - other[0]; double dy = current[1] - other[1]; @@ -111,6 +107,41 @@ private void applySampledRepulsion(int index, List coordinates, double } } + static List sampledPeerIndexes(int index, int size, int maxSamples) { + int samples = Math.min(maxSamples, size - 1); + if (samples <= 0) { + return List.of(); + } + int stride = coprimeStride(size); + List peers = new ArrayList<>(samples); + for (int sample = 1; sample <= samples; sample++) { + int otherIndex = Math.floorMod(index + sample * stride, size); + if (otherIndex != index) { + peers.add(otherIndex); + } + } + return peers; + } + + private static int coprimeStride(int size) { + int stride = Math.min(37, Math.max(1, size - 1)); + while (stride > 1 && gcd(stride, size) != 1) { + stride--; + } + return stride; + } + + private static int gcd(int left, int right) { + int a = Math.abs(left); + int b = Math.abs(right); + while (b != 0) { + int next = a % b; + a = b; + b = next; + } + return a; + } + private double similarity(double distance) { return 1.0d / (1.0d + Math.max(0.0d, distance)); } diff --git a/studio-platform-ai/src/test/java/studio/one/platform/ai/core/vector/visualization/NeighborVectorProjectionGeneratorTest.java b/studio-platform-ai/src/test/java/studio/one/platform/ai/core/vector/visualization/NeighborVectorProjectionGeneratorTest.java index 5ba861d9..9e9b4412 100644 --- a/studio-platform-ai/src/test/java/studio/one/platform/ai/core/vector/visualization/NeighborVectorProjectionGeneratorTest.java +++ b/studio-platform-ai/src/test/java/studio/one/platform/ai/core/vector/visualization/NeighborVectorProjectionGeneratorTest.java @@ -48,6 +48,16 @@ void generatorsReturnEmptyWhenNoEmbeddingExists() { .isEmpty(); } + @Test + void repulsionSamplerKeepsDistinctPeersWhenSizeIsMultipleOfDefaultStride() { + assertThat(NeighborVectorProjectionGenerator.sampledPeerIndexes(0, 37, 16)) + .doesNotHaveDuplicates() + .hasSize(16); + assertThat(NeighborVectorProjectionGenerator.sampledPeerIndexes(0, 74, 16)) + .doesNotHaveDuplicates() + .hasSize(16); + } + private List items() { return List.of( item("a", List.of(1.0, 0.0, 0.0, 0.2)),