From 1d8a40dcab17d6bc612cfbb4f3ebb6492da2affe Mon Sep 17 00:00:00 2001 From: Cursor Agent Date: Mon, 20 Apr 2026 15:34:47 +0000 Subject: [PATCH] Improve RAG retrieval path and metadata-aware expansion Co-authored-by: jwlee --- app/graphs/crag/nodes.py | 51 +++++++++++++++++++----- app/rag/chunkers/recursive_chunker.py | 35 ++++++++++++---- app/rag/pipelines/ingest_pipeline.py | 3 +- app/rag/policies/expansion_policy.py | 9 ++++- app/storage/vectorstores/qdrant_store.py | 16 ++++++-- tests/unit/test_expansion_policy.py | 20 ++++++++++ tests/unit/test_retrievers.py | 7 ++++ 7 files changed, 115 insertions(+), 26 deletions(-) diff --git a/app/graphs/crag/nodes.py b/app/graphs/crag/nodes.py index 33bafa3..68e59fa 100644 --- a/app/graphs/crag/nodes.py +++ b/app/graphs/crag/nodes.py @@ -13,7 +13,6 @@ from app.core.llm_io_log import log_llm_io from app.graphs.crag.state import CRAGState from app.providers.llm_provider import get_llm -from app.providers.vectorstore_provider import get_vectorstore logger = logging.getLogger(__name__) @@ -81,12 +80,16 @@ def rewrite_query(state: CRAGState) -> dict: def hybrid_retrieve(state: CRAGState) -> dict: - """Retrieve relevant child chunks using hybrid (vector + BM25) search. + """Retrieve relevant child chunks with multi-query RRF fusion. - When ``multi_query`` mode is enabled via env var ``MULTI_QUERY=1``, - generates additional query variants and fuses the retrieval results. + Uses ``hybrid_retriever.hybrid_retrieve`` per query variant and then fuses + across variants with RRF to improve recall while keeping deterministic ranking. """ from app.core.runtime_config import get_max_retrieval_docs, get_multi_query_enabled + from app.rag.retrievers.hybrid_retriever import ( + hybrid_retrieve as run_hybrid, + ) + from app.rag.retrievers.hybrid_retriever import reciprocal_rank_fusion query = _active_query(state) attempt = state.get("retrieval_attempt", 0) + 1 @@ -99,13 +102,13 @@ def hybrid_retrieve(state: CRAGState) -> dict: queries = generate_multi_query(query, n=3) logger.info("Multi-query: %d variants", len(queries)) - store = get_vectorstore() - seen: dict[str, dict] = {} + retrieval_k = max(10, get_max_retrieval_docs() * 4) + ranked_lists: list[list[dict]] = [] for q in queries: - for r in store.search(q, top_k=get_max_retrieval_docs()): - seen.setdefault(r["text"], r) + ranked_lists.append(run_hybrid(q, top_k=retrieval_k)) - children = list(seen.values()) + fused = reciprocal_rank_fusion(ranked_lists) + children = fused[:retrieval_k] return {"retrieved_children": children, "retrieval_attempt": attempt} @@ -113,10 +116,36 @@ def hybrid_retrieve(state: CRAGState) -> dict: def expand_context(state: CRAGState) -> dict: - """Expand child hits to parent / larger chunks (small-to-big strategy).""" + """Expand child hits to parent chunks when ``parent_text`` is available.""" children = state.get("retrieved_children", []) logger.info("Expanding %d child chunks to parent context", len(children)) - return {"expanded_contexts": children} + + expanded: list[dict] = [] + seen_texts: set[str] = set() + for child in children: + if not isinstance(child, dict): + text = str(child) + if text and text not in seen_texts: + seen_texts.add(text) + expanded.append({"text": text, "score": 0.0, "metadata": {}}) + continue + + metadata = dict(child.get("metadata", {})) + parent_text = metadata.get("parent_text") + text = str(parent_text or child.get("text", "")) + if not text or text in seen_texts: + continue + + seen_texts.add(text) + expanded.append( + { + "text": text, + "score": child.get("score", 0.0), + "metadata": metadata, + } + ) + + return {"expanded_contexts": expanded or children} # ── rerank_context ──────────────────────────────────────────────────────────── diff --git a/app/rag/chunkers/recursive_chunker.py b/app/rag/chunkers/recursive_chunker.py index c08cd4a..72b1e47 100644 --- a/app/rag/chunkers/recursive_chunker.py +++ b/app/rag/chunkers/recursive_chunker.py @@ -2,18 +2,20 @@ from __future__ import annotations +from typing import Any + from langchain_text_splitters import RecursiveCharacterTextSplitter def recursive_chunk( - texts: list[str], + docs: list[str] | list[dict[str, Any]], chunk_size: int = 512, chunk_overlap: int = 64, -) -> list[dict]: - """Split texts using LangChain's ``RecursiveCharacterTextSplitter``. +) -> list[dict[str, Any]]: + """Split documents using LangChain's ``RecursiveCharacterTextSplitter``. Args: - texts: List of raw document strings. + docs: Either raw text strings or ``{text, metadata}`` dicts. chunk_size: Maximum characters per chunk. chunk_overlap: Character overlap between adjacent chunks. @@ -24,8 +26,25 @@ def recursive_chunk( chunk_size=chunk_size, chunk_overlap=chunk_overlap, ) - chunks = [] - for text in texts: - for chunk in splitter.split_text(text): - chunks.append({"text": chunk, "metadata": {}}) + + chunks: list[dict[str, Any]] = [] + for doc_idx, item in enumerate(docs): + if isinstance(item, dict): + text = str(item.get("text", "")) + base_meta = dict(item.get("metadata", {})) + else: + text = str(item) + base_meta = {} + + if not text: + continue + + split_chunks = splitter.split_text(text) + for chunk_idx, chunk in enumerate(split_chunks): + meta = { + **base_meta, + "doc_index": base_meta.get("doc_index", doc_idx), + "chunk_index": chunk_idx, + } + chunks.append({"text": chunk, "metadata": meta}) return chunks diff --git a/app/rag/pipelines/ingest_pipeline.py b/app/rag/pipelines/ingest_pipeline.py index 0406314..3225f03 100644 --- a/app/rag/pipelines/ingest_pipeline.py +++ b/app/rag/pipelines/ingest_pipeline.py @@ -41,8 +41,7 @@ def run_ingest( enriched = [extract_metadata(d) for d in cleaned] unique = dedup_documents(enriched) - texts = [d["text"] for d in unique] - chunks = recursive_chunk(texts, chunk_size=chunk_size, chunk_overlap=chunk_overlap) + chunks = recursive_chunk(unique, chunk_size=chunk_size, chunk_overlap=chunk_overlap) if not chunks: logger.warning("No chunks produced from %d docs", len(raw_docs)) diff --git a/app/rag/policies/expansion_policy.py b/app/rag/policies/expansion_policy.py index 8f97991..a96ccf1 100644 --- a/app/rag/policies/expansion_policy.py +++ b/app/rag/policies/expansion_policy.py @@ -9,6 +9,13 @@ from typing import Any +def _child_text_len(child: Any) -> int: + """Return text length for child item regardless of shape.""" + if isinstance(child, dict): + return len(str(child.get("text", ""))) + return len(str(child)) + + def should_expand(state: dict[str, Any]) -> bool: """Determine whether retrieved child chunks should be expanded. @@ -27,6 +34,6 @@ def should_expand(state: dict[str, Any]) -> bool: if not children: return False - avg_len = sum(len(c) for c in children) / len(children) + avg_len = sum(_child_text_len(c) for c in children) / len(children) # Expand if average child chunk is shorter than 300 characters return avg_len < 300 diff --git a/app/storage/vectorstores/qdrant_store.py b/app/storage/vectorstores/qdrant_store.py index 05b1093..0517339 100644 --- a/app/storage/vectorstores/qdrant_store.py +++ b/app/storage/vectorstores/qdrant_store.py @@ -67,9 +67,17 @@ def embed_query(self, text: str) -> list[float]: return self.embed_texts([text])[0] @staticmethod - def _deterministic_id(text: str) -> str: - """Generate a deterministic UUID-v5 from text content for deduplication.""" - return str(uuid.uuid5(uuid.NAMESPACE_DNS, text)) + def _deterministic_id(text: str, metadata: dict[str, Any]) -> str: + """Generate a stable UUID-v5 using text + provenance metadata.""" + key_parts = [ + text, + str(metadata.get("source", "")), + str(metadata.get("doc_id", "")), + str(metadata.get("parent_id", "")), + str(metadata.get("chunk_index", "")), + ] + key = "|".join(key_parts) + return str(uuid.uuid5(uuid.NAMESPACE_DNS, key)) # ── VectorStorePort interface ───────────────────────────────────────────── @@ -105,7 +113,7 @@ def add_documents( points = [ PointStruct( - id=self._deterministic_id(text), + id=self._deterministic_id(text, meta), vector=vec, payload={"text": text, **meta}, ) diff --git a/tests/unit/test_expansion_policy.py b/tests/unit/test_expansion_policy.py index 7e749bf..538ce6e 100644 --- a/tests/unit/test_expansion_policy.py +++ b/tests/unit/test_expansion_policy.py @@ -19,3 +19,23 @@ def test_long_chunks_no_expansion(): long_chunk = "x" * 500 state = {"retrieved_children": [long_chunk, long_chunk]} assert should_expand(state) is False + + +def test_dict_children_uses_text_length(): + state = { + "retrieved_children": [ + {"text": "x" * 20, "metadata": {}}, + {"text": "y" * 25, "metadata": {}}, + ] + } + assert should_expand(state) is True + + +def test_dict_children_long_text_no_expand(): + state = { + "retrieved_children": [ + {"text": "x" * 400, "metadata": {}}, + {"text": "y" * 500, "metadata": {}}, + ] + } + assert should_expand(state) is False diff --git a/tests/unit/test_retrievers.py b/tests/unit/test_retrievers.py index ed90a89..8540970 100644 --- a/tests/unit/test_retrievers.py +++ b/tests/unit/test_retrievers.py @@ -23,3 +23,10 @@ def test_rrf_single_list(): def test_rrf_empty_input(): assert reciprocal_rank_fusion([]) == [] + + +def test_rrf_prefers_consensus_doc(): + list_a = [{"text": "consensus", "score": 0.9}, {"text": "a_only", "score": 0.8}] + list_b = [{"text": "consensus", "score": 0.7}, {"text": "b_only", "score": 0.6}] + merged = reciprocal_rank_fusion([list_a, list_b]) + assert merged[0]["text"] == "consensus"