Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions starter/studio-platform-starter-ai/build.gradle.kts
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,10 @@ dependencies {
testImplementation("org.mockito:mockito-core")
testImplementation("org.assertj:assertj-core")
testImplementation("org.springframework:spring-jdbc")
testImplementation("org.testcontainers:junit-jupiter")
testImplementation("org.testcontainers:postgresql")
testImplementation("com.h2database:h2")
testImplementation(project(":studio-platform"))
testImplementation(project(":starter:studio-platform-starter-chunking"))
testRuntimeOnly("org.postgresql:postgresql")
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,127 @@
package studio.one.platform.ai.adapters.vector;

import static org.assertj.core.api.Assertions.assertThat;

import java.lang.reflect.Field;
import java.util.List;
import java.util.Map;
import java.util.Objects;

import javax.xml.parsers.DocumentBuilderFactory;

import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.springframework.jdbc.core.JdbcTemplate;
import org.springframework.jdbc.datasource.DriverManagerDataSource;
import org.testcontainers.containers.PostgreSQLContainer;
import org.testcontainers.junit.jupiter.Container;
import org.testcontainers.junit.jupiter.Testcontainers;
import org.testcontainers.utility.DockerImageName;
import org.w3c.dom.Element;

import studio.one.platform.ai.core.vector.VectorDocument;
import studio.one.platform.ai.core.vector.VectorSearchRequest;
import studio.one.platform.ai.core.vector.VectorSearchResult;

@Testcontainers
class PgVectorStoreAdapterV2PostgresTest {

@Container
static final PostgreSQLContainer<?> POSTGRES = new PostgreSQLContainer<>(
DockerImageName.parse("pgvector/pgvector:pg16").asCompatibleSubstituteFor("postgres"));

private PgVectorStoreAdapterV2 adapter;

@BeforeEach
void setUp() throws Exception {
JdbcTemplate jdbcTemplate = new JdbcTemplate(new DriverManagerDataSource(
POSTGRES.getJdbcUrl(),
POSTGRES.getUsername(),
POSTGRES.getPassword()));
jdbcTemplate.execute("CREATE EXTENSION IF NOT EXISTS vector");
jdbcTemplate.execute("DROP TABLE IF EXISTS tb_ai_document_chunk");
jdbcTemplate.execute("""
CREATE TABLE tb_ai_document_chunk (
id BIGSERIAL PRIMARY KEY,
object_type VARCHAR(100) NOT NULL,
object_id VARCHAR(100) NOT NULL,
chunk_index INTEGER NOT NULL,
text TEXT NOT NULL,
metadata JSONB NOT NULL,
embedding vector(2) NOT NULL,
CONSTRAINT uq_test_chunk UNIQUE (object_type, object_id, chunk_index)
)
""");
adapter = new PgVectorStoreAdapterV2(jdbcTemplate);
setField("upsertSql", sql("upsertChunk"));
setField("searchByObjectSql", sql("searchByObject"));
setField("hybridSearchByObjectSql", sql("hybridSearchByObject"));
adapter.upsert(List.of(
document("chunk-1", "attachment", "6", 0, "java backend", List.of(0.1, 0.2)),
document("chunk-2", "forums-post-attachment", "7", 0, "spring api", List.of(0.2, 0.3))));
}

@Test
void searchByObjectAllowsObjectTypeOnlyScopeWithNullObjectId() {
List<VectorSearchResult> results = adapter.searchByObject(
"attachment",
null,
new VectorSearchRequest(List.of(0.1, 0.2), 10));

assertThat(results).singleElement()
.extracting(result -> result.document().id())
.isEqualTo("chunk-1");
}

@Test
void hybridSearchByObjectAllowsObjectTypeOnlyScopeWithNullObjectId() {
List<VectorSearchResult> results = adapter.hybridSearchByObject(
"java",
"attachment",
null,
new VectorSearchRequest(List.of(0.1, 0.2), 10),
0.7,
0.3);

assertThat(results).singleElement()
.extracting(result -> result.document().id())
.isEqualTo("chunk-1");
}

private static VectorDocument document(
String id,
String objectType,
String objectId,
int chunkIndex,
String text,
List<Double> embedding) {
return new VectorDocument(id, text, Map.of(
"objectType", objectType,
"objectId", objectId,
"chunkIndex", chunkIndex,
"chunkId", id), embedding);
}

private void setField(String fieldName, Object value) throws Exception {
Field field = PgVectorStoreAdapterV2.class.getDeclaredField(fieldName);
field.setAccessible(true);
field.set(adapter, value);
}

private static String sql(String id) throws Exception {
try (var input = PgVectorStoreAdapterV2PostgresTest.class.getClassLoader()
.getResourceAsStream("sql/ai-sqlset.xml")) {
var document = DocumentBuilderFactory.newInstance()
.newDocumentBuilder()
.parse(new java.io.ByteArrayInputStream(Objects.requireNonNull(input).readAllBytes()));
var nodes = document.getElementsByTagName("sql-query");
for (int i = 0; i < nodes.getLength(); i++) {
Element element = (Element) nodes.item(i);
if (id.equals(element.getAttribute("id"))) {
return element.getTextContent().trim();
}
}
}
throw new IllegalArgumentException("SQL not found: " + id);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,9 @@

import static org.assertj.core.api.Assertions.assertThat;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.eq;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when;

import java.time.Instant;
Expand All @@ -16,8 +18,10 @@
import studio.one.platform.ai.core.embedding.EmbeddingPort;
import studio.one.platform.ai.core.embedding.EmbeddingResponse;
import studio.one.platform.ai.core.embedding.EmbeddingVector;
import studio.one.platform.ai.core.vector.VectorDocument;
import studio.one.platform.ai.core.vector.VectorSearchHit;
import studio.one.platform.ai.core.vector.VectorSearchRequest;
import studio.one.platform.ai.core.vector.VectorSearchResult;
import studio.one.platform.ai.core.vector.VectorSearchResults;
import studio.one.platform.ai.core.vector.VectorStorePort;
import studio.one.platform.ai.core.vector.visualization.ProjectionAlgorithm;
Expand Down Expand Up @@ -125,6 +129,39 @@ void searchUsesRowVectorItemIdWhenChunkIdIsAbsent() {
.isEqualTo("row-7");
}

@Test
void searchWithTargetTypesUsesObjectTypeOnlyVectorScope() {
EmbeddingPort embeddingPort = mock(EmbeddingPort.class);
VectorStorePort vectorStorePort = mock(VectorStorePort.class);
VectorProjectionRepository projections = mock(VectorProjectionRepository.class);
VectorProjectionPointRepository points = new FakePointRepository(List.of(
new ProjectionPointView("chunk-1", "attachment", "6", "Document", 0.2, 0.4, null, Map.of())));
when(projections.findById("proj-1")).thenReturn(Optional.of(projection()));
when(embeddingPort.embed(any())).thenReturn(new EmbeddingResponse(List.of(
new EmbeddingVector("query", List.of(0.1, 0.2)))));
when(vectorStorePort.searchByObject(eq("attachment"), eq(null), any(VectorSearchRequest.class)))
.thenReturn(List.of(new VectorSearchResult(
new VectorDocument("chunk-1", "stored chunk", Map.of("chunkId", "chunk-1"), List.of()),
0.9)));
DefaultVectorSearchVisualizationService service = new DefaultVectorSearchVisualizationService(
embeddingPort,
vectorStorePort,
projections,
points);

VectorSearchVisualizationResult result = service.search(new VectorSearchVisualizationCommand(
"proj-1",
"java",
List.of("attachment"),
10,
null));

assertThat(result.results()).singleElement()
.extracting(VectorSearchVisualizationResult.ResultPoint::vectorItemId)
.isEqualTo("chunk-1");
verify(vectorStorePort).searchByObject(eq("attachment"), eq(null), any(VectorSearchRequest.class));
}

private VectorProjection projection() {
return new VectorProjection(
"proj-1",
Expand Down
8 changes: 4 additions & 4 deletions studio-platform-ai/src/main/resources/sql/ai-sqlset.xml
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,8 @@
<![CDATA[
SELECT id, object_id, text, metadata, (embedding <-> :vector) AS distance
FROM tb_ai_document_chunk
WHERE (:objectType IS NULL OR object_type = :objectType)
AND (:objectId IS NULL OR object_id = :objectId)
WHERE (CAST(:objectType AS varchar) IS NULL OR object_type = CAST(:objectType AS varchar))
AND (CAST(:objectId AS varchar) IS NULL OR object_id = CAST(:objectId AS varchar))
ORDER BY embedding <-> :vector ASC
LIMIT :limit
]]>
Expand All @@ -56,8 +56,8 @@
ts_rank_cd(to_tsvector('simple', text || ' ' || COALESCE(metadata->>'keywordsText','')), plainto_tsquery(:query)) AS bm25,
((embedding <-> :vector) * :vectorWeight) - (COALESCE(ts_rank_cd(to_tsvector('simple', text || ' ' || COALESCE(metadata->>'keywordsText','')), plainto_tsquery(:query)),0) * :lexicalWeight) AS hybrid
FROM tb_ai_document_chunk
WHERE (:objectType IS NULL OR object_type = :objectType)
AND (:objectId IS NULL OR object_id = :objectId)
WHERE (CAST(:objectType AS varchar) IS NULL OR object_type = CAST(:objectType AS varchar))
AND (CAST(:objectId AS varchar) IS NULL OR object_id = CAST(:objectId AS varchar))
ORDER BY hybrid ASC
LIMIT :limit
]]>
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
package studio.one.platform.ai.core.vector;

import static org.assertj.core.api.Assertions.assertThat;

import java.io.IOException;
import java.nio.charset.StandardCharsets;
import java.util.Objects;

import org.junit.jupiter.api.Test;

class VectorSqlSetContractTest {

@Test
void objectScopedSearchCastsNullableScopeParametersForPostgres() throws IOException {
String sqlset = new String(
Objects.requireNonNull(getClass().getClassLoader().getResourceAsStream("sql/ai-sqlset.xml"))
.readAllBytes(),
StandardCharsets.UTF_8);

assertThat(sqlset)
.contains("sql-query id=\"searchByObject\"")
.contains("sql-query id=\"hybridSearchByObject\"");
assertThat(sqlset.split("CAST\\(:objectType AS varchar\\) IS NULL OR object_type = CAST\\(:objectType AS varchar\\)", -1))
.hasSize(3);
assertThat(sqlset.split("CAST\\(:objectId AS varchar\\) IS NULL OR object_id = CAST\\(:objectId AS varchar\\)", -1))
.hasSize(3);
}
}
Loading