Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 5 additions & 3 deletions starter/studio-platform-starter-ai-web/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -113,8 +113,10 @@ studio:
원본 embedding 테이블은 변경하지 않고, `tb_ai_vector_projection`과 `tb_ai_vector_projection_point`에
projection job 상태와 미리 계산된 좌표만 저장한다. 화면 요청 시마다 고차원 벡터를 다시 projection하지 않는다.

기본 알고리즘은 `PCA`다. v1 구현은 Java 내장 연산으로 PCA 좌표를 계산하고, 후속 UMAP/t-SNE는
`VectorProjectionGenerator` 구현을 추가해 확장한다. `targetTypes`는 UI 문서 분류가 아니라
지원 알고리즘은 `PCA`, `UMAP`, `TSNE`다. 기본값은 `PCA`이며, UMAP/t-SNE는 서버가 PCA 초기 좌표를
embedding cosine distance 기반 이웃 보존 방식으로 보정해 저장하는 deterministic 구현이다. 클라이언트는 알고리즘별 계산을 하지 않고
points API가 내려준 `x`, `y` 좌표를 그대로 렌더링한다. 더 높은 품질의 projection이 필요하면
`VectorProjectionGenerator` 구현을 교체해 확장한다. `targetTypes`는 UI 문서 분류가 아니라
`tb_ai_document_chunk.object_type`에 저장된 RAG index objectType 기준이다. 예를 들어 `attachment`,
`forums-post-attachment`, 정책 object type 값처럼 색인 job이 사용한 objectType을 지정한다.
`targetTypes`가 비어 있으면 전체 vector item을 대상으로 한다.
Expand All @@ -133,7 +135,7 @@ Content-Type: application/json
{
"name": "NCS-과정-청크 벡터맵",
"targetTypes": ["NCS_UNIT", "COURSE", "COURSE_CHUNK"],
"algorithm": "PCA",
"algorithm": "UMAP",
"filters": {
"useYn": "Y"
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,9 @@
import org.springframework.context.annotation.Conditional;
import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Configuration;
import org.springframework.context.annotation.Primary;
import org.springframework.core.Ordered;
import org.springframework.core.annotation.Order;
import org.springframework.core.env.Environment;
import org.springframework.http.converter.json.Jackson2ObjectMapperBuilder;
import org.springframework.jdbc.core.namedparam.NamedParameterJdbcTemplate;
Expand All @@ -31,6 +34,8 @@
import studio.one.platform.ai.core.vector.VectorStorePort;
import studio.one.platform.ai.core.vector.visualization.ExistingVectorItemRepository;
import studio.one.platform.ai.core.vector.visualization.PcaVectorProjectionGenerator;
import studio.one.platform.ai.core.vector.visualization.TsneVectorProjectionGenerator;
import studio.one.platform.ai.core.vector.visualization.UmapVectorProjectionGenerator;
import studio.one.platform.ai.core.vector.visualization.VectorProjectionGenerator;
import studio.one.platform.ai.core.vector.visualization.VectorProjectionPointRepository;
import studio.one.platform.ai.core.vector.visualization.VectorProjectionRepository;
Expand Down Expand Up @@ -154,11 +159,27 @@ VectorController vectorController(
}

@Bean
@Primary
@Order(Ordered.LOWEST_PRECEDENCE)
@ConditionalOnMissingBean
VectorProjectionGenerator vectorProjectionGenerator() {
PcaVectorProjectionGenerator pcaVectorProjectionGenerator() {
return new PcaVectorProjectionGenerator();
}

@Bean
@Order(Ordered.LOWEST_PRECEDENCE)
@ConditionalOnMissingBean
UmapVectorProjectionGenerator umapVectorProjectionGenerator() {
return new UmapVectorProjectionGenerator();
}

@Bean
@Order(Ordered.LOWEST_PRECEDENCE)
@ConditionalOnMissingBean
TsneVectorProjectionGenerator tsneVectorProjectionGenerator() {
return new TsneVectorProjectionGenerator();
}

@Bean(name = "vectorProjectionExecutor")
@ConditionalOnMissingBean(name = "vectorProjectionExecutor")
Executor vectorProjectionExecutor() {
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
package studio.one.platform.ai.autoconfigure;

import static org.assertj.core.api.Assertions.assertThat;
import static org.mockito.Mockito.mock;

import java.util.Map;

import org.junit.jupiter.api.Test;
import org.springframework.boot.autoconfigure.AutoConfigurations;
import org.springframework.boot.autoconfigure.context.ConfigurationPropertiesAutoConfiguration;
import org.springframework.boot.test.context.runner.ApplicationContextRunner;
import org.springframework.test.util.ReflectionTestUtils;

import studio.one.platform.ai.autoconfigure.config.AiAdapterProperties;
import studio.one.platform.ai.core.chat.ChatPort;
import studio.one.platform.ai.core.embedding.EmbeddingPort;
import studio.one.platform.ai.core.registry.AiProviderRegistry;
import studio.one.platform.ai.core.vector.visualization.ExistingVectorItemRepository;
import studio.one.platform.ai.core.vector.visualization.ProjectionAlgorithm;
import studio.one.platform.ai.core.vector.visualization.VectorItem;
import studio.one.platform.ai.core.vector.visualization.VectorProjectionGenerator;
import studio.one.platform.ai.core.vector.visualization.VectorProjectionPoint;
import studio.one.platform.ai.core.vector.visualization.VectorProjectionPointRepository;
import studio.one.platform.ai.core.vector.visualization.VectorProjectionRepository;
import studio.one.platform.ai.service.visualization.DefaultVectorProjectionJobService;
import studio.one.platform.ai.service.visualization.DefaultVectorProjectionService;
import studio.one.platform.ai.service.visualization.VectorProjectionJobService;
import studio.one.platform.ai.service.visualization.VectorProjectionService;
import studio.one.platform.ai.service.pipeline.RagPipelineService;
import studio.one.platform.ai.service.prompt.PromptRenderer;

class VectorProjectionGeneratorAutoConfigurationTest {

private final ApplicationContextRunner contextRunner = new ApplicationContextRunner()
.withConfiguration(AutoConfigurations.of(ConfigurationPropertiesAutoConfiguration.class))
.withUserConfiguration(AiWebAutoConfiguration.class)
.withBean(AiProviderRegistry.class, () -> mock(AiProviderRegistry.class))
.withBean(RagPipelineService.class, () -> mock(RagPipelineService.class))
.withBean(EmbeddingPort.class, () -> mock(EmbeddingPort.class))
.withBean(ChatPort.class, () -> mock(ChatPort.class))
.withBean(PromptRenderer.class, () -> mock(PromptRenderer.class))
.withBean(AiAdapterProperties.class, AiAdapterProperties::new)
.withPropertyValues(
"studio.features.ai.enabled=true",
"studio.ai.endpoints.enabled=true");

@Test
void registersDefaultProjectionGeneratorsForSupportedAlgorithms() {
contextRunner.run(context -> {
assertThat(context).hasNotFailed();
Map<String, VectorProjectionGenerator> generators = context.getBeansOfType(VectorProjectionGenerator.class);
assertThat(generators.values())
.extracting(VectorProjectionGenerator::algorithm)
.contains(ProjectionAlgorithm.PCA, ProjectionAlgorithm.UMAP, ProjectionAlgorithm.TSNE);
assertThat(context.getBean(VectorProjectionGenerator.class).algorithm()).isEqualTo(ProjectionAlgorithm.PCA);
});
}

@Test
void customGeneratorIsIncludedInAutoConfiguredJobServiceWithDefaults() {
VectorProjectionGenerator customPca = new VectorProjectionGenerator() {
@Override
public ProjectionAlgorithm algorithm() {
return ProjectionAlgorithm.PCA;
}

@Override
public java.util.List<VectorProjectionPoint> generate(
String projectionId,
java.util.List<VectorItem> items,
java.time.Instant createdAt) {
return java.util.List.of();
}
};

contextRunner
.withBean(VectorProjectionGenerator.class, () -> customPca)
.withBean(VectorProjectionRepository.class, () -> mock(VectorProjectionRepository.class))
.withBean(VectorProjectionPointRepository.class, () -> mock(VectorProjectionPointRepository.class))
.withBean(ExistingVectorItemRepository.class, () -> mock(ExistingVectorItemRepository.class))
.run(context -> {
assertThat(context).hasNotFailed();
assertThat(context).hasSingleBean(VectorProjectionJobService.class);
assertThat(context).hasSingleBean(VectorProjectionService.class);
DefaultVectorProjectionJobService jobService =
(DefaultVectorProjectionJobService) context.getBean(VectorProjectionJobService.class);
@SuppressWarnings("unchecked")
java.util.List<VectorProjectionGenerator> generators =
(java.util.List<VectorProjectionGenerator>) ReflectionTestUtils.getField(jobService, "generators");
assertThat(generators)
.contains(customPca)
.extracting(VectorProjectionGenerator::algorithm)
.contains(ProjectionAlgorithm.PCA, ProjectionAlgorithm.UMAP, ProjectionAlgorithm.TSNE);
DefaultVectorProjectionService projectionService =
(DefaultVectorProjectionService) context.getBean(VectorProjectionService.class);
assertThat(ReflectionTestUtils.getField(projectionService, "jobService")).isSameAs(jobService);
});
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
import java.util.Map;

import org.junit.jupiter.api.Test;
import org.mockito.ArgumentCaptor;
import org.springframework.http.HttpStatus;
import org.springframework.web.server.ResponseStatusException;

Expand All @@ -21,6 +22,7 @@
import studio.one.platform.ai.core.vector.visualization.ProjectionStatus;
import studio.one.platform.ai.core.vector.visualization.VectorItem;
import studio.one.platform.ai.core.vector.visualization.VectorProjection;
import studio.one.platform.ai.service.visualization.VectorProjectionCreateCommand;
import studio.one.platform.ai.service.visualization.VectorProjectionService;
import studio.one.platform.ai.service.visualization.VectorSearchVisualizationService;
import studio.one.platform.ai.web.dto.visualization.ProjectionCreateRequest;
Expand Down Expand Up @@ -48,6 +50,39 @@ void createProjectionDefaultsToPcaAndReturnsRequestedStatus() {
verify(projectionService).create(any());
}

@Test
void createProjectionAcceptsUmapAndTsneAlgorithms() {
VectorProjectionService projectionService = mock(VectorProjectionService.class);
when(projectionService.create(any())).thenReturn(projection(ProjectionStatus.REQUESTED));
VectorVisualizationMgmtController controller = new VectorVisualizationMgmtController(
projectionService,
mock(VectorSearchVisualizationService.class));
ArgumentCaptor<VectorProjectionCreateCommand> captor = ArgumentCaptor.forClass(VectorProjectionCreateCommand.class);

controller.createProjection(new ProjectionCreateRequest("UMAP map", List.of(), "UMAP", Map.of()));
controller.createProjection(new ProjectionCreateRequest("TSNE map", List.of(), "tsne", Map.of()));

verify(projectionService, org.mockito.Mockito.times(2)).create(captor.capture());
assertThat(captor.getAllValues())
.extracting(VectorProjectionCreateCommand::algorithm)
.containsExactly(ProjectionAlgorithm.UMAP, ProjectionAlgorithm.TSNE);
}

@Test
void createProjectionRejectsUnsupportedAlgorithm() {
VectorVisualizationMgmtController controller = new VectorVisualizationMgmtController(
mock(VectorProjectionService.class),
mock(VectorSearchVisualizationService.class));

assertThatThrownBy(() -> controller.createProjection(new ProjectionCreateRequest(
"bad map",
List.of(),
"MDS",
Map.of())))
.isInstanceOf(ResponseStatusException.class)
.hasMessageContaining("UNSUPPORTED_PROJECTION_ALGORITHM");
}

@Test
void pointsReturnsClientOrientedShape() {
VectorProjectionService projectionService = mock(VectorProjectionService.class);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,10 @@

import lombok.extern.slf4j.Slf4j;
import studio.one.platform.ai.core.vector.visualization.ExistingVectorItemRepository;
import studio.one.platform.ai.core.vector.visualization.PcaVectorProjectionGenerator;
import studio.one.platform.ai.core.vector.visualization.ProjectionStatus;
import studio.one.platform.ai.core.vector.visualization.TsneVectorProjectionGenerator;
import studio.one.platform.ai.core.vector.visualization.UmapVectorProjectionGenerator;
import studio.one.platform.ai.core.vector.visualization.VectorItem;
import studio.one.platform.ai.core.vector.visualization.VectorProjection;
import studio.one.platform.ai.core.vector.visualization.VectorProjectionGenerator;
Expand Down Expand Up @@ -67,12 +70,22 @@ public void run(String projectionId) {
}

private VectorProjectionGenerator generatorFor(VectorProjection projection) {
return generators.stream()
List<VectorProjectionGenerator> matching = generators.stream()
.filter(generator -> generator.algorithm() == projection.algorithm())
.toList();
return matching.stream()
.filter(generator -> !isDefaultGenerator(generator))
.findFirst()
.or(() -> matching.stream().findFirst())
.orElseThrow(() -> new IllegalArgumentException("UNSUPPORTED_PROJECTION_ALGORITHM"));
}

private boolean isDefaultGenerator(VectorProjectionGenerator generator) {
return generator instanceof PcaVectorProjectionGenerator
|| generator instanceof UmapVectorProjectionGenerator
|| generator instanceof TsneVectorProjectionGenerator;
}

private List<String> actualTargetTypes(List<VectorItem> items) {
return items.stream()
.map(VectorItem::targetType)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import org.springframework.web.server.ResponseStatusException;

import studio.one.platform.ai.core.vector.visualization.ExistingVectorItemRepository;
import studio.one.platform.ai.core.vector.visualization.PcaVectorProjectionGenerator;
import studio.one.platform.ai.core.vector.visualization.ProjectionAlgorithm;
import studio.one.platform.ai.core.vector.visualization.ProjectionPointPage;
import studio.one.platform.ai.core.vector.visualization.ProjectionPointView;
Expand Down Expand Up @@ -73,6 +74,72 @@ public List<VectorProjectionPoint> generate(String projectionId, List<VectorItem
assertThat(points.points).hasSize(1);
}

@Test
void createRunsRequestedNonPcaGenerator() {
FakeProjectionRepository projections = new FakeProjectionRepository();
FakePointRepository points = new FakePointRepository();
FakeItemRepository items = new FakeItemRepository(List.of(item("chunk-1")));
DefaultVectorProjectionJobService job = new DefaultVectorProjectionJobService(
projections,
points,
items,
List.of(generator(ProjectionAlgorithm.UMAP)));
DefaultVectorProjectionService service = new DefaultVectorProjectionService(
projections,
points,
items,
job,
Runnable::run);

VectorProjection projection = service.create(new VectorProjectionCreateCommand(
"map",
ProjectionAlgorithm.UMAP,
List.of("COURSE_CHUNK"),
Map.of(),
"tester"));

assertThat(projection.algorithm()).isEqualTo(ProjectionAlgorithm.UMAP);
VectorProjection saved = projections.findById(projection.projectionId()).orElseThrow();
assertThat(saved.status()).isEqualTo(ProjectionStatus.COMPLETED);
assertThat(saved.algorithm()).isEqualTo(ProjectionAlgorithm.UMAP);
assertThat(points.points).singleElement()
.extracting(VectorProjectionPoint::projectionId)
.isEqualTo(projection.projectionId());
}

@Test
void createPrefersCustomGeneratorWhenDefaultGeneratorHasSameAlgorithm() {
FakeProjectionRepository projections = new FakeProjectionRepository();
FakePointRepository points = new FakePointRepository();
FakeItemRepository items = new FakeItemRepository(List.of(item("chunk-1")));
DefaultVectorProjectionJobService job = new DefaultVectorProjectionJobService(
projections,
points,
items,
List.of(new PcaVectorProjectionGenerator(), generator(ProjectionAlgorithm.PCA, 0.7, 0.8)));
DefaultVectorProjectionService service = new DefaultVectorProjectionService(
projections,
points,
items,
job,
Runnable::run);

VectorProjection projection = service.create(new VectorProjectionCreateCommand(
"map",
ProjectionAlgorithm.PCA,
List.of("COURSE_CHUNK"),
Map.of(),
"tester"));

assertThat(projections.findById(projection.projectionId()).orElseThrow().status())
.isEqualTo(ProjectionStatus.COMPLETED);
assertThat(points.points).singleElement()
.satisfies(point -> {
assertThat(point.x()).isEqualTo(0.7);
assertThat(point.y()).isEqualTo(0.8);
});
}

@Test
void pointsRejectsProjectionThatIsNotCompleted() {
FakeProjectionRepository projections = new FakeProjectionRepository();
Expand Down Expand Up @@ -148,6 +215,26 @@ private static VectorItem item(String id) {
return new VectorItem(id, "COURSE_CHUNK", "course-1", "label", "text", List.of(0.1, 0.2), "model", 2, Map.of(), Instant.now());
}

private static VectorProjectionGenerator generator(ProjectionAlgorithm algorithm) {
return generator(algorithm, 0.1, 0.2);
}

private static VectorProjectionGenerator generator(ProjectionAlgorithm algorithm, double x, double y) {
return new VectorProjectionGenerator() {
@Override
public ProjectionAlgorithm algorithm() {
return algorithm;
}

@Override
public List<VectorProjectionPoint> generate(String projectionId, List<VectorItem> sourceItems, Instant createdAt) {
return sourceItems.stream()
.map(source -> new VectorProjectionPoint(projectionId, source.vectorItemId(), x, y, null, 0, createdAt))
.toList();
}
};
}

private static final class FakeProjectionRepository implements VectorProjectionRepository {
private final Map<String, VectorProjection> projections = new LinkedHashMap<>();

Expand Down
Loading
Loading