Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 34 additions & 0 deletions starter/studio-platform-starter-ai-web/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -448,6 +448,40 @@ RAG context는 이슈 #202부터 설정된 chunk 수와 문자 수를 넘지 않
retrieval hit content만 사용한다. 확장 후에도 아래 `max-chunks`, `max-chars`, `include-scores`
설정은 그대로 적용된다.

RAG Chat 응답은 실제 답변 생성 프롬프트에 포함된 근거를 `metadata.ragReferences` 배열로 반환한다.
순서는 system context의 `[1]`, `[2]` 순서와 같으며, context expansion 또는 fallback이 적용된 경우에도
최종 프롬프트에 들어간 content를 기준으로 한다. 클라이언트는 별도 management search를 다시 호출하지 않고
이 값으로 출처 UI를 구성할 수 있다.

```json
{
"metadata": {
"ragReferences": [
{
"index": 1,
"documentId": "3",
"sourceName": "sample.pdf",
"chunkId": "chunk-1",
"chunkOrder": 0,
"score": 0.91,
"content": "프롬프트에 포함된 근거 본문",
"page": 3,
"pageNumber": 3,
"sourceRef": "page[3]",
"metadata": {
"objectType": "attachment",
"objectId": "3"
}
}
]
}
}
```

`sourceName`은 `sourceName`, `title`, `filename`, `fileName`, `name` metadata 순서로 선택한다.
위치 정보는 metadata에 있으면 `page`/`pageNumber`, `slide`/`slideNumber`, `sourceRef`,
`section`, `heading`으로 함께 내려간다.

