-
Notifications
You must be signed in to change notification settings - Fork 0
Improve retrieval path quality and metadata-aware expansion #7
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: cursor/advanced-rag-refactor-24ca
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||
|---|---|---|---|---|---|---|
|
|
@@ -13,7 +13,6 @@ | |||||
| from app.core.llm_io_log import log_llm_io | ||||||
| from app.graphs.crag.state import CRAGState | ||||||
| from app.providers.llm_provider import get_llm | ||||||
| from app.providers.vectorstore_provider import get_vectorstore | ||||||
|
|
||||||
| logger = logging.getLogger(__name__) | ||||||
|
|
||||||
|
|
@@ -81,12 +80,16 @@ def rewrite_query(state: CRAGState) -> dict: | |||||
|
|
||||||
|
|
||||||
| def hybrid_retrieve(state: CRAGState) -> dict: | ||||||
| """Retrieve relevant child chunks using hybrid (vector + BM25) search. | ||||||
| """Retrieve relevant child chunks with multi-query RRF fusion. | ||||||
|
|
||||||
| When ``multi_query`` mode is enabled via env var ``MULTI_QUERY=1``, | ||||||
| generates additional query variants and fuses the retrieval results. | ||||||
| Uses ``hybrid_retriever.hybrid_retrieve`` per query variant and then fuses | ||||||
| across variants with RRF to improve recall while keeping deterministic ranking. | ||||||
| """ | ||||||
| from app.core.runtime_config import get_max_retrieval_docs, get_multi_query_enabled | ||||||
| from app.rag.retrievers.hybrid_retriever import ( | ||||||
| hybrid_retrieve as run_hybrid, | ||||||
| ) | ||||||
| from app.rag.retrievers.hybrid_retriever import reciprocal_rank_fusion | ||||||
|
|
||||||
| query = _active_query(state) | ||||||
| attempt = state.get("retrieval_attempt", 0) + 1 | ||||||
|
|
@@ -99,24 +102,50 @@ def hybrid_retrieve(state: CRAGState) -> dict: | |||||
| queries = generate_multi_query(query, n=3) | ||||||
| logger.info("Multi-query: %d variants", len(queries)) | ||||||
|
|
||||||
| store = get_vectorstore() | ||||||
| seen: dict[str, dict] = {} | ||||||
| retrieval_k = max(10, get_max_retrieval_docs() * 4) | ||||||
| ranked_lists: list[list[dict]] = [] | ||||||
| for q in queries: | ||||||
| for r in store.search(q, top_k=get_max_retrieval_docs()): | ||||||
| seen.setdefault(r["text"], r) | ||||||
| ranked_lists.append(run_hybrid(q, top_k=retrieval_k)) | ||||||
|
|
||||||
| children = list(seen.values()) | ||||||
| fused = reciprocal_rank_fusion(ranked_lists) | ||||||
| children = fused[:retrieval_k] | ||||||
| return {"retrieved_children": children, "retrieval_attempt": attempt} | ||||||
|
|
||||||
|
|
||||||
| # ── expand_context ──────────────────────────────────────────────────────────── | ||||||
|
|
||||||
|
|
||||||
| def expand_context(state: CRAGState) -> dict: | ||||||
| """Expand child hits to parent / larger chunks (small-to-big strategy).""" | ||||||
| """Expand child hits to parent chunks when ``parent_text`` is available.""" | ||||||
| children = state.get("retrieved_children", []) | ||||||
| logger.info("Expanding %d child chunks to parent context", len(children)) | ||||||
| return {"expanded_contexts": children} | ||||||
|
|
||||||
| expanded: list[dict] = [] | ||||||
| seen_texts: set[str] = set() | ||||||
| for child in children: | ||||||
| if not isinstance(child, dict): | ||||||
| text = str(child) | ||||||
| if text and text not in seen_texts: | ||||||
| seen_texts.add(text) | ||||||
| expanded.append({"text": text, "score": 0.0, "metadata": {}}) | ||||||
| continue | ||||||
|
|
||||||
| metadata = dict(child.get("metadata", {})) | ||||||
| parent_text = metadata.get("parent_text") | ||||||
| text = str(parent_text or child.get("text", "")) | ||||||
| if not text or text in seen_texts: | ||||||
| continue | ||||||
|
|
||||||
| seen_texts.add(text) | ||||||
| expanded.append( | ||||||
| { | ||||||
| "text": text, | ||||||
| "score": child.get("score", 0.0), | ||||||
| "metadata": metadata, | ||||||
| } | ||||||
| ) | ||||||
|
|
||||||
| return {"expanded_contexts": expanded or children} | ||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The fallback
Suggested change
|
||||||
|
|
||||||
|
|
||||||
| # ── rerank_context ──────────────────────────────────────────────────────────── | ||||||
|
|
||||||
| Original file line number | Diff line number | Diff line change | ||||
|---|---|---|---|---|---|---|
|
|
@@ -2,18 +2,20 @@ | |||||
|
|
||||||
| from __future__ import annotations | ||||||
|
|
||||||
| from typing import Any | ||||||
|
|
||||||
| from langchain_text_splitters import RecursiveCharacterTextSplitter | ||||||
|
|
||||||
|
|
||||||
| def recursive_chunk( | ||||||
| texts: list[str], | ||||||
| docs: list[str] | list[dict[str, Any]], | ||||||
| chunk_size: int = 512, | ||||||
| chunk_overlap: int = 64, | ||||||
| ) -> list[dict]: | ||||||
| """Split texts using LangChain's ``RecursiveCharacterTextSplitter``. | ||||||
| ) -> list[dict[str, Any]]: | ||||||
| """Split documents using LangChain's ``RecursiveCharacterTextSplitter``. | ||||||
|
|
||||||
| Args: | ||||||
| texts: List of raw document strings. | ||||||
| docs: Either raw text strings or ``{text, metadata}`` dicts. | ||||||
| chunk_size: Maximum characters per chunk. | ||||||
| chunk_overlap: Character overlap between adjacent chunks. | ||||||
|
|
||||||
|
|
@@ -24,8 +26,25 @@ def recursive_chunk( | |||||
| chunk_size=chunk_size, | ||||||
| chunk_overlap=chunk_overlap, | ||||||
| ) | ||||||
| chunks = [] | ||||||
| for text in texts: | ||||||
| for chunk in splitter.split_text(text): | ||||||
| chunks.append({"text": chunk, "metadata": {}}) | ||||||
|
|
||||||
| chunks: list[dict[str, Any]] = [] | ||||||
| for doc_idx, item in enumerate(docs): | ||||||
| if isinstance(item, dict): | ||||||
| text = str(item.get("text", "")) | ||||||
| base_meta = dict(item.get("metadata", {})) | ||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Similar to the issue in the CRAG nodes,
Suggested change
|
||||||
| else: | ||||||
| text = str(item) | ||||||
| base_meta = {} | ||||||
|
|
||||||
| if not text: | ||||||
| continue | ||||||
|
|
||||||
| split_chunks = splitter.split_text(text) | ||||||
| for chunk_idx, chunk in enumerate(split_chunks): | ||||||
| meta = { | ||||||
| **base_meta, | ||||||
| "doc_index": base_meta.get("doc_index", doc_idx), | ||||||
| "chunk_index": chunk_idx, | ||||||
| } | ||||||
| chunks.append({"text": chunk, "metadata": meta}) | ||||||
| return chunks | ||||||
| Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -67,9 +67,17 @@ def embed_query(self, text: str) -> list[float]: | |||||||||||||||||||||||||||||||||||||||||
| return self.embed_texts([text])[0] | ||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||
| @staticmethod | ||||||||||||||||||||||||||||||||||||||||||
| def _deterministic_id(text: str) -> str: | ||||||||||||||||||||||||||||||||||||||||||
| """Generate a deterministic UUID-v5 from text content for deduplication.""" | ||||||||||||||||||||||||||||||||||||||||||
| return str(uuid.uuid5(uuid.NAMESPACE_DNS, text)) | ||||||||||||||||||||||||||||||||||||||||||
| def _deterministic_id(text: str, metadata: dict[str, Any]) -> str: | ||||||||||||||||||||||||||||||||||||||||||
| """Generate a stable UUID-v5 using text + provenance metadata.""" | ||||||||||||||||||||||||||||||||||||||||||
| key_parts = [ | ||||||||||||||||||||||||||||||||||||||||||
| text, | ||||||||||||||||||||||||||||||||||||||||||
| str(metadata.get("source", "")), | ||||||||||||||||||||||||||||||||||||||||||
| str(metadata.get("doc_id", "")), | ||||||||||||||||||||||||||||||||||||||||||
| str(metadata.get("parent_id", "")), | ||||||||||||||||||||||||||||||||||||||||||
| str(metadata.get("chunk_index", "")), | ||||||||||||||||||||||||||||||||||||||||||
| ] | ||||||||||||||||||||||||||||||||||||||||||
| key = "|".join(key_parts) | ||||||||||||||||||||||||||||||||||||||||||
| return str(uuid.uuid5(uuid.NAMESPACE_DNS, key)) | ||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||
|
Comment on lines
+72
to
81
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The use of a simple pipe separator
Suggested change
|
||||||||||||||||||||||||||||||||||||||||||
| # ── VectorStorePort interface ───────────────────────────────────────────── | ||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||
|
|
@@ -105,7 +113,7 @@ def add_documents( | |||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||
| points = [ | ||||||||||||||||||||||||||||||||||||||||||
| PointStruct( | ||||||||||||||||||||||||||||||||||||||||||
| id=self._deterministic_id(text), | ||||||||||||||||||||||||||||||||||||||||||
| id=self._deterministic_id(text, meta), | ||||||||||||||||||||||||||||||||||||||||||
| vector=vec, | ||||||||||||||||||||||||||||||||||||||||||
| payload={"text": text, **meta}, | ||||||||||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Calling
dict()onchild.get("metadata", {})will raise aTypeErrorif the"metadata"key exists in the dictionary but its value is explicitly set tonull(None). Usingor {}ensures a dictionary is always passed to the constructor.