diff --git a/starter/studio-platform-starter-ai/build.gradle.kts b/starter/studio-platform-starter-ai/build.gradle.kts index 577b0fdd..6aa6d1ad 100644 --- a/starter/studio-platform-starter-ai/build.gradle.kts +++ b/starter/studio-platform-starter-ai/build.gradle.kts @@ -44,7 +44,10 @@ dependencies { testImplementation("org.mockito:mockito-core") testImplementation("org.assertj:assertj-core") testImplementation("org.springframework:spring-jdbc") + testImplementation("org.testcontainers:junit-jupiter") + testImplementation("org.testcontainers:postgresql") testImplementation("com.h2database:h2") testImplementation(project(":studio-platform")) testImplementation(project(":starter:studio-platform-starter-chunking")) + testRuntimeOnly("org.postgresql:postgresql") } diff --git a/starter/studio-platform-starter-ai/src/test/java/studio/one/platform/ai/adapters/vector/PgVectorStoreAdapterV2PostgresTest.java b/starter/studio-platform-starter-ai/src/test/java/studio/one/platform/ai/adapters/vector/PgVectorStoreAdapterV2PostgresTest.java new file mode 100644 index 00000000..c2e3c28e --- /dev/null +++ b/starter/studio-platform-starter-ai/src/test/java/studio/one/platform/ai/adapters/vector/PgVectorStoreAdapterV2PostgresTest.java @@ -0,0 +1,127 @@ +package studio.one.platform.ai.adapters.vector; + +import static org.assertj.core.api.Assertions.assertThat; + +import java.lang.reflect.Field; +import java.util.List; +import java.util.Map; +import java.util.Objects; + +import javax.xml.parsers.DocumentBuilderFactory; + +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.springframework.jdbc.core.JdbcTemplate; +import org.springframework.jdbc.datasource.DriverManagerDataSource; +import org.testcontainers.containers.PostgreSQLContainer; +import org.testcontainers.junit.jupiter.Container; +import org.testcontainers.junit.jupiter.Testcontainers; +import org.testcontainers.utility.DockerImageName; +import org.w3c.dom.Element; + +import studio.one.platform.ai.core.vector.VectorDocument; +import studio.one.platform.ai.core.vector.VectorSearchRequest; +import studio.one.platform.ai.core.vector.VectorSearchResult; + +@Testcontainers +class PgVectorStoreAdapterV2PostgresTest { + + @Container + static final PostgreSQLContainer POSTGRES = new PostgreSQLContainer<>( + DockerImageName.parse("pgvector/pgvector:pg16").asCompatibleSubstituteFor("postgres")); + + private PgVectorStoreAdapterV2 adapter; + + @BeforeEach + void setUp() throws Exception { + JdbcTemplate jdbcTemplate = new JdbcTemplate(new DriverManagerDataSource( + POSTGRES.getJdbcUrl(), + POSTGRES.getUsername(), + POSTGRES.getPassword())); + jdbcTemplate.execute("CREATE EXTENSION IF NOT EXISTS vector"); + jdbcTemplate.execute("DROP TABLE IF EXISTS tb_ai_document_chunk"); + jdbcTemplate.execute(""" + CREATE TABLE tb_ai_document_chunk ( + id BIGSERIAL PRIMARY KEY, + object_type VARCHAR(100) NOT NULL, + object_id VARCHAR(100) NOT NULL, + chunk_index INTEGER NOT NULL, + text TEXT NOT NULL, + metadata JSONB NOT NULL, + embedding vector(2) NOT NULL, + CONSTRAINT uq_test_chunk UNIQUE (object_type, object_id, chunk_index) + ) + """); + adapter = new PgVectorStoreAdapterV2(jdbcTemplate); + setField("upsertSql", sql("upsertChunk")); + setField("searchByObjectSql", sql("searchByObject")); + setField("hybridSearchByObjectSql", sql("hybridSearchByObject")); + adapter.upsert(List.of( + document("chunk-1", "attachment", "6", 0, "java backend", List.of(0.1, 0.2)), + document("chunk-2", "forums-post-attachment", "7", 0, "spring api", List.of(0.2, 0.3)))); + } + + @Test + void searchByObjectAllowsObjectTypeOnlyScopeWithNullObjectId() { + List results = adapter.searchByObject( + "attachment", + null, + new VectorSearchRequest(List.of(0.1, 0.2), 10)); + + assertThat(results).singleElement() + .extracting(result -> result.document().id()) + .isEqualTo("chunk-1"); + } + + @Test + void hybridSearchByObjectAllowsObjectTypeOnlyScopeWithNullObjectId() { + List results = adapter.hybridSearchByObject( + "java", + "attachment", + null, + new VectorSearchRequest(List.of(0.1, 0.2), 10), + 0.7, + 0.3); + + assertThat(results).singleElement() + .extracting(result -> result.document().id()) + .isEqualTo("chunk-1"); + } + + private static VectorDocument document( + String id, + String objectType, + String objectId, + int chunkIndex, + String text, + List embedding) { + return new VectorDocument(id, text, Map.of( + "objectType", objectType, + "objectId", objectId, + "chunkIndex", chunkIndex, + "chunkId", id), embedding); + } + + private void setField(String fieldName, Object value) throws Exception { + Field field = PgVectorStoreAdapterV2.class.getDeclaredField(fieldName); + field.setAccessible(true); + field.set(adapter, value); + } + + private static String sql(String id) throws Exception { + try (var input = PgVectorStoreAdapterV2PostgresTest.class.getClassLoader() + .getResourceAsStream("sql/ai-sqlset.xml")) { + var document = DocumentBuilderFactory.newInstance() + .newDocumentBuilder() + .parse(new java.io.ByteArrayInputStream(Objects.requireNonNull(input).readAllBytes())); + var nodes = document.getElementsByTagName("sql-query"); + for (int i = 0; i < nodes.getLength(); i++) { + Element element = (Element) nodes.item(i); + if (id.equals(element.getAttribute("id"))) { + return element.getTextContent().trim(); + } + } + } + throw new IllegalArgumentException("SQL not found: " + id); + } +} diff --git a/starter/studio-platform-starter-ai/src/test/java/studio/one/platform/ai/service/visualization/DefaultVectorSearchVisualizationServiceTest.java b/starter/studio-platform-starter-ai/src/test/java/studio/one/platform/ai/service/visualization/DefaultVectorSearchVisualizationServiceTest.java index de143ded..3b5dd753 100644 --- a/starter/studio-platform-starter-ai/src/test/java/studio/one/platform/ai/service/visualization/DefaultVectorSearchVisualizationServiceTest.java +++ b/starter/studio-platform-starter-ai/src/test/java/studio/one/platform/ai/service/visualization/DefaultVectorSearchVisualizationServiceTest.java @@ -2,7 +2,9 @@ import static org.assertj.core.api.Assertions.assertThat; import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.eq; import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.verify; import static org.mockito.Mockito.when; import java.time.Instant; @@ -16,8 +18,10 @@ import studio.one.platform.ai.core.embedding.EmbeddingPort; import studio.one.platform.ai.core.embedding.EmbeddingResponse; import studio.one.platform.ai.core.embedding.EmbeddingVector; +import studio.one.platform.ai.core.vector.VectorDocument; import studio.one.platform.ai.core.vector.VectorSearchHit; import studio.one.platform.ai.core.vector.VectorSearchRequest; +import studio.one.platform.ai.core.vector.VectorSearchResult; import studio.one.platform.ai.core.vector.VectorSearchResults; import studio.one.platform.ai.core.vector.VectorStorePort; import studio.one.platform.ai.core.vector.visualization.ProjectionAlgorithm; @@ -125,6 +129,39 @@ void searchUsesRowVectorItemIdWhenChunkIdIsAbsent() { .isEqualTo("row-7"); } + @Test + void searchWithTargetTypesUsesObjectTypeOnlyVectorScope() { + EmbeddingPort embeddingPort = mock(EmbeddingPort.class); + VectorStorePort vectorStorePort = mock(VectorStorePort.class); + VectorProjectionRepository projections = mock(VectorProjectionRepository.class); + VectorProjectionPointRepository points = new FakePointRepository(List.of( + new ProjectionPointView("chunk-1", "attachment", "6", "Document", 0.2, 0.4, null, Map.of()))); + when(projections.findById("proj-1")).thenReturn(Optional.of(projection())); + when(embeddingPort.embed(any())).thenReturn(new EmbeddingResponse(List.of( + new EmbeddingVector("query", List.of(0.1, 0.2))))); + when(vectorStorePort.searchByObject(eq("attachment"), eq(null), any(VectorSearchRequest.class))) + .thenReturn(List.of(new VectorSearchResult( + new VectorDocument("chunk-1", "stored chunk", Map.of("chunkId", "chunk-1"), List.of()), + 0.9))); + DefaultVectorSearchVisualizationService service = new DefaultVectorSearchVisualizationService( + embeddingPort, + vectorStorePort, + projections, + points); + + VectorSearchVisualizationResult result = service.search(new VectorSearchVisualizationCommand( + "proj-1", + "java", + List.of("attachment"), + 10, + null)); + + assertThat(result.results()).singleElement() + .extracting(VectorSearchVisualizationResult.ResultPoint::vectorItemId) + .isEqualTo("chunk-1"); + verify(vectorStorePort).searchByObject(eq("attachment"), eq(null), any(VectorSearchRequest.class)); + } + private VectorProjection projection() { return new VectorProjection( "proj-1", diff --git a/studio-platform-ai/src/main/resources/sql/ai-sqlset.xml b/studio-platform-ai/src/main/resources/sql/ai-sqlset.xml index c648f0bf..98f3da41 100644 --- a/studio-platform-ai/src/main/resources/sql/ai-sqlset.xml +++ b/studio-platform-ai/src/main/resources/sql/ai-sqlset.xml @@ -30,8 +30,8 @@ :vector) AS distance FROM tb_ai_document_chunk - WHERE (:objectType IS NULL OR object_type = :objectType) - AND (:objectId IS NULL OR object_id = :objectId) + WHERE (CAST(:objectType AS varchar) IS NULL OR object_type = CAST(:objectType AS varchar)) + AND (CAST(:objectId AS varchar) IS NULL OR object_id = CAST(:objectId AS varchar)) ORDER BY embedding <-> :vector ASC LIMIT :limit ]]> @@ -56,8 +56,8 @@ ts_rank_cd(to_tsvector('simple', text || ' ' || COALESCE(metadata->>'keywordsText','')), plainto_tsquery(:query)) AS bm25, ((embedding <-> :vector) * :vectorWeight) - (COALESCE(ts_rank_cd(to_tsvector('simple', text || ' ' || COALESCE(metadata->>'keywordsText','')), plainto_tsquery(:query)),0) * :lexicalWeight) AS hybrid FROM tb_ai_document_chunk - WHERE (:objectType IS NULL OR object_type = :objectType) - AND (:objectId IS NULL OR object_id = :objectId) + WHERE (CAST(:objectType AS varchar) IS NULL OR object_type = CAST(:objectType AS varchar)) + AND (CAST(:objectId AS varchar) IS NULL OR object_id = CAST(:objectId AS varchar)) ORDER BY hybrid ASC LIMIT :limit ]]> diff --git a/studio-platform-ai/src/test/java/studio/one/platform/ai/core/vector/VectorSqlSetContractTest.java b/studio-platform-ai/src/test/java/studio/one/platform/ai/core/vector/VectorSqlSetContractTest.java new file mode 100644 index 00000000..e3260301 --- /dev/null +++ b/studio-platform-ai/src/test/java/studio/one/platform/ai/core/vector/VectorSqlSetContractTest.java @@ -0,0 +1,28 @@ +package studio.one.platform.ai.core.vector; + +import static org.assertj.core.api.Assertions.assertThat; + +import java.io.IOException; +import java.nio.charset.StandardCharsets; +import java.util.Objects; + +import org.junit.jupiter.api.Test; + +class VectorSqlSetContractTest { + + @Test + void objectScopedSearchCastsNullableScopeParametersForPostgres() throws IOException { + String sqlset = new String( + Objects.requireNonNull(getClass().getClassLoader().getResourceAsStream("sql/ai-sqlset.xml")) + .readAllBytes(), + StandardCharsets.UTF_8); + + assertThat(sqlset) + .contains("sql-query id=\"searchByObject\"") + .contains("sql-query id=\"hybridSearchByObject\""); + assertThat(sqlset.split("CAST\\(:objectType AS varchar\\) IS NULL OR object_type = CAST\\(:objectType AS varchar\\)", -1)) + .hasSize(3); + assertThat(sqlset.split("CAST\\(:objectId AS varchar\\) IS NULL OR object_id = CAST\\(:objectId AS varchar\\)", -1)) + .hasSize(3); + } +}