```yaml
studio:
ai:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
import java.util.ArrayList;
import java.util.HashMap;
import java.util.Iterator;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Locale;
import java.util.Map;
Expand Down Expand Up @@ -71,6 +72,7 @@
import studio.one.platform.ai.core.rag.RagRetrievalDiagnostics;
import studio.one.platform.ai.core.rag.RagSearchRequest;
import studio.one.platform.ai.core.rag.RagSearchResult;
import studio.one.platform.ai.core.vector.VectorRecord;
import studio.one.platform.ai.autoconfigure.AiWebRagProperties;
import studio.one.platform.ai.service.pipeline.RagPipelineOptions;
import studio.one.platform.ai.service.pipeline.RagPipelineService;
Expand Down Expand Up @@ -441,6 +443,7 @@ private ResponseEntity<ApiResponse<ChatResponseDto>> chatWithRagInternal(
appendConversation(principal, memory, chat.messages().stream().map(this::toDomainMessage).toList(), response);
boolean exposeDiagnostics = shouldExposeDiagnostics(request);
Map<String, Object> extraMetadata = memoryMetadata(memory, memoryMessageCount);
extraMetadata.put("ragReferences", ragReferences(contextResult.usedResults()));
if (exposeDiagnostics && contextResult.diagnostics() != null) {
extraMetadata.put("ragContextDiagnostics", contextResult.diagnostics().toMetadata());
}
Expand Down Expand Up @@ -690,6 +693,101 @@ private ChatResponseDto toDto(
return new ChatResponseDto(messages, response.model(), metadata);
}

private List<Map<String, Object>> ragReferences(List<RagSearchResult> results) {
if (results == null || results.isEmpty()) {
return List.of();
}
List<Map<String, Object>> references = new ArrayList<>(results.size());
for (int i = 0; i < results.size(); i++) {
references.add(ragReference(i + 1, results.get(i)));
}
return List.copyOf(references);
}

private Map<String, Object> ragReference(int index, RagSearchResult result) {
Map<String, Object> metadata = result.metadata() == null ? Map.of() : result.metadata();
Map<String, Object> reference = new LinkedHashMap<>();
reference.put("index", index);
String documentId = firstText(metadata,
VectorRecord.KEY_DOCUMENT_ID,
"documentId",
"sourceDocumentId");
if (documentId == null) {
documentId = result.documentId();
}
put(reference, "documentId", documentId);
String sourceName = firstText(metadata, "sourceName", "title", "filename", "fileName", "name");
put(reference, "sourceName", sourceName == null ? documentId : sourceName);
String chunkId = firstText(metadata, VectorRecord.KEY_CHUNK_ID, "chunkId");
put(reference, "chunkId", chunkId == null ? documentId : chunkId);
put(reference, "chunkOrder", firstInteger(metadata, "chunkOrder", VectorRecord.KEY_CHUNK_INDEX));
put(reference, "score", result.score());
put(reference, "content", result.content());
Integer page = firstInteger(metadata, VectorRecord.KEY_PAGE, "page", "pageNumber");
if (page != null) {
reference.put("page", page);
reference.put("pageNumber", page);
}
Integer slide = firstInteger(metadata, VectorRecord.KEY_SLIDE, "slide", "slideNumber");
if (slide != null) {
reference.put("slide", slide);
reference.put("slideNumber", slide);
}
put(reference, "sourceRef", firstText(metadata, VectorRecord.KEY_SOURCE_REF, "sourceRef", "sourceRefs"));
put(reference, "section", firstText(metadata, "section", VectorRecord.KEY_HEADING_PATH, "headingPath"));
put(reference, "heading", firstText(metadata, "heading", VectorRecord.KEY_HEADING_PATH, "headingPath"));
if (!metadata.isEmpty()) {
reference.put("metadata", metadata);
}
return Map.copyOf(reference);
}

private void put(Map<String, Object> target, String key, Object value) {
if (value != null) {
target.put(key, value);
}
}

private String firstText(Map<String, Object> metadata, String... keys) {
for (String key : keys) {
Object value = metadata.get(key);
if (value instanceof String text && !text.isBlank()) {
return text.trim();
}
if (value != null && !(value instanceof String)) {
String text = value.toString();
if (!text.isBlank()) {
return text.trim();
}
}
}
return null;
}

private Integer firstInteger(Map<String, Object> metadata, String... keys) {
for (String key : keys) {
Integer value = integer(metadata.get(key));
if (value != null) {
return value;
}
}
return null;
}

private Integer integer(Object value) {
if (value instanceof Number number) {
return number.intValue();
}
if (value instanceof String text && !text.isBlank()) {
try {
return Integer.parseInt(text.trim());
} catch (NumberFormatException ignored) {
return null;
}
}
return null;
}

private boolean shouldExposeDiagnostics(ChatRagRequestDto request) {
return allowClientDebug && Boolean.TRUE.equals(request.debug());
}
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
package studio.one.platform.ai.web.controller;

import java.util.ArrayList;
import java.util.Comparator;
import java.util.LinkedHashMap;
import java.util.List;
Expand Down Expand Up @@ -107,6 +108,7 @@ public BuildResult buildWithDiagnostics(List<RagSearchResult> results, List<RagS
String strategy = null;
String fallbackReason = expansionSupported ? null : "disabled";
boolean contextLimitHit = false;
List<RagSearchResult> usedResults = new ArrayList<>();
for (int i = 0; i < count; i++) {
RagSearchResult original = results.get(i);
ExpansionAttempt attempt = expandResultWithDiagnostics(original, expansionCandidates);
Expand All @@ -122,13 +124,15 @@ public BuildResult buildWithDiagnostics(List<RagSearchResult> results, List<RagS
fallbackReason = "context_limit";
break;
}
usedResults.add(original);
fallbackHitCount++;
fallbackReason = "context_limit";
if (strategy == null) {
strategy = attempt.strategy();
}
continue;
}
usedResults.add(attempt.result());
if (attempt.expanded()) {
expandedHitCount++;
if (strategy == null) {
Expand All @@ -155,7 +159,8 @@ public BuildResult buildWithDiagnostics(List<RagSearchResult> results, List<RagS
fallbackHitCount,
candidateCount,
resultCount,
fallbackReason));
fallbackReason),
usedResults);
}

private boolean appendWithinLimit(StringBuilder sb, String chunk) {
Expand Down Expand Up @@ -339,7 +344,15 @@ private boolean hasText(String value) {
return value != null && !value.isBlank();
}

public record BuildResult(String context, Diagnostics diagnostics) {
public record BuildResult(String context, Diagnostics diagnostics, List<RagSearchResult> usedResults) {

public BuildResult(String context, Diagnostics diagnostics) {
this(context, diagnostics, List.of());
}

public BuildResult {
usedResults = usedResults == null ? List.of() : List.copyOf(usedResults);
}
}

public record Diagnostics(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -602,6 +602,54 @@ void ragChatAddsContextAndClientSystemPromptAndSearchesByObject() {
assertThat(chatCaptor.getValue().messages().get(1).role().name()).isEqualTo("USER");
}

@Test
@SuppressWarnings("unchecked")
void ragChatReturnsReferencesForPromptContext() {
when(ragPipelineService.search(any(RagSearchRequest.class)))
.thenReturn(List.of(new RagSearchResult(
"doc-1",
"file text",
Map.of(
"sourceName", "sample.pdf",
RagContextBuilder.KEY_CHUNK_ID, "chunk-1",
ChunkMetadata.KEY_CHUNK_ORDER, 7,
"page", 3,
"sourceRef", "page[3]"),
0.9d)));

ChatResponseDto response = controller.chatWithRag(new ChatRagRequestDto(
new ChatRequestDto(
null,
null,
List.of(new ChatMessageDto("user", "summarize")),
null,
null,
null,
null,
null,
null),
"summary",
3,
"attachment",
"123")).getBody().getData();

List<Map<String, Object>> references = (List<Map<String, Object>>) response.metadata().get("ragReferences");
assertThat(references).hasSize(1);
assertThat(references.get(0))
.containsEntry("index", 1)
.containsEntry("documentId", "doc-1")
.containsEntry("sourceName", "sample.pdf")
.containsEntry("chunkId", "chunk-1")
.containsEntry("chunkOrder", 7)
.containsEntry("score", 0.9d)
.containsEntry("content", "file text")
.containsEntry("page", 3)
.containsEntry("pageNumber", 3)
.containsEntry("sourceRef", "page[3]");
assertThat((Map<String, Object>) references.get(0).get("metadata"))
.containsEntry("sourceName", "sample.pdf");
}

@Test
void ragChatAllowsNonAttachmentObjectScope() {
ArgumentCaptor<RagSearchRequest> ragCaptor = ArgumentCaptor.forClass(RagSearchRequest.class);
Expand Down Expand Up @@ -1162,6 +1210,11 @@ void ragChatFallsBackToSearchChunkWhenExpandedContextExceedsLimit() {
.containsEntry("applied", false)
.containsEntry("fallbackReason", "context_limit")
.containsEntry("fallbackHitCount", 1);
List<Map<String, Object>> references = (List<Map<String, Object>>) response.metadata().get("ragReferences");
assertThat(references).hasSize(1);
assertThat(references.get(0))
.containsEntry("content", "seed body")
.containsEntry("chunkId", "chunk-2");
}

@Test
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,31 @@ void fallsBackToOriginalChunkWhenExpandedContextExceedsCharacterBudget() {
.containsEntry("expandedHitCount", 0)
.containsEntry("fallbackHitCount", 1)
.containsEntry("fallbackReason", "context_limit");
assertThat(result.usedResults())
.extracting(RagSearchResult::content)
.containsExactly("seed");
}

@Test
void reportsExpandedResultsInPromptOrder() {
RagContextBuilder builder = new RagContextBuilder(8, 12_000, true, TestWindowChunkContextExpander.asList());

RagContextBuilder.BuildResult result = builder.buildWithDiagnostics(
List.of(
result("chunk-2", "seed", metadata("chunk-2")),
result("chunk-4", "tail", metadata("chunk-4", "chunk-3", null, 3))),
List.of(
result("chunk-1", "previous", metadata("chunk-1", null, "chunk-2", 0)),
result("chunk-2", "seed", metadata("chunk-2", "chunk-1", "chunk-3", 1)),
result("chunk-3", "next", metadata("chunk-3", "chunk-2", "chunk-4", 2)),
result("chunk-4", "tail", metadata("chunk-4", "chunk-3", null, 3))));

assertThat(result.context())
.contains("[1] docId=chunk-2")
.contains("[2] docId=chunk-4");
assertThat(result.usedResults())
.extracting(RagSearchResult::content)
.containsExactly("previous\nseed\nnext\ntail", "previous\nseed\nnext\ntail");
}

@Test
Expand Down
Loading