From 3eb5354de25d7ad3256ff7476518424eca7caf08 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E7=8F=AD=E6=89=AC?= Date: Mon, 13 Apr 2026 14:40:03 +0800 Subject: [PATCH 01/56] bump version --- src/sirchmunk/version.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/sirchmunk/version.py b/src/sirchmunk/version.py index 87b826d..73c89eb 100644 --- a/src/sirchmunk/version.py +++ b/src/sirchmunk/version.py @@ -1 +1 @@ -__version__ = "0.0.7+main" +__version__ = "0.0.8+main" From b72a878c2947c63cae2fcda090353071cad4b29c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E7=8F=AD=E6=89=AC?= Date: Mon, 13 Apr 2026 17:19:12 +0800 Subject: [PATCH 02/56] Introduce Sirchmunk Learnings (insights from pageindex and LLM wiki) --- src/sirchmunk/cli/cli.py | 250 ++++++ src/sirchmunk/learnings/README.md | 218 ++++++ src/sirchmunk/learnings/__init__.py | 29 +- src/sirchmunk/learnings/compiler.py | 840 +++++++++++++++++++++ src/sirchmunk/learnings/knowledge_base.py | 33 +- src/sirchmunk/learnings/lint.py | 213 ++++++ src/sirchmunk/learnings/tree_indexer.py | 444 +++++++++++ src/sirchmunk/llm/prompts.py | 87 +++ src/sirchmunk/schema/knowledge.py | 12 + src/sirchmunk/search.py | 113 +++ src/sirchmunk/storage/knowledge_storage.py | 69 +- 11 files changed, 2294 insertions(+), 14 deletions(-) create mode 100644 src/sirchmunk/learnings/README.md create mode 100644 src/sirchmunk/learnings/compiler.py create mode 100644 src/sirchmunk/learnings/lint.py create mode 100644 src/sirchmunk/learnings/tree_indexer.py diff --git a/src/sirchmunk/cli/cli.py b/src/sirchmunk/cli/cli.py index 8919732..9d09762 100644 --- a/src/sirchmunk/cli/cli.py +++ b/src/sirchmunk/cli/cli.py @@ -6,6 +6,7 @@ sirchmunk init - Initialize working directory + generate .env sirchmunk serve - Start the API server (backend only) sirchmunk search - Perform a search query + sirchmunk compile - Compile documents into knowledge indices sirchmunk web init - Build WebUI frontend (requires Node.js) sirchmunk web serve - Start API + WebUI (single port) sirchmunk web serve --dev - Start API + Next.js dev server (dual port) @@ -1225,6 +1226,207 @@ def cmd_mcp_version(args: argparse.Namespace) -> int: return 0 +# ------------------------------------------------------------------ +# sirchmunk compile +# ------------------------------------------------------------------ + +def cmd_compile(args: argparse.Namespace) -> int: + """Compile document collections into structured knowledge indices. + + Builds PageIndex-style tree indices and LLM Wiki-style knowledge + clusters for downstream search acceleration. + + Args: + args: Command-line arguments + + Returns: + Exit code (0 for success, non-zero for failure) + """ + try: + work_path = Path( + getattr(args, "work_path", None) or str(_get_default_work_path()) + ).expanduser().resolve() + os.environ["SIRCHMUNK_WORK_PATH"] = str(work_path) + + env_file = work_path / ".env" + if env_file.exists(): + _load_env_file(env_file) + + paths = args.paths or None + if not paths: + print(" Error: --paths is required for compile.") + print(" Usage: sirchmunk compile --paths /data/docs") + return 1 + + # Status mode + if getattr(args, "status", False): + return asyncio.run(_compile_status(paths, work_path)) + + # Lint mode + if getattr(args, "lint", False): + return asyncio.run(_compile_lint( + work_path, auto_fix=getattr(args, "fix", False), + )) + + # Normal compile + incremental = not getattr(args, "full", False) + return asyncio.run(_compile_run( + paths=paths, + work_path=work_path, + incremental=incremental, + max_files=getattr(args, "max_files", None), + concurrency=getattr(args, "concurrency", 3), + shallow=getattr(args, "shallow", False), + )) + + except KeyboardInterrupt: + print("\n Compile cancelled.") + return 130 + except Exception as e: + logger.error(f"Compile failed: {e}", exc_info=True) + print(f" Compile error: {e}") + return 1 + + +async def _compile_run( + paths: list, + work_path: Path, + incremental: bool = True, + max_files: Optional[int] = None, + concurrency: int = 3, + shallow: bool = False, +) -> int: + """Execute compile using AgenticSearch.""" + from sirchmunk.search import AgenticSearch + from sirchmunk.llm.openai_chat import OpenAIChat + + llm_api_key = os.getenv("LLM_API_KEY", "") + if not llm_api_key: + print(" LLM_API_KEY is not set.") + print(" Configure it in ~/.sirchmunk/.env or set the environment variable.") + return 1 + + llm = OpenAIChat( + base_url=os.getenv("LLM_BASE_URL", "https://api.openai.com/v1"), + api_key=llm_api_key, + model=os.getenv("LLM_MODEL_NAME", "gpt-5.2"), + ) + + searcher = AgenticSearch(llm=llm, work_path=str(work_path)) + + print("=" * 60) + print(" Sirchmunk Knowledge Compile") + print("=" * 60) + print() + print(f" Paths: {', '.join(paths)}") + print(f" Incremental: {incremental}") + if shallow: + print(" Mode: shallow (tree indexing skipped)") + if max_files: + print(f" Max files: {max_files} (importance sampling)") + print() + + report = await searcher.compile( + paths=paths, + incremental=incremental, + shallow=shallow, + max_files=max_files, + concurrency=concurrency, + ) + + print() + print("=" * 60) + print(" Compile Report") + print("=" * 60) + print() + print(f" Total files: {report.get('total_files', 0)}") + print(f" Files added: {report.get('files_added', 0)}") + print(f" Files modified: {report.get('files_modified', 0)}") + print(f" Files skipped: {report.get('files_skipped', 0)}") + if report.get("files_sampled"): + print(f" Files sampled: {report['files_sampled']}") + print(f" Trees built: {report.get('trees_built', 0)}") + print(f" Clusters created: {report.get('clusters_created', 0)}") + print(f" Clusters merged: {report.get('clusters_merged', 0)}") + print(f" Cross-refs: {report.get('cross_refs_built', 0)}") + print(f" Elapsed: {report.get('elapsed_seconds', 0):.1f}s") + if report.get("errors"): + print(f" Errors: {len(report['errors'])}") + for err in report["errors"][:5]: + print(f" - {err}") + print() + + return 0 + + +async def _compile_status(paths: list, work_path: Path) -> int: + """Show compile status.""" + from sirchmunk.search import AgenticSearch + from sirchmunk.llm.openai_chat import OpenAIChat + + llm = OpenAIChat( + base_url=os.getenv("LLM_BASE_URL", "https://api.openai.com/v1"), + api_key=os.getenv("LLM_API_KEY", ""), + model=os.getenv("LLM_MODEL_NAME", "gpt-5.2"), + ) + + searcher = AgenticSearch(llm=llm, work_path=str(work_path)) + status = await searcher.compile_status(paths=paths) + + print("=" * 60) + print(" Compile Status") + print("=" * 60) + print() + print(f" Compiled files: {status.get('total_compiled_files', 0)}") + print(f" Tree indices: {status.get('total_trees', 0)}") + print(f" Clusters: {status.get('total_clusters', 0)}") + print(f" Last compile: {status.get('last_compile_at', 'Never')}") + print() + + return 0 + + +async def _compile_lint(work_path: Path, auto_fix: bool = False) -> int: + """Run knowledge lint checks.""" + from sirchmunk.search import AgenticSearch + from sirchmunk.llm.openai_chat import OpenAIChat + + llm = OpenAIChat( + base_url=os.getenv("LLM_BASE_URL", "https://api.openai.com/v1"), + api_key=os.getenv("LLM_API_KEY", ""), + model=os.getenv("LLM_MODEL_NAME", "gpt-5.2"), + ) + + searcher = AgenticSearch(llm=llm, work_path=str(work_path)) + report = await searcher.compile_lint(auto_fix=auto_fix) + + print("=" * 60) + print(" Knowledge Lint Report") + print("=" * 60) + print() + print(f" Clusters checked: {report.get('total_clusters_checked', 0)}") + print(f" Trees checked: {report.get('total_trees_checked', 0)}") + print(f" Errors: {report.get('errors', 0)}") + print(f" Warnings: {report.get('warnings', 0)}") + if auto_fix: + print(f" Auto-fixes: {report.get('auto_fixes_applied', 0)}") + print() + + issues = report.get("issues", []) + if issues: + for issue in issues[:20]: + severity = issue.get("severity", "info").upper() + msg = issue.get("message", "") + cid = issue.get("cluster_id", "") + fixed = " [FIXED]" if issue.get("auto_fixed") else "" + print(f" [{severity}] {msg} {f'(cluster={cid})' if cid else ''}{fixed}") + if len(issues) > 20: + print(f" ... and {len(issues) - 20} more") + print() + + return 0 + + # ------------------------------------------------------------------ # sirchmunk upload # ------------------------------------------------------------------ @@ -1435,6 +1637,54 @@ def create_parser() -> argparse.ArgumentParser: ) search_parser.set_defaults(func=cmd_search) + # === compile command === + compile_parser = subparsers.add_parser( + "compile", + help="Compile document collections into knowledge indices", + description=( + "Compile documents into structured knowledge indices (tree + clusters). " + "Optional step after 'sirchmunk init'." + ), + ) + compile_parser.add_argument( + "--paths", nargs="+", required=True, + help="Directories or files to compile", + ) + compile_parser.add_argument( + "--full", action="store_true", default=False, + help="Force full recompile (ignore incremental cache)", + ) + compile_parser.add_argument( + "--max-files", type=int, default=None, + help="Max files to process (triggers importance sampling for large sets)", + ) + compile_parser.add_argument( + "--concurrency", type=int, default=3, + help="Max parallel file compilations (default: 3)", + ) + compile_parser.add_argument( + "--shallow", action="store_true", default=False, + help="Skip tree indexing — use direct LLM summarisation only (faster)", + ) + compile_parser.add_argument( + "--status", action="store_true", default=False, + help="Show compile status instead of running compile", + ) + compile_parser.add_argument( + "--lint", action="store_true", default=False, + help="Run knowledge health checks", + ) + compile_parser.add_argument( + "--fix", action="store_true", default=False, + help="Auto-fix lint issues (use with --lint)", + ) + compile_parser.add_argument( + "--work-path", + default=None, + help="Working directory (default: ~/.sirchmunk)", + ) + compile_parser.set_defaults(func=cmd_compile) + # === web command group === web_parser = subparsers.add_parser( "web", diff --git a/src/sirchmunk/learnings/README.md b/src/sirchmunk/learnings/README.md new file mode 100644 index 0000000..0fc1bbe --- /dev/null +++ b/src/sirchmunk/learnings/README.md @@ -0,0 +1,218 @@ +# Sirchmunk Learnings Module + +The `sirchmunk/learnings` module implements **knowledge compilation and continuous learning** capabilities. It houses the core logic for transforming raw document collections into structured, searchable knowledge networks. + +## Architecture Overview + +``` +learnings/ +├── __init__.py # Public API exports +├── knowledge_base.py # Runtime knowledge builder (search-time) +├── evidence_processor.py # Monte Carlo evidence sampling +├── compiler.py # Offline knowledge compiler (compile-time) +├── tree_indexer.py # PageIndex-style document tree indexer +├── lint.py # Knowledge network health checks +└── README.md # This file +``` + +### Design Philosophy + +The module fuses insights from three frameworks: + +1. **PageIndex** (VectifyAI) — Hierarchical tree indexing replaces brute-force vector search with LLM reasoning-based navigation. The key insight: *similarity ≠ relevance*. + +2. **LLM Wiki** (Karpathy) — Documents are not merely "indexed" but "compiled" into an interlinked knowledge network that compounds over time. Knowledge clusters grow richer with each compile cycle. + +3. **NotebookLM** (Google) — Strict source grounding ensures every claim traces back to original evidence. The `EvidenceUnit` system provides full provenance. + +### Compile vs. Search + +| Aspect | Compile (offline) | Search (runtime) | +|--------|-------------------|-------------------| +| **When** | `sirchmunk compile` | `sirchmunk search` | +| **Speed** | Minutes (batch) | Seconds (interactive) | +| **Purpose** | Build indices + knowledge | Answer queries | +| **Module** | `compiler.py` (uses `tree_indexer.py`) | `knowledge_base.py`, `evidence_processor.py` | +| **Required** | Optional | Always available | + +Compile products are automatically leveraged by search when present, but search functions independently without them. + +--- + +## Components + +### DocumentTreeIndexer (`tree_indexer.py`) + +Builds hierarchical JSON tree indices for structured long documents. + +**Key concepts:** +- Only triggers for documents ≥ 50KB in eligible formats (PDF, DOCX, MD, HTML, etc.) +- LLM analyzes document structure recursively (up to 4 levels deep) +- Each node stores: title, summary, character range +- Query-time navigation: LLM selects relevant branches instead of scanning everything + +**Data structures:** +- `TreeNode` — Single node with `node_id`, `title`, `summary`, `char_range`, `children` +- `DocumentTree` — Complete tree for a document, JSON-serializable, cached by file hash + +**Usage:** +```python +indexer = DocumentTreeIndexer(llm=llm, cache_dir=cache_path) + +# Build (async, LLM-powered) +tree = await indexer.build_tree(file_path, content, max_depth=4) + +# Navigate (async, LLM-powered branch selection) +leaves = await indexer.navigate(tree, query="How does X work?") +for leaf in leaves: + relevant_text = content[leaf.char_range[0]:leaf.char_range[1]] + +# Cache check (sync) +if indexer.has_tree(file_path): + tree = indexer.load_tree(file_path) +``` + +### KnowledgeCompiler (`compiler.py`) + +Orchestrates the unified compile pipeline. + +**Four-phase pipeline:** +1. **File Discovery & Change Detection** — Scans paths, compares with manifest for incremental processing +2. **Per-File Compile** — Unified pipeline per file: tree-if-eligible → summary → topics → rich evidence +3. **Knowledge Aggregation** — Merges into existing clusters or creates new ones (three-tier similarity) +4. **Cross-Reference Building** — Creates `WeakSemanticEdge` links between related clusters + +**Unified single-file pipeline:** +For each file, the compiler runs a single pipeline instead of separate "tree" and "wiki" modes: +- If the file is ≥ 50KB and in an eligible format, a tree is built first. The root node's summary is synthesized from children's section summaries via LLM, and `EvidenceUnit` snippets + `tree_path` are populated directly from tree leaves. +- If the file is small or `shallow=True`, a direct LLM summary is generated instead. +- In both cases, topics are extracted and a `KnowledgeCluster` is created/merged. + +**Three-tier similarity strategy:** +| Similarity | Action | +|-----------|--------| +| ≥ 0.80 | Merge into existing cluster, re-compute embedding | +| 0.50 – 0.79 | Create new cluster + build `embed_sim` weak edges | +| < 0.50 | Create standalone cluster | + +**Importance probability sampling** (`ImportanceSampler`): +For large datasets, select a representative subset using weighted random sampling: +- File size (log-scaled): larger files contain more information +- Novelty: uncompiled files get 4× weight over already-compiled ones +- Extension diversity: structured formats (PDF, DOCX) get 1.5× boost + +**Key data structures:** +- `CompileManifest` — Tracks file hashes and cluster associations for incremental compile +- `FileManifestEntry` — Per-file state (hash, compile timestamp, tree flag, cluster IDs) +- `CompileReport` — Statistics from a compile run +- `CompileStatus` — Quick status snapshot + +### KnowledgeLint (`lint.py`) + +Health checks for the knowledge network (inspired by LLM Wiki's Lint operation). + +**Checks performed:** +- **Empty clusters** — Clusters with minimal or no content +- **Stale evidence** — Evidence pointing to files that no longer exist +- **Orphan clusters** — Clusters with no evidence and no queries +- **Isolated clusters** — Clusters with no cross-references +- **Orphan trees** — Tree cache files without matching manifest entries +- **Stale manifest** — Manifest entries pointing to deleted files + +**Auto-fix capabilities:** +- Deprecate clusters where all evidence sources are gone +- Remove orphan tree cache files + +### KnowledgeBase (`knowledge_base.py`) + +Runtime knowledge builder used during search operations. + +**Tree-aware evidence extraction:** +When a tree index exists for a file, `_extract_evidence_for_file()` navigates to relevant sections first, then runs Monte Carlo sampling within those narrowed regions. This dramatically improves precision for large documents. + +### MonteCarloEvidenceSampling (`evidence_processor.py`) + +Statistical sampling for finding relevant regions in documents. Used both at compile-time and search-time. + +--- + +## CLI Interface + +```bash +# Compile documents (optional, after sirchmunk init) +sirchmunk compile --paths /data/docs /data/reports + +# Incremental compile (default, skips unchanged files) +sirchmunk compile --paths /data/docs + +# Full recompile +sirchmunk compile --paths /data/docs --full + +# Importance sampling for large datasets +sirchmunk compile --paths /data/docs --max-files 100 + +# Shallow mode: skip tree indexing, use direct LLM summarisation +sirchmunk compile --paths /data/docs --shallow + +# Check compile status +sirchmunk compile --paths /data/docs --status + +# Run health checks +sirchmunk compile --paths /data/docs --lint +sirchmunk compile --paths /data/docs --lint --fix +``` + +## Python SDK + +```python +from sirchmunk.search import AgenticSearch + +searcher = AgenticSearch(work_path="~/.sirchmunk") + +# Compile +report = await searcher.compile( + paths=["/data/docs"], + incremental=True, + shallow=False, # set True to skip tree indexing + max_files=100, # importance sampling + concurrency=3, +) + +# Status +status = await searcher.compile_status(paths=["/data/docs"]) + +# Lint +lint_report = await searcher.compile_lint(auto_fix=True) + +# Search (automatically uses compile products when available) +result = await searcher.search("query", paths=["/data/docs"]) +``` + +--- + +## Cache Layout + +``` +{work_path}/.cache/ +├── compile/ +│ ├── manifest.json # Compile manifest (incremental state) +│ └── trees/ +│ ├── {file_hash_1}.json # Tree index for document 1 +│ └── {file_hash_2}.json # Tree index for document 2 +└── knowledge/ + └── knowledge_clusters.parquet # Persistent cluster storage (DuckDB + Parquet) +``` + +## Schema Extensions + +The compile feature extends existing schemas: + +- **`EvidenceUnit`** — Added `tree_path` (node IDs from tree navigation) and `page_range` (character offsets) +- **`KnowledgeCluster`** — Added `merge_count` (tracks compile-time merge frequency for lifecycle promotion: ≥ 3 merges → `STABLE`) + +## Design Principles + +- **SOLID compliance**: Each class has a single responsibility; dependencies are injected via constructor +- **Optional by design**: Compile never breaks existing search functionality +- **Incremental**: Only processes changed files; manifest tracks state across runs +- **Production-ready**: Bounded concurrency, error isolation per file, graceful schema migration diff --git a/src/sirchmunk/learnings/__init__.py b/src/sirchmunk/learnings/__init__.py index 0829846..bc14211 100644 --- a/src/sirchmunk/learnings/__init__.py +++ b/src/sirchmunk/learnings/__init__.py @@ -1 +1,28 @@ -# Copyright (c) ModelScope Contributors. All rights reserved. \ No newline at end of file +# Copyright (c) ModelScope Contributors. All rights reserved. + +from sirchmunk.learnings.compiler import ( + CompileManifest, + CompileReport, + CompileStatus, + ImportanceSampler, + KnowledgeCompiler, +) +from sirchmunk.learnings.lint import KnowledgeLint, LintReport +from sirchmunk.learnings.tree_indexer import ( + DocumentTree, + DocumentTreeIndexer, + TreeNode, +) + +__all__ = [ + "CompileManifest", + "CompileReport", + "CompileStatus", + "DocumentTree", + "DocumentTreeIndexer", + "ImportanceSampler", + "KnowledgeCompiler", + "KnowledgeLint", + "LintReport", + "TreeNode", +] \ No newline at end of file diff --git a/src/sirchmunk/learnings/compiler.py b/src/sirchmunk/learnings/compiler.py new file mode 100644 index 0000000..3c2b0da --- /dev/null +++ b/src/sirchmunk/learnings/compiler.py @@ -0,0 +1,840 @@ +# Copyright (c) ModelScope Contributors. All rights reserved. +""" +Knowledge compiler — orchestrates offline compile of document collections. + +Fuses PageIndex (tree indexing) and LLM Wiki (knowledge compilation network) +into a single compile pipeline that produces structured tree indices and +knowledge clusters for downstream search acceleration. +""" + +import asyncio +import json +import math +import os +import random +import hashlib +from dataclasses import dataclass, field +from datetime import datetime, timezone +from pathlib import Path +from typing import Any, Dict, List, Optional, Set, Tuple, Union + +from sirchmunk.learnings.tree_indexer import ( + DocumentTree, + DocumentTreeIndexer, +) +from sirchmunk.llm.openai_chat import OpenAIChat +from sirchmunk.schema.knowledge import ( + AbstractionLevel, + EvidenceUnit, + KnowledgeCluster, + Lifecycle, + WeakSemanticEdge, +) +from sirchmunk.storage.knowledge_storage import KnowledgeStorage +from sirchmunk.utils import LogCallback, create_logger +from sirchmunk.utils.file_utils import fast_extract, get_fast_hash + +# Concurrency cap for LLM-heavy file processing +_DEFAULT_CONCURRENCY = 3 + +# Similarity threshold for merging into existing clusters during compile +_MERGE_SIMILARITY_THRESHOLD = 0.75 + + +# --------------------------------------------------------------------------- +# Data structures +# --------------------------------------------------------------------------- + +@dataclass +class FileManifestEntry: + """State of a single file in the compile manifest.""" + + file_hash: str + compiled_at: str + has_tree: bool + cluster_ids: List[str] + size_bytes: int + + def to_dict(self) -> Dict[str, Any]: + return { + "file_hash": self.file_hash, + "compiled_at": self.compiled_at, + "has_tree": self.has_tree, + "cluster_ids": self.cluster_ids, + "size_bytes": self.size_bytes, + } + + @classmethod + def from_dict(cls, data: Dict[str, Any]) -> "FileManifestEntry": + return cls( + file_hash=data["file_hash"], + compiled_at=data["compiled_at"], + has_tree=data.get("has_tree", False), + cluster_ids=data.get("cluster_ids", []), + size_bytes=data.get("size_bytes", 0), + ) + + +@dataclass +class CompileManifest: + """Tracks compiled file states for incremental processing.""" + + version: str = "1.0" + last_compile_at: Optional[str] = None + files: Dict[str, FileManifestEntry] = field(default_factory=dict) + + def to_json(self) -> str: + return json.dumps({ + "version": self.version, + "last_compile_at": self.last_compile_at, + "files": {k: v.to_dict() for k, v in self.files.items()}, + }, ensure_ascii=False, indent=2) + + @classmethod + def from_json(cls, json_str: str) -> "CompileManifest": + data = json.loads(json_str) + files = { + k: FileManifestEntry.from_dict(v) + for k, v in data.get("files", {}).items() + } + return cls( + version=data.get("version", "1.0"), + last_compile_at=data.get("last_compile_at"), + files=files, + ) + + +@dataclass +class FileEntry: + """Discovered file pending compilation.""" + + path: str + size_bytes: int + file_hash: str + + +@dataclass +class ChangeSet: + """Delta between discovered files and the manifest.""" + + added: List[FileEntry] = field(default_factory=list) + modified: List[FileEntry] = field(default_factory=list) + deleted: List[str] = field(default_factory=list) + unchanged: List[str] = field(default_factory=list) + + +@dataclass +class FileCompileResult: + """Result of compiling a single file.""" + + path: str + tree: Optional[DocumentTree] = None + summary: str = "" + topics: List[str] = field(default_factory=list) + evidence: Optional[EvidenceUnit] = None + cluster_ids: List[str] = field(default_factory=list) + error: Optional[str] = None + + +@dataclass +class CompileReport: + """Summary report of a compile run.""" + + total_files: int = 0 + files_added: int = 0 + files_modified: int = 0 + files_skipped: int = 0 + files_deleted: int = 0 + files_sampled: int = 0 + trees_built: int = 0 + clusters_created: int = 0 + clusters_merged: int = 0 + cross_refs_built: int = 0 + errors: List[str] = field(default_factory=list) + elapsed_seconds: float = 0.0 + + def to_dict(self) -> Dict[str, Any]: + return { + "total_files": self.total_files, + "files_added": self.files_added, + "files_modified": self.files_modified, + "files_skipped": self.files_skipped, + "files_deleted": self.files_deleted, + "files_sampled": self.files_sampled, + "trees_built": self.trees_built, + "clusters_created": self.clusters_created, + "clusters_merged": self.clusters_merged, + "cross_refs_built": self.cross_refs_built, + "errors": self.errors, + "elapsed_seconds": round(self.elapsed_seconds, 2), + } + + +@dataclass +class CompileStatus: + """Status snapshot of the compile state.""" + + total_compiled_files: int = 0 + total_clusters: int = 0 + total_trees: int = 0 + last_compile_at: Optional[str] = None + manifest_path: str = "" + + +# --------------------------------------------------------------------------- +# Importance probability sampler +# --------------------------------------------------------------------------- + +class ImportanceSampler: + """Select a representative subset of files using importance-based probability. + + Sampling strategy for large datasets: + - Larger files get higher probability (they contain more information). + - Uncompiled (new) files are prioritised over previously compiled ones. + - Files with rare extensions get a mild boost (diversity signal). + - The final probability is proportional to a composite importance score. + """ + + def __init__(self, max_files: int, seed: Optional[int] = None): + self._max_files = max_files + self._rng = random.Random(seed) + + def sample(self, files: List[FileEntry], manifest: CompileManifest) -> List[FileEntry]: + """Return up to *max_files* entries sampled by importance.""" + if len(files) <= self._max_files: + return files + + scores = [self._score(f, manifest) for f in files] + total = sum(scores) or 1.0 + probs = [s / total for s in scores] + + selected_indices = set() + attempts = 0 + while len(selected_indices) < self._max_files and attempts < len(files) * 3: + idx = self._weighted_choice(probs) + selected_indices.add(idx) + attempts += 1 + + return [files[i] for i in sorted(selected_indices)] + + def _score(self, entry: FileEntry, manifest: CompileManifest) -> float: + """Compute composite importance score.""" + # Size factor: log-scaled, bounded + size_score = math.log2(max(entry.size_bytes, 1024)) / 20.0 + + # Novelty factor: new files are more important + novelty = 2.0 if entry.path not in manifest.files else 0.5 + + # Extension diversity: rare extensions get a mild boost + ext = Path(entry.path).suffix.lower() + diversity = 1.5 if ext in {".pdf", ".docx", ".doc", ".tex"} else 1.0 + + return size_score * novelty * diversity + + def _weighted_choice(self, probs: List[float]) -> int: + r = self._rng.random() + cumulative = 0.0 + for i, p in enumerate(probs): + cumulative += p + if r <= cumulative: + return i + return len(probs) - 1 + + +# --------------------------------------------------------------------------- +# Compiler +# --------------------------------------------------------------------------- + +class KnowledgeCompiler: + """Orchestrate compile pipeline: file discovery -> tree indexing -> knowledge aggregation.""" + + # File extensions eligible for compilation + _ELIGIBLE_EXTENSIONS = { + ".pdf", ".docx", ".doc", ".md", ".markdown", ".html", ".htm", + ".rst", ".tex", ".txt", ".pptx", ".xlsx", + } + + def __init__( + self, + llm: OpenAIChat, + embedding_client: Optional[Any], + knowledge_storage: KnowledgeStorage, + tree_indexer: DocumentTreeIndexer, + work_path: Union[str, Path], + log_callback: LogCallback = None, + ): + self._llm = llm + self._embedding = embedding_client + self._storage = knowledge_storage + self._tree_indexer = tree_indexer + self._work_path = Path(work_path).expanduser().resolve() + self._log = create_logger(log_callback=log_callback) + + self._compile_dir = self._work_path / ".cache" / "compile" + self._compile_dir.mkdir(parents=True, exist_ok=True) + self._manifest_path = self._compile_dir / "manifest.json" + + # ------------------------------------------------------------------ # + # Public API # + # ------------------------------------------------------------------ # + + async def compile( + self, + paths: List[str], + *, + incremental: bool = True, + shallow: bool = False, + max_files: Optional[int] = None, + concurrency: int = _DEFAULT_CONCURRENCY, + ) -> CompileReport: + """Execute the unified knowledge compile pipeline. + + Args: + paths: Directories or files to compile. + incremental: Skip unchanged files. + shallow: Skip tree building even for eligible files — use direct + LLM summarisation only (faster, lower quality). + max_files: Cap on files to process (triggers importance sampling). + concurrency: Max parallel file compilations. + """ + import time + t0 = time.monotonic() + report = CompileReport() + + # Phase 1: discover and diff + await self._log.info("[Compile] Phase 1: File discovery & change detection") + manifest = self._load_manifest() + discovered = await self._discover_files(paths) + report.total_files = len(discovered) + await self._log.info(f"[Compile] Discovered {len(discovered)} eligible files") + + if incremental: + changes = self._detect_changes(discovered, manifest) + to_compile = changes.added + changes.modified + report.files_skipped = len(changes.unchanged) + report.files_deleted = len(changes.deleted) + for deleted_path in changes.deleted: + manifest.files.pop(deleted_path, None) + else: + to_compile = discovered + report.files_skipped = 0 + + report.files_added = len([f for f in to_compile if f.path not in manifest.files]) + report.files_modified = len(to_compile) - report.files_added + + # Phase 1.5: importance sampling for large datasets + if max_files and len(to_compile) > max_files: + await self._log.info( + f"[Compile] Applying importance sampling: {len(to_compile)} -> {max_files} files" + ) + sampler = ImportanceSampler(max_files=max_files) + to_compile = sampler.sample(to_compile, manifest) + report.files_sampled = len(to_compile) + + if not to_compile: + await self._log.info("[Compile] No files to compile (all up-to-date)") + report.elapsed_seconds = time.monotonic() - t0 + return report + + await self._log.info( + f"[Compile] Phase 2: Processing {len(to_compile)} files " + f"(concurrency={concurrency})" + ) + + # Phase 2: compile files with bounded concurrency + semaphore = asyncio.Semaphore(concurrency) + results: List[FileCompileResult] = [] + + async def _bounded(entry: FileEntry) -> FileCompileResult: + async with semaphore: + return await self._compile_single_file(entry, shallow=shallow) + + tasks = [_bounded(f) for f in to_compile] + for coro in asyncio.as_completed(tasks): + result = await coro + results.append(result) + if result.error: + report.errors.append(f"{result.path}: {result.error}") + else: + if result.tree: + report.trees_built += 1 + # Update manifest + manifest.files[result.path] = FileManifestEntry( + file_hash=get_fast_hash(result.path) or "", + compiled_at=datetime.now(timezone.utc).isoformat(), + has_tree=result.tree is not None, + cluster_ids=result.cluster_ids, + size_bytes=Path(result.path).stat().st_size if Path(result.path).exists() else 0, + ) + + # Phase 3: aggregate results into knowledge network + await self._log.info("[Compile] Phase 3: Knowledge aggregation") + for r in results: + if r.error or not r.summary: + continue + created, merged = await self._aggregate_to_knowledge_network(r) + report.clusters_created += created + report.clusters_merged += merged + + # Phase 4: cross-references + await self._log.info("[Compile] Phase 4: Building cross-references") + report.cross_refs_built = await self._build_cross_references(results) + + # Phase 5: persist manifest + manifest.last_compile_at = datetime.now(timezone.utc).isoformat() + self._save_manifest(manifest) + self._storage.force_sync() + + report.elapsed_seconds = time.monotonic() - t0 + await self._log.info( + f"[Compile] Done in {report.elapsed_seconds:.1f}s — " + f"trees={report.trees_built}, created={report.clusters_created}, " + f"merged={report.clusters_merged}, errors={len(report.errors)}" + ) + return report + + async def get_status(self, paths: List[str]) -> CompileStatus: + """Return current compile status for the given paths.""" + manifest = self._load_manifest() + path_set = {str(Path(p).resolve()) for p in paths} + + compiled_count = 0 + tree_count = 0 + cluster_ids: Set[str] = set() + for fp, entry in manifest.files.items(): + for p in path_set: + if fp.startswith(p): + compiled_count += 1 + if entry.has_tree: + tree_count += 1 + cluster_ids.update(entry.cluster_ids) + break + + return CompileStatus( + total_compiled_files=compiled_count, + total_clusters=len(cluster_ids), + total_trees=tree_count, + last_compile_at=manifest.last_compile_at, + manifest_path=str(self._manifest_path), + ) + + # ------------------------------------------------------------------ # + # File discovery and change detection # + # ------------------------------------------------------------------ # + + async def _discover_files(self, paths: List[str]) -> List[FileEntry]: + """Walk paths and return all compilation-eligible files.""" + entries: List[FileEntry] = [] + seen: Set[str] = set() + + for base in paths: + base_path = Path(base).expanduser().resolve() + if base_path.is_file(): + candidates = [base_path] + elif base_path.is_dir(): + candidates = sorted(base_path.rglob("*")) + else: + continue + + for fp in candidates: + if not fp.is_file(): + continue + if fp.suffix.lower() not in self._ELIGIBLE_EXTENSIONS: + continue + abs_path = str(fp.resolve()) + if abs_path in seen: + continue + seen.add(abs_path) + fh = get_fast_hash(abs_path) + if fh is None: + continue + entries.append(FileEntry( + path=abs_path, + size_bytes=fp.stat().st_size, + file_hash=fh, + )) + + return entries + + def _detect_changes( + self, discovered: List[FileEntry], manifest: CompileManifest, + ) -> ChangeSet: + """Compare discovered files against the manifest for incremental compile.""" + changes = ChangeSet() + current_paths = {f.path for f in discovered} + + for entry in discovered: + prev = manifest.files.get(entry.path) + if prev is None: + changes.added.append(entry) + elif prev.file_hash != entry.file_hash: + changes.modified.append(entry) + else: + changes.unchanged.append(entry.path) + + for old_path in manifest.files: + if old_path not in current_paths: + changes.deleted.append(old_path) + + return changes + + # ------------------------------------------------------------------ # + # Single-file compilation # + # ------------------------------------------------------------------ # + + async def _compile_single_file( + self, + entry: FileEntry, + *, + shallow: bool = False, + ) -> FileCompileResult: + """Unified compile pipeline: tree-if-eligible -> summary -> topics -> evidence. + + When *shallow* is True (or file is ineligible for tree indexing), + the pipeline skips tree building and summarises via a direct LLM call. + """ + result = FileCompileResult(path=entry.path) + try: + await self._log.info(f"[Compile] Processing: {Path(entry.path).name}") + + extraction = await fast_extract(file_path=entry.path) + content = extraction.content + if not content or len(content.strip()) < 100: + result.error = "Insufficient text content" + return result + + use_tree = ( + not shallow + and DocumentTreeIndexer.should_build_tree(entry.path, len(content)) + ) + + if use_tree: + result.tree = await self._tree_indexer.build_tree( + entry.path, content, + ) + + result.summary = await self._extract_summary( + entry.path, content, result.tree, + ) + result.topics = await self._extract_topics(result.summary) + result.evidence = self._build_evidence(entry, content, result) + + except Exception as exc: + result.error = str(exc) + await self._log.warning(f"[Compile] Failed: {entry.path}: {exc}") + + return result + + async def _extract_summary( + self, + file_path: str, + content: str, + tree: Optional[DocumentTree] = None, + ) -> str: + """Generate a document-level summary. + + When a tree is available its root already contains an LLM-synthesized + summary (produced by ``_synthesize_root_summary`` during tree build), + so we reuse it directly — no redundant LLM call. + """ + if tree and tree.root and tree.root.summary: + return tree.root.summary + + preview = content[:16000] if len(content) > 16000 else content + from sirchmunk.llm.prompts import COMPILE_DOC_SUMMARY + prompt = COMPILE_DOC_SUMMARY.format( + file_name=Path(file_path).name, + document_content=preview, + ) + resp = await self._llm.achat([{"role": "user", "content": prompt}]) + return resp.content.strip() + + def _build_evidence( + self, + entry: FileEntry, + content: str, + result: FileCompileResult, + ) -> EvidenceUnit: + """Build an EvidenceUnit, populating snippets/tree_path from tree leaves.""" + from sirchmunk.schema.metadata import FileInfo + + snippets: List[str] = [] + tree_path: Optional[List[str]] = None + + if result.tree and result.tree.root: + leaves = result.tree.root.all_leaves() + tree_path = [leaf.node_id for leaf in leaves] + for leaf in leaves: + start, end = leaf.char_range + snippet = content[start:end][:500] + if snippet.strip(): + snippets.append(snippet) + + return EvidenceUnit( + doc_id=FileInfo.get_cache_key(entry.path), + file_or_url=Path(entry.path), + summary=result.summary, + is_found=True, + snippets=snippets, + tree_path=tree_path, + extracted_at=datetime.now(timezone.utc), + ) + + async def _extract_topics(self, summary: str) -> List[str]: + """Extract key topics/entities from a document summary.""" + from sirchmunk.llm.prompts import COMPILE_TOPIC_EXTRACTION + prompt = COMPILE_TOPIC_EXTRACTION.format(summary=summary) + resp = await self._llm.achat([{"role": "user", "content": prompt}]) + try: + raw = resp.content.strip() + if raw.startswith("["): + parsed = json.loads(raw) + if isinstance(parsed, list): + return [str(t) for t in parsed if t] + return [t.strip() for t in raw.split(",") if t.strip()] + except (json.JSONDecodeError, TypeError): + return [] + + # ------------------------------------------------------------------ # + # Knowledge aggregation (LLM Wiki Ingest) # + # ------------------------------------------------------------------ # + + async def _aggregate_to_knowledge_network( + self, result: FileCompileResult, + ) -> Tuple[int, int]: + """Aggregate a file's compile result into the knowledge network. + + Three-tier similarity strategy (per design doc): + - similarity >= 0.80 → merge into existing cluster + - 0.50 <= sim < 0.80 → create new cluster + weak edge to similar + - similarity < 0.50 → create standalone cluster + + Returns: + (clusters_created, clusters_merged) + """ + created, merged = 0, 0 + if not result.summary: + return created, merged + + embedding = self._encode_text(result.summary) + + # Search for similar existing clusters across a wider range + best_match: Optional[Dict[str, Any]] = None + if embedding is not None: + similar = await self._storage.search_similar_clusters( + query_embedding=embedding, + top_k=3, + similarity_threshold=0.50, + ) + if similar: + best_match = similar[0] + + if best_match and best_match["similarity"] >= 0.80: + # Tier 1: merge into existing cluster + cluster = await self._storage.get(best_match["id"]) + if cluster: + await self._merge_into_cluster(cluster, result) + # Re-compute embedding for merged content + await self._update_cluster_embedding(cluster) + result.cluster_ids.append(cluster.id) + merged += 1 + return created, merged + + # Create a new cluster (Tier 2 or Tier 3) + cluster = await self._create_cluster(result) + if cluster: + result.cluster_ids.append(cluster.id) + await self._store_cluster_embedding(cluster, embedding, result.summary) + created += 1 + + # Tier 2: build weak edges to moderately similar clusters + if best_match and best_match["similarity"] >= 0.50: + for s in (similar or []): + if s["similarity"] >= 0.50: + target = await self._storage.get(s["id"]) + if target: + self._add_edge(cluster, target.id, "embed_sim", s["similarity"]) + self._add_edge(target, cluster.id, "embed_sim", s["similarity"]) + await self._storage.update(target) + await self._storage.update(cluster) + + return created, merged + + def _encode_text(self, text: str) -> Optional[Any]: + """Encode text to embedding vector, returns None on failure.""" + if not self._embedding: + return None + try: + return self._embedding.encode(text) + except Exception: + return None + + async def _store_cluster_embedding( + self, cluster: KnowledgeCluster, embedding: Optional[Any], text: str, + ) -> None: + """Store embedding for a cluster if available.""" + if embedding is None or not self._embedding: + return + text_hash = hashlib.md5(text.encode()).hexdigest() + vec = embedding.tolist() if hasattr(embedding, "tolist") else list(embedding) + await self._storage.store_embedding( + cluster.id, vec, + self._embedding.model_id or "default", + text_hash, + ) + + async def _update_cluster_embedding(self, cluster: KnowledgeCluster) -> None: + """Re-compute and store embedding after content merge.""" + content_text = str(cluster.content)[:2000] if cluster.content else "" + if not content_text: + return + embedding = self._encode_text(content_text) + await self._store_cluster_embedding(cluster, embedding, content_text) + + async def _merge_into_cluster( + self, + cluster: KnowledgeCluster, + result: FileCompileResult, + ) -> None: + """Merge a file compile result into an existing cluster.""" + # Append evidence + if result.evidence: + existing_doc_ids = {e.doc_id for e in cluster.evidences} + if result.evidence.doc_id not in existing_doc_ids: + cluster.evidences.append(result.evidence) + + # Enrich content via LLM merge + from sirchmunk.llm.prompts import COMPILE_MERGE_KNOWLEDGE + prompt = COMPILE_MERGE_KNOWLEDGE.format( + existing_content=str(cluster.content)[:3000], + new_summary=result.summary[:3000], + ) + resp = await self._llm.achat([{"role": "user", "content": prompt}]) + cluster.content = resp.content.strip() + + # Update metadata + cluster.search_results = list(set( + (cluster.search_results or []) + [result.path] + )) + merge_count = getattr(cluster, "merge_count", 0) or 0 + cluster.merge_count = merge_count + 1 + + # Lifecycle promotion + if cluster.merge_count >= 3 and cluster.lifecycle == Lifecycle.EMERGING: + cluster.lifecycle = Lifecycle.STABLE + + await self._storage.update(cluster) + + async def _create_cluster( + self, result: FileCompileResult, + ) -> Optional[KnowledgeCluster]: + """Create a new KnowledgeCluster from a file compile result.""" + cluster_text = result.summary + cluster_id = f"C{hashlib.sha256(cluster_text.encode('utf-8')).hexdigest()[:10]}" + + name = Path(result.path).stem[:60] + if result.topics: + name = result.topics[0][:60] + + cluster = KnowledgeCluster( + id=cluster_id, + name=name, + description=[result.summary[:500]], + content=result.summary, + evidences=[result.evidence] if result.evidence else [], + patterns=result.topics[:5], + lifecycle=Lifecycle.EMERGING, + confidence=0.5, + abstraction_level=AbstractionLevel.TECHNIQUE, + hotness=0.3, + search_results=[result.path], + ) + + ok = await self._storage.insert(cluster) + return cluster if ok else None + + # ------------------------------------------------------------------ # + # Cross-references # + # ------------------------------------------------------------------ # + + async def _build_cross_references( + self, results: List[FileCompileResult], + ) -> int: + """Build co-occurrence edges between clusters that share source files. + + Two clusters are co-occurring when the same source file contributed + evidence to both (e.g., different sections compiled into different + clusters). Includes historical data from the manifest. + """ + # Build a complete map: cluster_id -> set of source file paths + cluster_to_files: Dict[str, Set[str]] = {} + + # From current compile results + for r in results: + for cid in r.cluster_ids: + cluster_to_files.setdefault(cid, set()).add(r.path) + + # From manifest (historical data) + manifest = self._load_manifest() + for fp, entry in manifest.files.items(): + for cid in entry.cluster_ids: + cluster_to_files.setdefault(cid, set()).add(fp) + + # Find cluster pairs that share at least one source file + cluster_ids = list(cluster_to_files.keys()) + edges_created = 0 + pairs_seen: Set[Tuple[str, str]] = set() + + for i in range(len(cluster_ids)): + for j in range(i + 1, len(cluster_ids)): + cid_a, cid_b = cluster_ids[i], cluster_ids[j] + shared = cluster_to_files[cid_a] & cluster_to_files[cid_b] + if not shared: + continue + + pair_key = (min(cid_a, cid_b), max(cid_a, cid_b)) + if pair_key in pairs_seen: + continue + pairs_seen.add(pair_key) + + weight = min(len(shared) * 0.25, 1.0) + c_a = await self._storage.get(cid_a) + c_b = await self._storage.get(cid_b) + if c_a and c_b: + self._add_edge(c_a, cid_b, "co_occur", weight) + self._add_edge(c_b, cid_a, "co_occur", weight) + await self._storage.update(c_a) + await self._storage.update(c_b) + edges_created += 1 + + return edges_created + + @staticmethod + def _add_edge( + cluster: KnowledgeCluster, target_id: str, source: str, weight: float, + ) -> None: + """Add or update a WeakSemanticEdge on a cluster.""" + for edge in cluster.related_clusters: + if edge.target_cluster_id == target_id and edge.source == source: + edge.weight = max(edge.weight, weight) + return + cluster.related_clusters.append( + WeakSemanticEdge(target_cluster_id=target_id, weight=weight, source=source) + ) + + # ------------------------------------------------------------------ # + # Manifest I/O # + # ------------------------------------------------------------------ # + + def _load_manifest(self) -> CompileManifest: + if self._manifest_path.exists(): + try: + return CompileManifest.from_json( + self._manifest_path.read_text(encoding="utf-8") + ) + except Exception: + pass + return CompileManifest() + + def _save_manifest(self, manifest: CompileManifest) -> None: + self._manifest_path.write_text(manifest.to_json(), encoding="utf-8") diff --git a/src/sirchmunk/learnings/knowledge_base.py b/src/sirchmunk/learnings/knowledge_base.py index 387b368..bd2946c 100644 --- a/src/sirchmunk/learnings/knowledge_base.py +++ b/src/sirchmunk/learnings/knowledge_base.py @@ -120,11 +120,14 @@ async def _extract_evidence_for_file( confidence_threshold: float, top_k_snippets: int, verbose: bool, + tree_indexer=None, ) -> Optional[EvidenceUnit]: - """Extract evidence from a single file via Monte Carlo sampling. + """Extract evidence from a single file. - Performs text extraction followed by LLM-driven region-of-interest - identification. Designed to run concurrently for multiple files. + When a tree index exists for the file, uses LLM-driven tree navigation + to locate relevant sections precisely, then runs Monte Carlo sampling + within those narrowed regions. Falls back to full-document Monte + Carlo sampling otherwise. Args: file_path_or_url: Absolute path or URL to the document. @@ -133,6 +136,7 @@ async def _extract_evidence_for_file( confidence_threshold: Minimum confidence for evidence acceptance. top_k_snippets: Maximum evidence snippets per document. verbose: Whether to enable verbose logging. + tree_indexer: Optional DocumentTreeIndexer for tree-based navigation. Returns: EvidenceUnit on success, None on extraction failure. @@ -141,6 +145,28 @@ async def _extract_evidence_for_file( extraction_result = await fast_extract(file_path=file_path_or_url) doc_content: str = extraction_result.content + tree_path_ids = None + + # Try tree-based navigation for focused extraction + if tree_indexer is not None: + tree = tree_indexer.load_tree(file_path_or_url) + if tree is not None: + await self._log.info( + f"[KnowledgeBase] Using tree index for {Path(file_path_or_url).name}" + ) + leaves = await tree_indexer.navigate(tree, query) + if leaves: + # Narrow doc_content to matched regions + tree_path_ids = [n.node_id for n in leaves] + segments = [] + for node in leaves: + start, end = node.char_range + segment = doc_content[start:end] + if segment.strip(): + segments.append(segment) + if segments: + doc_content = "\n\n---\n\n".join(segments) + sampler = MonteCarloEvidenceSampling( llm=self.llm, doc_content=doc_content, @@ -162,6 +188,7 @@ async def _extract_evidence_for_file( snippets=roi_result.snippets, extracted_at=datetime.now(), conflict_group=[], + tree_path=tree_path_ids, ) self.llm_usages.extend(sampler.llm_usages) return evidence_unit diff --git a/src/sirchmunk/learnings/lint.py b/src/sirchmunk/learnings/lint.py new file mode 100644 index 0000000..e5baa6f --- /dev/null +++ b/src/sirchmunk/learnings/lint.py @@ -0,0 +1,213 @@ +# Copyright (c) ModelScope Contributors. All rights reserved. +""" +Knowledge lint — health checks and auto-fixes for the knowledge network. + +Inspired by LLM Wiki's Lint operation: validates cluster integrity, +detects stale evidence, and cleans orphaned tree indices. +""" + +import json +from dataclasses import dataclass, field +from pathlib import Path +from typing import Any, Dict, List, Optional, Set, Union + +from sirchmunk.schema.knowledge import KnowledgeCluster, Lifecycle +from sirchmunk.storage.knowledge_storage import KnowledgeStorage +from sirchmunk.utils import LogCallback, create_logger + + +@dataclass +class LintIssue: + """A single lint finding.""" + + severity: str # "error", "warning", "info" + category: str # "stale_evidence", "orphan_tree", "empty_cluster", etc. + message: str + cluster_id: Optional[str] = None + file_path: Optional[str] = None + auto_fixed: bool = False + + def to_dict(self) -> Dict[str, Any]: + return { + "severity": self.severity, + "category": self.category, + "message": self.message, + "cluster_id": self.cluster_id, + "file_path": self.file_path, + "auto_fixed": self.auto_fixed, + } + + +@dataclass +class LintReport: + """Summary of a lint run.""" + + total_clusters_checked: int = 0 + total_trees_checked: int = 0 + issues: List[LintIssue] = field(default_factory=list) + auto_fixes_applied: int = 0 + + @property + def errors(self) -> int: + return sum(1 for i in self.issues if i.severity == "error") + + @property + def warnings(self) -> int: + return sum(1 for i in self.issues if i.severity == "warning") + + def to_dict(self) -> Dict[str, Any]: + return { + "total_clusters_checked": self.total_clusters_checked, + "total_trees_checked": self.total_trees_checked, + "errors": self.errors, + "warnings": self.warnings, + "auto_fixes_applied": self.auto_fixes_applied, + "issues": [i.to_dict() for i in self.issues], + } + + +class KnowledgeLint: + """Validate the health of the knowledge network and apply auto-fixes.""" + + def __init__( + self, + knowledge_storage: KnowledgeStorage, + work_path: Union[str, Path], + log_callback: LogCallback = None, + ): + self._storage = knowledge_storage + self._work_path = Path(work_path).expanduser().resolve() + self._tree_dir = self._work_path / ".cache" / "compile" / "trees" + self._manifest_path = self._work_path / ".cache" / "compile" / "manifest.json" + self._log = create_logger(log_callback=log_callback) + + async def run(self, *, auto_fix: bool = False) -> LintReport: + """Execute all lint checks and optionally apply auto-fixes.""" + report = LintReport() + + await self._log.info("[Lint] Starting knowledge health check") + + # Check clusters + await self._check_clusters(report, auto_fix=auto_fix) + + # Check orphaned tree caches + await self._check_orphan_trees(report, auto_fix=auto_fix) + + # Check manifest consistency + await self._check_manifest(report) + + await self._log.info( + f"[Lint] Done — clusters={report.total_clusters_checked}, " + f"trees={report.total_trees_checked}, " + f"errors={report.errors}, warnings={report.warnings}, " + f"fixes={report.auto_fixes_applied}" + ) + return report + + async def _check_clusters(self, report: LintReport, auto_fix: bool) -> None: + """Validate each knowledge cluster.""" + all_clusters = await self._storage.find("", limit=10000) + report.total_clusters_checked = len(all_clusters) + + for cluster in all_clusters: + # Check: empty content + if not cluster.content or ( + isinstance(cluster.content, str) and len(cluster.content.strip()) < 10 + ): + report.issues.append(LintIssue( + severity="warning", + category="empty_cluster", + message=f"Cluster has empty or minimal content", + cluster_id=cluster.id, + )) + + # Check: stale evidence (source files no longer exist) + stale_count = 0 + for ev in cluster.evidences: + fp = str(ev.file_or_url) + if fp.startswith("/") and not Path(fp).exists(): + stale_count += 1 + + if stale_count > 0: + report.issues.append(LintIssue( + severity="warning", + category="stale_evidence", + message=f"{stale_count} evidence source(s) no longer exist", + cluster_id=cluster.id, + )) + + if auto_fix and stale_count == len(cluster.evidences): + cluster.lifecycle = Lifecycle.DEPRECATED + await self._storage.update(cluster) + report.auto_fixes_applied += 1 + report.issues[-1].auto_fixed = True + + # Check: no queries and no evidences (orphan cluster) + if not cluster.evidences and not cluster.queries: + report.issues.append(LintIssue( + severity="info", + category="orphan_cluster", + message="Cluster has no evidence and no queries", + cluster_id=cluster.id, + )) + + # Check: isolated cluster (no WeakSemanticEdge connections) + if not cluster.related_clusters and cluster.evidences: + report.issues.append(LintIssue( + severity="info", + category="isolated_cluster", + message="Cluster has no cross-references to other clusters", + cluster_id=cluster.id, + )) + + async def _check_orphan_trees(self, report: LintReport, auto_fix: bool) -> None: + """Find tree cache files whose source documents no longer exist.""" + if not self._tree_dir.exists(): + return + + manifest = self._load_manifest() + # Build set of valid file hashes from the manifest + valid_hashes: Set[str] = set() + for entry_data in manifest.get("files", {}).values(): + fh = entry_data.get("file_hash", "") + if fh: + valid_hashes.add(fh) + + tree_files = list(self._tree_dir.glob("*.json")) + report.total_trees_checked = len(tree_files) + + for tf in tree_files: + tree_hash = tf.stem + if tree_hash not in valid_hashes: + report.issues.append(LintIssue( + severity="info", + category="orphan_tree", + message=f"Tree cache has no matching manifest entry", + file_path=str(tf), + )) + if auto_fix: + tf.unlink(missing_ok=True) + report.auto_fixes_applied += 1 + report.issues[-1].auto_fixed = True + + async def _check_manifest(self, report: LintReport) -> None: + """Validate manifest references.""" + manifest = self._load_manifest() + files = manifest.get("files", {}) + + for fp, entry_data in files.items(): + if not Path(fp).exists(): + report.issues.append(LintIssue( + severity="warning", + category="stale_manifest", + message=f"Manifest references non-existent file", + file_path=fp, + )) + + def _load_manifest(self) -> Dict[str, Any]: + if self._manifest_path.exists(): + try: + return json.loads(self._manifest_path.read_text(encoding="utf-8")) + except Exception: + pass + return {} diff --git a/src/sirchmunk/learnings/tree_indexer.py b/src/sirchmunk/learnings/tree_indexer.py new file mode 100644 index 0000000..53ebf0b --- /dev/null +++ b/src/sirchmunk/learnings/tree_indexer.py @@ -0,0 +1,444 @@ +# Copyright (c) ModelScope Contributors. All rights reserved. +""" +Document tree indexer — PageIndex-inspired hierarchical structure analysis. + +Builds a JSON tree index for structured long documents (PDF, DOCX, MD, HTML) +so that downstream search can navigate via LLM reasoning instead of brute-force +Monte Carlo sampling. +""" + +import json +import re +from dataclasses import dataclass, field +from datetime import datetime, timezone +from pathlib import Path +from typing import Any, Dict, List, Optional, Tuple, Union + +from sirchmunk.llm.openai_chat import OpenAIChat +from sirchmunk.utils import LogCallback, create_logger +from sirchmunk.utils.file_utils import get_fast_hash + +# File-size threshold: skip tree indexing for small files +_TREE_MIN_CHARS = 50_000 # 50 K characters + +# Extensions eligible for tree indexing +_TREE_EXTENSIONS = { + ".pdf", ".docx", ".doc", ".md", ".markdown", + ".html", ".htm", ".rst", ".tex", ".txt", +} + + +# --------------------------------------------------------------------------- +# Data structures +# --------------------------------------------------------------------------- + +@dataclass +class TreeNode: + """Single node in the document tree.""" + + node_id: str + title: str + summary: str + char_range: Tuple[int, int] # [start, end) in the extracted text + level: int = 0 + page_range: Optional[Tuple[int, int]] = None + children: List["TreeNode"] = field(default_factory=list) + + def to_dict(self) -> Dict[str, Any]: + return { + "node_id": self.node_id, + "title": self.title, + "summary": self.summary, + "char_range": list(self.char_range), + "level": self.level, + "page_range": list(self.page_range) if self.page_range else None, + "children": [c.to_dict() for c in self.children], + } + + @classmethod + def from_dict(cls, data: Dict[str, Any]) -> "TreeNode": + children = [cls.from_dict(c) for c in data.get("children", [])] + pr = data.get("page_range") + return cls( + node_id=data["node_id"], + title=data["title"], + summary=data["summary"], + char_range=tuple(data["char_range"]), + level=data.get("level", 0), + page_range=tuple(pr) if pr else None, + children=children, + ) + + @property + def leaf(self) -> bool: + return len(self.children) == 0 + + def all_leaves(self) -> List["TreeNode"]: + """Return all leaf nodes under this subtree.""" + if self.leaf: + return [self] + leaves: List["TreeNode"] = [] + for c in self.children: + leaves.extend(c.all_leaves()) + return leaves + + +@dataclass +class DocumentTree: + """Complete tree index for a single document.""" + + file_path: str + file_hash: str + created_at: str + total_chars: int + total_pages: Optional[int] = None + root: Optional[TreeNode] = None + + def to_json(self) -> str: + return json.dumps({ + "file_path": self.file_path, + "file_hash": self.file_hash, + "created_at": self.created_at, + "total_chars": self.total_chars, + "total_pages": self.total_pages, + "root": self.root.to_dict() if self.root else None, + }, ensure_ascii=False, indent=2) + + @classmethod + def from_json(cls, json_str: str) -> "DocumentTree": + data = json.loads(json_str) + root = TreeNode.from_dict(data["root"]) if data.get("root") else None + return cls( + file_path=data["file_path"], + file_hash=data["file_hash"], + created_at=data["created_at"], + total_chars=data["total_chars"], + total_pages=data.get("total_pages"), + root=root, + ) + + +# --------------------------------------------------------------------------- +# Indexer +# --------------------------------------------------------------------------- + +class DocumentTreeIndexer: + """Build and cache PageIndex-style hierarchical tree indices for documents.""" + + def __init__( + self, + llm: OpenAIChat, + cache_dir: Union[str, Path], + log_callback: LogCallback = None, + ): + self._llm = llm + self._cache_dir = Path(cache_dir) + self._cache_dir.mkdir(parents=True, exist_ok=True) + self._log = create_logger(log_callback=log_callback) + + # ------------------------------------------------------------------ # + # Public API # + # ------------------------------------------------------------------ # + + async def build_tree( + self, + file_path: str, + content: str, + *, + max_depth: int = 4, + force_rebuild: bool = False, + total_pages: Optional[int] = None, + ) -> Optional[DocumentTree]: + """Build a tree index for a document. + + Returns None when the document is too small or unstructured. + """ + file_hash = get_fast_hash(file_path) + if file_hash is None: + return None + + if not force_rebuild: + cached = self._load_cache(file_hash) + if cached is not None: + await self._log.info(f"[TreeIndexer] Cache hit for {Path(file_path).name}") + return cached + + if len(content) < _TREE_MIN_CHARS: + return None + + ext = Path(file_path).suffix.lower() + if ext not in _TREE_EXTENSIONS: + return None + + await self._log.info( + f"[TreeIndexer] Building tree for {Path(file_path).name} " + f"({len(content)} chars, depth={max_depth})" + ) + + root = await self._build_node(content, level=0, max_depth=max_depth) + if root is None: + return None + + tree = DocumentTree( + file_path=file_path, + file_hash=file_hash, + created_at=datetime.now(timezone.utc).isoformat(), + total_chars=len(content), + total_pages=total_pages, + root=root, + ) + self._save_cache(file_hash, tree) + await self._log.info( + f"[TreeIndexer] Built tree: {self._count_nodes(root)} nodes, " + f"depth={self._max_node_depth(root)}" + ) + return tree + + async def navigate( + self, + tree: DocumentTree, + query: str, + *, + max_results: int = 3, + ) -> List[TreeNode]: + """Reasoning-based tree navigation: LLM selects the most relevant branches. + + Returns up to *max_results* leaf nodes with their char_range for + precise evidence extraction. + """ + if tree.root is None: + return [] + + candidates = tree.root.children if tree.root.children else [tree.root] + if not candidates: + return [tree.root] + + selected = await self._select_children(candidates, query) + if not selected: + return [] + + result_leaves: List[TreeNode] = [] + for node in selected: + if node.leaf: + result_leaves.append(node) + else: + deeper = await self._select_children(node.children, query) + for d in (deeper or node.children[:1]): + result_leaves.extend(d.all_leaves()[:max_results]) + + # Deduplicate and cap + seen_ids = set() + unique: List[TreeNode] = [] + for n in result_leaves: + if n.node_id not in seen_ids: + seen_ids.add(n.node_id) + unique.append(n) + return unique[:max_results] + + def load_tree(self, file_path: str) -> Optional[DocumentTree]: + """Load a cached tree index for the given file (sync).""" + file_hash = get_fast_hash(file_path) + if file_hash is None: + return None + return self._load_cache(file_hash) + + def has_tree(self, file_path: str) -> bool: + """Check whether a cached tree index exists for the file.""" + file_hash = get_fast_hash(file_path) + if file_hash is None: + return False + return self._cache_path(file_hash).exists() + + # ------------------------------------------------------------------ # + # Internals # + # ------------------------------------------------------------------ # + + async def _build_node( + self, text: str, level: int, max_depth: int, + offset: int = 0, + ) -> Optional[TreeNode]: + """Recursively build tree nodes via LLM structure analysis.""" + from sirchmunk.llm.prompts import COMPILE_TREE_STRUCTURE + + preview = text[:12000] if len(text) > 12000 else text + prompt = COMPILE_TREE_STRUCTURE.format( + document_content=preview, + max_sections=8, + ) + + resp = await self._llm.achat([{"role": "user", "content": prompt}]) + sections = self._parse_sections(resp.content, text) + + if not sections: + return TreeNode( + node_id=f"N{offset:06d}", + title="Document", + summary=text[:300], + char_range=(offset, offset + len(text)), + level=level, + ) + + children: List[TreeNode] = [] + for i, sec in enumerate(sections): + child = TreeNode( + node_id=f"N{sec['start'] + offset:06d}", + title=sec["title"], + summary=sec["summary"], + char_range=(sec["start"] + offset, sec["end"] + offset), + level=level + 1, + ) + section_text = text[sec["start"]:sec["end"]] + if level + 1 < max_depth and len(section_text) > _TREE_MIN_CHARS: + deeper = await self._build_node( + section_text, level + 1, max_depth, offset=sec["start"] + offset, + ) + if deeper and deeper.children: + child.children = deeper.children + children.append(child) + + root_summary = await self._synthesize_root_summary(children) + + return TreeNode( + node_id=f"N{offset:06d}", + title="Document", + summary=root_summary, + char_range=(offset, offset + len(text)), + level=level, + children=children, + ) + + async def _synthesize_root_summary(self, children: List[TreeNode]) -> str: + """Synthesize a document-level summary from children's section summaries.""" + if not children: + return "" + from sirchmunk.llm.prompts import COMPILE_SYNTHESIZE_SUMMARY + sections_text = "\n".join( + f"- {c.title}: {c.summary}" for c in children + ) + prompt = COMPILE_SYNTHESIZE_SUMMARY.format(sections=sections_text) + resp = await self._llm.achat([{"role": "user", "content": prompt}]) + return resp.content.strip() + + def _parse_sections( + self, llm_output: str, full_text: str, + ) -> List[Dict[str, Any]]: + """Parse LLM section output into [{title, summary, start, end}, ...].""" + # Try JSON array first + try: + raw = llm_output + # Strip markdown fences + raw = re.sub(r"^```(?:json)?\s*", "", raw, flags=re.MULTILINE) + raw = re.sub(r"```\s*$", "", raw, flags=re.MULTILINE).strip() + m = re.search(r"\[.*\]", raw, re.DOTALL) + if m: + items = json.loads(m.group()) + return self._resolve_positions(items, full_text) + except (json.JSONDecodeError, TypeError): + pass + return [] + + @staticmethod + def _resolve_positions( + items: List[Dict[str, Any]], full_text: str, + ) -> List[Dict[str, Any]]: + """Resolve section start/end character offsets from marker text.""" + resolved: List[Dict[str, Any]] = [] + prev_end = 0 + text_lower = full_text.lower() + for item in items: + title = item.get("title", "") + summary = item.get("summary", "") + marker = item.get("start_marker", title) + + pos = text_lower.find(marker.lower(), prev_end) if marker else -1 + start = pos if pos >= 0 else prev_end + + end_marker = item.get("end_marker", "") + if end_marker: + epos = text_lower.find(end_marker.lower(), start + 1) + end = epos if epos > start else min(start + 50000, len(full_text)) + else: + end = min(start + 50000, len(full_text)) + + resolved.append({ + "title": title, + "summary": summary, + "start": start, + "end": end, + }) + prev_end = end + + # Fix gaps: each section ends where the next begins + for i in range(len(resolved) - 1): + resolved[i]["end"] = resolved[i + 1]["start"] + if resolved: + resolved[-1]["end"] = len(full_text) + + return [s for s in resolved if s["end"] > s["start"]] + + async def _select_children( + self, nodes: List[TreeNode], query: str, + ) -> List[TreeNode]: + """LLM-driven branch selection: pick the most relevant children.""" + if len(nodes) <= 2: + return nodes + + listing = "\n".join( + f"[{i}] {n.title}: {n.summary[:150]}" + for i, n in enumerate(nodes) + ) + prompt = ( + f"Given the query: \"{query}\"\n\n" + f"Select the 1-2 most relevant sections (by index number):\n{listing}\n\n" + f"Return ONLY a JSON array of index numbers, e.g. [0, 2]" + ) + resp = await self._llm.achat([{"role": "user", "content": prompt}]) + try: + raw = resp.content.strip() + m = re.search(r"\[[\d\s,]+\]", raw) + if m: + indices = json.loads(m.group()) + return [nodes[i] for i in indices if 0 <= i < len(nodes)] + except (json.JSONDecodeError, IndexError, TypeError): + pass + return nodes[:2] + + # ------------------------------------------------------------------ # + # Cache I/O # + # ------------------------------------------------------------------ # + + def _cache_path(self, file_hash: str) -> Path: + return self._cache_dir / f"{file_hash}.json" + + def _save_cache(self, file_hash: str, tree: DocumentTree) -> None: + path = self._cache_path(file_hash) + path.write_text(tree.to_json(), encoding="utf-8") + + def _load_cache(self, file_hash: str) -> Optional[DocumentTree]: + path = self._cache_path(file_hash) + if not path.exists(): + return None + try: + return DocumentTree.from_json(path.read_text(encoding="utf-8")) + except Exception: + return None + + # ------------------------------------------------------------------ # + # Helpers # + # ------------------------------------------------------------------ # + + @staticmethod + def _count_nodes(node: TreeNode) -> int: + return 1 + sum(DocumentTreeIndexer._count_nodes(c) for c in node.children) + + @staticmethod + def _max_node_depth(node: TreeNode) -> int: + if not node.children: + return node.level + return max(DocumentTreeIndexer._max_node_depth(c) for c in node.children) + + @staticmethod + def should_build_tree(file_path: str, content_length: int) -> bool: + """Determine whether a file is eligible for tree indexing.""" + ext = Path(file_path).suffix.lower() + return ext in _TREE_EXTENSIONS and content_length >= _TREE_MIN_CHARS diff --git a/src/sirchmunk/llm/prompts.py b/src/sirchmunk/llm/prompts.py index 1a07e64..b3ded32 100644 --- a/src/sirchmunk/llm/prompts.py +++ b/src/sirchmunk/llm/prompts.py @@ -423,3 +423,90 @@ def generate_keyword_extraction_prompt(num_levels: int = 3) -> str: true/false true/false """ + + +# --------------------------------------------------------------------------- +# Knowledge Compile prompts +# --------------------------------------------------------------------------- + +COMPILE_TREE_STRUCTURE = """Analyze the following document and identify its natural hierarchical structure (chapters, sections, subsections). + +### Document Content (may be truncated) +{document_content} + +### Output Requirements +Return a JSON array of top-level sections. Each section object must have: +- "title": Section heading or descriptive title +- "summary": 1-2 sentence summary of the section content +- "start_marker": A short text string (5-15 words) that appears verbatim at the start of this section in the document +- "end_marker": A short text string that appears at the start of the NEXT section (empty for the last section) + +Maximum {max_sections} sections. Identify only the most significant structural boundaries. + +### Output Format +Return ONLY a JSON array, no extra text: +[ + {{"title": "...", "summary": "...", "start_marker": "...", "end_marker": "..."}}, + ... +] +""" + + +COMPILE_SYNTHESIZE_SUMMARY = """Synthesize a comprehensive document summary from the following section summaries. + +### Section Summaries +{sections} + +### Output +Provide a unified, coherent summary in 3-8 sentences that captures the document's overall topic, key arguments, and conclusions. Do not simply list the sections — weave them into a natural narrative. +Write in the same language as the section summaries.""" + + +COMPILE_DOC_SUMMARY = """Summarize the following document concisely, capturing the key topics, arguments, conclusions, and important details. + +### File: {file_name} + +### Document Content (may be truncated) +{document_content} + +### Output +Provide a comprehensive summary in 3-8 sentences. Focus on: +1. What is this document about (main topic/purpose) +2. Key findings, arguments, or conclusions +3. Important details, data points, or methodologies + +Write the summary in the same language as the document content.""" + + +COMPILE_TOPIC_EXTRACTION = """Extract the 3-5 most important topics, concepts, or entities from the following summary. + +### Summary +{summary} + +### Output +Return ONLY a JSON array of topic strings, no extra text: +["topic1", "topic2", "topic3"] + +Rules: +- Each topic should be 1-4 words +- Prefer specific, domain-relevant terms over generic ones +- Use the same language as the summary""" + + +COMPILE_MERGE_KNOWLEDGE = """You are merging new information into an existing knowledge cluster. + +### Existing Knowledge +{existing_content} + +### New Information +{new_summary} + +### Task +Produce an updated, unified summary that: +1. Preserves all important information from the existing knowledge +2. Integrates the new information, avoiding redundancy +3. Highlights any contradictions or complementary perspectives +4. Maintains a coherent, well-structured narrative + +### Output +Return ONLY the merged summary text (no extra tags or metadata). Keep the same language as the inputs.""" diff --git a/src/sirchmunk/schema/knowledge.py b/src/sirchmunk/schema/knowledge.py index 336963d..2a6e149 100644 --- a/src/sirchmunk/schema/knowledge.py +++ b/src/sirchmunk/schema/knowledge.py @@ -57,6 +57,12 @@ class EvidenceUnit: # IDs of conflict group if this evidence contradicts others conflict_group: Optional[List[str]] = None + # Tree-index node path from root to the matched node (e.g. ["N000000", "N001234"]) + tree_path: Optional[List[str]] = None + + # Character range within the document for precise evidence location + page_range: Optional[List[int]] = None + def to_dict(self) -> Dict[str, Any]: """ Serialize EvidenceUnit to a dictionary. @@ -69,6 +75,8 @@ def to_dict(self) -> Dict[str, Any]: "snippets": self.snippets, "extracted_at": self.extracted_at.isoformat(), "conflict_group": self.conflict_group, + "tree_path": self.tree_path, + "page_range": self.page_range, } @@ -234,6 +242,9 @@ class KnowledgeCluster: # Used for semantic similarity matching and cluster reuse queries: List[str] = None + # Number of times this cluster has been merged with new evidence during compile + merge_count: int = 0 + def __post_init__(self): if self.related_clusters is None: self.related_clusters = [] @@ -391,5 +402,6 @@ def to_dict(self) -> Dict[str, Any]: "related_clusters": [rc.to_dict() for rc in self.related_clusters], "search_results": self.search_results, "queries": self.queries, + "merge_count": self.merge_count, } diff --git a/src/sirchmunk/search.py b/src/sirchmunk/search.py index aee1c16..52f9650 100644 --- a/src/sirchmunk/search.py +++ b/src/sirchmunk/search.py @@ -919,6 +919,119 @@ def _ensure_tool_registry( self._tool_registry_key = cache_key return registry + # ------------------------------------------------------------------ + # Knowledge compile entry point + # ------------------------------------------------------------------ + + async def compile( + self, + paths: Optional[Union[str, Path, List[str], List[Path]]] = None, + *, + incremental: bool = True, + shallow: bool = False, + max_files: Optional[int] = None, + concurrency: int = 3, + ) -> Dict[str, Any]: + """Compile document collections into structured knowledge indices. + + Optional offline pre-processing step that builds tree indices and + knowledge clusters. Products are automatically leveraged by + subsequent search() calls. + + Args: + paths: Directories or files to compile. Falls back to self.paths. + incremental: Skip unchanged files (default True). + shallow: Skip tree building — use direct LLM summarisation only. + max_files: Cap on files — triggers importance sampling for large sets. + concurrency: Max parallel file compilations. + + Returns: + CompileReport as a dict. + """ + from sirchmunk.learnings.compiler import KnowledgeCompiler + from sirchmunk.learnings.tree_indexer import DocumentTreeIndexer + + resolved = self._resolve_paths(paths) + await self._logger.info( + f"[Compile] Starting compile for {len(resolved)} path(s)" + ) + + tree_cache = self.work_path / ".cache" / "compile" / "trees" + _cb = getattr(self._logger, 'log_callback', None) + tree_indexer = DocumentTreeIndexer( + llm=self.llm, + cache_dir=tree_cache, + log_callback=_cb, + ) + + compiler = KnowledgeCompiler( + llm=self.llm, + embedding_client=self.embedding_client, + knowledge_storage=self.knowledge_storage, + tree_indexer=tree_indexer, + work_path=self.work_path, + log_callback=_cb, + ) + + report = await compiler.compile( + paths=resolved, + incremental=incremental, + shallow=shallow, + max_files=max_files, + concurrency=concurrency, + ) + + return report.to_dict() + + async def compile_status( + self, + paths: Optional[Union[str, Path, List[str], List[Path]]] = None, + ) -> Dict[str, Any]: + """Return current compile status for the given paths.""" + from sirchmunk.learnings.compiler import KnowledgeCompiler + from sirchmunk.learnings.tree_indexer import DocumentTreeIndexer + + resolved = self._resolve_paths(paths) + + tree_cache = self.work_path / ".cache" / "compile" / "trees" + tree_indexer = DocumentTreeIndexer( + llm=self.llm, cache_dir=tree_cache, + ) + + compiler = KnowledgeCompiler( + llm=self.llm, + embedding_client=self.embedding_client, + knowledge_storage=self.knowledge_storage, + tree_indexer=tree_indexer, + work_path=self.work_path, + ) + + status = await compiler.get_status(resolved) + return { + "total_compiled_files": status.total_compiled_files, + "total_clusters": status.total_clusters, + "total_trees": status.total_trees, + "last_compile_at": status.last_compile_at, + "manifest_path": status.manifest_path, + } + + async def compile_lint( + self, + *, + auto_fix: bool = False, + ) -> Dict[str, Any]: + """Run knowledge health checks and optionally auto-fix issues.""" + from sirchmunk.learnings.lint import KnowledgeLint + + linter = KnowledgeLint( + knowledge_storage=self.knowledge_storage, + work_path=self.work_path, + log_callback=getattr(self._logger, 'log_callback', None), + ) + + report = await linter.run(auto_fix=auto_fix) + return report.to_dict() + # ------------------------------------------------------------------ # Unified search entry point # ------------------------------------------------------------------ diff --git a/src/sirchmunk/storage/knowledge_storage.py b/src/sirchmunk/storage/knowledge_storage.py index 0a99168..e62c1cf 100644 --- a/src/sirchmunk/storage/knowledge_storage.py +++ b/src/sirchmunk/storage/knowledge_storage.py @@ -107,6 +107,10 @@ def _load_from_parquet(self): variable-length ``FLOAT[]`` from Parquet's list encoding, breaking ``list_cosine_similarity`` which requires matching fixed-size types. + Handles schema evolution gracefully: if the parquet file has fewer + columns than the current schema (e.g., missing ``merge_count``), + missing columns are filled with defaults instead of failing. + Also records the file's modification time so that ``_check_and_reload()`` can detect external changes later. """ @@ -117,11 +121,38 @@ def _load_from_parquet(self): self.db.drop_table(self.table_name, if_exists=True) # Create table with explicit schema (preserves FLOAT[384]) self._create_table() - # Insert data from parquet — DuckDB casts to the declared types - self.db.execute( - f"INSERT INTO {self.table_name} " - f"SELECT * FROM read_parquet('{self.parquet_file}')" - ) + # Detect parquet columns to handle schema evolution + try: + pq_cols = self.db.fetch_all( + f"SELECT column_name FROM parquet_schema('{self.parquet_file}')" + ) + pq_col_names = {row[0] for row in pq_cols} + except Exception: + pq_col_names = None + + if pq_col_names is not None: + # Build column-by-column SELECT with defaults for missing cols + schema_cols = list(self._get_schema_columns()) + select_parts = [] + for col_name in schema_cols: + if col_name in pq_col_names: + select_parts.append(col_name) + elif col_name == "merge_count": + select_parts.append("0 AS merge_count") + else: + select_parts.append(f"NULL AS {col_name}") + select_clause = ", ".join(select_parts) + self.db.execute( + f"INSERT INTO {self.table_name} " + f"SELECT {select_clause} FROM read_parquet('{self.parquet_file}')" + ) + else: + # Fallback: try direct SELECT * (works when schemas match) + self.db.execute( + f"INSERT INTO {self.table_name} " + f"SELECT * FROM read_parquet('{self.parquet_file}')" + ) + count = self.db.get_table_count(self.table_name) # Record mtime for stale-detection self._parquet_loaded_mtime = pq.stat().st_mtime @@ -138,6 +169,18 @@ def _load_from_parquet(self): self._create_table() self._parquet_loaded_mtime = 0.0 + def _get_schema_columns(self) -> List[str]: + """Return the ordered list of column names in the canonical schema.""" + return [ + "id", "name", "description", "content", "scripts", "resources", + "evidences", "patterns", "constraints", "confidence", + "abstraction_level", "landmark_potential", "hotness", "lifecycle", + "create_time", "last_modified", "version", "related_clusters", + "search_results", "queries", "merge_count", + "embedding_vector", "embedding_model", "embedding_timestamp", + "embedding_text_hash", + ] + def _check_and_reload(self): """Check if the parquet file was modified externally and reload if so. @@ -190,6 +233,7 @@ def _create_table(self): "related_clusters": "VARCHAR", # JSON array "search_results": "VARCHAR", # JSON array "queries": "VARCHAR", # JSON array of historical queries + "merge_count": "INTEGER", # compile merge counter "embedding_vector": "FLOAT[384]", # 384-dim embedding vector "embedding_model": "VARCHAR", # Model identifier "embedding_timestamp": "TIMESTAMP", # Embedding computation time @@ -338,21 +382,22 @@ def _cluster_to_row(self, cluster: KnowledgeCluster) -> Dict[str, Any]: "related_clusters": json.dumps([rc.to_dict() for rc in cluster.related_clusters]), "search_results": json.dumps(cluster.search_results) if cluster.search_results else None, "queries": json.dumps(cluster.queries) if cluster.queries else None, + "merge_count": cluster.merge_count or 0, } def _row_to_cluster(self, row: tuple) -> KnowledgeCluster: """ Convert database row to KnowledgeCluster. - Expected row structure (24 columns): + Expected row structure (25 columns): id, name, description, content, scripts, resources, evidences, patterns, constraints, confidence, abstraction_level, landmark_potential, hotness, lifecycle, create_time, last_modified, version, related_clusters, search_results, queries, - embedding_vector, embedding_model, embedding_timestamp, embedding_text_hash + merge_count, embedding_vector, embedding_model, embedding_timestamp, embedding_text_hash """ - if len(row) != 24: + if len(row) != 25: raise ValueError( - f"Expected 24 columns in knowledge_clusters row, got {len(row)}. " + f"Expected 25 columns in knowledge_clusters row, got {len(row)}. " f"Please ensure the table schema is up to date." ) @@ -361,6 +406,7 @@ def _row_to_cluster(self, row: tuple) -> KnowledgeCluster: id, name, description, content, scripts, resources, evidences, patterns, constraints, confidence, abstraction_level, landmark_potential, hotness, lifecycle, create_time, last_modified, version, related_clusters, search_results, queries, + merge_count, _embedding_vector, _embedding_model, _embedding_timestamp, _embedding_text_hash ) = row @@ -400,7 +446,9 @@ def _row_to_cluster(self, row: tuple) -> KnowledgeCluster: is_found=ev_dict["is_found"], snippets=ev_dict["snippets"], extracted_at=extracted_at_parsed or datetime.now(), - conflict_group=ev_dict.get("conflict_group") + conflict_group=ev_dict.get("conflict_group"), + tree_path=ev_dict.get("tree_path"), + page_range=ev_dict.get("page_range"), )) # Parse constraints @@ -463,6 +511,7 @@ def _row_to_cluster(self, row: tuple) -> KnowledgeCluster: related_clusters=related_clusters_parsed, search_results=search_results_parsed, queries=queries_parsed, + merge_count=merge_count or 0, ) # ------------------------------------------------------------------ # From c4f4b166df34f5e4258eefce9fae687ab1dae82d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E7=8F=AD=E6=89=AC?= Date: Mon, 13 Apr 2026 20:07:35 +0800 Subject: [PATCH 03/56] improve compile infer --- src/sirchmunk/learnings/README.md | 30 ++ src/sirchmunk/learnings/knowledge_base.py | 4 + src/sirchmunk/search.py | 556 ++++++++++++++++++++-- 3 files changed, 563 insertions(+), 27 deletions(-) diff --git a/src/sirchmunk/learnings/README.md b/src/sirchmunk/learnings/README.md index 0fc1bbe..92bc22b 100644 --- a/src/sirchmunk/learnings/README.md +++ b/src/sirchmunk/learnings/README.md @@ -37,6 +37,36 @@ The module fuses insights from three frameworks: Compile products are automatically leveraged by search when present, but search functions independently without them. +### How Search Consumes Compile Products + +``` +Compile products Search consumption path +───────────────── ────────────────────────────────────────────── +KnowledgeCluster ─┬─ FAST + DEEP Phase 0: embedding similarity + .content │ reuse (instant short-circuit, no LLM cost) + .embedding │ → enriched with evidence snippets + .evidences[].file_or_url │ + ├─ DEEP Phase 1: _probe_knowledge_cache() + │ fuzzy text search → file path discovery + │ +WeakSemanticEdge ├─ DEEP Phase 1: one-hop graph expansion + .related_clusters │ follows edges to gather neighbour files + │ +DocumentTree (.json) └─ DEEP Phase 3: tree-navigated evidence + via tree_indexer _build_cluster() → knowledge_base.build() + → _extract_evidence_for_file(tree_indexer) + → narrows doc to relevant sections before + Monte Carlo sampling +``` + +| Compile product | FAST | DEEP | +|-----------------|------|------| +| Cluster embedding reuse | Yes | Yes | +| Evidence snippets in reused content | Yes | Yes | +| Fuzzy cluster → file path hints | — | Yes | +| Graph edge expansion (neighbours) | — | Yes | +| Tree-navigated evidence extraction | — | Yes | + --- ## Components diff --git a/src/sirchmunk/learnings/knowledge_base.py b/src/sirchmunk/learnings/knowledge_base.py index bd2946c..7296f71 100644 --- a/src/sirchmunk/learnings/knowledge_base.py +++ b/src/sirchmunk/learnings/knowledge_base.py @@ -208,6 +208,7 @@ async def build( top_k_snippets: Optional[int] = 5, confidence_threshold: Optional[float] = 8.0, verbose: bool = True, + tree_indexer=None, ) -> Union[KnowledgeCluster, None]: """Build a knowledge cluster from retrieved information and metadata. @@ -223,6 +224,8 @@ async def build( top_k_snippets: Max evidence snippets per file. confidence_threshold: Min confidence for evidence acceptance. verbose: Enable verbose logging. + tree_indexer: Optional DocumentTreeIndexer for tree-navigated + evidence extraction (uses compiled tree indices when available). Returns: KnowledgeCluster on success, None if no evidence found. @@ -250,6 +253,7 @@ async def build( confidence_threshold=confidence_threshold, top_k_snippets=top_k_snippets, verbose=verbose, + tree_indexer=tree_indexer, ) for info in retrieved_infos ] diff --git a/src/sirchmunk/search.py b/src/sirchmunk/search.py index 52f9650..63d3ba8 100644 --- a/src/sirchmunk/search.py +++ b/src/sirchmunk/search.py @@ -8,6 +8,7 @@ import os import re import traceback +from dataclasses import dataclass, field from datetime import datetime, timezone from pathlib import Path from typing import Any, Dict, List, Literal, Optional, Tuple, Union @@ -81,6 +82,43 @@ _NO_RESULTS_MESSAGE = "No results found." +# Soft-similarity threshold for gradient cluster reuse (P2) +_SOFT_SIM_THRESHOLD = 0.65 + + +@dataclass +class SoftClusterHit: + """Signals from clusters that are related but below the hard reuse threshold. + + Carries structured hints (keywords, file paths, background context) that + downstream retrieval phases can exploit without short-circuiting the search. + """ + + patterns: List[str] + file_paths: List[str] + context_summary: str + cluster_ids: List[str] + + +@dataclass +class KnowledgeProbeResult: + """Rich result from knowledge cache probing (P3). + + Replaces the flat ``List[str]`` that ``_probe_knowledge_cache`` used to return. + """ + + file_paths: List[str] + extra_keywords: List[str] + background_context: str + + +@dataclass +class CompileHints: + """Zero-LLM hints gathered from compile manifest and tree cache (P4).""" + + file_paths: List[str] + extra_keywords: List[str] + class AgenticSearch(BaseSearch): @@ -460,6 +498,72 @@ async def _try_reuse_cluster(self, query: str, paths: Optional[List[str]] = None ) return None + async def _try_soft_reuse( + self, query: str, paths: Optional[List[str]] = None, + ) -> Optional[SoftClusterHit]: + """Gradient reuse: extract structured hints from moderately similar clusters. + + Called when ``_try_reuse_cluster`` misses (similarity < hard threshold). + Uses a softer threshold to find clusters that are *related* but not + close enough for full reuse. Returns patterns, file paths, and a + background context summary that downstream phases can exploit. + """ + if not self.embedding_client or not self.embedding_client.is_ready(): + return None + + try: + query_embedding = (await self.embedding_client.embed([query]))[0] + similar = await self.knowledge_storage.search_similar_clusters( + query_embedding=query_embedding, + top_k=5, + similarity_threshold=_SOFT_SIM_THRESHOLD, + search_paths=paths, + ) + if not similar: + return None + + patterns: List[str] = [] + file_paths: List[str] = [] + context_parts: List[str] = [] + cluster_ids: List[str] = [] + seen_paths: set = set() + + for match in similar: + cid = match["id"] + cluster_ids.append(cid) + c = await self.knowledge_storage.get(cid) + if not c: + continue + for p in getattr(c, "patterns", []) or []: + if p and p not in patterns: + patterns.append(p) + for ev in getattr(c, "evidences", []): + fp = str(getattr(ev, "file_or_url", "")) + if fp and fp not in seen_paths and Path(fp).exists(): + seen_paths.add(fp) + file_paths.append(fp) + content = c.content + if isinstance(content, list): + content = "\n".join(content) + if content: + context_parts.append(str(content)[:500]) + + if not patterns and not file_paths: + return None + + await self._logger.info( + f"[SoftReuse] {len(similar)} soft hits: " + f"{len(patterns)} patterns, {len(file_paths)} files" + ) + return SoftClusterHit( + patterns=patterns[:10], + file_paths=file_paths[:10], + context_summary="\n\n".join(context_parts[:3]), + cluster_ids=cluster_ids, + ) + except Exception: + return None + def _add_query_to_cluster(self, cluster: KnowledgeCluster, query: str) -> None: """ Add query to cluster's queries list with FIFO strategy. @@ -478,6 +582,36 @@ def _add_query_to_cluster(self, cluster: KnowledgeCluster, query: str) -> None: # Remove oldest queries (from the beginning) cluster.queries = cluster.queries[-self.max_queries_per_cluster:] + @staticmethod + def _enrich_reused_content(cluster: KnowledgeCluster) -> str: + """Build the answer text from a reused cluster. + + When the cluster carries compiled evidence with non-empty snippets + (populated during ``sirchmunk compile``), appends them as supporting + excerpts so the user sees both the summary and the underlying source + material. + """ + content = cluster.content + if isinstance(content, list): + content = "\n".join(content) + content = str(content or "") + + evidence_parts: List[str] = [] + for ev in getattr(cluster, "evidences", []): + snippets = getattr(ev, "snippets", None) + if not snippets: + continue + source = str(getattr(ev, "file_or_url", "unknown")) + for snip in snippets: + text = snip if isinstance(snip, str) else snip.get("snippet", "") + if text and text.strip(): + evidence_parts.append(f"[{Path(source).name}] {text.strip()}") + + if evidence_parts: + content += "\n\n---\nSupporting evidence:\n" + "\n\n".join(evidence_parts[:5]) + + return content + async def _save_cluster_with_embedding(self, cluster: KnowledgeCluster) -> None: """Save knowledge cluster to persistent storage, compute embedding, and flush to parquet. @@ -1256,17 +1390,17 @@ async def _search_deep( # ============================================================== reused = await self._try_reuse_cluster(query, paths) if reused is not None: - content = reused.content - if isinstance(content, list): - content = "\n".join(content) - return str(content), reused, context + return self._enrich_reused_content(reused), reused, context + + # P2: gradient reuse — extract hints from moderately similar clusters + soft_hit = await self._try_soft_reuse(query, paths) await self._logger.info(f"[search] Starting multi-path retrieval for: '{query[:80]}'") # ============================================================== - # Phase 1: Parallel probing — all four paths fire concurrently + # Phase 1: Parallel probing — five paths fire concurrently # ============================================================== - await self._logger.info("[Phase 1] Parallel probing: keywords + dir_scan + knowledge + spec_cache") + await self._logger.info("[Phase 1] Parallel probing: keywords + dir_scan + knowledge + spec_cache + tree_index") context.increment_loop() phase1_results = await asyncio.gather( @@ -1274,24 +1408,53 @@ async def _search_deep( self._probe_dir_scan(paths, enable_dir_scan), self._probe_knowledge_cache(query), self._load_spec_context(paths, stale_hours=spec_stale_hours), + self._probe_tree_index(query), return_exceptions=True, ) kw_result = phase1_results[0] if not isinstance(phase1_results[0], Exception) else ({}, []) scan_result = phase1_results[1] if not isinstance(phase1_results[1], Exception) else None - knowledge_hits = phase1_results[2] if not isinstance(phase1_results[2], Exception) else [] + knowledge_probe = phase1_results[2] if not isinstance(phase1_results[2], Exception) else KnowledgeProbeResult([], [], "") spec_context = phase1_results[3] if not isinstance(phase1_results[3], Exception) else "" + tree_hits = phase1_results[4] if not isinstance(phase1_results[4], Exception) else [] - for i, label in enumerate(["keywords", "dir_scan", "knowledge", "spec_cache"]): + for i, label in enumerate(["keywords", "dir_scan", "knowledge", "spec_cache", "tree_index"]): if isinstance(phase1_results[i], Exception): await self._logger.warning(f"[Phase 1] {label} probe failed: {phase1_results[i]}") + # Backwards compat: knowledge_probe may be a plain list from old code paths + if isinstance(knowledge_probe, list): + knowledge_probe = KnowledgeProbeResult(file_paths=knowledge_probe, extra_keywords=[], background_context="") + query_keywords, initial_keywords = kw_result if isinstance(kw_result, tuple) else ({}, []) + # P2: inject soft-hit patterns into keywords + if soft_hit: + for p in soft_hit.patterns: + if p not in initial_keywords: + initial_keywords.append(p) + if p not in query_keywords: + query_keywords[p] = 0.6 + + # P3: inject extra keywords from structured knowledge probe + for kw in knowledge_probe.extra_keywords: + if kw not in initial_keywords: + initial_keywords.append(kw) + if kw not in query_keywords: + query_keywords[kw] = 0.5 + + # P2 + P3: append background context for Phase 4 LLM prompt + if soft_hit and soft_hit.context_summary: + spec_context = f"{spec_context}\n\n{soft_hit.context_summary}" if spec_context else soft_hit.context_summary + if knowledge_probe.background_context: + spec_context = f"{spec_context}\n\n{knowledge_probe.background_context}" if spec_context else knowledge_probe.background_context + await self._logger.info( f"[Phase 1] Results: keywords={len(initial_keywords)}, " f"dir_scan={'OK' if scan_result else 'N/A'}, " - f"knowledge_hits={len(knowledge_hits)}, " + f"knowledge_files={len(knowledge_probe.file_paths)}, " + f"tree_hits={len(tree_hits)}, " + f"soft_hit={'YES' if soft_hit else 'NO'}, " f"spec_cache={'YES' if spec_context else 'NO'}" ) @@ -1336,12 +1499,16 @@ async def _search_deep( # ============================================================== # Phase 3: Merge file paths + build KnowledgeCluster + # P1 tree hits get highest priority; P2 soft-hit files next # ============================================================== context.increment_loop() + extra_knowledge_files = knowledge_probe.file_paths + if soft_hit: + extra_knowledge_files = soft_hit.file_paths + extra_knowledge_files merged_files = self._merge_file_paths( - keyword_files=keyword_files, + keyword_files=list(tree_hits) + keyword_files, dir_scan_files=dir_scan_files, - knowledge_hits=knowledge_hits, + knowledge_hits=extra_knowledge_files, ) await self._logger.info(f"[Phase 3] Merged {len(merged_files)} unique candidate files") @@ -1352,6 +1519,19 @@ async def _search_deep( query_keywords=query_keywords, top_k_files=top_k_files, ) + # ============================================================== + # Phase 3.5: Graph context enrichment (P5) + # Append related knowledge from graph neighbours to cluster content + # so the answer-generation LLM has richer context. + # ============================================================== + graph_ctx = "" + if cluster: + graph_ctx = await self._gather_graph_context(cluster) + if graph_ctx and cluster.content: + if isinstance(cluster.content, list): + cluster.content = "\n".join(cluster.content) + cluster.content = f"{cluster.content}\n\n{graph_ctx}" + # ============================================================== # Phase 4: Generate answer — cluster summary or ReAct refinement # ============================================================== @@ -1383,9 +1563,11 @@ async def _search_deep( answer, should_save = await self._summarise_cluster_fallback(query) else: await self._logger.info("[Phase 4] Evidence insufficient, launching ReAct refinement") + # P5: enrich ReAct context with graph knowledge + react_spec = f"{spec_context}\n\n{graph_ctx}" if graph_ctx else spec_context react_answer, context = await self._react_refinement( query=query, paths=paths, - initial_keywords=initial_keywords, spec_context=spec_context, + initial_keywords=initial_keywords, spec_context=react_spec, enable_dir_scan=enable_dir_scan, max_loops=max_loops, max_token_budget=max_token_budget, max_depth=max_depth, include=include, exclude=exclude, @@ -1751,11 +1933,11 @@ async def _search_fast( # ============================================================== reused = await self._try_reuse_cluster(query, paths) if reused is not None: - content = reused.content - if isinstance(content, list): - content = "\n".join(content) await self._logger.success("[FAST] Reused cached knowledge cluster") - return str(content), reused, context + return self._enrich_reused_content(reused), reused, context + + # P2: gradient reuse — structured hints from moderately similar clusters + soft_hit = await self._try_soft_reuse(query, paths) # ============================================================== # Step 1: LLM query analysis only (dir scan deferred until needed) @@ -1833,6 +2015,38 @@ async def _search_fast( msg = f"Could not extract search terms from query: '{query}'" return msg, None, context + # ============================================================== + # Step 1.5: Compile-aware enrichment (P2 + P4, zero LLM calls) + # ============================================================== + all_kw_set = set(primary + fallback) + + # P2: inject soft-hit patterns as fallback keywords + if soft_hit: + for p in soft_hit.patterns: + if p not in all_kw_set: + fallback.append(p) + all_kw_set.add(p) + keyword_idfs.setdefault(p, 0.6) + + # P4: compile hints from manifest + tree cache + compile_hints = await self._probe_compile_hints(primary + fallback) + for kw in compile_hints.extra_keywords: + if kw not in all_kw_set: + fallback.append(kw) + all_kw_set.add(kw) + keyword_idfs.setdefault(kw, 0.5) + + compile_hint_files: List[str] = [] + if soft_hit: + compile_hint_files.extend(soft_hit.file_paths) + compile_hint_files.extend(compile_hints.file_paths) + + if compile_hint_files: + await self._logger.info( + f"[FAST:Step1.5] Compile hints: {len(compile_hint_files)} files, " + f"{len(compile_hints.extra_keywords)} extra keywords" + ) + await self._logger.info( f"[FAST:Step1] Primary: {primary}, Fallback: {fallback}" ) @@ -1870,6 +2084,17 @@ async def _search_fast( fallback, top_k=top_k_files, keyword_idfs=keyword_idfs, **rga_kwargs ) + # --- Fallback: compile-hint files when rga misses (P2+P4) --- + if not best_files and compile_hint_files: + used_level = "compile_hint" + await self._logger.info( + f"[FAST:Step2] rga miss — using {len(compile_hint_files)} compile-hint files" + ) + best_files = [ + {"path": p, "matches": [], "total_matches": 0, "weighted_score": 0.0} + for p in compile_hint_files[:top_k_files] + ] + # --- Fallback: use dir_scan only when rga misses and dir scan is enabled --- if not best_files and enable_dir_scan: scan_result = await self._probe_dir_scan(paths, enable=True, max_files=300) @@ -2592,32 +2817,251 @@ async def _probe_dir_scan( async def _probe_knowledge_cache( self, query: str, - ) -> List[str]: - """Search knowledge cache for related clusters, return known file paths. + ) -> KnowledgeProbeResult: + """Structured knowledge probe: embedding search with graph expansion. - Returns: - List of file paths from previously cached clusters. + Uses embedding similarity (threshold 0.50) when available, falling back + to SQL LIKE. Extracts file paths, topic keywords, and background + context from matched clusters and their graph neighbours. """ + empty = KnowledgeProbeResult([], [], "") try: - clusters = await self.knowledge_storage.find(query, limit=3) + clusters: List[KnowledgeCluster] = [] + + # Prefer embedding search for semantic quality + if self.embedding_client and self.embedding_client.is_ready(): + try: + qe = (await self.embedding_client.embed([query]))[0] + similar = await self.knowledge_storage.search_similar_clusters( + query_embedding=qe, top_k=5, similarity_threshold=0.50, + ) + for m in (similar or []): + c = await self.knowledge_storage.get(m["id"]) + if c: + clusters.append(c) + except Exception: + pass + + # Fallback to SQL LIKE when embedding unavailable or empty if not clusters: - return [] + clusters = await self.knowledge_storage.find(query, limit=3) + if not clusters: + return empty + + seen_paths: set = set() file_paths: List[str] = [] - for c in clusters: + extra_keywords: List[str] = [] + context_parts: List[str] = [] + seen_kw: set = set() + + def _collect_cluster(c: KnowledgeCluster) -> None: for ev in getattr(c, "evidences", []): fp = str(getattr(ev, "file_or_url", "")) - if fp and Path(fp).exists(): + if fp and fp not in seen_paths and Path(fp).exists(): + seen_paths.add(fp) file_paths.append(fp) + for p in getattr(c, "patterns", []) or []: + if p and p.lower() not in seen_kw: + seen_kw.add(p.lower()) + extra_keywords.append(p) + content = c.content + if isinstance(content, list): + content = "\n".join(content) + if content: + context_parts.append(str(content)[:500]) + + for c in clusters: + _collect_cluster(c) + + # One-hop graph expansion via WeakSemanticEdge + neighbour_ids: set = set() + for c in clusters: + for edge in getattr(c, "related_clusters", []): + tid = getattr(edge, "target_cluster_id", None) + if tid and tid not in neighbour_ids: + neighbour_ids.add(tid) + + for nid in list(neighbour_ids)[:6]: + try: + neighbour = await self.knowledge_storage.get(nid) + if neighbour: + _collect_cluster(neighbour) + except Exception: + pass if file_paths: await self._logger.info( - f"[Probe:Knowledge] Found {len(file_paths)} files from cached clusters" + f"[Probe:Knowledge] {len(file_paths)} files, " + f"{len(extra_keywords)} keywords from " + f"{len(clusters)} clusters + {len(neighbour_ids)} neighbours" + ) + + return KnowledgeProbeResult( + file_paths=file_paths, + extra_keywords=extra_keywords[:15], + background_context="\n\n".join(context_parts[:3]), + ) + except Exception: + return empty + + async def _probe_tree_index(self, query: str) -> List[str]: + """LLM-driven file discovery via compiled tree root summaries (PageIndex). + + Loads all cached document trees, presents their root summaries to the + LLM, and asks it to select the most relevant 1-3 documents. For + selected trees, optionally drills one level deeper into children. + + Returns file paths of the most relevant documents. + """ + tree_cache = self.work_path / ".cache" / "compile" / "trees" + if not tree_cache.exists(): + return [] + + try: + from sirchmunk.learnings.tree_indexer import DocumentTree + + trees: List[DocumentTree] = [] + for tree_file in sorted(tree_cache.glob("*.json"))[:50]: + try: + t = DocumentTree.from_json( + tree_file.read_text(encoding="utf-8") + ) + if t.root and t.file_path: + trees.append(t) + except Exception: + continue + + if not trees: + return [] + + # If few trees, return all without LLM + if len(trees) <= 2: + return [t.file_path for t in trees if Path(t.file_path).exists()] + + # LLM-driven selection among tree roots + listing = "\n".join( + f"[{i}] {Path(t.file_path).name}: {(t.root.summary or '')[:200]}" + for i, t in enumerate(trees) + ) + prompt = ( + f'Given the query: "{query}"\n\n' + f"Select the 1-3 most relevant documents (by index number):\n{listing}\n\n" + f"Return ONLY a JSON array of index numbers, e.g. [0, 2]" + ) + resp = await self.llm.achat([{"role": "user", "content": prompt}]) + self.llm_usages.append(resp.usage) + + selected_indices: List[int] = [] + try: + raw = resp.content.strip() + m = re.search(r"\[[\d\s,]+\]", raw) + if m: + selected_indices = [ + idx for idx in json.loads(m.group()) + if isinstance(idx, int) and 0 <= idx < len(trees) + ] + except (json.JSONDecodeError, TypeError): + pass + + if not selected_indices: + selected_indices = list(range(min(2, len(trees)))) + + result_paths: List[str] = [] + for idx in selected_indices: + fp = trees[idx].file_path + if Path(fp).exists(): + result_paths.append(fp) + + if result_paths: + await self._logger.info( + f"[Probe:TreeIndex] LLM selected {len(result_paths)} documents " + f"from {len(trees)} tree indices" ) - return file_paths + return result_paths except Exception: return [] + async def _probe_compile_hints(self, keywords: List[str]) -> CompileHints: + """Zero-LLM enrichment from compile manifest and tree cache. + + Scans the compile manifest for clusters whose patterns overlap with + the query keywords, and scans cached tree root summaries for keyword + matches. No LLM calls — only local JSON reads and in-memory DB lookups. + """ + empty = CompileHints([], []) + if not keywords: + return empty + + kw_lower = {k.lower() for k in keywords} + file_paths: List[str] = [] + extra_keywords: List[str] = [] + seen_paths: set = set() + seen_kw: set = set(kw_lower) + + # --- Cluster pattern matching via manifest --- + manifest_path = self.work_path / ".cache" / "compile" / "manifest.json" + if manifest_path.exists(): + try: + from sirchmunk.learnings.compiler import CompileManifest + manifest = CompileManifest.from_json( + manifest_path.read_text(encoding="utf-8") + ) + cluster_ids: set = set() + for entry in manifest.files.values(): + cluster_ids.update(entry.cluster_ids) + + for cid in list(cluster_ids)[:50]: + try: + c = await self.knowledge_storage.get(cid) + except Exception: + continue + if not c: + continue + cluster_patterns = [ + p.lower() for p in (getattr(c, "patterns", []) or []) if p + ] + if kw_lower & set(cluster_patterns): + for ev in getattr(c, "evidences", []): + fp = str(getattr(ev, "file_or_url", "")) + if fp and fp not in seen_paths and Path(fp).exists(): + seen_paths.add(fp) + file_paths.append(fp) + for p in cluster_patterns: + if p not in seen_kw: + seen_kw.add(p) + extra_keywords.append(p) + except Exception: + pass + + # --- Tree root summary scanning (keyword substring match) --- + tree_cache = self.work_path / ".cache" / "compile" / "trees" + if tree_cache.exists(): + try: + from sirchmunk.learnings.tree_indexer import DocumentTree + for tree_file in sorted(tree_cache.glob("*.json"))[:100]: + try: + tree = DocumentTree.from_json( + tree_file.read_text(encoding="utf-8") + ) + except Exception: + continue + if not tree.root or not tree.file_path: + continue + summary_lower = (tree.root.summary or "").lower() + if any(kw in summary_lower for kw in kw_lower): + fp = tree.file_path + if fp not in seen_paths and Path(fp).exists(): + seen_paths.add(fp) + file_paths.append(fp) + except Exception: + pass + + return CompileHints( + file_paths=file_paths[:15], + extra_keywords=extra_keywords[:10], + ) + @staticmethod async def _async_noop(default=None): """No-op coroutine used as placeholder in gather().""" @@ -2744,6 +3188,20 @@ def _merge_file_paths( return merged + def _get_tree_indexer(self): + """Lazily construct a DocumentTreeIndexer for search-time tree navigation.""" + from sirchmunk.learnings.tree_indexer import DocumentTreeIndexer + + tree_cache = self.work_path / ".cache" / "compile" / "trees" + if not tree_cache.exists(): + return None + _cb = getattr(self._logger, 'log_callback', None) + return DocumentTreeIndexer( + llm=self.llm, + cache_dir=tree_cache, + log_callback=_cb, + ) + async def _build_cluster( self, query: str, @@ -2755,7 +3213,9 @@ async def _build_cluster( """Build a KnowledgeCluster via knowledge_base.build(). Constructs the Request wrapper and delegates to the knowledge - base for parallel Monte Carlo evidence sampling. + base for parallel Monte Carlo evidence sampling. When compiled + tree indices exist, passes a ``tree_indexer`` so that evidence + extraction can navigate to relevant sections before sampling. """ try: request = Request( @@ -2775,6 +3235,7 @@ async def _build_cluster( top_k_files=top_k_files, top_k_snippets=top_k_snippets, verbose=self.verbose, + tree_indexer=self._get_tree_indexer(), ) self.llm_usages.extend(self.knowledge_base.llm_usages) self.knowledge_base.llm_usages.clear() @@ -2789,6 +3250,47 @@ async def _build_cluster( await self._logger.warning(f"[Phase 3] knowledge_base.build() failed: {exc}") return None + async def _gather_graph_context(self, cluster: KnowledgeCluster) -> str: + """Enrich answer context with knowledge from graph neighbours. + + Traverses the cluster's ``related_clusters`` edges (sorted by weight), + fetches the top neighbours, and returns a joined summary string that + can be appended to the cluster content before answer generation. + """ + edges = sorted( + getattr(cluster, "related_clusters", []) or [], + key=lambda e: getattr(e, "weight", 0), + reverse=True, + ) + if not edges: + return "" + + parts: List[str] = [] + for edge in edges[:3]: + tid = getattr(edge, "target_cluster_id", None) + if not tid: + continue + try: + neighbour = await self.knowledge_storage.get(tid) + except Exception: + continue + if not neighbour: + continue + content = neighbour.content + if isinstance(content, list): + content = "\n".join(content) + name = getattr(neighbour, "name", "") or "" + snippet = str(content or "")[:300] + if snippet: + parts.append(f"- {name}: {snippet}") + + if not parts: + return "" + await self._logger.info( + f"[Phase 3.5] Graph context: {len(parts)} neighbour summaries" + ) + return "Related knowledge:\n" + "\n".join(parts) + # ------------------------------------------------------------------ # Phase 4: Answer generation # ------------------------------------------------------------------ From 645847766859fb846c7b7ca899d529a07b1e904c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E7=8F=AD=E6=89=AC?= Date: Mon, 13 Apr 2026 21:00:33 +0800 Subject: [PATCH 04/56] improve search pipeline for compile mode --- src/sirchmunk/learnings/compiler.py | 59 ++++- src/sirchmunk/llm/prompts.py | 25 +++ src/sirchmunk/search.py | 336 +++++++++++++++++++++++----- 3 files changed, 366 insertions(+), 54 deletions(-) diff --git a/src/sirchmunk/learnings/compiler.py b/src/sirchmunk/learnings/compiler.py index 3c2b0da..4ccd5da 100644 --- a/src/sirchmunk/learnings/compiler.py +++ b/src/sirchmunk/learnings/compiler.py @@ -380,11 +380,14 @@ async def _bounded(entry: FileEntry) -> FileCompileResult: await self._log.info("[Compile] Phase 4: Building cross-references") report.cross_refs_built = await self._build_cross_references(results) - # Phase 5: persist manifest + # Phase 5: persist manifest + document catalog manifest.last_compile_at = datetime.now(timezone.utc).isoformat() self._save_manifest(manifest) self._storage.force_sync() + # Generate document catalog for search-time routing + self._build_document_catalog(manifest) + report.elapsed_seconds = time.monotonic() - t0 await self._log.info( f"[Compile] Done in {report.elapsed_seconds:.1f}s — " @@ -838,3 +841,57 @@ def _load_manifest(self) -> CompileManifest: def _save_manifest(self, manifest: CompileManifest) -> None: self._manifest_path.write_text(manifest.to_json(), encoding="utf-8") + + # ------------------------------------------------------------------ # + # Document catalog for search-time routing # + # ------------------------------------------------------------------ # + + def _build_document_catalog(self, manifest: CompileManifest) -> None: + """Generate a lightweight catalog mapping files to their tree root summaries. + + The catalog is consumed by FAST search to fuse query analysis with + LLM-driven document routing in a single prompt. Each entry carries + the filename and a truncated root summary (≤250 chars). + """ + tree_cache = self._compile_dir / "trees" + entries: List[Dict[str, str]] = [] + + for file_path, entry in manifest.files.items(): + summary = "" + if entry.has_tree and tree_cache.exists(): + tree_file = tree_cache / f"{entry.file_hash}.json" + if tree_file.exists(): + try: + tree = DocumentTree.from_json( + tree_file.read_text(encoding="utf-8"), + ) + if tree.root and tree.root.summary: + summary = tree.root.summary[:250] + except Exception: + pass + + if not summary: + # Fallback: use first cluster's description + for cid in entry.cluster_ids[:1]: + try: + import asyncio + loop = asyncio.get_event_loop() + if loop.is_running(): + break + c = loop.run_until_complete(self._storage.get(cid)) + if c and c.description: + summary = str(c.description[0])[:250] + except Exception: + break + + entries.append({ + "path": file_path, + "name": Path(file_path).name, + "summary": summary, + }) + + catalog_path = self._compile_dir / "document_catalog.json" + catalog_path.write_text( + json.dumps(entries, ensure_ascii=False, indent=2), + encoding="utf-8", + ) diff --git a/src/sirchmunk/llm/prompts.py b/src/sirchmunk/llm/prompts.py index b3ded32..7b07b1c 100644 --- a/src/sirchmunk/llm/prompts.py +++ b/src/sirchmunk/llm/prompts.py @@ -389,6 +389,31 @@ def generate_keyword_extraction_prompt(num_levels: int = 3) -> str: """ +FAST_QUERY_ANALYSIS_WITH_CATALOG = """Classify the user query, extract search terms, AND select the most relevant document(s) from the compiled index. + +### User Query +{user_input} + +### Compiled Document Index +{document_listing} + +### Output +Return JSON only, no extra text: +{{"type": "search", "primary": ["compound phrase"], "fallback": ["term1", "term2"], "idf": {{"compound phrase": 8.0, "term1": 2.5}}, "primary_alt": [], "fallback_alt": [], "file_hints": [], "intent": "...", "selected_docs": [0, 2], "doc_confidence": "high"}} + +Rules: +- **type**: "search" if the query requires retrieving information from files or documents; "chat" if it is a greeting, small talk, or conversational message — set primary/fallback to empty arrays, put a brief reply in "response". "summary" if the user wants to summarize entire documents. +- **primary**: 1 compound phrase (2-3 words) most likely to appear **verbatim** in the target document. +- **fallback**: 1-3 single-word atomic terms. Tried only if primary misses. +- **primary_alt / fallback_alt**: Cross-lingual equivalents (Chinese↔English). Only the most critical 1-2 terms. +- **file_hints**: filename fragments or glob patterns ONLY if clearly implied; empty array otherwise. +- **intent**: one sentence describing the query intent. +- **idf**: IDF weight (1.0-10.0) for EVERY keyword. Higher for rare terms. +- **selected_docs**: Index numbers (from the Compiled Document Index above) of the 1-3 most relevant documents for this query. Consider BOTH the filename and the summary. Choose documents whose content is most likely to answer the query. +- **doc_confidence**: "high" if you are very confident the selected documents contain the answer; "medium" if likely but uncertain; "low" if guessing. +""" + + ROI_RESULT_SUMMARY = """ ### Task Analyze the provided {text_content} and generate a concise summary in the form of a Markdown Briefing. diff --git a/src/sirchmunk/search.py b/src/sirchmunk/search.py index 63d3ba8..9d192da 100644 --- a/src/sirchmunk/search.py +++ b/src/sirchmunk/search.py @@ -20,6 +20,7 @@ KEYWORD_QUERY_PLACEHOLDER, generate_keyword_extraction_prompt, FAST_QUERY_ANALYSIS, + FAST_QUERY_ANALYSIS_WITH_CATALOG, ROI_RESULT_SUMMARY, SEARCH_RESULT_SUMMARY, DOC_SUMMARY, @@ -1940,9 +1941,26 @@ async def _search_fast( soft_hit = await self._try_soft_reuse(query, paths) # ============================================================== - # Step 1: LLM query analysis only (dir scan deferred until needed) + # Step 1: Fused LLM query analysis + document routing + # When a compiled document catalog exists, the LLM sees all + # document summaries and selects the most relevant ones in the + # same call that extracts keywords (zero extra LLM cost). # ============================================================== - prompt = FAST_QUERY_ANALYSIS.format(user_input=query) + catalog = self._load_document_catalog() + catalog_routed_files: List[str] = [] + catalog_confidence: str = "low" + + if catalog: + listing = "\n".join( + f"[{i}] {e['name']}: {e['summary'][:200]}" + for i, e in enumerate(catalog) + ) + prompt = FAST_QUERY_ANALYSIS_WITH_CATALOG.format( + user_input=query, document_listing=listing, + ) + else: + prompt = FAST_QUERY_ANALYSIS.format(user_input=query) + resp = await self.llm.achat( messages=[{"role": "user", "content": prompt}], stream=False, @@ -1957,6 +1975,21 @@ async def _search_fast( query_type = analysis.get("type", "search") file_hints = analysis.get("file_hints", []) + # Extract catalog-routed files from the fused response + if catalog: + selected_indices = analysis.get("selected_docs", []) + catalog_confidence = analysis.get("doc_confidence", "low") + for idx in selected_indices: + if isinstance(idx, int) and 0 <= idx < len(catalog): + fp = catalog[idx]["path"] + if Path(fp).exists(): + catalog_routed_files.append(fp) + if catalog_routed_files: + await self._logger.info( + f"[FAST:Step1] Catalog routing ({catalog_confidence}): " + f"{[Path(p).name for p in catalog_routed_files]}" + ) + if query_type == "chat": chat_reply = analysis.get("response", "") if chat_reply: @@ -2017,6 +2050,7 @@ async def _search_fast( # ============================================================== # Step 1.5: Compile-aware enrichment (P2 + P4, zero LLM calls) + # Catalog-routed files from the fused Step 1 are merged here. # ============================================================== all_kw_set = set(primary + fallback) @@ -2037,13 +2071,26 @@ async def _search_fast( keyword_idfs.setdefault(kw, 0.5) compile_hint_files: List[str] = [] + # Catalog-routed files get highest priority + seen_hint_paths: set = set() + for fp in catalog_routed_files: + if fp not in seen_hint_paths: + seen_hint_paths.add(fp) + compile_hint_files.append(fp) if soft_hit: - compile_hint_files.extend(soft_hit.file_paths) - compile_hint_files.extend(compile_hints.file_paths) + for fp in soft_hit.file_paths: + if fp not in seen_hint_paths: + seen_hint_paths.add(fp) + compile_hint_files.append(fp) + for fp in compile_hints.file_paths: + if fp not in seen_hint_paths: + seen_hint_paths.add(fp) + compile_hint_files.append(fp) if compile_hint_files: await self._logger.info( - f"[FAST:Step1.5] Compile hints: {len(compile_hint_files)} files, " + f"[FAST:Step1.5] Compile hints: {len(compile_hint_files)} files " + f"(catalog={len(catalog_routed_files)}, soft={len(soft_hit.file_paths) if soft_hit else 0}), " f"{len(compile_hints.extra_keywords)} extra keywords" ) @@ -2053,7 +2100,9 @@ async def _search_fast( # ============================================================== # Step 2: rga cascade — primary first, fallback only if needed - # Dir scan runs only when enabled, for fallback when rga misses. + # When catalog routing has high confidence, catalog-routed files + # are used directly (skipping rga) to avoid noise from unrelated + # files. Otherwise rga runs first and catalog acts as fallback. # ============================================================== context.add_search(query) include_patterns = list(include or []) @@ -2070,7 +2119,19 @@ async def _search_fast( used_level = "primary" evidence = "" - if primary: + # High-confidence catalog routing: skip rga, use catalog directly + if catalog_routed_files and catalog_confidence == "high": + used_level = "catalog_route" + await self._logger.info( + f"[FAST:Step2] High-confidence catalog routing → " + f"{[Path(p).name for p in catalog_routed_files[:top_k_files]]}" + ) + best_files = [ + {"path": p, "matches": [], "total_matches": 0, "weighted_score": 0.0} + for p in catalog_routed_files[:top_k_files] + ] + + if not best_files and primary: best_files = await self._fast_find_best_file( primary, top_k=top_k_files, keyword_idfs=keyword_idfs, **rga_kwargs ) @@ -2084,7 +2145,7 @@ async def _search_fast( fallback, top_k=top_k_files, keyword_idfs=keyword_idfs, **rga_kwargs ) - # --- Fallback: compile-hint files when rga misses (P2+P4) --- + # --- Fallback: compile-hint files when rga misses (catalog + P2 + P4) --- if not best_files and compile_hint_files: used_level = "compile_hint" await self._logger.info( @@ -2128,52 +2189,59 @@ async def _search_fast( ) # ============================================================== - # Step 3: Context sampling around grep hits (no LLM) - # Multi-file evidence aggregation + # Step 2.5 + Step 3: Tree navigation (1 LLM call) runs in + # parallel with rga evidence sampling (0 LLM). The merged + # result is higher quality than either alone. # ============================================================== - evidence_parts = [] - total_evidence_chars = 0 - for bf in best_files: - if total_evidence_chars >= self._FAST_MAX_EVIDENCE_CHARS: - break - - file_path = bf["path"] - fname = Path(file_path).name - ext = Path(file_path).suffix.lower() - - # Small file short-circuit: read full content instead of grep sampling - ev = None - if ext in self._FAST_TEXT_EXTENSIONS: - try: - file_size = Path(file_path).stat().st_size - if file_size < self._FAST_SMALL_FILE_THRESHOLD: - full_text = Path(file_path).read_text(errors="replace") - if len(full_text) < self._FAST_SMALL_FILE_THRESHOLD: - ev = f"[{fname}]\n{full_text}" - await self._logger.info( - f"[FAST] Small file short-circuit: reading full content of {fname} " - f"({len(full_text)} chars)" - ) - except Exception: - pass # Fall through to normal evidence extraction - - # Normal path: grep-based evidence sampling - if ev is None: - ev = await self._fast_sample_evidence(file_path, bf.get("matches", [])) - if ev: - remaining = self._FAST_MAX_EVIDENCE_CHARS - total_evidence_chars - chunk = ev[:remaining] - evidence_parts.append(chunk) - total_evidence_chars += len(chunk) - context.mark_file_read(file_path) - - evidence = "\n\n---\n\n".join(evidence_parts) + async def _rga_evidence() -> str: + """Collect rga-based evidence from best_files (zero LLM).""" + parts: List[str] = [] + chars = 0 + for bf in best_files: + if chars >= self._FAST_MAX_EVIDENCE_CHARS: + break + fp = bf["path"] + fn = Path(fp).name + ext = Path(fp).suffix.lower() + ev = None + if ext in self._FAST_TEXT_EXTENSIONS: + try: + sz = Path(fp).stat().st_size + if sz < self._FAST_SMALL_FILE_THRESHOLD: + full = Path(fp).read_text(errors="replace") + if len(full) < self._FAST_SMALL_FILE_THRESHOLD: + ev = f"[{fn}]\n{full}" + except Exception: + pass + if ev is None: + ev = await self._fast_sample_evidence(fp, bf.get("matches", [])) + if ev: + remaining = self._FAST_MAX_EVIDENCE_CHARS - chars + parts.append(ev[:remaining]) + chars += len(parts[-1]) + context.mark_file_read(fp) + return "\n\n---\n\n".join(parts) + + # Launch tree navigation for the primary file alongside rga + tree_nav_target = best_files[0]["path"] + rga_task = _rga_evidence() + tree_task = self._navigate_tree_for_evidence(tree_nav_target, query) + + rga_ev, tree_ev = await asyncio.gather(rga_task, tree_task) + + # Merge: tree evidence first (highest quality), then rga + evidence_parts_final: List[str] = [] + if tree_ev: + evidence_parts_final.append(tree_ev) + if rga_ev: + evidence_parts_final.append(rga_ev) + evidence = "\n\n---\n\n".join(evidence_parts_final) if not evidence or len(evidence.strip()) < 20: if llm_fallback: await self._logger.info( - "[FAST:Step3] No usable evidence, llm_fallback=True \u2192 LLM summary" + "[FAST:Step3] No usable evidence, llm_fallback=True → LLM summary" ) evidence = self._LLM_FALLBACK_EVIDENCE else: @@ -2181,7 +2249,8 @@ async def _search_fast( return _NO_RESULTS_MESSAGE, None, context await self._logger.info( - f"[FAST:Step3] Evidence: {len(evidence)} chars from {Path(file_path).name}" + f"[FAST:Step3] Evidence: {len(evidence)} chars " + f"(tree={'yes' if tree_ev else 'no'}, rga={'yes' if rga_ev else 'no'})" ) keywords_used = primary if used_level == "primary" else fallback @@ -2206,21 +2275,52 @@ async def _search_fast( answer, should_save, should_answer = self._parse_summary_response( answer_resp.content or "" ) + + # ============================================================== + # Step 5: Self-correction retry (conditional, ≤1 extra LLM call) + # When the answer gate rejects the first attempt, try alternative + # evidence sources before giving up. + # ============================================================== + if not should_answer: + retry_evidence = await self._fast_self_correct( + query, best_files, catalog_routed_files, context, + ) + if retry_evidence: + await self._logger.info( + f"[FAST:Step5] Retrying with {len(retry_evidence)} chars of alternative evidence" + ) + retry_prompt = ROI_RESULT_SUMMARY.format( + user_input=query, text_content=retry_evidence, + ) + retry_resp = await self.llm.achat( + messages=[{"role": "user", "content": retry_prompt}], + stream=True, + ) + self.llm_usages.append(retry_resp.usage) + if retry_resp.usage and isinstance(retry_resp.usage, dict): + context.add_llm_tokens( + retry_resp.usage.get("total_tokens", 0), usage=retry_resp.usage, + ) + answer, should_save, should_answer = self._parse_summary_response( + retry_resp.content or "" + ) + if not should_answer: if llm_fallback: await self._logger.info( - "[FAST:Step4] Summary gate rejected evidence, llm_fallback=True → LLM fallback" + "[FAST:Step5] Retry also rejected, llm_fallback=True → LLM fallback" ) answer, should_save = await self._summarise_fast_fallback(query, context) else: await self._logger.warning( - "[FAST:Step4] Summary gate rejected evidence and llm_fallback=False " + "[FAST:Step5] Evidence rejected after retry, llm_fallback=False " "→ returning no results" ) return _NO_RESULTS_MESSAGE, None, context + if not should_save: await self._logger.info("[FAST] Quality gate: low-quality answer, skipping cluster save") - await self._logger.success("[FAST] Search complete (2 LLM calls, no persist)") + await self._logger.success("[FAST] Search complete (no persist)") return answer, None, context cluster = self._build_fast_cluster( @@ -2234,7 +2334,7 @@ async def _search_fast( f"[FAST] Failed to save cluster with embedding: {exc}" ) - await self._logger.success("[FAST] Search complete (2 LLM calls)") + await self._logger.success("[FAST] Search complete") return answer, cluster, context # ---- FAST helpers ---- @@ -2634,6 +2734,136 @@ async def _fast_read_file_head( pass return "" + def _load_document_catalog(self) -> Optional[List[Dict[str, str]]]: + """Load the compiled document catalog for fused query+route prompt. + + Returns None when compile has not been run or catalog is missing. + """ + catalog_path = self.work_path / ".cache" / "compile" / "document_catalog.json" + if not catalog_path.exists(): + return None + try: + entries = json.loads(catalog_path.read_text(encoding="utf-8")) + if isinstance(entries, list) and entries: + return entries + except Exception: + pass + return None + + async def _navigate_tree_for_evidence( + self, file_path: str, query: str, + ) -> Optional[str]: + """LLM-driven tree navigation: select relevant sections and read leaf content. + + Uses 1 LLM call to drill into the compiled tree index for + *file_path*, returning concatenated leaf content as evidence. + Returns None when no tree cache is available. + """ + indexer = self._get_tree_indexer() + if indexer is None: + return None + tree = indexer.load_tree(file_path) + if tree is None or tree.root is None: + return None + + try: + leaves = await indexer.navigate(tree, query, max_results=3) + except Exception: + return None + + if not leaves: + return None + + fname = Path(file_path).name + # Read leaf content from the original document via char_range + parts: List[str] = [] + try: + from sirchmunk.utils.file_utils import fast_extract + extraction = await fast_extract(file_path=file_path) + full_text = extraction.content or "" + except Exception: + full_text = "" + + for leaf in leaves: + start, end = leaf.char_range + if full_text and end > start: + segment = full_text[start:end] + else: + segment = leaf.summary or "" + if segment.strip(): + header = f"[{fname} → {leaf.title}]" + parts.append(f"{header}\n{segment[:3000]}") + + if not parts: + return None + + evidence = "\n\n".join(parts) + await self._logger.info( + f"[FAST:TreeNav] Extracted {len(parts)} sections, " + f"{len(evidence)} chars from {fname}" + ) + return evidence + + async def _fast_self_correct( + self, + query: str, + best_files: Optional[List[Dict[str, Any]]], + catalog_routed_files: List[str], + context: SearchContext, + ) -> Optional[str]: + """Attempt to gather alternative evidence when the first answer is rejected. + + Three strategies tried in order: + A) Tree-navigate a 2nd catalog-routed file not yet tried. + B) Retrieve the most semantically similar compiled cluster's content. + C) Tree-navigate the 2nd-best rga file if available. + + Returns alternative evidence string, or None if all strategies fail. + """ + first_file = best_files[0]["path"] if best_files else "" + + # Strategy A: 2nd catalog-routed file via tree navigation + for fp in catalog_routed_files: + if fp == first_file: + continue + tree_ev = await self._navigate_tree_for_evidence(fp, query) + if tree_ev and len(tree_ev.strip()) > 50: + context.mark_file_read(fp) + return tree_ev + + # Strategy B: cluster content from knowledge storage + if self.embedding_client and self.knowledge_storage: + try: + qe = self.embedding_client.encode(query) + if qe is not None: + vec = qe.tolist() if hasattr(qe, "tolist") else list(qe) + hits = await self.knowledge_storage.search_similar_clusters( + query_embedding=vec, top_k=2, similarity_threshold=0.50, + ) + if hits: + parts: List[str] = [] + for h in hits[:2]: + c = await self.knowledge_storage.get(h["id"]) + if c and c.content: + parts.append(str(c.content)[:3000]) + for ev in (c.evidences or [])[:3]: + for s in (ev.snippets or [])[:2]: + parts.append(s[:500]) + if parts: + return "\n\n---\n\n".join(parts) + except Exception: + pass + + # Strategy C: 2nd rga file via tree navigation + if best_files and len(best_files) > 1: + fp2 = best_files[1]["path"] + tree_ev = await self._navigate_tree_for_evidence(fp2, query) + if tree_ev and len(tree_ev.strip()) > 50: + context.mark_file_read(fp2) + return tree_ev + + return None + @staticmethod def _parse_fast_json(text: str) -> Dict[str, Any]: """Extract JSON from the FAST query analysis LLM response.""" From 1f6f799fd1c3bcb357ae5cfb5d436b90c1c0f647 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E7=8F=AD=E6=89=AC?= Date: Tue, 14 Apr 2026 19:21:25 +0800 Subject: [PATCH 05/56] fix and enhance llm wiki and tree index for FAST search --- src/sirchmunk/llm/prompts.py | 39 ++++ src/sirchmunk/search.py | 394 +++++++++++++++++++++++++++++++++-- 2 files changed, 417 insertions(+), 16 deletions(-) diff --git a/src/sirchmunk/llm/prompts.py b/src/sirchmunk/llm/prompts.py index 7b07b1c..8df111d 100644 --- a/src/sirchmunk/llm/prompts.py +++ b/src/sirchmunk/llm/prompts.py @@ -449,6 +449,45 @@ def generate_keyword_extraction_prompt(num_levels: int = 3) -> str: true/false """ +ROI_RESULT_SUMMARY_WITH_CONTEXT = """ +### Task +Analyze the provided evidence and generate a concise summary in the form of a Markdown Briefing. +Leverage the document context below for better understanding of the source material's structure and purpose. + +### Constraints +1. **Language Continuity**: The output must be in the SAME language as the User Input. +2. **Format**: Use Markdown (headings, bullet points, and bold text) for high readability. +3. **Style**: Keep it professional, objective, and clear. Avoid fluff. + +### Document Context +{document_context} + +### Input Data +- **User Input**: {user_input} +- **Search Result Text**: {text_content} + +### Quality Evaluation +After generating the summary, make TWO decisions: +1) whether the query can be answered from the provided evidence; +2) whether this result is worth caching. + +Evaluate based on: +1. Does the search result contain substantial, relevant information for the user input? +2. Is the content meaningful and not just error messages or "no information found"? +3. Are there sufficient evidences and context to answer the user's query? + +- : output "true" only if the evidence is sufficient to answer the query. +- : output "true" only if the evidence is sufficient AND the result is worth caching. +- If evidence is insufficient or irrelevant, both SHOULD_ANSWER and SHOULD_SAVE MUST be "false". + +### Output Format + +[Generate the Markdown Briefing here] + +true/false +true/false +""" + # --------------------------------------------------------------------------- # Knowledge Compile prompts diff --git a/src/sirchmunk/search.py b/src/sirchmunk/search.py index 9d192da..2976f60 100644 --- a/src/sirchmunk/search.py +++ b/src/sirchmunk/search.py @@ -11,7 +11,7 @@ from dataclasses import dataclass, field from datetime import datetime, timezone from pathlib import Path -from typing import Any, Dict, List, Literal, Optional, Tuple, Union +from typing import Any, Dict, List, Literal, Optional, Set, Tuple, Union from sirchmunk.base import BaseSearch from sirchmunk.learnings.knowledge_base import KnowledgeBase @@ -121,6 +121,22 @@ class CompileHints: extra_keywords: List[str] +@dataclass +class CompileArtifacts: + """Compile artifact availability context for adaptive activation in FAST mode. + + Created once at the start of ``_search_fast()`` via + ``_detect_compile_artifacts()`` and threaded through all pipeline steps. + Each step checks the relevant field and falls back gracefully when the + artifact is absent. + """ + + catalog: Optional[List[Dict[str, str]]] + catalog_map: Dict[str, Dict[str, str]] # path -> catalog entry for O(1) lookup + tree_indexer: Optional[Any] # DocumentTreeIndexer (lazy import) + tree_available_paths: Set[str] # file paths that have cached tree indices + + class AgenticSearch(BaseSearch): def __init__( @@ -1893,6 +1909,32 @@ async def _llm_chunked_summarize(self, combined: str, query: str) -> str: _FAST_MAX_EVIDENCE_CHARS = 15_000 _FAST_SMALL_FILE_THRESHOLD = 100_000 # 100K chars - read full file instead of grep sampling + # --- Wiki-enhanced ranking constants --- + _WIKI_BLEND_ALPHA = 0.7 + """TF-IDF weight in the hybrid score; Wiki weight = 1 - alpha.""" + _WIKI_MAX_SCORE = 10.0 + """Upper bound for the wiki relevance score.""" + _WIKI_CATALOG_KEYWORD_OVERLAP_MAX = 5.0 + """Maximum sub-score for catalog summary keyword overlap.""" + _WIKI_TREE_AVAILABILITY_BONUS = 2.0 + """Bonus for files that have a compiled tree index.""" + _WIKI_CATALOG_PRESENCE_FULL = 3.0 + """Catalog presence bonus for summaries > 100 chars.""" + _WIKI_CATALOG_PRESENCE_MEDIUM = 2.0 + """Catalog presence bonus for summaries > 30 chars.""" + _WIKI_CATALOG_PRESENCE_MINIMAL = 1.0 + """Catalog presence bonus for summaries > 0 chars.""" + _TREE_CACHE_SCAN_LIMIT = 200 + """Max tree JSON files to parse during artifact detection.""" + _CATALOG_LISTING_MAX_ENTRIES = 20 + """Max catalog entries in the enriched listing for Step 1.""" + _CATALOG_KEYWORD_MIN_LEN = 2 + """Minimum character length for a catalog keyword token.""" + _CATALOG_KEYWORD_MAX_LEN = 20 + """Maximum character length for a catalog keyword token.""" + _CATALOG_SUMMARY_TRUNCATE = 200 + """Max chars of catalog summary shown in the listing.""" + _LLM_FALLBACK_EVIDENCE = ( "[No relevant documents found]\n\n" "The search did not find relevant content in the available documents. " @@ -1928,6 +1970,15 @@ async def _search_fast( context = SearchContext() await self._logger.info(f"[FAST] Starting greedy search for: '{query[:80]}'") + # --- Adaptive compile artifact detection (one-shot, zero LLM) --- + artifacts = self._detect_compile_artifacts() + if artifacts.catalog or artifacts.tree_available_paths: + await self._logger.info( + f"[FAST:Artifacts] catalog={'yes' if artifacts.catalog else 'no'} " + f"({len(artifacts.catalog) if artifacts.catalog else 0} docs), " + f"trees={len(artifacts.tree_available_paths)}" + ) + # ============================================================== # Step 0: Cluster reuse — instant short-circuit (no LLM cost) # When reuse succeeds we return here; no persistence step runs. @@ -1946,15 +1997,12 @@ async def _search_fast( # document summaries and selects the most relevant ones in the # same call that extracts keywords (zero extra LLM cost). # ============================================================== - catalog = self._load_document_catalog() + catalog = artifacts.catalog catalog_routed_files: List[str] = [] catalog_confidence: str = "low" if catalog: - listing = "\n".join( - f"[{i}] {e['name']}: {e['summary'][:200]}" - for i, e in enumerate(catalog) - ) + listing = self._build_enriched_catalog_listing(catalog) prompt = FAST_QUERY_ANALYSIS_WITH_CATALOG.format( user_input=query, document_listing=listing, ) @@ -2118,6 +2166,7 @@ async def _search_fast( best_files: Optional[List[Dict[str, Any]]] = None used_level = "primary" evidence = "" + file_path: Optional[str] = None # set when best_files found # High-confidence catalog routing: skip rga, use catalog directly if catalog_routed_files and catalog_confidence == "high": @@ -2133,7 +2182,9 @@ async def _search_fast( if not best_files and primary: best_files = await self._fast_find_best_file( - primary, top_k=top_k_files, keyword_idfs=keyword_idfs, **rga_kwargs + primary, top_k=top_k_files, keyword_idfs=keyword_idfs, + query=query, artifacts=artifacts, + **rga_kwargs, ) if not best_files and fallback: @@ -2142,7 +2193,9 @@ async def _search_fast( "[FAST:Step2] Primary miss, trying fine-grained fallback" ) best_files = await self._fast_find_best_file( - fallback, top_k=top_k_files, keyword_idfs=keyword_idfs, **rga_kwargs + fallback, top_k=top_k_files, keyword_idfs=keyword_idfs, + query=query, artifacts=artifacts, + **rga_kwargs, ) # --- Fallback: compile-hint files when rga misses (catalog + P2 + P4) --- @@ -2183,9 +2236,13 @@ async def _search_fast( if best_files: file_path = best_files[0]["path"] match_objects = best_files[0].get("matches", []) + wiki_info = "" + if best_files[0].get("wiki_relevance") is not None: + wiki_info = f", wiki={best_files[0]['wiki_relevance']:.1f}" await self._logger.info( f"[FAST:Step2] Best file ({used_level}): {Path(file_path).name} " - f"({best_files[0].get('total_matches', 0)} hits, score={best_files[0].get('weighted_score', 0):.2f})" + f"({best_files[0].get('total_matches', 0)} hits, " + f"score={best_files[0].get('weighted_score', 0):.2f}{wiki_info})" ) # ============================================================== @@ -2248,20 +2305,35 @@ async def _rga_evidence() -> str: await self._logger.warning("[FAST:Step3] No usable evidence extracted") return _NO_RESULTS_MESSAGE, None, context + tree_available = file_path in artifacts.tree_available_paths if artifacts else False await self._logger.info( f"[FAST:Step3] Evidence: {len(evidence)} chars " - f"(tree={'yes' if tree_ev else 'no'}, rga={'yes' if rga_ev else 'no'})" + f"(tree={'yes' if tree_ev else 'no'}, rga={'yes' if rga_ev else 'no'}, " + f"tree_indexed={'yes' if tree_available else 'no'})" ) keywords_used = primary if used_level == "primary" else fallback # ============================================================== # Step 4: LLM answer from focused evidence (single call) + # Wiki-enhanced: inject document context when catalog available. # ============================================================== - answer_prompt = ROI_RESULT_SUMMARY.format( - user_input=query, - text_content=evidence, - ) + doc_context = self._build_answer_context(file_path, artifacts) if best_files else None + if doc_context: + from sirchmunk.llm.prompts import ROI_RESULT_SUMMARY_WITH_CONTEXT + answer_prompt = ROI_RESULT_SUMMARY_WITH_CONTEXT.format( + user_input=query, + text_content=evidence, + document_context=doc_context, + ) + await self._logger.info( + f"[FAST:Step4] Wiki-enhanced answer generation with catalog context" + ) + else: + answer_prompt = ROI_RESULT_SUMMARY.format( + user_input=query, + text_content=evidence, + ) answer_resp = await self.llm.achat( messages=[{"role": "user", "content": answer_prompt}], stream=True, @@ -2324,7 +2396,7 @@ async def _rga_evidence() -> str: return answer, None, context cluster = self._build_fast_cluster( - query, answer, file_path, evidence, keywords_used, + query, answer, file_path or "", evidence, keywords_used, ) self._add_query_to_cluster(cluster, query) try: @@ -2475,6 +2547,86 @@ def _prune_by_score( # Cap at top_k return result[:top_k] + @staticmethod + def _compute_wiki_relevance( + file_path: str, + query: str, + keywords: List[str], + catalog_map: Dict[str, Dict[str, str]], + tree_available_paths: Set[str], + ) -> float: + """Compute wiki-based relevance score for a candidate file (0-10 scale). + + Uses three sub-scores derived from compile artifacts: + + 1. **Catalog summary overlap** (0-``_WIKI_CATALOG_KEYWORD_OVERLAP_MAX``): + proportion of query keywords that appear in the catalog entry's + summary. When *keywords* is empty, falls back to whole-query + substring matching against the summary to avoid returning 0 for + valid queries. + 2. **Tree availability bonus** (0-``_WIKI_TREE_AVAILABILITY_BONUS``): + a file with a compiled tree index likely has rich structure. + 3. **Catalog presence bonus** (0-``_WIKI_CATALOG_PRESENCE_FULL``): + files important enough to be in the catalog get a baseline boost. + + All scoring is pure text matching — no LLM, no embedding. + + Args: + file_path: Absolute path of the candidate file. + query: Original user query. + keywords: Extracted search keywords from FAST Step 1. + catalog_map: ``{path: catalog_entry}`` from CompileArtifacts. + tree_available_paths: Set of file paths with cached tree indices. + + Returns: + Float in [0, 10] representing wiki-derived relevance. + """ + cls = AgenticSearch # access class constants from static method + score = 0.0 + + entry = catalog_map.get(file_path) + + # Sub-score 1: Catalog summary keyword overlap + if entry: + summary_lower = (entry.get("summary", "") + " " + entry.get("name", "")).lower() + query_lower = query.lower() + matches = 0 + total = 0 + for kw in keywords: + if kw: + total += 1 + if kw.lower() in summary_lower: + matches += 1 + # Also check whole query as a substring + if len(query_lower) >= 2 and query_lower in summary_lower: + matches += 1 + total += 1 + # When keywords list is empty but query is non-empty, fall back to + # character-level overlap so the sub-score is not silently 0. + if total == 0 and query_lower: + # Simple overlap: count how many query chars appear in summary + overlap = sum(1 for ch in query_lower if ch in summary_lower) + ratio = overlap / max(len(query_lower), 1) + score += ratio * cls._WIKI_CATALOG_KEYWORD_OVERLAP_MAX + elif total > 0: + score += (matches / total) * cls._WIKI_CATALOG_KEYWORD_OVERLAP_MAX + + # Sub-score 2: Tree availability bonus + if file_path in tree_available_paths: + score += cls._WIKI_TREE_AVAILABILITY_BONUS + + # Sub-score 3: Catalog presence bonus + if entry: + summary_len = len(entry.get("summary", "")) + if summary_len > 100: + score += cls._WIKI_CATALOG_PRESENCE_FULL + elif summary_len > 30: + score += cls._WIKI_CATALOG_PRESENCE_MEDIUM + elif summary_len > 0: + score += cls._WIKI_CATALOG_PRESENCE_MINIMAL + + return min(score, cls._WIKI_MAX_SCORE) + async def _fast_find_best_file( self, keywords: List[str], @@ -2484,9 +2636,23 @@ async def _fast_find_best_file( exclude: Optional[List[str]] = None, top_k: int = 1, keyword_idfs: Optional[Dict[str, float]] = None, + query: str = "", + artifacts: Optional["CompileArtifacts"] = None, ) -> Optional[List[Dict[str, Any]]]: """Search per keyword via rga and return the top-k best-matching files - ranked by IDF-weighted log-TF scoring. + ranked by IDF-weighted log-TF scoring, optionally enhanced with + wiki-derived relevance from compile artifacts. + + Args: + keywords: Search keywords from FAST Step 1. + paths: Search paths. + max_depth: Maximum directory depth for rga. + include: Glob patterns to include. + exclude: Glob patterns to exclude. + top_k: Number of top files to return. + keyword_idfs: Pre-computed IDF values for keywords. + query: Original user query (used for wiki relevance scoring). + artifacts: Compile artifacts for adaptive wiki-enhanced ranking. Returns: List of merged file dicts (path, matches, lines, total_matches, weighted_score) or None. @@ -2576,6 +2742,25 @@ async def _fast_find_best_file( score += idf * (1.0 + math.log(tf)) f["weighted_score"] = score + # --- Wiki-enhanced hybrid scoring (adaptive: only when artifacts exist) --- + if artifacts and artifacts.catalog_map: + # Normalize TF-IDF scores to [0, 10] to align with Wiki score range + max_tf_idf = max((f["weighted_score"] for f in merged), default=1.0) + if max_tf_idf <= 0: + max_tf_idf = 1.0 + for f in merged: + wiki_score = self._compute_wiki_relevance( + f["path"], query, keywords, + artifacts.catalog_map, artifacts.tree_available_paths, + ) + f["wiki_relevance"] = wiki_score + # Normalize TF-IDF to [0, 10] before blending + tf_idf_norm = (f["weighted_score"] / max_tf_idf) * self._WIKI_MAX_SCORE + f["weighted_score"] = ( + self._WIKI_BLEND_ALPHA * tf_idf_norm + + (1 - self._WIKI_BLEND_ALPHA) * wiki_score + ) + merged.sort(key=lambda f: f["weighted_score"], reverse=True) pruned = self._prune_by_score(merged, top_k=top_k) @@ -2750,6 +2935,183 @@ def _load_document_catalog(self) -> Optional[List[Dict[str, str]]]: pass return None + def _detect_compile_artifacts(self) -> CompileArtifacts: + """One-shot probe of all compile artifacts for adaptive FAST activation. + + Reads the document catalog and scans the tree cache directory to + determine which compile products are available. Called once at the + start of ``_search_fast()``; the result is passed to downstream + helpers so they can enable enhanced logic only when artifacts exist. + + Cost: one JSON read (catalog) + one directory listing (tree cache). + Tree path results are cached in ``_tree_paths_cache`` so subsequent + calls within the same instance avoid re-parsing every JSON file. + Returns a ``CompileArtifacts`` with ``None``/empty fields when + compile has not been run. + """ + catalog = self._load_document_catalog() + catalog_map: Dict[str, Dict[str, str]] = {} + if catalog: + for entry in catalog: + p = entry.get("path", "") + if p: + catalog_map[p] = entry + + indexer = self._get_tree_indexer() + # Use cached tree paths when available to avoid re-parsing all JSONs + tree_paths: Set[str] = getattr(self, "_tree_paths_cache", None) or set() + if indexer is not None and not tree_paths: + tree_cache = self.work_path / ".cache" / "compile" / "trees" + if tree_cache.exists(): + try: + from sirchmunk.learnings.tree_indexer import DocumentTree + for tf in sorted(tree_cache.glob("*.json"))[:self._TREE_CACHE_SCAN_LIMIT]: + try: + tree = DocumentTree.from_json( + tf.read_text(encoding="utf-8") + ) + if tree.file_path: + tree_paths.add(tree.file_path) + except Exception: + pass + except Exception: + pass + # Cache for future calls within this instance + self._tree_paths_cache = tree_paths + + return CompileArtifacts( + catalog=catalog, + catalog_map=catalog_map, + tree_indexer=indexer, + tree_available_paths=tree_paths, + ) + + @staticmethod + def _extract_catalog_keywords(summary: str, max_kw: int = 3) -> List[str]: + """Extract salient keywords from a catalog summary via simple heuristics. + + Uses word-length filtering, Chinese character detection, and CJK n-gram + extraction to pick the most informative tokens. For CJK-heavy text + (which does not use whitespace word boundaries), consecutive CJK + character runs are extracted as additional candidate tokens. + + No LLM or embedding involved. + + Args: + summary: Document summary text from the compiled catalog. + max_kw: Maximum number of keywords to return. + + Returns: + List of up to *max_kw* keywords. + """ + cls = AgenticSearch + if not summary: + return [] + import re as _re + + # Split on whitespace and common punctuation (incl. CJK punctuation) + tokens = _re.split( + r'[\s,;\uff0c\uff1b\u3001\u3002\uff1a:!?\uff01\uff1f()\[\]{}\u201c\u201d\u2018\u2019\u0022\u0027]+', + summary, + ) + + # For CJK text, also extract consecutive CJK character runs (2-6 chars) + # so that e.g. "停车位申请条件" yields ["停车位申请条件", "停车位", "申请条件", ...] + cjk_runs = _re.findall(r'[\u4e00-\u9fff\u3400-\u4dbf]{2,}', summary) + # Generate sub-phrases from long CJK runs (bigrams/trigrams/4-grams) + cjk_ngrams: List[str] = [] + for run in cjk_runs: + cjk_ngrams.append(run) + if len(run) > 4: + # Extract 2-4 char sub-phrases from each run + for n in (4, 3, 2): + for i in range(len(run) - n + 1): + cjk_ngrams.append(run[i:i + n]) + + tokens = tokens + cjk_ngrams + + # Filter: keep tokens with appropriate length and not purely numeric + candidates = [ + t for t in tokens + if len(t) >= cls._CATALOG_KEYWORD_MIN_LEN + and not t.isdigit() + and len(t) <= cls._CATALOG_KEYWORD_MAX_LEN + ] + # Prefer longer tokens (more specific) + candidates.sort(key=len, reverse=True) + # Deduplicate case-insensitively + seen: Set[str] = set() + result: List[str] = [] + for c in candidates: + lower = c.lower() + if lower not in seen: + seen.add(lower) + result.append(c) + if len(result) >= max_kw: + break + return result + + def _build_enriched_catalog_listing( + self, + catalog: List[Dict[str, str]], + max_entries: Optional[int] = None, + ) -> str: + """Build an enriched catalog listing with keywords for FAST Step 1. + + Compared to the plain ``[i] name: summary[:200]`` format, this adds + extracted keywords to help the LLM make more informed document + selections. + + Args: + catalog: Entries from ``document_catalog.json``. + max_entries: Cap to prevent prompt overflow. + + Returns: + Formatted listing string for injection into the FAST query + analysis prompt. + """ + lines: List[str] = [] + _max = max_entries if max_entries is not None else self._CATALOG_LISTING_MAX_ENTRIES + _trunc = self._CATALOG_SUMMARY_TRUNCATE + for i, entry in enumerate(catalog[:_max]): + name = entry.get("name", "") + summary = entry.get("summary", "") + kws = AgenticSearch._extract_catalog_keywords(summary) + kw_str = ", ".join(kws) if kws else "" + if kw_str: + lines.append(f"[{i}] {name}: {summary[:_trunc]} [Keywords: {kw_str}]") + else: + lines.append(f"[{i}] {name}: {summary[:_trunc]}") + return "\n".join(lines) + + def _build_answer_context( + self, + best_file_path: str, + artifacts: CompileArtifacts, + ) -> Optional[str]: + """Build document context from catalog for wiki-enhanced answer generation. + + Returns a short context string describing the source document, or + None when no catalog entry exists for *best_file_path*. + + Args: + best_file_path: Path of the top-ranked file from Step 2. + artifacts: Compile artifact availability context. + + Returns: + Context string or None. + """ + if not artifacts.catalog_map: + return None + entry = artifacts.catalog_map.get(best_file_path) + if not entry: + return None + name = entry.get("name", Path(best_file_path).name) + summary = entry.get("summary", "") + if not summary: + return None + return f"Source Document: {name}\nDocument Overview: {summary}" + async def _navigate_tree_for_evidence( self, file_path: str, query: str, ) -> Optional[str]: From 077be35a63998958c0ab2225fe7297ffd2f2d3d7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E7=8F=AD=E6=89=AC?= Date: Tue, 14 Apr 2026 19:40:17 +0800 Subject: [PATCH 06/56] fix _extract_catalog_keywords for llm wiki --- src/sirchmunk/search.py | 49 +++++++++++++++++++++++++++++++++-------- 1 file changed, 40 insertions(+), 9 deletions(-) diff --git a/src/sirchmunk/search.py b/src/sirchmunk/search.py index 2976f60..6880b3b 100644 --- a/src/sirchmunk/search.py +++ b/src/sirchmunk/search.py @@ -3005,47 +3005,64 @@ def _extract_catalog_keywords(summary: str, max_kw: int = 3) -> List[str]: List of up to *max_kw* keywords. """ cls = AgenticSearch - if not summary: + if max_kw <= 0: + return [] + summary_text = str(summary or "").strip() + if not summary_text: return [] import re as _re # Split on whitespace and common punctuation (incl. CJK punctuation) tokens = _re.split( - r'[\s,;\uff0c\uff1b\u3001\u3002\uff1a:!?\uff01\uff1f()\[\]{}\u201c\u201d\u2018\u2019\u0022\u0027]+', - summary, + r'[\s,;\uff0c\uff1b\u3001\u3002\uff1a:!?\uff01\uff1f()\[\]{}\u201c\u201d\u2018\u2019\u0022\u0027/\\|`~@#$%^&*=+<>]+', + summary_text, ) # For CJK text, also extract consecutive CJK character runs (2-6 chars) # so that e.g. "停车位申请条件" yields ["停车位申请条件", "停车位", "申请条件", ...] - cjk_runs = _re.findall(r'[\u4e00-\u9fff\u3400-\u4dbf]{2,}', summary) + cjk_runs = _re.findall(r'[\u4e00-\u9fff\u3400-\u4dbf]{2,}', summary_text) # Generate sub-phrases from long CJK runs (bigrams/trigrams/4-grams) cjk_ngrams: List[str] = [] + max_ngram_per_run = 40 for run in cjk_runs: cjk_ngrams.append(run) if len(run) > 4: # Extract 2-4 char sub-phrases from each run + added = 0 for n in (4, 3, 2): for i in range(len(run) - n + 1): cjk_ngrams.append(run[i:i + n]) + added += 1 + if added >= max_ngram_per_run: + break + if added >= max_ngram_per_run: + break tokens = tokens + cjk_ngrams # Filter: keep tokens with appropriate length and not purely numeric candidates = [ t for t in tokens - if len(t) >= cls._CATALOG_KEYWORD_MIN_LEN + if t + and len(t) >= cls._CATALOG_KEYWORD_MIN_LEN and not t.isdigit() and len(t) <= cls._CATALOG_KEYWORD_MAX_LEN + and not _re.fullmatch(r"[_\-.]+", t) ] # Prefer longer tokens (more specific) candidates.sort(key=len, reverse=True) # Deduplicate case-insensitively seen: Set[str] = set() + chosen_norms: List[str] = [] result: List[str] = [] for c in candidates: lower = c.lower() if lower not in seen: + # Avoid noisy micro-fragments when a longer token already exists. + if len(lower) <= 4 and any(lower in kept for kept in chosen_norms): + continue seen.add(lower) + chosen_norms.append(lower) result.append(c) if len(result) >= max_kw: break @@ -3070,18 +3087,32 @@ def _build_enriched_catalog_listing( Formatted listing string for injection into the FAST query analysis prompt. """ + if not isinstance(catalog, list) or not catalog: + return "" lines: List[str] = [] _max = max_entries if max_entries is not None else self._CATALOG_LISTING_MAX_ENTRIES + if _max <= 0: + return "" _trunc = self._CATALOG_SUMMARY_TRUNCATE for i, entry in enumerate(catalog[:_max]): - name = entry.get("name", "") - summary = entry.get("summary", "") + if not isinstance(entry, dict): + continue + name = str(entry.get("name") or entry.get("path") or "") + summary = str(entry.get("summary") or "") + # Keep one-line prompt entries to avoid accidental prompt pollution. + name = " ".join(name.split()) + summary = " ".join(summary.split()) + if not name: + name = f"doc_{i}" kws = AgenticSearch._extract_catalog_keywords(summary) kw_str = ", ".join(kws) if kws else "" + shown_summary = summary[:_trunc] + if len(summary) > _trunc: + shown_summary += "..." if kw_str: - lines.append(f"[{i}] {name}: {summary[:_trunc]} [Keywords: {kw_str}]") + lines.append(f"[{i}] {name}: {shown_summary} [Keywords: {kw_str}]") else: - lines.append(f"[{i}] {name}: {summary[:_trunc]}") + lines.append(f"[{i}] {name}: {shown_summary}") return "\n".join(lines) def _build_answer_context( From a602197eef20bfb77e75c726aad62cf1b1b88408 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E7=8F=AD=E6=89=AC?= Date: Tue, 14 Apr 2026 20:33:31 +0800 Subject: [PATCH 07/56] add tree guided sampling --- src/sirchmunk/search.py | 191 +++++++++++++++++++++++++++++++++++++++- 1 file changed, 188 insertions(+), 3 deletions(-) diff --git a/src/sirchmunk/search.py b/src/sirchmunk/search.py index 6880b3b..577c276 100644 --- a/src/sirchmunk/search.py +++ b/src/sirchmunk/search.py @@ -1935,6 +1935,14 @@ async def _llm_chunked_summarize(self, combined: str, query: str) -> str: _CATALOG_SUMMARY_TRUNCATE = 200 """Max chars of catalog summary shown in the listing.""" + # --- Tree-guided sampling constants --- + _TREE_SAMPLE_MAX_SECTIONS = 3 + """Max tree sections to include per file in tree-guided sampling.""" + _TREE_SAMPLE_SECTION_MAX_CHARS = 3000 + """Max chars per tree section.""" + _TREE_SAMPLE_RGA_SUPPLEMENT = True + """Whether to append rga evidence after tree sections as supplementary context.""" + _LLM_FALLBACK_EVIDENCE = ( "[No relevant documents found]\n\n" "The search did not find relevant content in the available documents. " @@ -2249,10 +2257,24 @@ async def _search_fast( # Step 2.5 + Step 3: Tree navigation (1 LLM call) runs in # parallel with rga evidence sampling (0 LLM). The merged # result is higher quality than either alone. + # Tree-guided sampling is integrated into _rga_evidence() for + # secondary files; the primary file gets a dedicated parallel + # tree_task to avoid blocking rga. # ============================================================== + # Track files already receiving parallel tree navigation to + # avoid duplicate LLM calls inside _rga_evidence(). + tree_nav_done: Set[str] = set() + tree_nav_target = best_files[0]["path"] + + if artifacts and tree_nav_target in artifacts.tree_available_paths: + tree_task = self._navigate_tree_for_evidence(tree_nav_target, query) + tree_nav_done.add(tree_nav_target) + else: + tree_task = self._async_noop(None) + async def _rga_evidence() -> str: - """Collect rga-based evidence from best_files (zero LLM).""" + """Collect evidence from best_files: tree-guided when available, rga fallback.""" parts: List[str] = [] chars = 0 for bf in best_files: @@ -2262,6 +2284,8 @@ async def _rga_evidence() -> str: fn = Path(fp).name ext = Path(fp).suffix.lower() ev = None + + # 1. Small file: read entirely (existing logic) if ext in self._FAST_TEXT_EXTENSIONS: try: sz = Path(fp).stat().st_size @@ -2271,8 +2295,35 @@ async def _rga_evidence() -> str: ev = f"[{fn}]\n{full}" except Exception: pass + + # 2. Tree-guided sampling (adaptive, skip files handled + # by the parallel tree_task to avoid duplicate LLM) + if ( + ev is None + and artifacts + and fp in artifacts.tree_available_paths + and fp not in tree_nav_done + ): + try: + tree_ev_inner = await self._tree_guided_sample( + fp, query, + match_objects=bf.get("matches", []), + max_chars=self._FAST_MAX_EVIDENCE_CHARS - chars, + artifacts=artifacts, + ) + if tree_ev_inner: + ev = tree_ev_inner + await self._logger.info( + f"[FAST:Step3] Tree-guided sample for {fn} " + f"({len(tree_ev_inner)} chars)" + ) + except Exception: + pass + + # 3. Fallback: rga sampling (existing logic) if ev is None: ev = await self._fast_sample_evidence(fp, bf.get("matches", [])) + if ev: remaining = self._FAST_MAX_EVIDENCE_CHARS - chars parts.append(ev[:remaining]) @@ -2281,9 +2332,7 @@ async def _rga_evidence() -> str: return "\n\n---\n\n".join(parts) # Launch tree navigation for the primary file alongside rga - tree_nav_target = best_files[0]["path"] rga_task = _rga_evidence() - tree_task = self._navigate_tree_for_evidence(tree_nav_target, query) rga_ev, tree_ev = await asyncio.gather(rga_task, tree_task) @@ -3143,6 +3192,142 @@ def _build_answer_context( return None return f"Source Document: {name}\nDocument Overview: {summary}" + async def _tree_guided_sample( + self, + file_path: str, + query: str, + *, + match_objects: Optional[List[Dict[str, Any]]] = None, + max_chars: int = 0, + artifacts: Optional["CompileArtifacts"] = None, + pre_navigated_leaves: Optional[List[Any]] = None, + ) -> Optional[str]: + """Tree-guided evidence sampling: use compiled tree index to locate + relevant sections, then read precise char_range content. + + Falls back to None when no tree index is available, letting callers + use their default sampling strategy (rga windows, Monte Carlo, etc.). + + This method is designed to be called from both FAST and DEEP modes: + - FAST: called inside _rga_evidence() per-file loop + - DEEP: called before/alongside Monte Carlo sampling + + Args: + file_path: Absolute path to the target file. + query: User query for LLM-driven branch selection. + match_objects: Optional rga match objects for hybrid evidence. + max_chars: Character budget for this file's evidence. + Uses ``_FAST_MAX_EVIDENCE_CHARS`` when 0. + artifacts: Compile artifact context; when None, probes lazily. + pre_navigated_leaves: Pre-computed leaf nodes from a prior + ``navigate()`` call. When provided the method skips the + LLM navigation step (avoids duplicate LLM calls). + + Returns: + Formatted evidence string with tree-navigated sections, or None + when tree index is unavailable (caller should fall back). + """ + if max_chars <= 0: + max_chars = self._FAST_MAX_EVIDENCE_CHARS + + # --- Guard: tree availability --- + if artifacts is not None: + if file_path not in artifacts.tree_available_paths: + return None + else: + # Lazy probe when artifacts not provided (DEEP mode entry) + indexer = self._get_tree_indexer() + if indexer is None or not indexer.has_tree(file_path): + return None + + fname = Path(file_path).name + + # --- Obtain leaf nodes --- + leaves = pre_navigated_leaves + if leaves is None: + try: + indexer = self._get_tree_indexer() + if indexer is None: + return None + tree = indexer.load_tree(file_path) + if tree is None or tree.root is None: + return None + leaves = await indexer.navigate( + tree, query, + max_results=self._TREE_SAMPLE_MAX_SECTIONS, + ) + except Exception: + return None + + if not leaves: + return None + + # --- Read full text once for char_range slicing --- + try: + from sirchmunk.utils.file_utils import fast_extract + extraction = await fast_extract(file_path=file_path) + full_text = extraction.content or "" + except Exception: + full_text = "" + + # --- Extract tree sections --- + parts: List[str] = [] + total_chars = 0 + for leaf in leaves[: self._TREE_SAMPLE_MAX_SECTIONS]: + start, end = leaf.char_range + if full_text and end > start: + segment = full_text[start:end] + else: + segment = leaf.summary or "" + segment = segment[: self._TREE_SAMPLE_SECTION_MAX_CHARS] + if not segment.strip(): + continue + header = f"[{fname} \u2192 {leaf.title}]" + chunk = f"{header}\n{segment}" + if total_chars + len(chunk) > max_chars: + remaining = max_chars - total_chars + if remaining > 200: + parts.append(chunk[:remaining]) + total_chars += remaining + break + parts.append(chunk) + total_chars += len(chunk) + + # --- Optional rga supplement --- + if ( + self._TREE_SAMPLE_RGA_SUPPLEMENT + and match_objects + and total_chars < max_chars + ): + hit_lines: List[int] = [] + for m in match_objects: + ln = m.get("data", {}).get("line_number") + if isinstance(ln, int): + hit_lines.append(ln) + if hit_lines: + ext = Path(file_path).suffix.lower() + if ext in self._FAST_TEXT_EXTENSIONS: + rga_ctx = self._read_context_windows( + file_path, hit_lines, + window=self._FAST_CONTEXT_WINDOW, + max_chars=max_chars - total_chars, + ) + if rga_ctx: + rga_section = f"[{fname} \u2192 rga hits]\n{rga_ctx}" + parts.append(rga_section) + total_chars += len(rga_section) + + if not parts: + return None + + evidence = "\n\n".join(parts) + await self._logger.info( + f"[TreeSample] {fname}: " + f"{len(parts)} sections, {total_chars} chars " + f"(pre_nav={'yes' if pre_navigated_leaves else 'no'})" + ) + return evidence + async def _navigate_tree_for_evidence( self, file_path: str, query: str, ) -> Optional[str]: From 8233c35b8a79f8af02d3dbe68bf9880c723ea35c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E7=8F=AD=E6=89=AC?= Date: Tue, 14 Apr 2026 23:38:01 +0800 Subject: [PATCH 08/56] fix compile quality and large-file processing --- src/sirchmunk/learnings/compiler.py | 146 ++++++++++++++++++++---- src/sirchmunk/learnings/tree_indexer.py | 21 +++- src/sirchmunk/search.py | 12 +- 3 files changed, 152 insertions(+), 27 deletions(-) diff --git a/src/sirchmunk/learnings/compiler.py b/src/sirchmunk/learnings/compiler.py index 4ccd5da..10b56a6 100644 --- a/src/sirchmunk/learnings/compiler.py +++ b/src/sirchmunk/learnings/compiler.py @@ -40,6 +40,16 @@ # Similarity threshold for merging into existing clusters during compile _MERGE_SIMILARITY_THRESHOLD = 0.75 +# Max chars for manifest-persisted document summary (used in Phase 2 & catalog) +_MANIFEST_SUMMARY_MAX_LEN = 250 + +# Preview window for direct LLM summarisation (no tree), ~4K tokens +_SUMMARY_PREVIEW_CHARS = 16_000 + +# Multi-section sampling for large documents without a tree index +_SUMMARY_SAMPLE_SECTIONS = 3 # Number of sections to sample for large docs +_SUMMARY_SAMPLE_SECTION_CHARS = 5_000 # Chars per sampled section + # --------------------------------------------------------------------------- # Data structures @@ -54,6 +64,7 @@ class FileManifestEntry: has_tree: bool cluster_ids: List[str] size_bytes: int + summary: str = "" # 新增:存储编译期生成的文档摘要 def to_dict(self) -> Dict[str, Any]: return { @@ -62,6 +73,7 @@ def to_dict(self) -> Dict[str, Any]: "has_tree": self.has_tree, "cluster_ids": self.cluster_ids, "size_bytes": self.size_bytes, + "summary": self.summary, } @classmethod @@ -72,6 +84,7 @@ def from_dict(cls, data: Dict[str, Any]) -> "FileManifestEntry": has_tree=data.get("has_tree", False), cluster_ids=data.get("cluster_ids", []), size_bytes=data.get("size_bytes", 0), + summary=data.get("summary", ""), ) @@ -365,6 +378,7 @@ async def _bounded(entry: FileEntry) -> FileCompileResult: has_tree=result.tree is not None, cluster_ids=result.cluster_ids, size_bytes=Path(result.path).stat().st_size if Path(result.path).exists() else 0, + summary=result.summary[:_MANIFEST_SUMMARY_MAX_LEN] if result.summary else "", ) # Phase 3: aggregate results into knowledge network @@ -516,8 +530,12 @@ async def _compile_single_file( entry.path, content, ) + # Enrich content with structural metadata for non-text types + metadata_prefix = self._extract_structured_metadata(entry.path, content) + enriched_content = metadata_prefix + content if metadata_prefix else content + result.summary = await self._extract_summary( - entry.path, content, result.tree, + entry.path, enriched_content, result.tree, ) result.topics = await self._extract_topics(result.summary) result.evidence = self._build_evidence(entry, content, result) @@ -539,11 +557,14 @@ async def _extract_summary( When a tree is available its root already contains an LLM-synthesized summary (produced by ``_synthesize_root_summary`` during tree build), so we reuse it directly — no redundant LLM call. + + For large documents without a tree, uses multi-section sampling + (beginning, middle, end) to capture the full scope of the document. """ if tree and tree.root and tree.root.summary: return tree.root.summary - preview = content[:16000] if len(content) > 16000 else content + preview = self._build_summary_preview(content) from sirchmunk.llm.prompts import COMPILE_DOC_SUMMARY prompt = COMPILE_DOC_SUMMARY.format( file_name=Path(file_path).name, @@ -552,6 +573,100 @@ async def _extract_summary( resp = await self._llm.achat([{"role": "user", "content": prompt}]) return resp.content.strip() + @staticmethod + def _build_summary_preview(content: str) -> str: + """Build a representative preview for LLM summarisation. + + For short documents (≤ _SUMMARY_PREVIEW_CHARS), returns the full + content. For large documents, samples the beginning, middle, and + end to capture the document's full scope within the token budget. + """ + if len(content) <= _SUMMARY_PREVIEW_CHARS: + return content + + section_size = _SUMMARY_SAMPLE_SECTION_CHARS + mid_start = max(section_size, (len(content) - section_size) // 2) + + head = content[:section_size] + middle = content[mid_start:mid_start + section_size] + tail = content[-section_size:] + + return ( + f"[Beginning of document]\n{head}\n\n" + f"[... content omitted ...]\n\n" + f"[Middle of document]\n{middle}\n\n" + f"[... content omitted ...]\n\n" + f"[End of document]\n{tail}" + ) + + @staticmethod + def _extract_structured_metadata(file_path: str, content: str) -> str: + """Extract structural metadata for non-text document types. + + For spreadsheets and presentations, prepend a structural overview + (sheet names, column headers, slide titles) so the LLM summariser + has better context than raw extracted text alone. + + Returns a metadata prefix string (may be empty for unsupported types). + """ + ext = Path(file_path).suffix.lower() + + if ext == ".xlsx": + return KnowledgeCompiler._extract_xlsx_metadata(file_path) + if ext == ".pptx": + return KnowledgeCompiler._extract_pptx_metadata(file_path) + + return "" + + @staticmethod + def _extract_xlsx_metadata(file_path: str) -> str: + """Extract structural metadata from Excel files. + + Reads sheet names, row counts, and column headers (first row) to + provide the LLM with a structural overview of the workbook. + Caps at 10 sheets and 15 columns per sheet for bounded output. + """ + try: + import openpyxl + wb = openpyxl.load_workbook(file_path, read_only=True, data_only=True) + lines: List[str] = ["[Excel Workbook Structure]"] + for sheet_name in wb.sheetnames[:10]: # Cap at 10 sheets + ws = wb[sheet_name] + # Extract column headers (first row) + headers: List[str] = [] + for cell in ws.iter_rows(min_row=1, max_row=1, values_only=True): + headers = [str(h) for h in cell if h is not None] + break + row_count = ws.max_row or 0 + header_str = ", ".join(headers[:15]) if headers else "no headers" + lines.append(f"- Sheet '{sheet_name}': {row_count} rows, columns: [{header_str}]") + wb.close() + return "\n".join(lines) + "\n\n" + except Exception: + return "" + + @staticmethod + def _extract_pptx_metadata(file_path: str) -> str: + """Extract structural metadata from PowerPoint files. + + Reads slide count and titles (from the title placeholder) to give + the LLM a table-of-contents-like overview of the presentation. + Caps at 20 slides for bounded output. + """ + try: + from pptx import Presentation + prs = Presentation(file_path) + lines: List[str] = [f"[PowerPoint Structure: {len(prs.slides)} slides]"] + for i, slide in enumerate(prs.slides[:20], 1): # Cap at 20 slides + title = "" + if slide.shapes.title: + title = slide.shapes.title.text.strip() + if title: + lines.append(f"- Slide {i}: {title}") + return "\n".join(lines) + "\n\n" + except Exception: + return "" + def _build_evidence( self, entry: FileEntry, @@ -851,14 +966,19 @@ def _build_document_catalog(self, manifest: CompileManifest) -> None: The catalog is consumed by FAST search to fuse query analysis with LLM-driven document routing in a single prompt. Each entry carries - the filename and a truncated root summary (≤250 chars). + the filename and a truncated root summary (<= _MANIFEST_SUMMARY_MAX_LEN chars). + + Summary is sourced from the manifest (populated during Phase 2 compile), + with a tree-root fallback for backward compatibility. """ tree_cache = self._compile_dir / "trees" entries: List[Dict[str, str]] = [] for file_path, entry in manifest.files.items(): - summary = "" - if entry.has_tree and tree_cache.exists(): + summary = entry.summary # Primary: manifest-persisted summary + + # Fallback: read from tree root if manifest summary is empty + if not summary and entry.has_tree and tree_cache.exists(): tree_file = tree_cache / f"{entry.file_hash}.json" if tree_file.exists(): try: @@ -866,24 +986,10 @@ def _build_document_catalog(self, manifest: CompileManifest) -> None: tree_file.read_text(encoding="utf-8"), ) if tree.root and tree.root.summary: - summary = tree.root.summary[:250] + summary = tree.root.summary[:_MANIFEST_SUMMARY_MAX_LEN] except Exception: pass - if not summary: - # Fallback: use first cluster's description - for cid in entry.cluster_ids[:1]: - try: - import asyncio - loop = asyncio.get_event_loop() - if loop.is_running(): - break - c = loop.run_until_complete(self._storage.get(cid)) - if c and c.description: - summary = str(c.description[0])[:250] - except Exception: - break - entries.append({ "path": file_path, "name": Path(file_path).name, diff --git a/src/sirchmunk/learnings/tree_indexer.py b/src/sirchmunk/learnings/tree_indexer.py index 53ebf0b..8bd2983 100644 --- a/src/sirchmunk/learnings/tree_indexer.py +++ b/src/sirchmunk/learnings/tree_indexer.py @@ -21,6 +21,11 @@ # File-size threshold: skip tree indexing for small files _TREE_MIN_CHARS = 50_000 # 50 K characters +# Adaptive preview window for LLM structure analysis +_TREE_PREVIEW_MIN = 12_000 # Minimum preview window (chars) +_TREE_PREVIEW_MAX = 50_000 # Maximum preview window (~12K tokens) +_TREE_PREVIEW_RATIO = 0.15 # Fraction of document to preview + # Extensions eligible for tree indexing _TREE_EXTENSIONS = { ".pdf", ".docx", ".doc", ".md", ".markdown", @@ -260,7 +265,8 @@ async def _build_node( """Recursively build tree nodes via LLM structure analysis.""" from sirchmunk.llm.prompts import COMPILE_TREE_STRUCTURE - preview = text[:12000] if len(text) > 12000 else text + preview_size = self._compute_preview_size(len(text)) + preview = text[:preview_size] prompt = COMPILE_TREE_STRUCTURE.format( document_content=preview, max_sections=8, @@ -427,6 +433,19 @@ def _load_cache(self, file_hash: str) -> Optional[DocumentTree]: # Helpers # # ------------------------------------------------------------------ # + @staticmethod + def _compute_preview_size(text_len: int) -> int: + """Compute adaptive preview window size for LLM structure analysis. + + Scales with document length: at least *_TREE_PREVIEW_MIN* chars, + up to *_TREE_PREVIEW_MAX*, using *_TREE_PREVIEW_RATIO* of the + document length as the baseline. + """ + return max( + _TREE_PREVIEW_MIN, + min(int(text_len * _TREE_PREVIEW_RATIO), _TREE_PREVIEW_MAX), + ) + @staticmethod def _count_nodes(node: TreeNode) -> int: return 1 + sum(DocumentTreeIndexer._count_nodes(c) for c in node.children) diff --git a/src/sirchmunk/search.py b/src/sirchmunk/search.py index 577c276..f38028f 100644 --- a/src/sirchmunk/search.py +++ b/src/sirchmunk/search.py @@ -1910,18 +1910,18 @@ async def _llm_chunked_summarize(self, combined: str, query: str) -> str: _FAST_SMALL_FILE_THRESHOLD = 100_000 # 100K chars - read full file instead of grep sampling # --- Wiki-enhanced ranking constants --- - _WIKI_BLEND_ALPHA = 0.7 + _WIKI_BLEND_ALPHA = 0.85 """TF-IDF weight in the hybrid score; Wiki weight = 1 - alpha.""" _WIKI_MAX_SCORE = 10.0 """Upper bound for the wiki relevance score.""" _WIKI_CATALOG_KEYWORD_OVERLAP_MAX = 5.0 """Maximum sub-score for catalog summary keyword overlap.""" - _WIKI_TREE_AVAILABILITY_BONUS = 2.0 - """Bonus for files that have a compiled tree index.""" - _WIKI_CATALOG_PRESENCE_FULL = 3.0 + _WIKI_TREE_AVAILABILITY_BONUS = 0.5 + """Bonus for files that have a compiled tree index (weak signal).""" + _WIKI_CATALOG_PRESENCE_FULL = 2.0 """Catalog presence bonus for summaries > 100 chars.""" - _WIKI_CATALOG_PRESENCE_MEDIUM = 2.0 - """Catalog presence bonus for summaries > 30 chars.""" + _WIKI_CATALOG_PRESENCE_MEDIUM = 1.5 + """Catalog presence bonus for summaries > 30 chars (must be < FULL).""" _WIKI_CATALOG_PRESENCE_MINIMAL = 1.0 """Catalog presence bonus for summaries > 0 chars.""" _TREE_CACHE_SCAN_LIMIT = 200 From 1de1c98817c8e6c9beb29045debe38df4682cab7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E7=8F=AD=E6=89=AC?= Date: Wed, 15 Apr 2026 00:09:58 +0800 Subject: [PATCH 09/56] adopt the latest compile processing --- src/sirchmunk/search.py | 117 ++++++++++++++++++++++++++++------------ 1 file changed, 82 insertions(+), 35 deletions(-) diff --git a/src/sirchmunk/search.py b/src/sirchmunk/search.py index f38028f..5128702 100644 --- a/src/sirchmunk/search.py +++ b/src/sirchmunk/search.py @@ -135,6 +135,7 @@ class CompileArtifacts: catalog_map: Dict[str, Dict[str, str]] # path -> catalog entry for O(1) lookup tree_indexer: Optional[Any] # DocumentTreeIndexer (lazy import) tree_available_paths: Set[str] # file paths that have cached tree indices + manifest_map: Dict[str, Any] = field(default_factory=dict) # {path: FileManifestEntry} class AgenticSearch(BaseSearch): @@ -1426,6 +1427,7 @@ async def _search_deep( self._probe_knowledge_cache(query), self._load_spec_context(paths, stale_hours=spec_stale_hours), self._probe_tree_index(query), + self._probe_compile_hints(initial_keywords if initial_keywords else [query]), return_exceptions=True, ) @@ -1434,8 +1436,9 @@ async def _search_deep( knowledge_probe = phase1_results[2] if not isinstance(phase1_results[2], Exception) else KnowledgeProbeResult([], [], "") spec_context = phase1_results[3] if not isinstance(phase1_results[3], Exception) else "" tree_hits = phase1_results[4] if not isinstance(phase1_results[4], Exception) else [] + compile_hints = phase1_results[5] if not isinstance(phase1_results[5], Exception) else CompileHints([], []) - for i, label in enumerate(["keywords", "dir_scan", "knowledge", "spec_cache", "tree_index"]): + for i, label in enumerate(["keywords", "dir_scan", "knowledge", "spec_cache", "tree_index", "compile_hints"]): if isinstance(phase1_results[i], Exception): await self._logger.warning(f"[Phase 1] {label} probe failed: {phase1_results[i]}") @@ -1471,6 +1474,7 @@ async def _search_deep( f"dir_scan={'OK' if scan_result else 'N/A'}, " f"knowledge_files={len(knowledge_probe.file_paths)}, " f"tree_hits={len(tree_hits)}, " + f"compile_hints={len(compile_hints.file_paths)}, " f"soft_hit={'YES' if soft_hit else 'NO'}, " f"spec_cache={'YES' if spec_context else 'NO'}" ) @@ -1523,7 +1527,7 @@ async def _search_deep( if soft_hit: extra_knowledge_files = soft_hit.file_paths + extra_knowledge_files merged_files = self._merge_file_paths( - keyword_files=list(tree_hits) + keyword_files, + keyword_files=list(tree_hits) + compile_hints.file_paths + keyword_files, dir_scan_files=dir_scan_files, knowledge_hits=extra_knowledge_files, ) @@ -2285,22 +2289,9 @@ async def _rga_evidence() -> str: ext = Path(fp).suffix.lower() ev = None - # 1. Small file: read entirely (existing logic) - if ext in self._FAST_TEXT_EXTENSIONS: - try: - sz = Path(fp).stat().st_size - if sz < self._FAST_SMALL_FILE_THRESHOLD: - full = Path(fp).read_text(errors="replace") - if len(full) < self._FAST_SMALL_FILE_THRESHOLD: - ev = f"[{fn}]\n{full}" - except Exception: - pass - - # 2. Tree-guided sampling (adaptive, skip files handled - # by the parallel tree_task to avoid duplicate LLM) + # 1. Tree-guided sampling FIRST for tree-indexed files if ( - ev is None - and artifacts + artifacts and fp in artifacts.tree_available_paths and fp not in tree_nav_done ): @@ -2320,6 +2311,17 @@ async def _rga_evidence() -> str: except Exception: pass + # 2. Small file: read entirely (only if tree didn't provide evidence) + if ev is None and ext in self._FAST_TEXT_EXTENSIONS: + try: + sz = Path(fp).stat().st_size + if sz < self._FAST_SMALL_FILE_THRESHOLD: + full = Path(fp).read_text(errors="replace") + if len(full) < self._FAST_SMALL_FILE_THRESHOLD: + ev = f"[{fn}]\n{full}" + except Exception: + pass + # 3. Fallback: rga sampling (existing logic) if ev is None: ev = await self._fast_sample_evidence(fp, bf.get("matches", [])) @@ -2641,11 +2643,15 @@ def _compute_wiki_relevance( query_lower = query.lower() matches = 0 total = 0 + summary_tokens = cls._tokenize_for_matching(summary_lower) for kw in keywords: if kw: total += 1 - if kw.lower() in summary_lower: - matches += 1 + kw_low = kw.lower() + if kw_low in summary_tokens: + matches += 1 # Full token match + elif kw_low in summary_lower: + matches += 0.5 # Substring-only match (lower confidence) # Also check whole query as a substring if len(query_lower) >= 2 and query_lower in summary_lower: matches += 1 @@ -3006,25 +3012,43 @@ def _detect_compile_artifacts(self) -> CompileArtifacts: if p: catalog_map[p] = entry + # Load manifest for rich metadata (size, has_tree, cluster_ids) + manifest_map: Dict[str, Any] = {} + manifest_path = self.work_path / ".cache" / "compile" / "manifest.json" + if manifest_path.exists(): + try: + from sirchmunk.learnings.compiler import CompileManifest + manifest = CompileManifest.from_json( + manifest_path.read_text(encoding="utf-8") + ) + manifest_map = manifest.files # {file_path: FileManifestEntry} + except Exception: + pass + indexer = self._get_tree_indexer() # Use cached tree paths when available to avoid re-parsing all JSONs tree_paths: Set[str] = getattr(self, "_tree_paths_cache", None) or set() - if indexer is not None and not tree_paths: - tree_cache = self.work_path / ".cache" / "compile" / "trees" - if tree_cache.exists(): - try: - from sirchmunk.learnings.tree_indexer import DocumentTree - for tf in sorted(tree_cache.glob("*.json"))[:self._TREE_CACHE_SCAN_LIMIT]: - try: - tree = DocumentTree.from_json( - tf.read_text(encoding="utf-8") - ) - if tree.file_path: - tree_paths.add(tree.file_path) - except Exception: - pass - except Exception: - pass + if not tree_paths: + # Prefer manifest-based detection (fast, O(1) per file) + if manifest_map: + tree_paths = {fp for fp, entry in manifest_map.items() if entry.has_tree} + # Fallback: scan tree cache directory (legacy path) + elif indexer is not None: + tree_cache = self.work_path / ".cache" / "compile" / "trees" + if tree_cache.exists(): + try: + from sirchmunk.learnings.tree_indexer import DocumentTree + for tf in sorted(tree_cache.glob("*.json"))[:self._TREE_CACHE_SCAN_LIMIT]: + try: + tree = DocumentTree.from_json( + tf.read_text(encoding="utf-8") + ) + if tree.file_path: + tree_paths.add(tree.file_path) + except Exception: + pass + except Exception: + pass # Cache for future calls within this instance self._tree_paths_cache = tree_paths @@ -3033,8 +3057,31 @@ def _detect_compile_artifacts(self) -> CompileArtifacts: catalog_map=catalog_map, tree_indexer=indexer, tree_available_paths=tree_paths, + manifest_map=manifest_map, ) + @staticmethod + def _tokenize_for_matching(text: str) -> Set[str]: + """Tokenize text into meaningful units for keyword matching. + + Splits on whitespace and CJK/Latin punctuation boundaries, then + generates 2-3 char n-grams for CJK-heavy tokens to handle + unsegmented Chinese text. Returns a set of lowercased tokens. + """ + import re + tokens: Set[str] = set() + raw = re.split(r'[\s,;.!?,;。!?::、\u201c\u201d\u2018\u2019()()\[\]{}<>《》\-/]+', text.lower()) + for t in raw: + t = t.strip() + if not t: + continue + tokens.add(t) + if len(t) >= 2 and any('\u4e00' <= c <= '\u9fff' for c in t): + for n in (2, 3): + for i in range(len(t) - n + 1): + tokens.add(t[i:i + n]) + return tokens + @staticmethod def _extract_catalog_keywords(summary: str, max_kw: int = 3) -> List[str]: """Extract salient keywords from a catalog summary via simple heuristics. From 938ced1e657d58f0cc042631c97f52de5fbfe328 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E7=8F=AD=E6=89=AC?= Date: Wed, 15 Apr 2026 16:02:45 +0800 Subject: [PATCH 10/56] refactor tree indexing with toc --- src/sirchmunk/learnings/compiler.py | 45 ++- src/sirchmunk/learnings/toc_extractor.py | 391 +++++++++++++++++++++++ src/sirchmunk/learnings/tree_indexer.py | 154 ++++++++- src/sirchmunk/search.py | 111 +++++++ 4 files changed, 697 insertions(+), 4 deletions(-) create mode 100644 src/sirchmunk/learnings/toc_extractor.py diff --git a/src/sirchmunk/learnings/compiler.py b/src/sirchmunk/learnings/compiler.py index 10b56a6..fac9b79 100644 --- a/src/sirchmunk/learnings/compiler.py +++ b/src/sirchmunk/learnings/compiler.py @@ -65,6 +65,8 @@ class FileManifestEntry: cluster_ids: List[str] size_bytes: int summary: str = "" # 新增:存储编译期生成的文档摘要 + has_explicit_toc: bool = False # Whether a native TOC was extracted from the file + tree_node_count: int = 0 # Number of nodes in the tree index (quality metric) def to_dict(self) -> Dict[str, Any]: return { @@ -74,6 +76,8 @@ def to_dict(self) -> Dict[str, Any]: "cluster_ids": self.cluster_ids, "size_bytes": self.size_bytes, "summary": self.summary, + "has_explicit_toc": self.has_explicit_toc, + "tree_node_count": self.tree_node_count, } @classmethod @@ -85,6 +89,8 @@ def from_dict(cls, data: Dict[str, Any]) -> "FileManifestEntry": cluster_ids=data.get("cluster_ids", []), size_bytes=data.get("size_bytes", 0), summary=data.get("summary", ""), + has_explicit_toc=data.get("has_explicit_toc", False), + tree_node_count=data.get("tree_node_count", 0), ) @@ -147,6 +153,8 @@ class FileCompileResult: evidence: Optional[EvidenceUnit] = None cluster_ids: List[str] = field(default_factory=list) error: Optional[str] = None + has_explicit_toc: bool = False # Whether TOC was extracted from native structure + tree_node_count: int = 0 # Number of nodes in the tree index @dataclass @@ -379,6 +387,8 @@ async def _bounded(entry: FileEntry) -> FileCompileResult: cluster_ids=result.cluster_ids, size_bytes=Path(result.path).stat().st_size if Path(result.path).exists() else 0, summary=result.summary[:_MANIFEST_SUMMARY_MAX_LEN] if result.summary else "", + has_explicit_toc=result.has_explicit_toc, + tree_node_count=result.tree_node_count, ) # Phase 3: aggregate results into knowledge network @@ -525,11 +535,26 @@ async def _compile_single_file( and DocumentTreeIndexer.should_build_tree(entry.path, len(content)) ) + # Phase 0.5: TOC extraction (zero LLM calls) + toc_entries = None + if use_tree: + from sirchmunk.learnings.toc_extractor import TOCExtractor + toc_entries = TOCExtractor.extract(entry.path, content) + if toc_entries: + await self._log.info( + f"[Compile] Extracted TOC with {len(toc_entries)} entries " + f"for {Path(entry.path).name}" + ) + if use_tree: result.tree = await self._tree_indexer.build_tree( - entry.path, content, + entry.path, content, toc_entries=toc_entries, ) + # Record TOC / tree metrics on the result for manifest persistence + result.has_explicit_toc = toc_entries is not None and len(toc_entries) > 0 + result.tree_node_count = self._count_tree_nodes(result.tree) + # Enrich content with structural metadata for non-text types metadata_prefix = self._extract_structured_metadata(entry.path, content) enriched_content = metadata_prefix + content if metadata_prefix else content @@ -940,6 +965,24 @@ def _add_edge( WeakSemanticEdge(target_cluster_id=target_id, weight=weight, source=source) ) + @staticmethod + def _count_tree_nodes(tree: Optional[DocumentTree]) -> int: + """Count total nodes in a DocumentTree (recursive). + + Args: + tree: The tree to count, or None. + + Returns: + Total node count, or 0 if tree is None. + """ + if tree is None or tree.root is None: + return 0 + + def _count(node: Any) -> int: + return 1 + sum(_count(c) for c in node.children) + + return _count(tree.root) + # ------------------------------------------------------------------ # # Manifest I/O # # ------------------------------------------------------------------ # diff --git a/src/sirchmunk/learnings/toc_extractor.py b/src/sirchmunk/learnings/toc_extractor.py new file mode 100644 index 0000000..85f3b8e --- /dev/null +++ b/src/sirchmunk/learnings/toc_extractor.py @@ -0,0 +1,391 @@ +# Copyright (c) ModelScope Contributors. All rights reserved. +""" +TOC (Table of Contents) extractor — pure local operations, zero LLM calls. + +Extracts hierarchical table-of-contents structures from various document +formats (PDF, Markdown, DOCX, HTML) using native format features (bookmarks, +heading styles, heading tags). The extracted TOCEntry list is consumed by +the tree indexer to accelerate tree construction. +""" + +import re +from dataclasses import dataclass, field +from pathlib import Path +from typing import List, Optional + +# Minimum number of TOC entries required to form a meaningful structure +_MIN_TOC_ENTRIES = 3 + +# Known heading-style prefixes across locales (English, Chinese, etc.) +_HEADING_STYLE_PREFIXES = ("Heading", "heading", "\u6807\u9898") # "标题" = Chinese + + +@dataclass +class TOCEntry: + """Single entry in an extracted table of contents.""" + + title: str + level: int # 0=root, 1=section, 2=subsection + char_start: int # Character offset in extracted text + char_end: Optional[int] = None + page_start: Optional[int] = None + page_end: Optional[int] = None + children: List["TOCEntry"] = field(default_factory=list) + + +class TOCExtractor: + """Extract TOC structure from documents using native format features. + + All methods are static — no instance state required. Each extraction + method handles one file format and returns a flat or nested list of + ``TOCEntry`` objects. The main ``extract()`` entry point dispatches + by file extension and resolves character positions against the + extracted text content. + + Design constraints: + - Pure local operations, zero LLM calls + - Exceptions handled internally; failure returns None + """ + + @staticmethod + def extract(file_path: str, content: str) -> Optional[List[TOCEntry]]: + """Main entry point: extract TOC entries from a file. + + Dispatches to format-specific extractors based on file extension, + then resolves character positions in the extracted text content. + + Args: + file_path: Absolute path to the source file. + content: Extracted text content of the file. + + Returns: + List of TOCEntry with resolved char positions, or None if + the file format is unsupported or fewer than _MIN_TOC_ENTRIES + entries are found. + """ + ext = Path(file_path).suffix.lower() + + entries: Optional[List[TOCEntry]] = None + if ext == ".pdf": + entries = TOCExtractor._extract_pdf_toc(file_path) + elif ext in (".md", ".markdown"): + entries = TOCExtractor._extract_markdown_toc(content) + elif ext in (".docx",): + entries = TOCExtractor._extract_docx_toc(file_path) + elif ext in (".html", ".htm"): + entries = TOCExtractor._extract_html_toc(content) + else: + return None + + if not entries: + return None + + # Flatten nested children for total count check + total = TOCExtractor._count_entries(entries) + if total < _MIN_TOC_ENTRIES: + return None + + # Resolve character positions in extracted text + entries = TOCExtractor._resolve_char_positions(entries, content) + return entries + + @staticmethod + def _extract_pdf_toc(file_path: str) -> Optional[List[TOCEntry]]: + """Extract TOC from PDF bookmarks/outline using pypdf. + + Recursively parses the nested bookmark structure from + ``PdfReader.outline``. + + Args: + file_path: Path to the PDF file. + + Returns: + List of TOCEntry with page_start populated, or None on failure. + """ + try: + from pypdf import PdfReader + + reader = PdfReader(file_path) + outline = reader.outline + if not outline: + return None + + entries: List[TOCEntry] = [] + TOCExtractor._parse_pdf_outline(reader, outline, entries, level=1) + return entries if entries else None + except Exception: + return None + + @staticmethod + def _parse_pdf_outline( + reader: "PdfReader", + outline_items: List, + entries: List[TOCEntry], + level: int, + ) -> None: + """Recursively parse pypdf outline items into TOCEntry list. + + Args: + reader: PdfReader instance for page number resolution. + outline_items: Nested list of outline Destination objects. + entries: Accumulator list to append entries to. + level: Current nesting level (1=top-level section). + """ + for item in outline_items: + if isinstance(item, list): + # Nested list means sub-bookmarks — attach to last entry + if entries: + sub_entries: List[TOCEntry] = [] + TOCExtractor._parse_pdf_outline( + reader, item, sub_entries, level=level + 1, + ) + entries[-1].children.extend(sub_entries) + else: + TOCExtractor._parse_pdf_outline( + reader, item, entries, level=level, + ) + else: + # Single bookmark destination + try: + title = item.title if hasattr(item, "title") else str(item) + page_num = None + try: + page_num = reader.get_destination_page_number(item) + except Exception: + pass + entry = TOCEntry( + title=title.strip(), + level=level, + char_start=0, + page_start=page_num, + ) + entries.append(entry) + except Exception: + continue + + @staticmethod + def _extract_markdown_toc(content: str) -> Optional[List[TOCEntry]]: + """Extract TOC from Markdown heading syntax (# / ## / ###). + + Matches ATX-style headings: lines beginning with 1-6 '#' characters + followed by whitespace and the heading text. + + Args: + content: Markdown text content. + + Returns: + List of TOCEntry with level derived from '#' count, or None. + """ + try: + pattern = re.compile(r"^(#{1,6})\s+(.+)$", re.MULTILINE) + matches = pattern.findall(content) + if not matches: + return None + + entries: List[TOCEntry] = [] + for hashes, title in matches: + entries.append(TOCEntry( + title=title.strip(), + level=len(hashes), + char_start=0, + )) + return entries if entries else None + except Exception: + return None + + @staticmethod + def _extract_docx_toc(file_path: str) -> Optional[List[TOCEntry]]: + """Extract TOC from DOCX heading styles using python-docx. + + Reads paragraphs with heading style names (English ``Heading``, + Chinese ``\u6807\u9898``, etc.), extracting the heading level from the style + name suffix (e.g., ``Heading 1`` -> level 1). + + Args: + file_path: Path to the DOCX file. + + Returns: + List of TOCEntry with level from heading style, or None. + """ + try: + import docx + + doc = docx.Document(file_path) + entries: List[TOCEntry] = [] + for para in doc.paragraphs: + style_name = para.style.name or "" + # Match heading styles across locales ("Heading 1", "标题 1", etc.) + matched_prefix = "" + for prefix in _HEADING_STYLE_PREFIXES: + if style_name.startswith(prefix): + matched_prefix = prefix + break + if not matched_prefix: + continue + level_str = style_name[len(matched_prefix):].strip() + try: + level = int(level_str) if level_str else 1 + except ValueError: + level = 1 + title = para.text.strip() + if title: + entries.append(TOCEntry( + title=title, + level=level, + char_start=0, + )) + return entries if entries else None + except Exception: + return None + + @staticmethod + def _extract_html_toc(content: str) -> Optional[List[TOCEntry]]: + """Extract TOC from HTML heading tags (

through

). + + Uses regex to match heading tags and strips inner HTML tags + from the title text. + + Args: + content: HTML text content. + + Returns: + List of TOCEntry with level from tag number, or None. + """ + try: + pattern = re.compile( + r"]*>(.*?)", + re.IGNORECASE | re.DOTALL, + ) + matches = pattern.findall(content) + if not matches: + return None + + entries: List[TOCEntry] = [] + for level_str, raw_title in matches: + # Strip HTML tags from title + title = re.sub(r"<[^>]+>", "", raw_title).strip() + if title: + entries.append(TOCEntry( + title=title, + level=int(level_str), + char_start=0, + )) + return entries if entries else None + except Exception: + return None + + @staticmethod + def _resolve_char_positions( + entries: List[TOCEntry], + content: str, + ) -> List[TOCEntry]: + """Resolve character start/end positions for TOC entries in content. + + Searches for each entry's title in the content text using + case-insensitive matching, progressing forward to avoid duplicate + matches. Sets char_end to the start of the next entry (or + len(content) for the last entry). + + Also recurses into children to resolve their positions. + + Args: + entries: Flat list of TOCEntry to resolve. + content: Full extracted text to search within. + + Returns: + The same list with char_start and char_end populated. + """ + if not content or not entries: + return entries + + content_lower = content.lower() + search_from = 0 + + # Collect all entries in document order (top-level + children) + flat: List[TOCEntry] = [] + TOCExtractor._flatten_entries(entries, flat) + + # Pass 1: resolve char_start for each entry + for entry in flat: + title_lower = entry.title.lower().strip() + if not title_lower: + entry.char_start = search_from + continue + # Normalise whitespace for fuzzy matching (PDF extracts may + # insert extra spaces inside headings). + title_normalised = re.sub(r"\s+", " ", title_lower) + pos = content_lower.find(title_normalised, search_from) + if pos < 0: + # Retry with the original (un-normalised) title + pos = content_lower.find(title_lower, search_from) + if pos >= 0: + entry.char_start = pos + search_from = pos + len(title_lower) + else: + # Title not found after search_from; try from beginning + pos = content_lower.find(title_normalised) + if pos < 0: + pos = content_lower.find(title_lower) + if pos >= 0: + entry.char_start = pos + # Do NOT reset search_from to avoid breaking order + else: + # Last resort: place at current search frontier + entry.char_start = search_from + + # Pass 2: resolve char_end as start of next entry (or len(content)) + for i in range(len(flat) - 1): + flat[i].char_end = flat[i + 1].char_start + if flat: + flat[-1].char_end = len(content) + + return entries + + @staticmethod + def _flatten_entries( + entries: List[TOCEntry], + flat: List[TOCEntry], + ) -> None: + """Flatten nested TOCEntry tree into document-order list. + + Args: + entries: Nested entry list. + flat: Accumulator for flattened output. + """ + for entry in entries: + flat.append(entry) + if entry.children: + TOCExtractor._flatten_entries(entry.children, flat) + + @staticmethod + def _count_entries(entries: List[TOCEntry]) -> int: + """Count total entries including nested children. + + Args: + entries: Nested entry list. + + Returns: + Total number of entries in the tree. + """ + count = 0 + for entry in entries: + count += 1 + if entry.children: + count += TOCExtractor._count_entries(entry.children) + return count + @staticmethod + def _count_entries(entries: List[TOCEntry]) -> int: + """Count total entries including nested children. + + Args: + entries: Nested entry list. + + Returns: + Total number of entries in the tree. + """ + count = 0 + for entry in entries: + count += 1 + if entry.children: + count += TOCExtractor._count_entries(entry.children) + return count diff --git a/src/sirchmunk/learnings/tree_indexer.py b/src/sirchmunk/learnings/tree_indexer.py index 8bd2983..abf5459 100644 --- a/src/sirchmunk/learnings/tree_indexer.py +++ b/src/sirchmunk/learnings/tree_indexer.py @@ -19,7 +19,18 @@ from sirchmunk.utils.file_utils import get_fast_hash # File-size threshold: skip tree indexing for small files -_TREE_MIN_CHARS = 50_000 # 50 K characters +_TREE_MIN_CHARS = 20_000 # 20 K characters (lowered from 50K for broader coverage) + +# Adaptive depth thresholds: (min_chars, max_depth) — evaluated top-down; +# **must** be sorted by min_chars descending so the first match wins. +_TREE_ADAPTIVE_DEPTH_THRESHOLDS: tuple = ( + (100_000, 4), + (50_000, 3), + (20_000, 2), +) + +# Summary snippet length extracted from section content (chars) +_TOC_NODE_SUMMARY_MAX_CHARS = 300 # Adaptive preview window for LLM structure analysis _TREE_PREVIEW_MIN = 12_000 # Minimum preview window (chars) @@ -153,9 +164,14 @@ async def build_tree( max_depth: int = 4, force_rebuild: bool = False, total_pages: Optional[int] = None, + toc_entries: Optional[List[Any]] = None, ) -> Optional[DocumentTree]: """Build a tree index for a document. + When *toc_entries* are provided (from TOCExtractor), uses the + TOC-accelerated path that skips recursive LLM analysis and builds + the tree directly from extracted headings. + Returns None when the document is too small or unstructured. """ file_hash = get_fast_hash(file_path) @@ -175,12 +191,34 @@ async def build_tree( if ext not in _TREE_EXTENSIONS: return None + # Use adaptive depth based on document length + effective_depth = self._compute_adaptive_depth(len(content)) + await self._log.info( f"[TreeIndexer] Building tree for {Path(file_path).name} " - f"({len(content)} chars, depth={max_depth})" + f"({len(content)} chars, depth={effective_depth})" ) - root = await self._build_node(content, level=0, max_depth=max_depth) + # TOC-accelerated path: skip recursive LLM analysis + if toc_entries: + root = await self._build_tree_from_toc(toc_entries, content) + if root is not None: + tree = DocumentTree( + file_path=file_path, + file_hash=file_hash, + created_at=datetime.now(timezone.utc).isoformat(), + total_chars=len(content), + total_pages=total_pages, + root=root, + ) + self._save_cache(file_hash, tree) + await self._log.info( + f"[TreeIndexer] Built tree from TOC: {self._count_nodes(root)} nodes" + ) + return tree + + # Fallback: existing recursive LLM path (with adaptive depth) + root = await self._build_node(content, level=0, max_depth=effective_depth) if root is None: return None @@ -258,6 +296,116 @@ def has_tree(self, file_path: str) -> bool: # Internals # # ------------------------------------------------------------------ # + async def _build_tree_from_toc( + self, + toc_entries: List[Any], + content: str, + ) -> Optional[TreeNode]: + """Build tree directly from extracted TOC entries, avoiding recursive LLM. + + Each TOCEntry becomes a TreeNode with char_range from the entry positions. + Only the root summary requires an LLM call (_synthesize_root_summary). + + Args: + toc_entries: List of TOCEntry from toc_extractor. + content: Full extracted text of the document. + + Returns: + Root TreeNode, or None if no children could be created. + """ + seen_ids: set = set() + children = self._toc_entries_to_nodes( + toc_entries, content, len(content), seen_ids, fallback_level=1, + ) + + if not children: + return None + + root_summary = await self._synthesize_root_summary(children) + return TreeNode( + node_id=self._unique_node_id(0, seen_ids), + title="Document", + summary=root_summary, + char_range=(0, len(content)), + level=0, + children=children, + ) + + @staticmethod + def _toc_entries_to_nodes( + entries: List[Any], + content: str, + parent_end: int, + seen_ids: set, + fallback_level: int, + ) -> List["TreeNode"]: + """Recursively convert TOCEntry trees into TreeNode trees. + + Handles arbitrary nesting depth and guards against invalid + char_start / char_end values. + """ + nodes: List[TreeNode] = [] + content_len = len(content) + for entry in entries: + start = max(0, min(entry.char_start, content_len)) + end = entry.char_end if entry.char_end and entry.char_end > start else parent_end + end = min(end, content_len) + + section_text = content[start:min(start + _TOC_NODE_SUMMARY_MAX_CHARS, end)] + nid = DocumentTreeIndexer._unique_node_id(start, seen_ids) + level = entry.level if entry.level > 0 else fallback_level + + child_nodes: List[TreeNode] = [] + if entry.children: + child_nodes = DocumentTreeIndexer._toc_entries_to_nodes( + entry.children, content, end, seen_ids, + fallback_level=level + 1, + ) + + node = TreeNode( + node_id=nid, + title=entry.title, + summary=section_text.strip(), + char_range=(start, end), + level=level, + children=child_nodes, + ) + nodes.append(node) + return nodes + + @staticmethod + def _unique_node_id(start: int, seen_ids: set) -> str: + """Generate a unique node_id based on char offset, appending a + disambiguator when collisions occur.""" + base = f"N{start:06d}" + if base not in seen_ids: + seen_ids.add(base) + return base + suffix = 1 + while f"{base}_{suffix}" in seen_ids: + suffix += 1 + nid = f"{base}_{suffix}" + seen_ids.add(nid) + return nid + + @staticmethod + def _compute_adaptive_depth(content_length: int) -> int: + """Compute max tree depth based on document length. + + Longer documents get deeper trees for finer-grained navigation. + Uses _TREE_ADAPTIVE_DEPTH_THRESHOLDS for threshold-based selection. + + Args: + content_length: Character count of the document. + + Returns: + Maximum tree depth (2-4). + """ + for threshold, depth in _TREE_ADAPTIVE_DEPTH_THRESHOLDS: + if content_length >= threshold: + return depth + return 2 # minimum depth + async def _build_node( self, text: str, level: int, max_depth: int, offset: int = 0, diff --git a/src/sirchmunk/search.py b/src/sirchmunk/search.py index 5128702..a9323fa 100644 --- a/src/sirchmunk/search.py +++ b/src/sirchmunk/search.py @@ -138,6 +138,38 @@ class CompileArtifacts: manifest_map: Dict[str, Any] = field(default_factory=dict) # {path: FileManifestEntry} +class _TreeNavCache: + """Per-search-session cache for tree navigation results. + + Avoids duplicate LLM navigation calls for the same file+query pair. + Created at the start of each ``_search_fast()`` invocation and reset + per search session. + """ + + __slots__ = ("_store",) + + def __init__(self) -> None: + self._store: Dict[str, Optional[List[Any]]] = {} + + @staticmethod + def _key(file_path: str, query: str) -> str: + import hashlib + return hashlib.md5(f"{file_path}:{query}".encode()).hexdigest() + + def get(self, file_path: str, query: str) -> Optional[List[Any]]: + """Retrieve cached navigation leaves for a file+query pair.""" + key = self._key(file_path, query) + return self._store.get(key) + + def has(self, file_path: str, query: str) -> bool: + """Check whether a cached result exists.""" + return self._key(file_path, query) in self._store + + def put(self, file_path: str, query: str, leaves: Optional[List[Any]]) -> None: + """Store navigation leaves for a file+query pair.""" + self._store[self._key(file_path, query)] = leaves + + class AgenticSearch(BaseSearch): def __init__( @@ -1518,6 +1550,29 @@ async def _search_deep( f"dir_scan_files={len(dir_scan_files)}" ) + # --- Phase 2.5: Parallel tree pre-navigation for top tree hits --- + _pre_nav_evidence: Dict[str, str] = {} + if tree_hits: + _nav_fps = [fp for fp in tree_hits[:self._DEEP_PRE_NAV_MAX_FILES]] + if _nav_fps: + _nav_results = await asyncio.gather( + *[self._tree_guided_sample( + fp, query, max_chars=self._FAST_MAX_EVIDENCE_CHARS, + ) for fp in _nav_fps], + return_exceptions=True, + ) + for fp, nav_res in zip(_nav_fps, _nav_results): + if isinstance(nav_res, Exception): + await self._logger.warning( + f"[Phase 2.5] Tree pre-nav failed for {Path(fp).name}: {nav_res}" + ) + elif isinstance(nav_res, str) and nav_res: + _pre_nav_evidence[fp] = nav_res + if _pre_nav_evidence: + await self._logger.info( + f"[Phase 2.5] Pre-navigated {len(_pre_nav_evidence)} tree files" + ) + # ============================================================== # Phase 3: Merge file paths + build KnowledgeCluster # P1 tree hits get highest priority; P2 soft-hit files next @@ -1547,6 +1602,17 @@ async def _search_deep( # ============================================================== graph_ctx = "" if cluster: + # Merge pre-navigated tree evidence into cluster content + if _pre_nav_evidence and cluster.content: + pre_nav_parts = [] + for fp, ev in _pre_nav_evidence.items(): + pre_nav_parts.append(f"[Tree evidence: {Path(fp).name}]\n{ev}") + if pre_nav_parts: + pre_nav_ctx = "\n\n".join(pre_nav_parts) + if isinstance(cluster.content, list): + cluster.content = "\n".join(cluster.content) + cluster.content = f"{cluster.content}\n\n{pre_nav_ctx}" + graph_ctx = await self._gather_graph_context(cluster) if graph_ctx and cluster.content: if isinstance(cluster.content, list): @@ -1946,6 +2012,10 @@ async def _llm_chunked_summarize(self, combined: str, query: str) -> str: """Max chars per tree section.""" _TREE_SAMPLE_RGA_SUPPLEMENT = True """Whether to append rga evidence after tree sections as supplementary context.""" + _TREE_ROOT_HINTS_MAX_FILES = 10 + """Maximum number of tree roots to include in FAST Step 1 hints.""" + _DEEP_PRE_NAV_MAX_FILES = 3 + """Maximum number of tree files to pre-navigate in DEEP Phase 2.5.""" _LLM_FALLBACK_EVIDENCE = ( "[No relevant documents found]\n\n" @@ -1982,6 +2052,9 @@ async def _search_fast( context = SearchContext() await self._logger.info(f"[FAST] Starting greedy search for: '{query[:80]}'") + # Reset per-session tree navigation cache + self._tree_nav_cache = _TreeNavCache() + # --- Adaptive compile artifact detection (one-shot, zero LLM) --- artifacts = self._detect_compile_artifacts() if artifacts.catalog or artifacts.tree_available_paths: @@ -2013,6 +2086,11 @@ async def _search_fast( catalog_routed_files: List[str] = [] catalog_confidence: str = "low" + # Build tree root hints for enhanced query analysis + tree_hints = "" + if artifacts and artifacts.tree_available_paths: + tree_hints = self._build_tree_root_hints(artifacts) + if catalog: listing = self._build_enriched_catalog_listing(catalog) prompt = FAST_QUERY_ANALYSIS_WITH_CATALOG.format( @@ -2021,6 +2099,10 @@ async def _search_fast( else: prompt = FAST_QUERY_ANALYSIS.format(user_input=query) + # Append tree structure hints to the prompt when available + if tree_hints: + prompt = prompt + tree_hints + resp = await self.llm.achat( messages=[{"role": "user", "content": prompt}], stream=False, @@ -3060,6 +3142,35 @@ def _detect_compile_artifacts(self) -> CompileArtifacts: manifest_map=manifest_map, ) + def _build_tree_root_hints(self, artifacts: CompileArtifacts) -> str: + """Build tree root summary hints for FAST Step 1 query analysis. + + Loads root summaries from cached trees and formats them as context + for the LLM to understand document-level structure. + + Args: + artifacts: Compile artifact context with tree metadata. + + Returns: + Formatted hint string, or empty string when no trees are available. + """ + if not artifacts.tree_available_paths: + return "" + indexer = artifacts.tree_indexer + if indexer is None: + return "" + hints: List[str] = [] + for i, fp in enumerate(sorted(artifacts.tree_available_paths)): + if i >= self._TREE_ROOT_HINTS_MAX_FILES: + break + tree = indexer.load_tree(fp) + if tree and tree.root and tree.root.summary: + name = Path(fp).name + hints.append(f"[{i}] {name}: {tree.root.summary[:150]}") + if not hints: + return "" + return "\nDocument structure hints:\n" + "\n".join(hints) + "\n" + @staticmethod def _tokenize_for_matching(text: str) -> Set[str]: """Tokenize text into meaningful units for keyword matching. From 29c0909b166ca1ee08a8464f4e50a4aad8ff43ef Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E7=8F=AD=E6=89=AC?= Date: Wed, 15 Apr 2026 21:18:22 +0800 Subject: [PATCH 11/56] enhance compile for excel and add embedding fallback for rga keywords retrieval --- src/sirchmunk/learnings/compiler.py | 265 +++++++++++++++++++++-- src/sirchmunk/learnings/summary_index.py | 255 ++++++++++++++++++++++ src/sirchmunk/search.py | 77 +++++++ 3 files changed, 578 insertions(+), 19 deletions(-) create mode 100644 src/sirchmunk/learnings/summary_index.py diff --git a/src/sirchmunk/learnings/compiler.py b/src/sirchmunk/learnings/compiler.py index fac9b79..2f8983a 100644 --- a/src/sirchmunk/learnings/compiler.py +++ b/src/sirchmunk/learnings/compiler.py @@ -41,7 +41,7 @@ _MERGE_SIMILARITY_THRESHOLD = 0.75 # Max chars for manifest-persisted document summary (used in Phase 2 & catalog) -_MANIFEST_SUMMARY_MAX_LEN = 250 +_MANIFEST_SUMMARY_MAX_LEN = 500 # Preview window for direct LLM summarisation (no tree), ~4K tokens _SUMMARY_PREVIEW_CHARS = 16_000 @@ -50,6 +50,13 @@ _SUMMARY_SAMPLE_SECTIONS = 3 # Number of sections to sample for large docs _SUMMARY_SAMPLE_SECTION_CHARS = 5_000 # Chars per sampled section +# Excel table-level adaptive sampling constants +_XLSX_TOTAL_ROW_BUDGET = 100 # Total sampled rows budget across all sheets +_XLSX_MIN_ROWS_PER_SHEET = 3 # Minimum sampled rows per sheet +_XLSX_MAX_ROWS_PER_SHEET = 50 # Maximum sampled rows per sheet +_XLSX_MAX_SHEETS = 10 # Maximum number of sheets to process +_XLSX_MAX_COLS_DISPLAY = 20 # Maximum columns to display per sheet + # --------------------------------------------------------------------------- # Data structures @@ -67,6 +74,7 @@ class FileManifestEntry: summary: str = "" # 新增:存储编译期生成的文档摘要 has_explicit_toc: bool = False # Whether a native TOC was extracted from the file tree_node_count: int = 0 # Number of nodes in the tree index (quality metric) + has_xlsx_digest: bool = False # Whether a pre-compiled Excel evidence digest exists def to_dict(self) -> Dict[str, Any]: return { @@ -78,6 +86,7 @@ def to_dict(self) -> Dict[str, Any]: "summary": self.summary, "has_explicit_toc": self.has_explicit_toc, "tree_node_count": self.tree_node_count, + "has_xlsx_digest": self.has_xlsx_digest, } @classmethod @@ -91,6 +100,7 @@ def from_dict(cls, data: Dict[str, Any]) -> "FileManifestEntry": summary=data.get("summary", ""), has_explicit_toc=data.get("has_explicit_toc", False), tree_node_count=data.get("tree_node_count", 0), + has_xlsx_digest=data.get("has_xlsx_digest", False), ) @@ -155,6 +165,7 @@ class FileCompileResult: error: Optional[str] = None has_explicit_toc: bool = False # Whether TOC was extracted from native structure tree_node_count: int = 0 # Number of nodes in the tree index + has_xlsx_digest: bool = False # Whether a pre-compiled Excel evidence digest exists @dataclass @@ -389,6 +400,7 @@ async def _bounded(entry: FileEntry) -> FileCompileResult: summary=result.summary[:_MANIFEST_SUMMARY_MAX_LEN] if result.summary else "", has_explicit_toc=result.has_explicit_toc, tree_node_count=result.tree_node_count, + has_xlsx_digest=result.has_xlsx_digest, ) # Phase 3: aggregate results into knowledge network @@ -412,6 +424,9 @@ async def _bounded(entry: FileEntry) -> FileCompileResult: # Generate document catalog for search-time routing self._build_document_catalog(manifest) + # Phase: Build summary index for embedding+BM25 fallback (optional, non-blocking) + await self._build_summary_index(manifest) + report.elapsed_seconds = time.monotonic() - t0 await self._log.info( f"[Compile] Done in {report.elapsed_seconds:.1f}s — " @@ -556,8 +571,16 @@ async def _compile_single_file( result.tree_node_count = self._count_tree_nodes(result.tree) # Enrich content with structural metadata for non-text types - metadata_prefix = self._extract_structured_metadata(entry.path, content) - enriched_content = metadata_prefix + content if metadata_prefix else content + ext = Path(entry.path).suffix.lower() + evidence_digest = "" + + if ext in (".xlsx", ".xls"): + # Excel: use adaptive sampling for both metadata and evidence + metadata_prefix, evidence_digest = self._extract_xlsx_sampling(entry.path) + enriched_content = metadata_prefix + content if metadata_prefix else content + else: + metadata_prefix = self._extract_structured_metadata(entry.path, content) + enriched_content = metadata_prefix + content if metadata_prefix else content result.summary = await self._extract_summary( entry.path, enriched_content, result.tree, @@ -565,6 +588,19 @@ async def _compile_single_file( result.topics = await self._extract_topics(result.summary) result.evidence = self._build_evidence(entry, content, result) + # Persist Excel evidence digest for search-time consumption + if evidence_digest.strip(): + try: + digest_dir = self._compile_dir / "xlsx_digests" + digest_dir.mkdir(parents=True, exist_ok=True) + file_hash = get_fast_hash(entry.path) or "" + if file_hash: + digest_path = digest_dir / f"{file_hash}.txt" + digest_path.write_text(evidence_digest, encoding="utf-8") + result.has_xlsx_digest = True + except Exception: + pass + except Exception as exc: result.error = str(exc) await self._log.warning(f"[Compile] Failed: {entry.path}: {exc}") @@ -637,38 +673,159 @@ def _extract_structured_metadata(file_path: str, content: str) -> str: ext = Path(file_path).suffix.lower() if ext == ".xlsx": - return KnowledgeCompiler._extract_xlsx_metadata(file_path) + metadata, _evidence = KnowledgeCompiler._extract_xlsx_sampling(file_path) + return metadata if ext == ".pptx": return KnowledgeCompiler._extract_pptx_metadata(file_path) return "" @staticmethod - def _extract_xlsx_metadata(file_path: str) -> str: - """Extract structural metadata from Excel files. + def _compute_xlsx_sample_rows(total_rows: int, num_sheets: int, sheet_rows: int) -> int: + """Compute adaptive sample row count per sheet. + + Strategy: + - Divides _XLSX_TOTAL_ROW_BUDGET equally across sheets + - Small sheets (<=budget) are fully sampled + - Large sheets are capped at budget + - Result clamped to [_XLSX_MIN_ROWS_PER_SHEET, _XLSX_MAX_ROWS_PER_SHEET] + """ + budget_per_sheet = max(1, _XLSX_TOTAL_ROW_BUDGET // max(1, num_sheets)) + n = min(sheet_rows, budget_per_sheet) + return max(_XLSX_MIN_ROWS_PER_SHEET, min(_XLSX_MAX_ROWS_PER_SHEET, n)) - Reads sheet names, row counts, and column headers (first row) to - provide the LLM with a structural overview of the workbook. - Caps at 10 sheets and 15 columns per sheet for bounded output. + @staticmethod + def _extract_xlsx_sampling(file_path: str) -> Tuple[str, str]: + """Extract structural metadata AND sampled content from Excel workbook. + + Performs table-level intelligent sampling with adaptive row counts + based on workbook size and sheet complexity. + + Returns: + (metadata_prefix, evidence_digest) + - metadata_prefix: injected into summary generation context + - evidence_digest: structured text usable directly as search evidence """ try: import openpyxl wb = openpyxl.load_workbook(file_path, read_only=True, data_only=True) - lines: List[str] = ["[Excel Workbook Structure]"] - for sheet_name in wb.sheetnames[:10]: # Cap at 10 sheets + + sheet_names = wb.sheetnames[:_XLSX_MAX_SHEETS] + num_sheets = len(sheet_names) + + # Phase 1: Collect sheet statistics + sheet_stats: List[Dict[str, Any]] = [] + for sheet_name in sheet_names: ws = wb[sheet_name] - # Extract column headers (first row) + row_count = ws.max_row or 0 + col_count = ws.max_column or 0 + # Read headers (first row) headers: List[str] = [] - for cell in ws.iter_rows(min_row=1, max_row=1, values_only=True): - headers = [str(h) for h in cell if h is not None] + for row in ws.iter_rows(min_row=1, max_row=1, values_only=True): + headers = [str(h) for h in row if h is not None] break - row_count = ws.max_row or 0 - header_str = ", ".join(headers[:15]) if headers else "no headers" - lines.append(f"- Sheet '{sheet_name}': {row_count} rows, columns: [{header_str}]") + sheet_stats.append({ + "name": sheet_name, + "rows": row_count, + "cols": col_count, + "headers": headers[:_XLSX_MAX_COLS_DISPLAY], + "ws": ws, + }) + + # Phase 2: Calculate total rows for adaptive sampling + total_rows = sum(s["rows"] for s in sheet_stats) + + meta_lines: List[str] = ["[Excel Workbook Structure]"] + evidence_lines: List[str] = [] + + for stat in sheet_stats: + ws = stat["ws"] + sheet_name = stat["name"] + row_count = stat["rows"] + col_count = stat["cols"] + headers = stat["headers"] + header_str = ", ".join(headers) if headers else "no headers" + + # Metadata line + meta_lines.append( + f"- Sheet '{sheet_name}': {row_count} rows, {col_count} columns, " + f"headers: [{header_str}]" + ) + + # Adaptive sampling + sample_n = KnowledgeCompiler._compute_xlsx_sample_rows( + total_rows, num_sheets, row_count + ) + + evidence_lines.append( + f"[Sheet '{sheet_name}' ({row_count} rows, {col_count} columns)]" + ) + evidence_lines.append(f"Columns: {header_str}") + + # Sample rows + if row_count <= sample_n: + evidence_lines.append(f"(Full content - {row_count} rows)") + else: + evidence_lines.append(f"Sample rows (top {sample_n} of {row_count}):") + + # Build table header + display_headers = headers[:_XLSX_MAX_COLS_DISPLAY] + if display_headers: + evidence_lines.append("| " + " | ".join(display_headers) + " |") + evidence_lines.append("|" + "|".join(["---"] * len(display_headers)) + "|") + + # Read sample rows (skip header row) + numeric_cols: Dict[int, List[float]] = {} # col_index -> numeric values + sampled = 0 + for row in ws.iter_rows( + min_row=2, + max_row=min(row_count, sample_n + 1), + values_only=True, + ): + cells: List[str] = [] + for ci, cell_val in enumerate(row): + if ci >= _XLSX_MAX_COLS_DISPLAY: + break + str_val = str(cell_val) if cell_val is not None else "" + cells.append(str_val[:50]) # truncate long cell values + # Track numeric values for statistics + if isinstance(cell_val, (int, float)) and cell_val == cell_val: + numeric_cols.setdefault(ci, []).append(float(cell_val)) + if cells: + evidence_lines.append("| " + " | ".join(cells) + " |") + sampled += 1 + + # Statistics for numeric columns + stat_parts: List[str] = [] + for ci, values in numeric_cols.items(): + if len(values) >= 2 and ci < len(display_headers): + col_name = display_headers[ci] + stat_parts.append( + f"{col_name} range [{min(values):.4g}-{max(values):.4g}]" + ) + if stat_parts: + evidence_lines.append(f"Statistics: {', '.join(stat_parts[:5])}") + + evidence_lines.append("") # blank line between sheets + wb.close() - return "\n".join(lines) + "\n\n" + + metadata = "\n".join(meta_lines) + "\n\n" + evidence = "\n".join(evidence_lines) + return metadata, evidence + except Exception: - return "" + return "", "" + + @staticmethod + def _extract_xlsx_metadata(file_path: str) -> str: + """Extract structural metadata from Excel files (legacy wrapper). + + Delegates to _extract_xlsx_sampling and returns only the metadata prefix + for backward compatibility. + """ + metadata, _evidence = KnowledgeCompiler._extract_xlsx_sampling(file_path) + return metadata @staticmethod def _extract_pptx_metadata(file_path: str) -> str: @@ -983,6 +1140,76 @@ def _count(node: Any) -> int: return _count(tree.root) + # ------------------------------------------------------------------ # + # Summary index for embedding + BM25 fallback # + # ------------------------------------------------------------------ # + + async def _build_summary_index(self, manifest: CompileManifest) -> None: + """Build summary embedding + BM25 index for fallback search. + + Creates a lightweight index mapping each compiled file to: + - Its summary text + - Pre-computed embedding vector (384-dim, if EmbeddingUtil available) + - Tokenized summary with term frequencies (via TokenizerUtil) + + The index is saved to .cache/compile/summary_index.json and consumed + by search.py as a last-resort fallback when rga keyword search fails. + + Skips gracefully if dependencies (EmbeddingUtil/TokenizerUtil) are unavailable. + """ + try: + from sirchmunk.utils.tokenizer_util import TokenizerUtil + from sirchmunk.learnings.summary_index import CompileSummaryIndex, SummaryIndexEntry + + entries: List[SummaryIndexEntry] = [] + summaries: List[str] = [] + + for file_path, entry in manifest.files.items(): + if entry.summary: + entries.append(SummaryIndexEntry( + file_path=file_path, + summary=entry.summary, + )) + summaries.append(entry.summary) + + if not entries: + return + + # Tokenize summaries + compute TF (always available) + tokenizer = TokenizerUtil() + for idx, entry in enumerate(entries): + tokens = tokenizer.segment(entry.summary) + entry.tokens = tokens + entry.token_freqs = {} + for t in tokens: + entry.token_freqs[t] = entry.token_freqs.get(t, 0) + 1 + + # Compute embeddings (optional — requires EmbeddingUtil) + try: + from sirchmunk.utils.embedding_util import EmbeddingUtil + embedding_util = EmbeddingUtil() + embedding_util.start_loading() + # Wait up to 60 seconds for model load + await embedding_util._ensure_model_async(timeout=60) + + if embedding_util.is_ready(): + embeddings = await embedding_util.embed(summaries) + for i, emb in enumerate(embeddings): + entries[i].embedding = emb + await self._log.info( + f"Summary index: computed embeddings for {len(entries)} entries" + ) + except Exception as emb_exc: + await self._log.warning( + f"Summary index: embedding computation skipped: {emb_exc}" + ) + + index = CompileSummaryIndex(entries) + index.save(self._compile_dir / "summary_index.json") + + except Exception as exc: + await self._log.warning(f"Failed to build summary index: {exc}") + # ------------------------------------------------------------------ # # Manifest I/O # # ------------------------------------------------------------------ # diff --git a/src/sirchmunk/learnings/summary_index.py b/src/sirchmunk/learnings/summary_index.py new file mode 100644 index 0000000..7ec355a --- /dev/null +++ b/src/sirchmunk/learnings/summary_index.py @@ -0,0 +1,255 @@ +"""Compile-time summary index for embedding + BM25 fallback retrieval. + +This module provides a lightweight, file-level index that combines: +- Semantic similarity via pre-computed embeddings (384-dim MiniLM) +- Lexical matching via BM25 scoring (TokenizerUtil segmentation) + +Used ONLY as a fallback when rga keyword search returns zero results. +""" + +import json +import math +import logging +from dataclasses import dataclass, field +from pathlib import Path +from typing import Any, Dict, List, Optional, Tuple + +logger = logging.getLogger(__name__) + + +@dataclass +class SummaryIndexEntry: + """Single file entry in the summary index.""" + file_path: str + summary: str + embedding: Optional[List[float]] = None # 384-dim, pre-normalized + tokens: Optional[List[str]] = None # TokenizerUtil.segment() output + token_freqs: Optional[Dict[str, int]] = None # pre-computed term frequencies + + +class CompileSummaryIndex: + """Pre-computed summary index for hybrid embedding + BM25 fallback search. + + This index is built at compile time and loaded at search time. + It provides a fallback retrieval mechanism when rga keyword search + returns zero results, combining semantic similarity (embedding cosine) + with lexical matching (BM25). + + The fusion algorithm uses Sigmoid Z-Score normalization: + 1. Compute raw scores from both channels + 2. Z-Score normalize each channel independently + 3. Weighted combination: alpha * z_embedding + (1-alpha) * z_bm25 + 4. Sigmoid activation for final score + """ + + # BM25 parameters (Okapi BM25 standard defaults) + _BM25_K1: float = 1.5 + _BM25_B: float = 0.75 + + # Fusion parameters + _DEFAULT_ALPHA: float = 0.5 # embedding weight; (1-alpha) = BM25 weight + + # Z-Score fallback for missing channel + _MISSING_CHANNEL_Z: float = -3.0 # ~0.1 percentile + + def __init__(self, entries: List[SummaryIndexEntry]) -> None: + self._entries = entries + self._num_docs = len(entries) + self._avg_doc_len = self._compute_avg_doc_len() + self._doc_freqs: Dict[str, int] = self._compute_doc_freqs() + + def _compute_avg_doc_len(self) -> float: + """Compute average document length (in tokens) across all entries.""" + lengths = [len(e.tokens or []) for e in self._entries] + return sum(lengths) / max(1, len(lengths)) + + def _compute_doc_freqs(self) -> Dict[str, int]: + """Compute document frequency for each unique token.""" + df: Dict[str, int] = {} + for entry in self._entries: + if entry.token_freqs: + for token in entry.token_freqs: + df[token] = df.get(token, 0) + 1 + return df + + @classmethod + def load(cls, index_path: Path) -> Optional["CompileSummaryIndex"]: + """Load index from JSON file. Returns None on failure.""" + try: + if not index_path.exists(): + return None + data = json.loads(index_path.read_text(encoding="utf-8")) + entries = [] + for item in data.get("entries", []): + entries.append(SummaryIndexEntry( + file_path=item["file_path"], + summary=item.get("summary", ""), + embedding=item.get("embedding"), + tokens=item.get("tokens"), + token_freqs=item.get("token_freqs"), + )) + if not entries: + return None + return cls(entries) + except Exception as exc: + logger.warning("Failed to load summary index from %s: %s", index_path, exc) + return None + + def save(self, index_path: Path) -> None: + """Persist index to JSON file.""" + index_path.parent.mkdir(parents=True, exist_ok=True) + data = { + "version": 1, + "num_entries": len(self._entries), + "entries": [ + { + "file_path": e.file_path, + "summary": e.summary, + "embedding": e.embedding, + "tokens": e.tokens, + "token_freqs": e.token_freqs, + } + for e in self._entries + ], + } + index_path.write_text( + json.dumps(data, ensure_ascii=False), + encoding="utf-8", + ) + logger.info("Summary index saved: %d entries -> %s", len(self._entries), index_path) + + def search( + self, + query_embedding: Optional[List[float]], + query_tokens: List[str], + top_k: int = 5, + alpha: float = _DEFAULT_ALPHA, + ) -> List[Tuple[str, float]]: + """Hybrid search combining embedding cosine similarity and BM25. + + Uses Sigmoid Z-Score fusion: + 1. Compute raw embedding cosine sim and BM25 score per document + 2. Z-Score normalize each channel + 3. Weighted linear combination + 4. Sigmoid activation + + Args: + query_embedding: 384-dim query vector (None to use BM25 only). + query_tokens: Tokenized query from TokenizerUtil.segment(). + top_k: Maximum number of results. + alpha: Embedding weight in [0, 1]. BM25 weight = 1 - alpha. + + Returns: + List of (file_path, fusion_score) sorted descending by score. + """ + if not self._entries: + return [] + + # Compute raw scores + emb_scores: List[Optional[float]] = [] + bm25_scores: List[float] = [] + + has_embedding = query_embedding is not None + + for entry in self._entries: + # Embedding channel + if has_embedding and entry.embedding: + emb_scores.append(self._cosine_similarity(query_embedding, entry.embedding)) + else: + emb_scores.append(None) + + # BM25 channel + bm25_scores.append(self._bm25_score(query_tokens, entry)) + + # Z-Score normalization + z_emb = self._z_score_normalize(emb_scores) + z_bm25 = self._z_score_normalize(bm25_scores) + + # Sigmoid fusion + results: List[Tuple[str, float]] = [] + for i, entry in enumerate(self._entries): + z_e = z_emb[i] if z_emb[i] is not None else self._MISSING_CHANNEL_Z + z_b = z_bm25[i] if z_bm25[i] is not None else self._MISSING_CHANNEL_Z + + combined = alpha * z_e + (1.0 - alpha) * z_b + score = 1.0 / (1.0 + math.exp(-combined)) + results.append((entry.file_path, score)) + + # Sort descending and return top_k + results.sort(key=lambda x: x[1], reverse=True) + return results[:top_k] + + def _bm25_score(self, query_tokens: List[str], entry: SummaryIndexEntry) -> float: + """Compute BM25 score for a single document. + + Uses standard Okapi BM25 formula: + score = sum over query terms: + IDF(t) * (tf * (k1 + 1)) / (tf + k1 * (1 - b + b * dl / avgdl)) + """ + if not query_tokens or not entry.token_freqs: + return 0.0 + + dl = len(entry.tokens or []) + score = 0.0 + + for token in query_tokens: + tf = entry.token_freqs.get(token, 0) + if tf == 0: + continue + + # IDF: log((N - df + 0.5) / (df + 0.5) + 1) + df = self._doc_freqs.get(token, 0) + idf = math.log((self._num_docs - df + 0.5) / (df + 0.5) + 1.0) + + # TF component + tf_component = (tf * (self._BM25_K1 + 1.0)) / ( + tf + self._BM25_K1 * (1.0 - self._BM25_B + self._BM25_B * dl / max(1.0, self._avg_doc_len)) + ) + + score += idf * tf_component + + return score + + @staticmethod + def _cosine_similarity(a: List[float], b: List[float]) -> float: + """Compute cosine similarity between two vectors. + + When embeddings are pre-normalized (L2 norm = 1), this reduces + to a simple dot product. + """ + if len(a) != len(b): + return 0.0 + dot = sum(x * y for x, y in zip(a, b)) + # Clamp to [-1, 1] for numerical safety + return max(-1.0, min(1.0, dot)) + + @staticmethod + def _z_score_normalize(scores: List[Optional[float]]) -> List[Optional[float]]: + """Z-Score normalize a list of scores, preserving None entries. + + None entries remain None (handled as _MISSING_CHANNEL_Z at fusion). + """ + valid = [s for s in scores if s is not None] + if len(valid) < 2: + # Not enough data points for meaningful normalization + return scores + + mean = sum(valid) / len(valid) + variance = sum((s - mean) ** 2 for s in valid) / len(valid) + std = math.sqrt(variance) if variance > 0 else 1.0 + + if std < 1e-9: + # All scores identical — return zeros + return [0.0 if s is not None else None for s in scores] + + return [(s - mean) / std if s is not None else None for s in scores] + + @property + def num_entries(self) -> int: + """Number of indexed documents.""" + return self._num_docs + + @property + def has_embeddings(self) -> bool: + """Whether any entry has a pre-computed embedding.""" + return any(e.embedding is not None for e in self._entries) diff --git a/src/sirchmunk/search.py b/src/sirchmunk/search.py index a9323fa..9b7bf47 100644 --- a/src/sirchmunk/search.py +++ b/src/sirchmunk/search.py @@ -136,6 +136,7 @@ class CompileArtifacts: tree_indexer: Optional[Any] # DocumentTreeIndexer (lazy import) tree_available_paths: Set[str] # file paths that have cached tree indices manifest_map: Dict[str, Any] = field(default_factory=dict) # {path: FileManifestEntry} + summary_index: Optional[Any] = None # CompileSummaryIndex (lazy-loaded) class _TreeNavCache: @@ -1998,6 +1999,8 @@ async def _llm_chunked_summarize(self, combined: str, query: str) -> str: """Max tree JSON files to parse during artifact detection.""" _CATALOG_LISTING_MAX_ENTRIES = 20 """Max catalog entries in the enriched listing for Step 1.""" + _ENABLE_EMBEDDING_FALLBACK: bool = True + """Enable embedding + BM25 hybrid fallback when rga returns zero results.""" _CATALOG_KEYWORD_MIN_LEN = 2 """Minimum character length for a catalog keyword token.""" _CATALOG_KEYWORD_MAX_LEN = 20 @@ -2371,6 +2374,22 @@ async def _rga_evidence() -> str: ext = Path(fp).suffix.lower() ev = None + # 0. Excel digest priority (pre-compiled evidence) + if artifacts and artifacts.manifest_map: + manifest_entry = artifacts.manifest_map.get(fp) + if manifest_entry and getattr(manifest_entry, 'has_xlsx_digest', False): + digest_path = ( + self.work_path / ".cache" / "compile" / "xlsx_digests" + / f"{manifest_entry.file_hash}.txt" + ) + if digest_path.exists(): + try: + digest_content = digest_path.read_text(encoding="utf-8") + if digest_content.strip(): + ev = f"[{fn} - Pre-compiled Evidence]\n{digest_content}" + except Exception: + pass + # 1. Tree-guided sampling FIRST for tree-indexed files if ( artifacts @@ -2857,6 +2876,53 @@ async def _fast_find_best_file( await self._logger.warning( f"[FAST] filename search failed: {exc}" ) + + # Layer 4: Embedding + BM25 hybrid fallback + # Triggered ONLY when layers 1-3 all return empty results + if (not all_raw + and self._ENABLE_EMBEDDING_FALLBACK + and artifacts is not None + and artifacts.summary_index is not None): + try: + query_emb = None + query_tokens: List[str] = [] + + # Compute query embedding (if embedding client available) + if (self.embedding_client + and self.embedding_client.is_ready() + and artifacts.summary_index.has_embeddings): + query_emb = (await self.embedding_client.embed([query]))[0] + + # Tokenize query for BM25 + from sirchmunk.utils.tokenizer_util import TokenizerUtil + _tokenizer = TokenizerUtil() + query_tokens = _tokenizer.segment(query) + + if query_emb is not None or query_tokens: + results = artifacts.summary_index.search( + query_embedding=query_emb, + query_tokens=query_tokens, + top_k=top_k or 3, + ) + + for file_path, score in results: + if Path(file_path).exists(): + all_raw.append({ + "path": file_path, + "matches": [], + "weighted_score": score * self._WIKI_MAX_SCORE, + }) + + if all_raw: + await self._logger.info( + f"[FAST] Embedding+BM25 fallback found {len(all_raw)} candidates" + ) + except Exception as exc: + await self._logger.warning( + f"[FAST] Embedding+BM25 fallback failed: {exc}" + ) + + if not all_raw: return None merged = GrepRetriever.merge_results(all_raw, limit=20) @@ -3134,12 +3200,23 @@ def _detect_compile_artifacts(self) -> CompileArtifacts: # Cache for future calls within this instance self._tree_paths_cache = tree_paths + # Load summary index for embedding fallback (optional) + summary_index = None + summary_index_path = self.work_path / ".cache" / "compile" / "summary_index.json" + if summary_index_path.exists(): + try: + from sirchmunk.learnings.summary_index import CompileSummaryIndex + summary_index = CompileSummaryIndex.load(summary_index_path) + except Exception: + pass + return CompileArtifacts( catalog=catalog, catalog_map=catalog_map, tree_indexer=indexer, tree_available_paths=tree_paths, manifest_map=manifest_map, + summary_index=summary_index, ) def _build_tree_root_hints(self, artifacts: CompileArtifacts) -> str: From d1f1fd4a0c425689f8fb8cef0cc037a9d463f3bc Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E7=8F=AD=E6=89=AC?= Date: Wed, 15 Apr 2026 21:35:05 +0800 Subject: [PATCH 12/56] fix storage --- src/sirchmunk/storage/knowledge_storage.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/sirchmunk/storage/knowledge_storage.py b/src/sirchmunk/storage/knowledge_storage.py index e62c1cf..0f09071 100644 --- a/src/sirchmunk/storage/knowledge_storage.py +++ b/src/sirchmunk/storage/knowledge_storage.py @@ -124,7 +124,7 @@ def _load_from_parquet(self): # Detect parquet columns to handle schema evolution try: pq_cols = self.db.fetch_all( - f"SELECT column_name FROM parquet_schema('{self.parquet_file}')" + f"SELECT name FROM parquet_schema('{self.parquet_file}')" ) pq_col_names = {row[0] for row in pq_cols} except Exception: From caf8e052f112ed17d5963b7b969ef8b18a1e14c2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E7=8F=AD=E6=89=AC?= Date: Thu, 16 Apr 2026 15:56:01 +0800 Subject: [PATCH 13/56] add financebench --- benchmarks/financebench/README.md | 103 +++++++ benchmarks/financebench/analyze_results.py | 272 +++++++++++++++++ benchmarks/financebench/config.py | 99 +++++++ benchmarks/financebench/data_loader.py | 108 +++++++ benchmarks/financebench/evaluate.py | 323 +++++++++++++++++++++ benchmarks/financebench/run_benchmark.py | 239 +++++++++++++++ benchmarks/financebench/runner.py | 279 ++++++++++++++++++ 7 files changed, 1423 insertions(+) create mode 100644 benchmarks/financebench/README.md create mode 100644 benchmarks/financebench/analyze_results.py create mode 100644 benchmarks/financebench/config.py create mode 100644 benchmarks/financebench/data_loader.py create mode 100644 benchmarks/financebench/evaluate.py create mode 100644 benchmarks/financebench/run_benchmark.py create mode 100644 benchmarks/financebench/runner.py diff --git a/benchmarks/financebench/README.md b/benchmarks/financebench/README.md new file mode 100644 index 0000000..23bd67d --- /dev/null +++ b/benchmarks/financebench/README.md @@ -0,0 +1,103 @@ +# FinanceBench Benchmark + +FinanceBench evaluation pipeline for **Sirchmunk AgenticSearch**. + +## Overview + +[FinanceBench](https://arxiv.org/abs/2311.11944) is an open-book financial QA benchmark +with **150 expert-annotated questions** across **40+ US public companies** (10-K/10-Q filings). + +### Evaluation Modes + +| Mode | Description | +|------|-------------| +| `singleDoc` | Each question searches only its target PDF (standard) | +| `sharedCorpus` | All questions search the full 41-PDF corpus | + +### Metrics + +- **3-Class Scoring**: Correct / Hallucination / Refusal (per FinanceBench paper) +- **EM / F1**: Exact Match and token-level F1 with financial value normalisation +- **Evidence Recall**: Retrieved pages vs gold evidence pages + +## Quick Start + +### 1. Setup + +```bash +cd benchmarks/financebench + +# Copy and edit the config file +cp .env.example .env.financebench +# Edit .env.financebench — set your LLM_API_KEY at minimum + +# Download FinanceBench data +# Place financebench_open_source.jsonl in ./data/ +# Place PDF corpus (41 files) in ./data/pdfs/ +``` + +### 2. Run + +```bash +# Run full benchmark (150 questions) +python run_benchmark.py + +# Run with custom config and question limit +python run_benchmark.py --env .env.financebench --limit 20 +``` + +### 3. Analyze + +```bash +# Analyze a completed run +python analyze_results.py output/results_YYYYMMDD_HHMMSS.jsonl + +# Show more error cases +python analyze_results.py output/results_*.jsonl --max-errors 50 +``` + +## Data Format + +The dataset file `financebench_open_source.jsonl` contains one JSON object per line: + +```json +{ + "financebench_id": "financebench_id_00001", + "question": "What is the FY2018 capital expenditure amount for 3M?", + "answer": "$1,577.00", + "doc_name": "3M_2018_10K", + "company": "3M", + "question_type": "fact-based-w-numerical-answer", + "question_reasoning": "retrieve", + "evidence": [{"evidence_text": "...", "evidence_page_num": 42}] +} +``` + +## File Structure + +``` +benchmarks/financebench/ +├── .env.example # Config template (copy to .env.financebench) +├── config.py # FinanceBenchConfig dataclass +├── data_loader.py # Dataset + PDF corpus loader +├── evaluate.py # EM/F1/3-class scoring + aggregation +├── runner.py # Async batch runner (AgenticSearch) +├── run_benchmark.py # CLI entry point +├── analyze_results.py # Post-hoc analysis tool +├── data/ +│ ├── financebench_open_source.jsonl +│ └── pdfs/ # 41 SEC-filing PDFs +├── output/ # Results + metrics (auto-created) +└── logs/ # Run logs (auto-created) +``` + +## SOTA Reference + +| System | Accuracy | Coverage | +|--------|----------|----------| +| Mafin 2.5 (SOTA) | 98.7% | 100% | +| Fintool | 98.0% | 66.7% | +| Quantly | 94.0% | 100% | +| GPT-4 (zero-shot) | 29.3% | 100% | + +> Mafin 2.5 uses PageIndex + Agentic Vectorless RAG 3.0 architecture. diff --git a/benchmarks/financebench/analyze_results.py b/benchmarks/financebench/analyze_results.py new file mode 100644 index 0000000..24d2b64 --- /dev/null +++ b/benchmarks/financebench/analyze_results.py @@ -0,0 +1,272 @@ +"""Analyze FinanceBench benchmark results. + +Read a JSONL results file produced by ``run_benchmark.py`` and print a +comprehensive analysis report including per-type breakdowns, per-company +accuracy, error cases, and a SOTA comparison table. + +Usage: + python analyze_results.py output/results_YYYYMMDD_HHMMSS.jsonl + python analyze_results.py output/results_*.jsonl --max-errors 30 +""" +from __future__ import annotations + +import argparse +import json +import sys +from collections import defaultdict +from pathlib import Path +from typing import Any, Dict, List, Tuple + +from evaluate import compute_metrics + + +# --------------------------------------------------------------------------- +# Data loading +# --------------------------------------------------------------------------- + + +def load_results(path: str) -> List[Dict[str, Any]]: + """Load a JSONL results file into a list of dicts. + + Args: + path: Path to a ``.jsonl`` file where each line is a JSON object. + + Returns: + List of result dicts. + + Raises: + FileNotFoundError: If *path* does not exist. + json.JSONDecodeError: If a line contains invalid JSON. + """ + p = Path(path) + if not p.exists(): + print(f"ERROR: file not found — {path}", file=sys.stderr) + sys.exit(1) + + results: list[dict] = [] + with open(p, encoding="utf-8") as f: + for lineno, line in enumerate(f, 1): + line = line.strip() + if not line: + continue + try: + results.append(json.loads(line)) + except json.JSONDecodeError as exc: + print(f"WARNING: skipping malformed line {lineno}: {exc}", file=sys.stderr) + return results + + +# --------------------------------------------------------------------------- +# Pretty-print helpers +# --------------------------------------------------------------------------- + + +def print_breakdown(title: str, breakdown: Dict[str, Dict[str, Any]]) -> None: + """Pretty-print a metrics breakdown table. + + Args: + title: Section header text. + breakdown: ``{group_name: {accuracy, hallucination_rate, ...}}``. + """ + print(f"\n=== Breakdown by {title} ===\n") + header = f" {'Group':<30} {'Acc%':>6} {'Hallu%':>7} {'Refuse%':>8} {'N':>4}" + print(header) + print(" " + "-" * (len(header) - 2)) + + for group, m in sorted(breakdown.items(), key=lambda kv: -kv[1].get("accuracy", 0)): + acc = m.get("accuracy", 0) + hal = m.get("hallucination_rate", 0) + ref = m.get("refusal_rate", 0) + n = m.get("n", 0) + print(f" {group:<30} {acc:>5.1f} {hal:>7.1f} {ref:>7.1f} {n:>4}") + + +def _compute_company_breakdown( + results: List[Dict[str, Any]], +) -> List[Tuple[str, float, int, int, int]]: + """Group results by company and return sorted by accuracy ascending. + + Returns: + List of ``(company, accuracy, correct, total, halluc)`` tuples, + sorted by accuracy ascending (worst first). + """ + groups: dict[str, list[dict]] = defaultdict(list) + for r in results: + company = r.get("company", "unknown") or "unknown" + groups[company].append(r) + + rows: list[tuple[str, float, int, int, int]] = [] + for company, items in groups.items(): + n = len(items) + correct = sum(1 for r in items if r.get("classification") == "correct") + halluc = sum(1 for r in items if r.get("classification") == "hallucination") + acc = (correct / n * 100) if n else 0.0 + rows.append((company, acc, correct, n, halluc)) + + rows.sort(key=lambda x: x[1]) # worst first + return rows + + +def print_company_breakdown(results: List[Dict[str, Any]], top_n: int = 10) -> None: + """Print per-company accuracy table, showing worst *top_n* companies. + + Args: + results: List of per-question result dicts. + top_n: Number of worst-performing companies to display. + """ + rows = _compute_company_breakdown(results) + if not rows: + return + + print(f"\n=== Worst {top_n} Companies by Accuracy ===\n") + header = f" {'Company':<40} {'Acc%':>6} {'Correct':>8} {'Hallu':>6} {'N':>4}" + print(header) + print(" " + "-" * (len(header) - 2)) + + for company, acc, correct, n, halluc in rows[:top_n]: + print(f" {company:<40} {acc:>5.1f} {correct:>8} {halluc:>6} {n:>4}") + + +def print_error_cases(results: List[Dict[str, Any]], max_show: int = 20) -> None: + """Print detailed listing of error cases (hallucination + refusal). + + Args: + results: List of per-question result dicts. + max_show: Maximum number of error cases to display. + """ + errors = [r for r in results if r.get("classification") != "correct"] + if not errors: + print("\n=== Error Cases ===\n None — perfect score!") + return + + print(f"\n=== Error Cases ({len(errors)} total, showing up to {max_show}) ===\n") + + for i, r in enumerate(errors[:max_show], 1): + fb_id = r.get("financebench_id", "?") + cls = r.get("classification", "?") + question = r.get("question", "")[:100] + pred = r.get("prediction", "")[:80] + gold = r.get("gold_answer", "")[:80] + company = r.get("company", "") + em = r.get("em", False) + f1 = r.get("f1", 0.0) + + print(f" [{i:>2}] {fb_id} [{cls.upper()}]") + print(f" Company: {company}") + print(f" Question: {question}{'...' if len(r.get('question', '')) > 100 else ''}") + print(f" Predicted: {pred}{'...' if len(r.get('prediction', '')) > 80 else ''}") + print(f" Gold: {gold}{'...' if len(r.get('gold_answer', '')) > 80 else ''}") + print(f" EM={em} F1={f1:.3f}") + if r.get("error"): + print(f" Error: {r['error'][:120]}") + print() + + if len(errors) > max_show: + print(f" ... and {len(errors) - max_show} more error(s) not shown.\n") + + +def print_comparison_with_sota(metrics: Dict[str, Any]) -> None: + """Compare with published SOTA results on FinanceBench. + + Reference baselines from the FinanceBench leaderboard and recent papers. + """ + print("\n=== Comparison with SOTA ===\n") + header = f" {'System':<30} {'Accuracy':>10} {'Coverage':>10}" + print(header) + print(" " + "-" * (len(header) - 2)) + print(f" {'Mafin 2.5 (SOTA)':<30} {'98.7%':>10} {'100%':>10}") + print(f" {'Fintool':<30} {'98.0%':>10} {'66.7%':>10}") + print(f" {'Quantly':<30} {'94.0%':>10} {'100%':>10}") + print(f" {'GPT-4 (zero-shot)':<30} {'29.3%':>10} {'100%':>10}") + + acc = metrics.get("accuracy", 0) + n = metrics.get("n", 0) + coverage = min(100.0, n / 150.0 * 100) + print(f" {'Sirchmunk (This Run)':<30} {f'{acc:.1f}%':>10} {f'{coverage:.0f}%':>10}") + print(f"\n (This run evaluated {n} questions)") + + +# --------------------------------------------------------------------------- +# Main +# --------------------------------------------------------------------------- + + +def main() -> None: + """Parse CLI arguments and generate a full analysis report.""" + parser = argparse.ArgumentParser( + description="Analyze FinanceBench benchmark results from a JSONL file", + ) + parser.add_argument( + "results_file", + help="Path to the results JSONL file produced by run_benchmark.py", + ) + parser.add_argument( + "--max-errors", + type=int, + default=20, + help="Maximum number of error cases to display (default: 20)", + ) + parser.add_argument( + "--top-companies", + type=int, + default=10, + help="Number of worst-performing companies to show (default: 10)", + ) + args = parser.parse_args() + + # Load + results = load_results(args.results_file) + if not results: + print("ERROR: no results loaded.", file=sys.stderr) + sys.exit(1) + + # Compute metrics + metrics = compute_metrics(results) + + # --- Overall summary --- + n = metrics.get("n", 0) + acc = metrics.get("accuracy", 0) + hallu = metrics.get("hallucination_rate", 0) + refuse = metrics.get("refusal_rate", 0) + avg_em = metrics.get("avg_em", 0) + avg_f1 = metrics.get("avg_f1", 0) + ev_recall = metrics.get("evidence_recall") + avg_latency = metrics.get("avg_latency", 0) + + print(f"\n{'=' * 60}") + print(f" FinanceBench Analysis ({n} questions)") + print(f"{'=' * 60}") + print(f" Accuracy: {acc:.1f}%") + print(f" Hallucination Rate: {hallu:.1f}%") + print(f" Refusal Rate: {refuse:.1f}%") + print(f" Avg EM: {avg_em:.3f}") + print(f" Avg F1: {avg_f1:.3f}") + if metrics.get("avg_evidence_recall") is not None: + print(f" Evidence Recall: {metrics['avg_evidence_recall']:.3f}") + else: + print(f" Evidence Recall: N/A (page-level telemetry unavailable)") + print(f" Avg Latency: {avg_latency:.1f}s") + + # --- Breakdowns --- + if "by_question_type" in metrics: + print_breakdown("Question Type", metrics["by_question_type"]) + + if "by_question_reasoning" in metrics: + print_breakdown("Question Reasoning", metrics["by_question_reasoning"]) + + # --- Per-company breakdown (worst performers) --- + print_company_breakdown(results, top_n=args.top_companies) + + # --- Error cases --- + print_error_cases(results, max_show=args.max_errors) + + # --- SOTA comparison --- + print_comparison_with_sota(metrics) + + print(f"\n{'=' * 60}") + print(f" Source: {args.results_file}") + print(f"{'=' * 60}\n") + + +if __name__ == "__main__": + main() diff --git a/benchmarks/financebench/config.py b/benchmarks/financebench/config.py new file mode 100644 index 0000000..f2e0fdb --- /dev/null +++ b/benchmarks/financebench/config.py @@ -0,0 +1,99 @@ +"""FinanceBench benchmark configuration.""" +from __future__ import annotations + +import os +from dataclasses import dataclass +from pathlib import Path + + +@dataclass +class FinanceBenchConfig: + """All settings for a FinanceBench evaluation run.""" + + # LLM + llm_api_key: str = "" + llm_base_url: str = "https://dashscope.aliyuncs.com/compatible-mode/v1" + llm_model: str = "qwen3.5-plus" + llm_timeout: int = 120 + + # Data paths + data_dir: str = "./data" + pdf_dir: str = "./data/pdfs" + output_dir: str = "./output" + + # Dataset + limit: int = 0 # 0 = all 150 + seed: int = 42 + + # Search + mode: str = "FAST" + top_k_files: int = 5 + max_token_budget: int = 128000 + enable_dir_scan: bool = True + + # Evaluation + eval_mode: str = "singleDoc" # singleDoc / sharedCorpus + enable_llm_judge: bool = True # TODO: LLM Judge not yet implemented, reserved for future use + extract_answer: bool = True + + # Concurrency + max_concurrent: int = 3 + request_delay: float = 0.5 + + @classmethod + def from_env(cls, env_path: str = ".env.financebench") -> "FinanceBenchConfig": + """Load config from .env file with ``os.environ`` fallback.""" + # Read .env file + env_vars: dict[str, str] = {} + p = Path(env_path) + if p.exists(): + for line in p.read_text(encoding="utf-8").splitlines(): + line = line.strip() + if not line or line.startswith("#"): + continue + if "=" in line: + k, v = line.split("=", 1) + env_vars[k.strip()] = v.strip() + + def _get(key: str, default: str = "") -> str: + return env_vars.get(key, os.environ.get(key, default)) + + def _bool(key: str, default: bool = False) -> bool: + v = _get(key, str(default)).lower() + return v in ("true", "1", "yes") + + def _int(key: str, default: int = 0) -> int: + try: + return int(_get(key, str(default))) + except (ValueError, TypeError): + return default + + def _float(key: str, default: float = 0.0) -> float: + try: + return float(_get(key, str(default))) + except (ValueError, TypeError): + return default + + return cls( + llm_api_key=_get("LLM_API_KEY"), + llm_base_url=_get( + "LLM_BASE_URL", + "https://dashscope.aliyuncs.com/compatible-mode/v1", + ), + llm_model=_get("LLM_MODEL_NAME", "qwen3.5-plus"), + llm_timeout=_int("LLM_TIMEOUT", 120), + data_dir=_get("FB_DATA_DIR", "./data"), + pdf_dir=_get("FB_PDF_DIR", "./data/pdfs"), + output_dir=_get("FB_OUTPUT_DIR", "./output"), + limit=_int("FB_LIMIT", 0), + seed=_int("FB_SEED", 42), + mode=_get("FB_MODE", "FAST"), + top_k_files=_int("FB_TOP_K_FILES", 5), + max_token_budget=_int("FB_MAX_TOKEN_BUDGET", 128000), + enable_dir_scan=_bool("FB_ENABLE_DIR_SCAN", True), + eval_mode=_get("FB_EVAL_MODE", "singleDoc"), + enable_llm_judge=_bool("FB_ENABLE_LLM_JUDGE", True), + extract_answer=_bool("FB_EXTRACT_ANSWER", True), + max_concurrent=_int("FB_MAX_CONCURRENT", 3), + request_delay=_float("FB_REQUEST_DELAY", 0.5), + ) diff --git a/benchmarks/financebench/data_loader.py b/benchmarks/financebench/data_loader.py new file mode 100644 index 0000000..7770865 --- /dev/null +++ b/benchmarks/financebench/data_loader.py @@ -0,0 +1,108 @@ +"""FinanceBench dataset loader.""" +from __future__ import annotations + +import json +from pathlib import Path +from typing import Any, Dict, List, Optional, Set, Tuple + + +class FinanceBenchLoader: + """Load and validate FinanceBench JSONL data. + + Expects: + - ``data_dir/financebench_open_source.jsonl`` – 150 QA rows + - ``data_dir/financebench_document_information.jsonl`` – doc metadata (optional) + - ``pdf_dir/`` – corpus of 41 SEC-filing PDFs named by ``doc_name`` + """ + + _QUESTIONS_FILE = "financebench_open_source.jsonl" + _DOC_INFO_FILE = "financebench_document_information.jsonl" + + def __init__(self, data_dir: str, pdf_dir: str) -> None: + self._data_dir = Path(data_dir) + self._pdf_dir = Path(pdf_dir) + + # ------------------------------------------------------------------ + # Public API + # ------------------------------------------------------------------ + + def load_questions(self) -> List[Dict[str, Any]]: + """Load the 150 open-source questions from JSONL. + + Raises: + FileNotFoundError: If the questions file is missing. + """ + path = self._data_dir / self._QUESTIONS_FILE + if not path.exists(): + raise FileNotFoundError(f"Questions file not found: {path}") + items: list[dict] = [] + with open(path, encoding="utf-8") as f: + for line in f: + line = line.strip() + if line: + items.append(json.loads(line)) + return items + + def load_doc_info(self) -> Dict[str, Dict[str, Any]]: + """Load document metadata, keyed by ``doc_name``. + + Returns an empty dict when the file is absent (it is optional). + """ + path = self._data_dir / self._DOC_INFO_FILE + if not path.exists(): + return {} + result: dict[str, dict] = {} + with open(path, encoding="utf-8") as f: + for line in f: + line = line.strip() + if line: + obj = json.loads(line) + doc_name = obj.get("doc_name", "") + if doc_name: + result[doc_name] = obj + return result + + def get_pdf_path(self, doc_name: str) -> Optional[str]: + """Resolve *doc_name* to a PDF file path. + + Resolution order: + 1. ``/.pdf`` + 2. ``/`` (file with no extension) + 3. Case-insensitive stem match across ``pdf_dir`` + """ + candidates = [ + self._pdf_dir / f"{doc_name}.pdf", + self._pdf_dir / doc_name, + ] + for c in candidates: + if c.exists(): + return str(c) + # Case-insensitive fallback + if self._pdf_dir.exists(): + lower = doc_name.lower() + for f in self._pdf_dir.iterdir(): + if f.stem.lower() == lower: + return str(f) + return None + + def get_unique_docs(self, questions: List[Dict[str, Any]]) -> Set[str]: + """Extract the unique set of ``doc_name`` values from *questions*.""" + return {q["doc_name"] for q in questions if "doc_name" in q} + + def validate_corpus( + self, questions: List[Dict[str, Any]] + ) -> Tuple[int, List[str]]: + """Check PDF availability for all referenced documents. + + Returns: + A tuple of ``(found_count, missing_doc_names)``. + """ + docs = self.get_unique_docs(questions) + missing: list[str] = [] + found = 0 + for doc in sorted(docs): + if self.get_pdf_path(doc): + found += 1 + else: + missing.append(doc) + return found, missing diff --git a/benchmarks/financebench/evaluate.py b/benchmarks/financebench/evaluate.py new file mode 100644 index 0000000..688cf41 --- /dev/null +++ b/benchmarks/financebench/evaluate.py @@ -0,0 +1,323 @@ +"""FinanceBench evaluation metrics. + +Implements the three-class scoring scheme from the FinanceBench paper +(Islam et al., 2023): **correct**, **hallucination**, **refusal**. + +Financial-value normalisation handles currency symbols, thousand separators, +trailing zeros, and percentage signs so that ``$1,577.00`` matches ``1577``. +""" +from __future__ import annotations + +import re +from collections import Counter, defaultdict +from typing import Any, Dict, List + +# ------------------------------------------------------------------ +# Constants +# ------------------------------------------------------------------ + +_REFUSAL_PHRASES: list[str] = [ + "i cannot", + "i can't", + "i could not", + "i couldn't", + "no results found", + "unable to", + "not able to", + "i don't know", + "i do not know", + "information is not available", + "not enough information", + "cannot determine", + "cannot be determined", + "insufficient data", + "no relevant information", + "data not found", + "unknown", +] + +_F1_CORRECT_THRESHOLD: float = 0.8 + +# Markdown / wrapper patterns compiled once +_RE_BOLD = re.compile(r"\*\*(.+?)\*\*") +_RE_ITALIC = re.compile(r"\*(.+?)\*") +_RE_QUOTES = re.compile(r'^["\u201c\u201d\']+|["\u201c\u201d\']+$') +_RE_ANSWER_PREFIX = re.compile( + r"^(the\s+(short\s+)?answer\s+is\s*:?\s*|answer\s*:\s*|short\s+answer\s*:\s*)", + re.IGNORECASE, +) +# Financial value helpers +_RE_DOLLAR = re.compile(r"^\$\s*") +_RE_THOUSAND_SEP = re.compile(r",(\d{3})") +_RE_TRAILING_ZEROS = re.compile(r"\.0+$") + + +# ------------------------------------------------------------------ +# Normalisation +# ------------------------------------------------------------------ + + +def normalize_answer(answer: str) -> str: + """Normalise an answer string for comparison. + + Steps: + 1. Strip Markdown bold / italic. + 2. Strip surrounding quotes. + 3. Strip trailing punctuation (``.``, ``:``). + 4. Remove common LLM wrapper phrases. + 5. Financial value normalisation (currency, commas, trailing zeros). + 6. Lowercase. + """ + s = answer.strip() + if not s: + return "" + + # 1. Markdown + s = _RE_BOLD.sub(r"\1", s) + s = _RE_ITALIC.sub(r"\1", s) + + # 2. Quotes + s = _RE_QUOTES.sub("", s).strip() + + # 3. Trailing punctuation + s = s.rstrip(".:") + + # 4. Wrapper phrases + s = _RE_ANSWER_PREFIX.sub("", s).strip() + + # 5. Financial normalisation + s = _normalize_financial_value(s) + + # 6. Lowercase + return s.lower().strip() + + +def _normalize_financial_value(text: str) -> str: + """Normalise financial figures for robust comparison. + + - ``$1,577.00`` → ``1577`` + - ``15.3%`` → ``15.3%`` + - ``$1577`` → ``1577`` + - ``1,577`` → ``1577`` + """ + s = text.strip() + + # Detect if value looks numeric (possibly with $ / % / commas) + stripped_for_check = _RE_DOLLAR.sub("", s) + stripped_for_check = stripped_for_check.replace(",", "").rstrip("%").strip() + try: + float(stripped_for_check) + except ValueError: + return s # Not a numeric value – return as-is + + # Remove dollar sign + s = _RE_DOLLAR.sub("", s) + + # Remember and temporarily strip percentage + has_pct = s.endswith("%") + if has_pct: + s = s[:-1].strip() + + # Remove thousand-separator commas + s = s.replace(",", "") + + # Remove trailing decimal zeros: 1577.00 → 1577, 15.30 → 15.3 + if "." in s: + s = s.rstrip("0").rstrip(".") + + # Re-attach percentage + if has_pct: + s = s + "%" + + return s + + +# ------------------------------------------------------------------ +# Matching helpers +# ------------------------------------------------------------------ + + +def exact_match(prediction: str, gold: str) -> bool: + """Return ``True`` when normalised strings are identical.""" + return normalize_answer(prediction) == normalize_answer(gold) + + +def f1_score(prediction: str, gold: str) -> float: + """Compute token-level F1 between *prediction* and *gold*. + + Tokenisation is simple whitespace splitting after normalisation. + Each token is further normalised as a financial value so that + ``$1577`` matches ``1577`` at the token level. + Returns 0.0 when either side is empty. + """ + pred_tokens = [_normalize_financial_value(t) for t in normalize_answer(prediction).split()] + gold_tokens = [_normalize_financial_value(t) for t in normalize_answer(gold).split()] + if not pred_tokens or not gold_tokens: + return 0.0 + + common = Counter(pred_tokens) & Counter(gold_tokens) + num_common = sum(common.values()) + if num_common == 0: + return 0.0 + + precision = num_common / len(pred_tokens) + recall = num_common / len(gold_tokens) + return 2 * precision * recall / (precision + recall) + + +# ------------------------------------------------------------------ +# Three-class classification +# ------------------------------------------------------------------ + + +def classify_answer( + prediction: str, + gold: str, + *, + is_no_result: bool = False, + f1_threshold: float = _F1_CORRECT_THRESHOLD, +) -> str: + """Classify a prediction into ``correct``, ``refusal``, or ``hallucination``. + + Classification logic (faithful to FinanceBench paper): + 1. If the system explicitly refused (``is_no_result=True``) or the + prediction contains a refusal phrase → **refusal**. + 2. If EM passes or token-level F1 ≥ *f1_threshold* → **correct**. + 3. Otherwise → **hallucination**. + """ + norm_pred = normalize_answer(prediction) + + # --- Refusal --- + if is_no_result: + return "refusal" + pred_lower = norm_pred.lower() + for phrase in _REFUSAL_PHRASES: + if phrase in pred_lower: + return "refusal" + + # --- Correct --- + if exact_match(prediction, gold): + return "correct" + if f1_score(prediction, gold) >= f1_threshold: + return "correct" + + # --- Hallucination --- + return "hallucination" + + +# ------------------------------------------------------------------ +# Evidence recall +# ------------------------------------------------------------------ + + +def evidence_recall( + retrieved_pages: List[int], + gold_evidence: List[Dict[str, Any]], +) -> float: + """Compute page-level evidence recall. + + ``gold_evidence`` entries carry ``evidence_page_num`` (0-indexed). + Returns 1.0 when there is no gold evidence (vacuously true). + """ + if not gold_evidence: + return 1.0 + + gold_pages = { + int(e["evidence_page_num"]) + for e in gold_evidence + if "evidence_page_num" in e + } + if not gold_pages: + return 1.0 + + retrieved_set = set(retrieved_pages) + hits = gold_pages & retrieved_set + return len(hits) / len(gold_pages) + + +# ------------------------------------------------------------------ +# Aggregate metrics +# ------------------------------------------------------------------ + + +def compute_metrics(results: List[Dict[str, Any]]) -> Dict[str, Any]: + """Aggregate per-question results into benchmark-level metrics. + + Expected keys per result dict: ``classification``, ``em``, ``f1``, + ``elapsed``, ``telemetry``, ``question_type``, ``question_reasoning``, + ``evidence_recall`` (optional). + + Returns a dict with overall stats plus breakdowns by *question_type* + and *question_reasoning*. + """ + n = len(results) + if n == 0: + return {"n": 0} + + # --- Overall counts --- + correct = sum(1 for r in results if r.get("classification") == "correct") + halluc = sum(1 for r in results if r.get("classification") == "hallucination") + refusal = sum(1 for r in results if r.get("classification") == "refusal") + + em_sum = sum(1 for r in results if r.get("em")) + f1_sum = sum(r.get("f1", 0.0) for r in results) + + latencies = [r["elapsed"] for r in results if "elapsed" in r] + avg_latency = sum(latencies) / len(latencies) if latencies else 0.0 + + token_counts = [ + r.get("telemetry", {}).get("total_tokens", 0) for r in results + ] + avg_tokens = sum(token_counts) / len(token_counts) if token_counts else 0 + + ev_recalls = [r["evidence_recall"] for r in results if r.get("evidence_recall") is not None] + avg_ev_recall = sum(ev_recalls) / len(ev_recalls) if ev_recalls else None + + overall = { + "n": n, + "accuracy": round(correct / n * 100, 2), + "hallucination_rate": round(halluc / n * 100, 2), + "refusal_rate": round(refusal / n * 100, 2), + "correct": correct, + "hallucination": halluc, + "refusal": refusal, + "avg_em": em_sum / n, + "avg_f1": f1_sum / n, + "avg_latency": round(avg_latency, 2), + "avg_tokens": round(avg_tokens, 1), + } + if avg_ev_recall is not None: + overall["evidence_recall"] = round(avg_ev_recall, 4) + + # --- Breakdowns --- + overall["by_question_type"] = _breakdown(results, "question_type") + overall["by_question_reasoning"] = _breakdown(results, "question_reasoning") + + return overall + + +def _breakdown(results: List[Dict[str, Any]], key: str) -> Dict[str, Dict[str, Any]]: + """Compute per-group accuracy / hallucination / refusal breakdown.""" + groups: dict[str, list[dict]] = defaultdict(list) + for r in results: + group = r.get(key, "unknown") + groups[group].append(r) + + out: dict[str, dict] = {} + for group, items in sorted(groups.items()): + g_n = len(items) + g_correct = sum(1 for r in items if r.get("classification") == "correct") + g_halluc = sum( + 1 for r in items if r.get("classification") == "hallucination" + ) + g_refusal = sum(1 for r in items if r.get("classification") == "refusal") + out[group] = { + "n": g_n, + "accuracy": round(g_correct / g_n * 100, 2) if g_n else 0.0, + "hallucination_rate": round(g_halluc / g_n * 100, 2) if g_n else 0.0, + "refusal_rate": round(g_refusal / g_n * 100, 2) if g_n else 0.0, + "correct": g_correct, + "hallucination": g_halluc, + "refusal": g_refusal, + } + return out diff --git a/benchmarks/financebench/run_benchmark.py b/benchmarks/financebench/run_benchmark.py new file mode 100644 index 0000000..28e99d7 --- /dev/null +++ b/benchmarks/financebench/run_benchmark.py @@ -0,0 +1,239 @@ +"""FinanceBench benchmark entry point. + +Usage: + cd benchmarks/financebench + python run_benchmark.py [--env .env.financebench] [--limit N] + +Examples: + # Run all 150 questions with default config + python run_benchmark.py + + # Run a quick sanity check with 10 questions + python run_benchmark.py --limit 10 + + # Use a custom .env file + python run_benchmark.py --env .env.custom --limit 20 +""" +from __future__ import annotations + +import argparse +import asyncio +import json +import logging +import random +import sys +import time +from datetime import datetime +from pathlib import Path +from typing import List + +from config import FinanceBenchConfig +from data_loader import FinanceBenchLoader +from evaluate import compute_metrics +from runner import run_batch + +# --------------------------------------------------------------------------- +# Logging +# --------------------------------------------------------------------------- + + +def setup_logging(output_dir: str) -> str: + """Configure logging to file + console. + + Creates a timestamped log file under ``logs/`` (relative to *output_dir*'s + parent, i.e. the benchmark root directory). + + Returns: + Absolute path to the log file. + """ + log_dir = Path("logs") + log_dir.mkdir(parents=True, exist_ok=True) + + ts = datetime.now().strftime("%Y%m%d_%H%M%S") + log_path = log_dir / f"benchmark_{ts}.log" + + root_logger = logging.getLogger("financebench") + root_logger.setLevel(logging.DEBUG) + + # File handler – DEBUG level, full detail + fh = logging.FileHandler(str(log_path), encoding="utf-8") + fh.setLevel(logging.DEBUG) + fh.setFormatter( + logging.Formatter( + "%(asctime)s %(name)-28s %(levelname)-7s %(message)s", + datefmt="%Y-%m-%d %H:%M:%S", + ) + ) + + # Console handler – INFO level, concise + ch = logging.StreamHandler(sys.stdout) + ch.setLevel(logging.INFO) + ch.setFormatter( + logging.Formatter("%(asctime)s %(levelname)-7s %(message)s", datefmt="%H:%M:%S") + ) + + root_logger.addHandler(fh) + root_logger.addHandler(ch) + + return str(log_path.resolve()) + + +# --------------------------------------------------------------------------- +# Summary printing +# --------------------------------------------------------------------------- + + +def _print_summary( + results: List[dict], + metrics: dict, + total_time: float, + results_path: Path, + metrics_path: Path, + log_path: str, +) -> None: + """Print a human-readable run summary to stdout.""" + n = len(results) + acc = metrics.get("accuracy", 0) + hallu = metrics.get("hallucination_rate", 0) + refuse = metrics.get("refusal_rate", 0) + avg_em = metrics.get("avg_em", 0) + avg_f1 = metrics.get("avg_f1", 0) + ev_recall = metrics.get("evidence_recall") + avg_latency = metrics.get("avg_latency", 0) + + print("\n" + "=" * 60) + print(f"FinanceBench Results ({n} questions)") + print("=" * 60) + print(f" Accuracy: {acc:.1f}%") + print(f" Hallucination Rate: {hallu:.1f}%") + print(f" Refusal Rate: {refuse:.1f}%") + print(f" Avg EM: {avg_em:.3f}") + print(f" Avg F1: {avg_f1:.3f}") + if ev_recall is not None: + print(f" Evidence Recall: {ev_recall:.3f}") + else: + print(f" Evidence Recall: N/A (page-level telemetry unavailable)") + print(f" Avg Latency: {avg_latency:.1f}s") + print(f" Total Time: {total_time:.1f}s") + print(f"\n Results: {results_path}") + print(f" Metrics: {metrics_path}") + print(f" Log: {log_path}") + + # Breakdown by question_type + by_qt = metrics.get("by_question_type") + if by_qt: + print(f"\n {'Question Type':<25} {'Acc%':>6} {'Hallu%':>7} {'Refuse%':>8} {'N':>4}") + print(" " + "-" * 52) + for qt, m in sorted(by_qt.items()): + qt_acc = m.get("accuracy", 0) + qt_hal = m.get("hallucination_rate", 0) + qt_ref = m.get("refusal_rate", 0) + qt_n = m.get("n", 0) + print(f" {qt:<25} {qt_acc:>5.1f} {qt_hal:>7.1f} {qt_ref:>7.1f} {qt_n:>4}") + + print("=" * 60) + + +# --------------------------------------------------------------------------- +# Main +# --------------------------------------------------------------------------- + + +def main() -> None: + """Parse CLI arguments, run the benchmark, and save results.""" + parser = argparse.ArgumentParser( + description="Run FinanceBench benchmark against Sirchmunk AgenticSearch", + ) + parser.add_argument( + "--env", + default=".env.financebench", + help="Path to .env config file (default: .env.financebench)", + ) + parser.add_argument( + "--limit", + type=int, + default=None, + help="Override FB_LIMIT — number of questions to evaluate", + ) + args = parser.parse_args() + + # 1. Load config + cfg = FinanceBenchConfig.from_env(args.env) + if args.limit is not None: + cfg.limit = args.limit + + # 2. Setup logging + log_path = setup_logging(cfg.output_dir) + logger = logging.getLogger("financebench") + + # 3. Load data + loader = FinanceBenchLoader(cfg.data_dir, cfg.pdf_dir) + questions = loader.load_questions() + logger.info("Loaded %d questions from %s", len(questions), cfg.data_dir) + + # 4. Validate corpus + found, missing = loader.validate_corpus(questions) + logger.info("PDF corpus: %d found, %d missing", found, len(missing)) + if missing: + preview = missing[:10] + suffix = "..." if len(missing) > 10 else "" + logger.warning("Missing PDFs: %s%s", preview, suffix) + + # 5. Apply limit / seed + if cfg.limit > 0 and cfg.limit < len(questions): + random.seed(cfg.seed) + questions = random.sample(questions, cfg.limit) + logger.info("Sampled %d questions (seed=%d)", len(questions), cfg.seed) + + # 6. Print run config + logger.info( + "Config: mode=%s, eval_mode=%s, extract_answer=%s, " + "llm_judge=%s, concurrent=%d, model=%s", + cfg.mode, + cfg.eval_mode, + cfg.extract_answer, + cfg.enable_llm_judge, + cfg.max_concurrent, + cfg.llm_model, + ) + + # 7. Run benchmark + t0 = time.time() + results = asyncio.run(run_batch(questions, cfg)) + total_time = time.time() - t0 + + # 8. Compute metrics + metrics = compute_metrics(results) + metrics["total_time_seconds"] = round(total_time, 2) + metrics["num_questions"] = len(questions) + metrics["config"] = { + "mode": cfg.mode, + "eval_mode": cfg.eval_mode, + "model": cfg.llm_model, + "top_k_files": cfg.top_k_files, + "extract_answer": cfg.extract_answer, + } + + # 9. Save results (JSONL) + metrics (JSON) + out_dir = Path(cfg.output_dir) + out_dir.mkdir(parents=True, exist_ok=True) + ts = datetime.now().strftime("%Y%m%d_%H%M%S") + results_path = out_dir / f"results_{ts}.jsonl" + metrics_path = out_dir / f"metrics_{ts}.json" + + with open(results_path, "w", encoding="utf-8") as f: + for r in results: + f.write(json.dumps(r, ensure_ascii=False) + "\n") + + with open(metrics_path, "w", encoding="utf-8") as f: + json.dump(metrics, f, indent=2, ensure_ascii=False) + + logger.info("Results saved to %s", results_path) + logger.info("Metrics saved to %s", metrics_path) + + # 10. Print summary + _print_summary(results, metrics, total_time, results_path, metrics_path, log_path) + + +if __name__ == "__main__": + main() diff --git a/benchmarks/financebench/runner.py b/benchmarks/financebench/runner.py new file mode 100644 index 0000000..72d0a0b --- /dev/null +++ b/benchmarks/financebench/runner.py @@ -0,0 +1,279 @@ +"""Run AgenticSearch on FinanceBench questions. + +Supports two evaluation modes: +- **singleDoc**: each question searches only its target PDF directory. +- **sharedCorpus**: all questions search the full PDF corpus. + +After search, an optional LLM extraction step converts the verbose +briefing into a short factoid answer suitable for EM/F1. +""" +from __future__ import annotations + +import asyncio +import json as json_mod +import logging +import time +from datetime import datetime +from pathlib import Path +from typing import Any, Dict, List + +from config import FinanceBenchConfig +from data_loader import FinanceBenchLoader +from evaluate import ( + classify_answer, + compute_metrics, + exact_match, + evidence_recall, + f1_score, + normalize_answer, +) + +logger = logging.getLogger("financebench.runner") + +# ------------------------------------------------------------------ +# Answer extraction prompt (financial domain) +# ------------------------------------------------------------------ + +_EXTRACT_PROMPT = """\ +Given the financial question and a verbose response, extract ONLY the short factoid answer. +Rules: +- Output ONLY the answer value/phrase (1-20 words). No explanation. +- If the response says it cannot find the answer, output: unknown +- For monetary values, keep the currency format (e.g., $1,577.00) +- For percentages, keep the % sign (e.g., 15.3%) +- For yes/no questions, output: yes or no + +Question: {question} +Response: {response} + +Short answer:""" + + +# NOTE: _normalize_prediction removed — use evaluate.normalize_answer instead. + + +# ------------------------------------------------------------------ +# LLM short-answer extraction +# ------------------------------------------------------------------ + + +async def _extract_short_answer( + question: str, + verbose: str, + llm: Any, +) -> str: + """Use *llm* to distil *verbose* into a short factoid answer.""" + prompt = _EXTRACT_PROMPT.format(question=question, response=verbose[:4000]) + try: + resp = await llm.achat( + messages=[{"role": "user", "content": prompt}], + stream=False, + ) + return resp.content.strip() + except Exception: + logger.warning("Short-answer extraction failed; falling back to raw answer.") + return verbose + + +# ------------------------------------------------------------------ +# Page extraction helper +# ------------------------------------------------------------------ + + +def _try_extract_pages(telemetry: Dict[str, Any]) -> List[int]: + """Best-effort extraction of retrieved page numbers from telemetry. + + Current limitation: Sirchmunk's ``read_file_ids`` contains plain file + paths without page-level suffixes, so this function will typically + return an empty list. When empty, callers should treat evidence + recall as *unavailable* (``None``) rather than zero. + """ + pages: list[int] = [] + for fid in telemetry.get("read_file_ids", []): + # Convention: page indices may be embedded in file IDs + if isinstance(fid, str) and "_page_" in fid: + try: + pages.append(int(fid.rsplit("_page_", 1)[-1])) + except (ValueError, IndexError): + pass + return pages + + +# ------------------------------------------------------------------ +# Single question execution +# ------------------------------------------------------------------ + + +async def run_single( + entry: Dict[str, Any], + loader: FinanceBenchLoader, + searcher: Any, + llm: Any, + cfg: FinanceBenchConfig, + semaphore: asyncio.Semaphore, +) -> Dict[str, Any]: + """Execute one FinanceBench question end-to-end.""" + fb_id = entry.get("financebench_id", "") + question = entry["question"] + gold = entry.get("answer", "") + gold_evidence = entry.get("evidence", []) + + async with semaphore: + t0 = time.time() + error: str | None = None + raw_answer = "" + answer = "" + telemetry: dict[str, Any] = {} + retrieved_pages: list[int] = [] + + try: + # Determine search paths based on eval mode + if cfg.eval_mode == "singleDoc": + pdf_path = loader.get_pdf_path(entry.get("doc_name", "")) + if pdf_path: + search_paths = [pdf_path] # pass the single PDF file directly + else: + logger.warning("PDF not found for %s, falling back to full corpus", entry.get("doc_name", "")) + search_paths = [cfg.pdf_dir] + else: + search_paths = [cfg.pdf_dir] + + result = await searcher.search( + query=question, + paths=search_paths, + mode=cfg.mode, + top_k_files=cfg.top_k_files, + max_token_budget=cfg.max_token_budget, + enable_dir_scan=cfg.enable_dir_scan, + return_context=True, + ) + + raw_answer = getattr(result, "answer", "") or str(result) + + # Collect telemetry + read_files = list(getattr(result, "read_file_ids", None) or set()) + telemetry = { + "read_file_ids": read_files, + "total_tokens": getattr(result, "total_llm_tokens", 0), + "loop_count": getattr(result, "loop_count", 0), + "llm_calls": len(getattr(result, "llm_usages", None) or []), + "num_files_read": len(read_files), + } + retrieved_pages = _try_extract_pages(telemetry) + + # Answer extraction + if cfg.extract_answer and raw_answer: + answer = await _extract_short_answer(question, raw_answer, llm) + answer = normalize_answer(answer) + else: + answer = normalize_answer(raw_answer) + + except Exception as exc: + error = str(exc) + logger.error("Error on %s: %s", fb_id, error) + + elapsed = time.time() - t0 + + # Delay between requests + if cfg.request_delay > 0: + await asyncio.sleep(cfg.request_delay) + + # --- Evaluation --- + is_no_result = not answer or answer.lower() in ("unknown", "") + em = exact_match(answer, gold) + f1 = f1_score(answer, gold) + classification = classify_answer(answer, gold, is_no_result=is_no_result) + if retrieved_pages: # only compute when page-level data is available + ev_recall = evidence_recall(retrieved_pages, gold_evidence) + else: + ev_recall = None # mark as unavailable, avoid false 0 + + return { + "financebench_id": fb_id, + "question": question, + "prediction": answer, + "raw_prediction": raw_answer, + "gold_answer": gold, + "company": entry.get("company", ""), + "doc_name": entry.get("doc_name", ""), + "question_type": entry.get("question_type", ""), + "question_reasoning": entry.get("question_reasoning", ""), + "elapsed": round(elapsed, 2), + "telemetry": telemetry, + "classification": classification, + "em": em, + "f1": round(f1, 4), + "evidence_recall": round(ev_recall, 4) if ev_recall is not None else None, + "error": error, + } + + +# ------------------------------------------------------------------ +# Batch execution +# ------------------------------------------------------------------ + + +async def run_batch( + samples: List[Dict[str, Any]], + cfg: FinanceBenchConfig, +) -> List[Dict[str, Any]]: + """Run all *samples* concurrently and persist results incrementally.""" + from sirchmunk.llm.openai_chat import OpenAIChat + from sirchmunk.search import AgenticSearch + + llm = OpenAIChat( + api_key=cfg.llm_api_key, + base_url=cfg.llm_base_url, + model=cfg.llm_model, + ) + searcher = AgenticSearch(llm=llm, reuse_knowledge=False, verbose=False) + loader = FinanceBenchLoader(data_dir=cfg.data_dir, pdf_dir=cfg.pdf_dir) + semaphore = asyncio.Semaphore(cfg.max_concurrent) + + # Prepare output directory / file + out_dir = Path(cfg.output_dir) + out_dir.mkdir(parents=True, exist_ok=True) + ts = datetime.now().strftime("%Y%m%d_%H%M%S") + out_path = out_dir / f"financebench_{ts}.jsonl" + + results: list[dict] = [] + completed = 0 + total = len(samples) + + async def _run_and_record(entry: Dict[str, Any]) -> Dict[str, Any]: + nonlocal completed + res = await run_single(entry, loader, searcher, llm, cfg, semaphore) + # Incremental save + with open(out_path, "a", encoding="utf-8") as fp: + fp.write(json_mod.dumps(res, ensure_ascii=False) + "\n") + completed += 1 + status = res["classification"] + logger.info( + "[%d/%d] %s %s EM=%s F1=%.2f %.1fs", + completed, + total, + res["financebench_id"], + status, + res["em"], + res["f1"], + res["elapsed"], + ) + return res + + tasks = [asyncio.create_task(_run_and_record(s)) for s in samples] + results = await asyncio.gather(*tasks) + + # Write aggregate metrics + metrics = compute_metrics(list(results)) + metrics_path = out_dir / f"financebench_{ts}_metrics.json" + with open(metrics_path, "w", encoding="utf-8") as fp: + json_mod.dump(metrics, fp, indent=2, ensure_ascii=False) + logger.info("Metrics saved to %s", metrics_path) + logger.info( + "Accuracy=%.2f%% Hallucination=%.2f%% Refusal=%.2f%%", + metrics.get("accuracy", 0), + metrics.get("hallucination_rate", 0), + metrics.get("refusal_rate", 0), + ) + + return list(results) From 4a0a01796bc51ae0869a8b78617aa0759cdc4c12 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E7=8F=AD=E6=89=AC?= Date: Thu, 16 Apr 2026 16:27:03 +0800 Subject: [PATCH 14/56] add llm judge for financebench --- benchmarks/financebench/analyze_results.py | 62 ++- benchmarks/financebench/config.py | 3 +- benchmarks/financebench/evaluate.py | 60 ++- benchmarks/financebench/judge.py | 420 +++++++++++++++++++++ benchmarks/financebench/run_benchmark.py | 37 +- benchmarks/financebench/runner.py | 37 +- 6 files changed, 597 insertions(+), 22 deletions(-) create mode 100644 benchmarks/financebench/judge.py diff --git a/benchmarks/financebench/analyze_results.py b/benchmarks/financebench/analyze_results.py index 24d2b64..a804284 100644 --- a/benchmarks/financebench/analyze_results.py +++ b/benchmarks/financebench/analyze_results.py @@ -69,16 +69,34 @@ def print_breakdown(title: str, breakdown: Dict[str, Dict[str, Any]]) -> None: breakdown: ``{group_name: {accuracy, hallucination_rate, ...}}``. """ print(f"\n=== Breakdown by {title} ===\n") - header = f" {'Group':<30} {'Acc%':>6} {'Hallu%':>7} {'Refuse%':>8} {'N':>4}" - print(header) - print(" " + "-" * (len(header) - 2)) - for group, m in sorted(breakdown.items(), key=lambda kv: -kv[1].get("accuracy", 0)): - acc = m.get("accuracy", 0) - hal = m.get("hallucination_rate", 0) - ref = m.get("refusal_rate", 0) - n = m.get("n", 0) - print(f" {group:<30} {acc:>5.1f} {hal:>7.1f} {ref:>7.1f} {n:>4}") + # Determine if judge data is available + has_judge = any(m.get("llm_judge_accuracy") is not None for m in breakdown.values()) + + if has_judge: + header = f" {'Group':<30} {'Acc%':>6} {'Hallu%':>7} {'Refuse%':>8} {'Judge%':>7} {'N':>4}" + print(header) + print(" " + "-" * (len(header) - 2)) + + for group, m in sorted(breakdown.items(), key=lambda kv: -kv[1].get("accuracy", 0)): + acc = m.get("accuracy", 0) + hal = m.get("hallucination_rate", 0) + ref = m.get("refusal_rate", 0) + n = m.get("n", 0) + jdg = m.get("llm_judge_accuracy") + jdg_str = f"{jdg:>6.1f}" if jdg is not None else " N/A" + print(f" {group:<30} {acc:>5.1f} {hal:>7.1f} {ref:>7.1f} {jdg_str} {n:>4}") + else: + header = f" {'Group':<30} {'Acc%':>6} {'Hallu%':>7} {'Refuse%':>8} {'N':>4}" + print(header) + print(" " + "-" * (len(header) - 2)) + + for group, m in sorted(breakdown.items(), key=lambda kv: -kv[1].get("accuracy", 0)): + acc = m.get("accuracy", 0) + hal = m.get("hallucination_rate", 0) + ref = m.get("refusal_rate", 0) + n = m.get("n", 0) + print(f" {group:<30} {acc:>5.1f} {hal:>7.1f} {ref:>7.1f} {n:>4}") def _compute_company_breakdown( @@ -183,6 +201,12 @@ def print_comparison_with_sota(metrics: Dict[str, Any]) -> None: n = metrics.get("n", 0) coverage = min(100.0, n / 150.0 * 100) print(f" {'Sirchmunk (This Run)':<30} {f'{acc:.1f}%':>10} {f'{coverage:.0f}%':>10}") + + # Show Judge Accuracy in SOTA table if available + judge_acc = metrics.get("llm_judge_accuracy") + if judge_acc is not None: + print(f" {'Sirchmunk (Judge Acc)':<30} {f'{judge_acc:.1f}%':>10} {f'{coverage:.0f}%':>10}") + print(f"\n (This run evaluated {n} questions)") @@ -247,6 +271,12 @@ def main() -> None: print(f" Evidence Recall: N/A (page-level telemetry unavailable)") print(f" Avg Latency: {avg_latency:.1f}s") + # LLM Judge independent metrics + if metrics.get("llm_judge_accuracy") is not None: + print(f"\n --- LLM Judge (Independent Evaluation) ---") + print(f" Judge Accuracy: {metrics['llm_judge_accuracy']:.1f}%") + print(f" Judge Correct: {metrics['llm_judge_correct']}/{metrics['llm_judge_count']}") + # --- Breakdowns --- if "by_question_type" in metrics: print_breakdown("Question Type", metrics["by_question_type"]) @@ -260,6 +290,20 @@ def main() -> None: # --- Error cases --- print_error_cases(results, max_show=args.max_errors) + # --- Judge-Rule Discrepancies --- + discrepancies = [r for r in results + if r.get("llm_judge_correct") is not None + and r.get("classification") != "correct" + and r.get("llm_judge_correct") is True] + if discrepancies: + print(f"\n=== Judge-Rule Discrepancies ({len(discrepancies)} cases) ===") + print(" (Cases where LLM Judge says correct but EM/F1 says wrong)") + for r in discrepancies[:10]: + print(f" {r.get('financebench_id', 'N/A')}: pred='{r.get('prediction', '')[:50]}' gold='{r.get('gold_answer', '')[:50]}'") + print(f" classification={r.get('classification')}, judge_reasoning={r.get('llm_judge_reasoning', '')[:80]}") + if len(discrepancies) > 10: + print(f" ... and {len(discrepancies) - 10} more discrepancy(ies) not shown.") + # --- SOTA comparison --- print_comparison_with_sota(metrics) diff --git a/benchmarks/financebench/config.py b/benchmarks/financebench/config.py index f2e0fdb..f51ea36 100644 --- a/benchmarks/financebench/config.py +++ b/benchmarks/financebench/config.py @@ -33,8 +33,9 @@ class FinanceBenchConfig: # Evaluation eval_mode: str = "singleDoc" # singleDoc / sharedCorpus - enable_llm_judge: bool = True # TODO: LLM Judge not yet implemented, reserved for future use + enable_llm_judge: bool = True # Use LLM to judge semantic equivalence (independent metric) extract_answer: bool = True + judge_f1_threshold: float = 0.8 # F1 threshold for 'correct' classification # Concurrency max_concurrent: int = 3 diff --git a/benchmarks/financebench/evaluate.py b/benchmarks/financebench/evaluate.py index 688cf41..3e78636 100644 --- a/benchmarks/financebench/evaluate.py +++ b/benchmarks/financebench/evaluate.py @@ -34,6 +34,26 @@ "no relevant information", "data not found", "unknown", + "i'm not able to", + "i am not able to", + "the document does not contain", + "the document doesn't contain", + "this information is not disclosed", + "not disclosed", + "could not find", + "couldn't find", + "no mention of", + "no information about", + "not provided in", + "not found in the document", + "i was unable to", + "unable to determine", + "unable to find", + "unable to locate", + "there is no data", + "no data available", + "not available in", + "not specified", ] _F1_CORRECT_THRESHOLD: float = 0.8 @@ -99,16 +119,29 @@ def _normalize_financial_value(text: str) -> str: - ``15.3%`` → ``15.3%`` - ``$1577`` → ``1577`` - ``1,577`` → ``1577`` + - ``($500)`` → ``-500`` + - ``-$500`` → ``-500`` """ s = text.strip() + # Handle accounting bracket notation for negatives: ($500) → -$500 + if s.startswith("(") and s.endswith(")"): + s = "-" + s[1:-1] + + # Handle negative sign: remember it, strip it for processing + negative = False + if s.startswith("-"): + negative = True + s = s[1:] + # Detect if value looks numeric (possibly with $ / % / commas) stripped_for_check = _RE_DOLLAR.sub("", s) stripped_for_check = stripped_for_check.replace(",", "").rstrip("%").strip() try: float(stripped_for_check) except ValueError: - return s # Not a numeric value – return as-is + # Not a numeric value – restore negative sign and return as-is + return ("-" + s) if negative else s # Remove dollar sign s = _RE_DOLLAR.sub("", s) @@ -129,6 +162,10 @@ def _normalize_financial_value(text: str) -> str: if has_pct: s = s + "%" + # Re-attach negative sign + if negative and not s.startswith("-"): + s = "-" + s + return s @@ -289,6 +326,18 @@ def compute_metrics(results: List[Dict[str, Any]]) -> Dict[str, Any]: if avg_ev_recall is not None: overall["evidence_recall"] = round(avg_ev_recall, 4) + # --- LLM Judge metrics (independent dimension, NOT fallback) --- + judge_results = [r for r in results if r.get("llm_judge_correct") is not None] + if judge_results: + judge_correct = sum(1 for r in judge_results if r["llm_judge_correct"]) + overall["llm_judge_accuracy"] = round(judge_correct / len(judge_results) * 100, 2) + overall["llm_judge_count"] = len(judge_results) + overall["llm_judge_correct"] = judge_correct + else: + overall["llm_judge_accuracy"] = None + overall["llm_judge_count"] = 0 + overall["llm_judge_correct"] = 0 + # --- Breakdowns --- overall["by_question_type"] = _breakdown(results, "question_type") overall["by_question_reasoning"] = _breakdown(results, "question_reasoning") @@ -311,7 +360,7 @@ def _breakdown(results: List[Dict[str, Any]], key: str) -> Dict[str, Dict[str, A 1 for r in items if r.get("classification") == "hallucination" ) g_refusal = sum(1 for r in items if r.get("classification") == "refusal") - out[group] = { + group_dict: dict[str, Any] = { "n": g_n, "accuracy": round(g_correct / g_n * 100, 2) if g_n else 0.0, "hallucination_rate": round(g_halluc / g_n * 100, 2) if g_n else 0.0, @@ -320,4 +369,11 @@ def _breakdown(results: List[Dict[str, Any]], key: str) -> Dict[str, Dict[str, A "hallucination": g_halluc, "refusal": g_refusal, } + # LLM Judge breakdown + g_judge = [r for r in items if r.get("llm_judge_correct") is not None] + if g_judge: + g_jc = sum(1 for r in g_judge if r["llm_judge_correct"]) + group_dict["llm_judge_accuracy"] = round(g_jc / len(g_judge) * 100, 2) + group_dict["llm_judge_count"] = len(g_judge) + out[group] = group_dict return out diff --git a/benchmarks/financebench/judge.py b/benchmarks/financebench/judge.py new file mode 100644 index 0000000..e52b6e6 --- /dev/null +++ b/benchmarks/financebench/judge.py @@ -0,0 +1,420 @@ +"""LLM-based semantic equivalence judge for FinanceBench. + +The judge evaluates whether a model's prediction is semantically +equivalent to the gold answer, operating as an **independent** +evaluation dimension alongside EM/F1 — not as a fallback. + +This provides a more nuanced correctness signal for financial QA, +where formatting differences (e.g., $1.5B vs $1,500M) can cause +EM/F1 to undercount correct answers. +""" + +from __future__ import annotations + +import json +import logging +import re +from typing import Any, Dict, Optional + +logger = logging.getLogger(__name__) + + +_JUDGE_PROMPT = """\ +You are an expert financial analyst and auditor evaluating answer correctness \ +with **zero tolerance for numerical or factual errors**. + +Question: {question} +Gold Answer: {gold} +Model Prediction: {prediction} + +Task: Determine if the model's prediction is **semantically equivalent** \ +to the gold answer in the context of this financial question. + +═══════════════════════════════════════════════ +EQUIVALENT — only when ALL of the following hold: +═══════════════════════════════════════════════ + +1. **Numerical precision (ZERO TOLERANCE)**: + - Values must be mathematically identical after unit conversion. + - $1.5B = $1,500M = $1,500,000K = $1,500,000,000 ✓ + - $1,577 ≠ $1,580 ✗ (rounding is NOT acceptable) + - 15.3% = 15.30% = 0.153 ✓ but 15.3% ≠ 15% ✗ + - $1.5M ≠ $1.5B ✗ (unit mismatch is a critical error) + +2. **Negative / bracket notation**: + - ($500) = -$500 = -500 ✓ + - ($500) ≠ $500 ✗ (sign matters) + +3. **Time period / fiscal year**: + - FY2018 = fiscal year 2018 = 2018 ✓ + - FY2018 ≠ FY2019 ✗ (different fiscal year — NEVER equivalent) + - Q3 2019 ≠ Q4 2019 ✗ (different quarter) + - "year ended December 2018" = FY2018 ✓ + +4. **Currency formatting**: + - $1,577.00 = $1577 = 1577 ✓ (same value, format differs) + +5. **Financial term equivalences (accepted)**: + - net income = net profit ✓ + - CAPEX = capital expenditure ✓ + - EPS = earnings per share ✓ + - EBITDA = earnings before interest, taxes, depreciation and amortization ✓ + - YoY = year-over-year ✓ + - COGS = cost of goods sold ✓ + - D&A = depreciation and amortization ✓ + +6. **Financial term distinctions (NOT interchangeable)**: + - revenue ≠ net revenue ≠ gross revenue (unless context is clear) + - operating income ≠ net income + - gross profit ≠ net profit + - total assets ≠ net assets + +7. **Prediction with extra context**: + - If prediction contains the correct answer with additional supporting \ + detail, treat as equivalent (e.g., "Revenue was $1,577M in FY2018" \ + vs "$1,577M" — equivalent, provided the value is correct). + +═══════════════════════════════════════════════ +NOT EQUIVALENT — if ANY of the following hold: +═══════════════════════════════════════════════ + +1. Different numerical values (even slightly: $1,577 ≠ $1,580) +2. Different time periods or fiscal years +3. Different companies or entities +4. Opposite trend direction (increased ≠ decreased, growth ≠ decline) +5. Unit mismatch ($1.5M ≠ $1.5B) +6. Missing or wrong sign (positive ≠ negative) +7. Prediction is vague or hedging where gold is precise +8. Prediction is a refusal or states it cannot find the answer +9. Near-approximate values that are not mathematically equal after unit conversion + +═══════════════════════════════════════════════ +CONSERVATIVE JUDGMENT POLICY +═══════════════════════════════════════════════ + +- **When in doubt, judge as NOT equivalent.** Financial accuracy demands \ + precision; a false positive (incorrectly marking wrong answer as correct) \ + is far worse than a false negative. +- If you are less than 80% confident the answers are equivalent, \ + judge as NOT equivalent. +- Set confidence to reflect your actual certainty (0.0 = no idea, \ + 1.0 = absolutely certain). + +═══════════════════════════════════════════════ +FEW-SHOT EXAMPLES +═══════════════════════════════════════════════ + +Example 1 — EQUIVALENT (format difference): + Gold: "$1,577" | Prediction: "$1,577.00 million" + → {{"equivalent": true, "confidence": 0.95, "reasoning": "Same value $1,577M, trailing zeros are formatting."}} + +Example 2 — EQUIVALENT (abbreviation): + Gold: "$1.5 billion" | Prediction: "$1,500M" + → {{"equivalent": true, "confidence": 0.97, "reasoning": "$1.5B = $1,500M, correct unit conversion."}} + +Example 3 — NOT EQUIVALENT (different value): + Gold: "$1,577" | Prediction: "$1,580" + → {{"equivalent": false, "confidence": 0.99, "reasoning": "Values differ: 1577 ≠ 1580. No rounding tolerance."}} + +Example 4 — NOT EQUIVALENT (different fiscal year): + Gold: "FY2018" | Prediction: "FY2019" + → {{"equivalent": false, "confidence": 1.0, "reasoning": "Different fiscal years."}} + +Example 5 — NOT EQUIVALENT (unit mismatch): + Gold: "$1.5 million" | Prediction: "$1.5 billion" + → {{"equivalent": false, "confidence": 1.0, "reasoning": "Unit mismatch: million ≠ billion."}} + +Example 6 — EQUIVALENT (negative notation): + Gold: "-$500" | Prediction: "($500)" + → {{"equivalent": true, "confidence": 0.98, "reasoning": "Same negative value, bracket = negative."}} + +Respond ONLY with a JSON object (no markdown, no extra text): +{{"equivalent": true or false, "confidence": 0.0 to 1.0, "reasoning": "brief explanation"}}""" + + +# Refusal detection phrases (subset for quick judge-side check) +_REFUSAL_INDICATORS: frozenset[str] = frozenset( + { + "i cannot", + "i can't", + "unable to", + "not able to", + "i don't know", + "i do not know", + "unknown", + "no results found", + "cannot determine", + "insufficient data", + "data not found", + "could not find", + "couldn't find", + "unable to determine", + "unable to find", + } +) + + +class FinanceBenchLLMJudge: + """LLM-based judge for semantic equivalence in financial QA. + + Operates as an independent evaluation dimension — NOT as a + fallback for EM/F1. Each question gets a separate judge verdict + that is tracked in its own metrics. + """ + + _CONFIDENCE_THRESHOLD: float = 0.7 + _MAX_RETRIES: int = 2 + + def __init__(self, llm: Any) -> None: + self._llm = llm + self._cache: Dict[tuple, Dict[str, Any]] = {} + + # ------------------------------------------------------------------ + # Public API + # ------------------------------------------------------------------ + + async def judge( + self, + prediction: str, + gold_answer: str, + question: str = "", + ) -> Dict[str, Any]: + """Judge whether prediction is semantically equivalent to gold. + + Args: + prediction: Model's answer text. + gold_answer: Ground-truth answer text. + question: The original question (for context). + + Returns: + { + "equivalent": bool, + "confidence": float (0-1), + "reasoning": str, + "cached": bool, + "error": Optional[str] + } + """ + # --- Refusal short-circuit (saves LLM call) --- + if self._is_refusal(prediction): + return { + "equivalent": False, + "confidence": 1.0, + "reasoning": "Prediction is a refusal — skipped LLM judge.", + "cached": False, + "error": None, + } + + # --- Quick exact-match shortcut --- + from evaluate import normalize_answer + + if normalize_answer(prediction) == normalize_answer(gold_answer): + return { + "equivalent": True, + "confidence": 1.0, + "reasoning": "Normalized exact match", + "cached": False, + "error": None, + } + + # --- Check cache (key includes question for context-sensitivity) --- + cache_key = ( + question.strip().lower(), + prediction.strip().lower(), + gold_answer.strip().lower(), + ) + if cache_key in self._cache: + result = dict(self._cache[cache_key]) + result["cached"] = True + return result + + # --- Call LLM with retry --- + prompt = _JUDGE_PROMPT.format( + question=question or "N/A", + gold=gold_answer, + prediction=prediction, + ) + + result: Dict[str, Any] | None = None + last_error: str | None = None + + for attempt in range(1, self._MAX_RETRIES + 1): + try: + resp = await self._llm.achat( + messages=[{"role": "user", "content": prompt}], + stream=False, + ) + raw = resp.content.strip() + result = self._parse_response(raw) + if result.get("error") is None: + break # success + last_error = result.get("error") + except Exception as e: + last_error = str(e) + logger.warning( + "LLM Judge call failed (attempt %d/%d): %s", + attempt, + self._MAX_RETRIES, + e, + ) + result = None + + if result is None or result.get("error") is not None: + result = { + "equivalent": False, + "confidence": 0.0, + "reasoning": f"Judge error after {self._MAX_RETRIES} attempts: {last_error}", + "error": last_error, + } + + # --- Apply confidence threshold (conservative) --- + if ( + result.get("error") is None + and result["equivalent"] + and result["confidence"] < self._CONFIDENCE_THRESHOLD + ): + result["equivalent"] = False + result["reasoning"] = ( + f"Overridden to NOT equivalent: confidence " + f"{result['confidence']:.2f} < threshold " + f"{self._CONFIDENCE_THRESHOLD} — conservative policy. " + f"Original reasoning: {result['reasoning']}" + ) + + result.setdefault("cached", False) + result.setdefault("error", None) + + # Cache successful results only + if result["error"] is None: + self._cache[cache_key] = { + k: v for k, v in result.items() if k != "cached" + } + + return result + + # ------------------------------------------------------------------ + # Parsing + # ------------------------------------------------------------------ + + def _parse_response(self, raw: str) -> Dict[str, Any]: + """Parse LLM JSON response with robust fallback heuristics.""" + # --- Try direct JSON parse --- + parsed = self._try_parse_json(raw) + if parsed is not None: + return self._validated_result(parsed, raw) + + # --- Fallback: keyword detection (conservative) --- + lower = raw.lower() + + # Look for explicit true/false patterns with word boundaries + true_match = re.search( + r'"equivalent"\s*:\s*true\b', lower + ) + false_match = re.search( + r'"equivalent"\s*:\s*false\b', lower + ) + + if false_match and not true_match: + return { + "equivalent": False, + "confidence": 0.5, + "reasoning": f"Keyword fallback (NOT equivalent): {raw[:200]}", + } + elif true_match and not false_match: + # Conservative: lower confidence for keyword-only parse + return { + "equivalent": True, + "confidence": 0.5, + "reasoning": f"Keyword fallback (equivalent): {raw[:200]}", + } + + # --- Cannot parse → conservative default --- + logger.warning("Cannot parse judge response: %s", raw[:200]) + return { + "equivalent": False, + "confidence": 0.0, + "reasoning": f"Unparseable response: {raw[:200]}", + "error": "parse_error", + } + + def _try_parse_json(self, raw: str) -> Optional[Dict[str, Any]]: + """Attempt multiple JSON extraction strategies.""" + strategies = [ + raw.strip(), + # Strip markdown code fences + re.sub(r"```(?:json)?\s*\n?", "", raw).strip().rstrip("`").strip(), + # Extract first {...} block + self._extract_json_block(raw), + ] + + for text in strategies: + if not text: + continue + # Fix common LLM JSON quirks + text = self._fix_json_quirks(text) + try: + return json.loads(text) + except (json.JSONDecodeError, ValueError): + continue + return None + + @staticmethod + def _extract_json_block(raw: str) -> Optional[str]: + """Extract the first {...} JSON object from raw text.""" + match = re.search(r"\{[^{}]*\}", raw, re.DOTALL) + return match.group(0) if match else None + + @staticmethod + def _fix_json_quirks(text: str) -> str: + """Fix common non-standard JSON from LLMs.""" + # Replace single quotes with double quotes (basic heuristic) + # Only if the text doesn't already have double quotes for keys + if "'" in text and '"' not in text: + text = text.replace("'", '"') + # Remove trailing commas before closing braces + text = re.sub(r",\s*}", "}", text) + text = re.sub(r",\s*]", "]", text) + return text + + def _validated_result( + self, obj: Dict[str, Any], raw: str + ) -> Dict[str, Any]: + """Build a validated result dict from parsed JSON, clamping values.""" + equivalent = bool(obj.get("equivalent", False)) + + # Clamp confidence to [0.0, 1.0] + try: + confidence = float(obj.get("confidence", 0.0)) + except (ValueError, TypeError): + confidence = 0.0 + confidence = max(0.0, min(1.0, confidence)) + + reasoning = str(obj.get("reasoning", "")) + + return { + "equivalent": equivalent, + "confidence": confidence, + "reasoning": reasoning, + } + + # ------------------------------------------------------------------ + # Helpers + # ------------------------------------------------------------------ + + @staticmethod + def _is_refusal(text: str) -> bool: + """Quick check whether *text* looks like a refusal / non-answer.""" + if not text or not text.strip(): + return True + lower = text.strip().lower() + if lower in ("unknown", "n/a", "none", ""): + return True + for phrase in _REFUSAL_INDICATORS: + if phrase in lower: + return True + return False + + @property + def cache_size(self) -> int: + """Return the number of cached judge results.""" + return len(self._cache) diff --git a/benchmarks/financebench/run_benchmark.py b/benchmarks/financebench/run_benchmark.py index 28e99d7..ef0df8f 100644 --- a/benchmarks/financebench/run_benchmark.py +++ b/benchmarks/financebench/run_benchmark.py @@ -115,6 +115,13 @@ def _print_summary( print(f" Evidence Recall: N/A (page-level telemetry unavailable)") print(f" Avg Latency: {avg_latency:.1f}s") print(f" Total Time: {total_time:.1f}s") + + # LLM Judge independent metrics + if metrics.get("llm_judge_accuracy") is not None: + print(f"\n --- LLM Judge (Independent) ---") + print(f" Judge Accuracy: {metrics['llm_judge_accuracy']:.1f}%") + print(f" Judge Correct: {metrics['llm_judge_correct']}/{metrics['llm_judge_count']}") + print(f"\n Results: {results_path}") print(f" Metrics: {metrics_path}") print(f" Log: {log_path}") @@ -122,14 +129,28 @@ def _print_summary( # Breakdown by question_type by_qt = metrics.get("by_question_type") if by_qt: - print(f"\n {'Question Type':<25} {'Acc%':>6} {'Hallu%':>7} {'Refuse%':>8} {'N':>4}") - print(" " + "-" * 52) - for qt, m in sorted(by_qt.items()): - qt_acc = m.get("accuracy", 0) - qt_hal = m.get("hallucination_rate", 0) - qt_ref = m.get("refusal_rate", 0) - qt_n = m.get("n", 0) - print(f" {qt:<25} {qt_acc:>5.1f} {qt_hal:>7.1f} {qt_ref:>7.1f} {qt_n:>4}") + # Determine if judge data is available + has_judge = any(m.get("llm_judge_accuracy") is not None for m in by_qt.values()) + if has_judge: + print(f"\n {'Question Type':<25} {'Acc%':>6} {'Hallu%':>7} {'Refuse%':>8} {'Judge%':>7} {'N':>4}") + print(" " + "-" * 59) + for qt, m in sorted(by_qt.items()): + qt_acc = m.get("accuracy", 0) + qt_hal = m.get("hallucination_rate", 0) + qt_ref = m.get("refusal_rate", 0) + qt_n = m.get("n", 0) + qt_judge = m.get("llm_judge_accuracy") + qt_judge_str = f"{qt_judge:>6.1f}" if qt_judge is not None else " N/A" + print(f" {qt:<25} {qt_acc:>5.1f} {qt_hal:>7.1f} {qt_ref:>7.1f} {qt_judge_str} {qt_n:>4}") + else: + print(f"\n {'Question Type':<25} {'Acc%':>6} {'Hallu%':>7} {'Refuse%':>8} {'N':>4}") + print(" " + "-" * 52) + for qt, m in sorted(by_qt.items()): + qt_acc = m.get("accuracy", 0) + qt_hal = m.get("hallucination_rate", 0) + qt_ref = m.get("refusal_rate", 0) + qt_n = m.get("n", 0) + print(f" {qt:<25} {qt_acc:>5.1f} {qt_hal:>7.1f} {qt_ref:>7.1f} {qt_n:>4}") print("=" * 60) diff --git a/benchmarks/financebench/runner.py b/benchmarks/financebench/runner.py index 72d0a0b..64d709f 100644 --- a/benchmarks/financebench/runner.py +++ b/benchmarks/financebench/runner.py @@ -111,6 +111,7 @@ async def run_single( llm: Any, cfg: FinanceBenchConfig, semaphore: asyncio.Semaphore, + judge: Any = None, ) -> Dict[str, Any]: """Execute one FinanceBench question end-to-end.""" fb_id = entry.get("financebench_id", "") @@ -188,6 +189,25 @@ async def run_single( else: ev_recall = None # mark as unavailable, avoid false 0 + # LLM Judge — independent evaluation dimension + # Skip judge for refusals (no point calling LLM on non-answers) + llm_judge_correct = None + llm_judge_reasoning = None + if judge is not None and classification != "refusal": + try: + judge_result = await judge.judge( + prediction=answer, + gold_answer=gold, + question=question, + ) + llm_judge_correct = judge_result.get("equivalent", False) + llm_judge_reasoning = judge_result.get("reasoning", "") + except Exception as e: + logger.warning("LLM Judge failed for %s: %s", fb_id, e) + elif judge is not None and classification == "refusal": + llm_judge_correct = False + llm_judge_reasoning = "Skipped: prediction classified as refusal" + return { "financebench_id": fb_id, "question": question, @@ -204,6 +224,8 @@ async def run_single( "em": em, "f1": round(f1, 4), "evidence_recall": round(ev_recall, 4) if ev_recall is not None else None, + "llm_judge_correct": llm_judge_correct, # None if judge disabled + "llm_judge_reasoning": llm_judge_reasoning, "error": error, } @@ -230,6 +252,13 @@ async def run_batch( loader = FinanceBenchLoader(data_dir=cfg.data_dir, pdf_dir=cfg.pdf_dir) semaphore = asyncio.Semaphore(cfg.max_concurrent) + # Initialise LLM Judge (uses the same test model) + judge = None + if cfg.enable_llm_judge: + from judge import FinanceBenchLLMJudge + judge = FinanceBenchLLMJudge(llm=llm) + logger.info("LLM Judge enabled (independent evaluation dimension)") + # Prepare output directory / file out_dir = Path(cfg.output_dir) out_dir.mkdir(parents=True, exist_ok=True) @@ -242,14 +271,17 @@ async def run_batch( async def _run_and_record(entry: Dict[str, Any]) -> Dict[str, Any]: nonlocal completed - res = await run_single(entry, loader, searcher, llm, cfg, semaphore) + res = await run_single(entry, loader, searcher, llm, cfg, semaphore, judge=judge) # Incremental save with open(out_path, "a", encoding="utf-8") as fp: fp.write(json_mod.dumps(res, ensure_ascii=False) + "\n") completed += 1 status = res["classification"] + judge_tag = "" + if res.get("llm_judge_correct") is not None: + judge_tag = " [judge:\u2713]" if res["llm_judge_correct"] else " [judge:\u2717]" logger.info( - "[%d/%d] %s %s EM=%s F1=%.2f %.1fs", + "[%d/%d] %s %s EM=%s F1=%.2f %.1fs%s", completed, total, res["financebench_id"], @@ -257,6 +289,7 @@ async def _run_and_record(entry: Dict[str, Any]) -> Dict[str, Any]: res["em"], res["f1"], res["elapsed"], + judge_tag, ) return res From 613c099af0653fc95934f18ffc7f8569cd232404 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E7=8F=AD=E6=89=AC?= Date: Thu, 16 Apr 2026 17:10:22 +0800 Subject: [PATCH 15/56] Adapt older knowledge cluster data structure --- src/sirchmunk/storage/knowledge_storage.py | 138 +++++++++++++++------ 1 file changed, 102 insertions(+), 36 deletions(-) diff --git a/src/sirchmunk/storage/knowledge_storage.py b/src/sirchmunk/storage/knowledge_storage.py index 0f09071..c74e05a 100644 --- a/src/sirchmunk/storage/knowledge_storage.py +++ b/src/sirchmunk/storage/knowledge_storage.py @@ -107,9 +107,11 @@ def _load_from_parquet(self): variable-length ``FLOAT[]`` from Parquet's list encoding, breaking ``list_cosine_similarity`` which requires matching fixed-size types. - Handles schema evolution gracefully: if the parquet file has fewer - columns than the current schema (e.g., missing ``merge_count``), - missing columns are filled with defaults instead of failing. + Handles schema evolution gracefully with adaptive column mapping: + - Forward compatible: old parquet (more cols) → new table (fewer cols), + extra columns in parquet are silently ignored. + - Backward compatible: new parquet (fewer cols) → old table (more cols), + missing columns are filled with defaults. Also records the file's modification time so that ``_check_and_reload()`` can detect external changes later. @@ -121,37 +123,62 @@ def _load_from_parquet(self): self.db.drop_table(self.table_name, if_exists=True) # Create table with explicit schema (preserves FLOAT[384]) self._create_table() - # Detect parquet columns to handle schema evolution - try: - pq_cols = self.db.fetch_all( - f"SELECT name FROM parquet_schema('{self.parquet_file}')" - ) - pq_col_names = {row[0] for row in pq_cols} - except Exception: - pq_col_names = None - - if pq_col_names is not None: - # Build column-by-column SELECT with defaults for missing cols - schema_cols = list(self._get_schema_columns()) - select_parts = [] - for col_name in schema_cols: - if col_name in pq_col_names: - select_parts.append(col_name) - elif col_name == "merge_count": - select_parts.append("0 AS merge_count") - else: - select_parts.append(f"NULL AS {col_name}") - select_clause = ", ".join(select_parts) - self.db.execute( - f"INSERT INTO {self.table_name} " - f"SELECT {select_clause} FROM read_parquet('{self.parquet_file}')" + + # Adaptive column mapping: detect parquet & table columns + parquet_cols = self._get_parquet_columns(self.parquet_file) + table_cols = self._get_table_columns() + + if not parquet_cols or not table_cols: + logger.warning( + "Could not detect columns for adaptive mapping, " + "skipping parquet load" ) else: - # Fallback: try direct SELECT * (works when schemas match) - self.db.execute( - f"INSERT INTO {self.table_name} " - f"SELECT * FROM read_parquet('{self.parquet_file}')" - ) + parquet_col_set = set(parquet_cols) + table_col_set = set(table_cols) + # Compute common columns (preserve table column order) + common_cols = [c for c in table_cols if c in parquet_col_set] + + if not common_cols: + logger.warning( + "No common columns between parquet and table, " + "skipping parquet load" + ) + else: + # Log column mismatches as warnings + ignored_cols = parquet_col_set - table_col_set + missing_cols = table_col_set - parquet_col_set + if ignored_cols: + logger.warning( + "Parquet has extra columns (ignored): %s", + ignored_cols, + ) + if missing_cols: + logger.warning( + "Table has extra columns (filled with defaults): %s", + missing_cols, + ) + + # Build INSERT with explicit column lists + # For common cols: select directly from parquet + # For missing cols (in table but not in parquet): use defaults + insert_cols = list(table_cols) # all table columns + select_parts = [] + for col_name in table_cols: + if col_name in parquet_col_set: + select_parts.append(col_name) + elif col_name == "merge_count": + select_parts.append("0 AS merge_count") + else: + select_parts.append(f"NULL AS {col_name}") + + cols_str = ", ".join(insert_cols) + select_clause = ", ".join(select_parts) + self.db.execute( + f"INSERT INTO {self.table_name} ({cols_str}) " + f"SELECT {select_clause} " + f"FROM read_parquet('{self.parquet_file}')" + ) count = self.db.get_table_count(self.table_name) # Record mtime for stale-detection @@ -163,10 +190,13 @@ def _load_from_parquet(self): self._parquet_loaded_mtime = 0.0 logger.info("Created new knowledge clusters table") except Exception as e: - logger.error(f"Failed to load from parquet: {e}") - # Try to recreate table - self.db.drop_table(self.table_name, if_exists=True) - self._create_table() + logger.warning(f"Failed to load from parquet (non-blocking): {e}") + # Try to recreate table so retrieval can still work + try: + self.db.drop_table(self.table_name, if_exists=True) + self._create_table() + except Exception as recreate_err: + logger.warning(f"Failed to recreate table after load failure: {recreate_err}") self._parquet_loaded_mtime = 0.0 def _get_schema_columns(self) -> List[str]: @@ -181,6 +211,42 @@ def _get_schema_columns(self) -> List[str]: "embedding_text_hash", ] + def _get_parquet_columns(self, parquet_path: str) -> List[str]: + """Get column names from a parquet file's schema. + + Uses DuckDB's ``parquet_schema()`` function. The returned metadata + rows use a ``name`` field (not ``column_name``). + + Returns: + Ordered list of column names, or empty list on failure. + """ + try: + rows = self.db.fetch_all( + f"SELECT name FROM parquet_schema('{parquet_path}') " + f"WHERE name != 'duckdb_schema'" + ) + return [row[0] for row in rows] + except Exception as e: + logger.warning(f"Failed to read parquet schema: {e}") + return [] + + def _get_table_columns(self) -> List[str]: + """Get column names from the current DuckDB table. + + Returns: + Ordered list of column names, or empty list on failure. + """ + try: + rows = self.db.fetch_all( + "SELECT column_name FROM information_schema.columns " + f"WHERE table_name = '{self.table_name}' " + "ORDER BY ordinal_position" + ) + return [row[0] for row in rows] + except Exception as e: + logger.warning(f"Failed to read table columns: {e}") + return [] + def _check_and_reload(self): """Check if the parquet file was modified externally and reload if so. From 6858418583efdee30ee5fb3da1e5047300cdae1b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E7=8F=AD=E6=89=AC?= Date: Thu, 16 Apr 2026 18:08:46 +0800 Subject: [PATCH 16/56] update finance bench readme --- benchmarks/financebench/README.md | 53 ++++++++++++++++++++++------- benchmarks/financebench/evaluate.py | 4 +-- 2 files changed, 43 insertions(+), 14 deletions(-) diff --git a/benchmarks/financebench/README.md b/benchmarks/financebench/README.md index 23bd67d..da6752f 100644 --- a/benchmarks/financebench/README.md +++ b/benchmarks/financebench/README.md @@ -20,33 +20,62 @@ with **150 expert-annotated questions** across **40+ US public companies** (10-K - **EM / F1**: Exact Match and token-level F1 with financial value normalisation - **Evidence Recall**: Retrieved pages vs gold evidence pages -## Quick Start +## Prerequisites + +### 1. Install Sirchmunk -### 1. Setup +Make sure Sirchmunk is installed and accessible: ```bash -cd benchmarks/financebench +pip install -e . +``` -# Copy and edit the config file -cp .env.example .env.financebench -# Edit .env.financebench — set your LLM_API_KEY at minimum +### 2. Prepare Corpus + +Download the FinanceBench dataset (PDF files and JSONL) and place them in the appropriate directory. +Update the paths in your `.env.financebench`: + +- `FB_PDF_DIR` — path to the directory containing the 10-K/10-Q PDF files +- `FB_QUESTIONS_FILE` — path to `financebench_open_source.jsonl` + +### 3. Initialize Workspace + +Initialize the Sirchmunk workspace pointing to the PDF corpus directory: + +```bash +sirchmunk init +``` -# Download FinanceBench data -# Place financebench_open_source.jsonl in ./data/ -# Place PDF corpus (41 files) in ./data/pdfs/ +### 4. Compile Knowledge Base + +Compile the corpus to build the knowledge base for retrieval: + +```bash +sirchmunk compile --paths /path/to/financebench/pdf_files ``` -### 2. Run +> **Note:** The compile step may take some time depending on the corpus size and your LLM provider's rate limits. For FinanceBench's ~41 PDFs (10-K/10-Q filings), expect 10-30 minutes. + +### 5. Configure Environment + +```bash +cp .env.example .env.financebench +# Edit .env.financebench with your API keys and paths +``` + +## Quick Start + +### 1. Run ```bash # Run full benchmark (150 questions) python run_benchmark.py # Run with custom config and question limit -python run_benchmark.py --env .env.financebench --limit 20 +python run_benchmark.py --env .env.custom --limit 20 ``` -### 3. Analyze +### 2. Analyze ```bash # Analyze a completed run diff --git a/benchmarks/financebench/evaluate.py b/benchmarks/financebench/evaluate.py index 3e78636..e22bf07 100644 --- a/benchmarks/financebench/evaluate.py +++ b/benchmarks/financebench/evaluate.py @@ -349,11 +349,11 @@ def _breakdown(results: List[Dict[str, Any]], key: str) -> Dict[str, Dict[str, A """Compute per-group accuracy / hallucination / refusal breakdown.""" groups: dict[str, list[dict]] = defaultdict(list) for r in results: - group = r.get(key, "unknown") + group = r.get(key) or "unknown" groups[group].append(r) out: dict[str, dict] = {} - for group, items in sorted(groups.items()): + for group, items in sorted(groups.items(), key=lambda x: (x[0] is None, x[0] or "")): g_n = len(items) g_correct = sum(1 for r in items if r.get("classification") == "correct") g_halluc = sum( From 9441ef23af039db48e0e035aa4d8e4dac9223498 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E7=8F=AD=E6=89=AC?= Date: Thu, 16 Apr 2026 19:20:44 +0800 Subject: [PATCH 17/56] refactor config for finbench --- .gitignore | 3 +- benchmarks/financebench/README.md | 27 +++++++++-- benchmarks/financebench/config.py | 59 ++++++++++++++++++------ benchmarks/financebench/run_benchmark.py | 14 ++++++ benchmarks/financebench/runner.py | 3 +- 5 files changed, 86 insertions(+), 20 deletions(-) diff --git a/.gitignore b/.gitignore index dbd34eb..f79f03f 100644 --- a/.gitignore +++ b/.gitignore @@ -270,4 +270,5 @@ benchmarks/*/data/ benchmarks/*/.env* benchmarks/*/logs/ benchmarks/*/results/ -benchmarks/*/output/ \ No newline at end of file +benchmarks/*/output/ +benchmarks/*/.work/ diff --git a/benchmarks/financebench/README.md b/benchmarks/financebench/README.md index da6752f..4751508 100644 --- a/benchmarks/financebench/README.md +++ b/benchmarks/financebench/README.md @@ -40,21 +40,26 @@ Update the paths in your `.env.financebench`: ### 3. Initialize Workspace -Initialize the Sirchmunk workspace pointing to the PDF corpus directory: +Initialize the Sirchmunk workspace with an experiment-isolated work path: ```bash -sirchmunk init +cd benchmarks/financebench +sirchmunk init --work-path ./.work ``` +This creates a `.work/` directory under the experiment folder, keeping knowledge base +and cache isolated from the default `~/.sirchmunk`. + ### 4. Compile Knowledge Base -Compile the corpus to build the knowledge base for retrieval: +Compile the PDF corpus into the experiment workspace: ```bash -sirchmunk compile --paths /path/to/financebench/pdf_files +sirchmunk compile --work-path ./.work --paths ``` -> **Note:** The compile step may take some time depending on the corpus size and your LLM provider's rate limits. For FinanceBench's ~41 PDFs (10-K/10-Q filings), expect 10-30 minutes. +> **Note:** The compile step may take some time depending on the corpus size. +> For FinanceBench's ~41 PDFs (10-K/10-Q filings), expect 10-30 minutes. ### 5. Configure Environment @@ -65,6 +70,18 @@ cp .env.example .env.financebench ## Quick Start +### Configuration Priority + +Configuration loads in this order (later overrides earlier): + +1. **Dataclass defaults** — hard-coded in `FinanceBenchConfig` +2. **Platform .env** — `.work/.env` (created by `sirchmunk init`) +3. **Experiment .env** — `.env.financebench` +4. **Command-line** — `--limit N`, `--env ` + +To reuse platform LLM config, leave `LLM_*` commented in `.env.financebench`. +To override, uncomment and set different values. + ### 1. Run ```bash diff --git a/benchmarks/financebench/config.py b/benchmarks/financebench/config.py index f51ea36..5c390ce 100644 --- a/benchmarks/financebench/config.py +++ b/benchmarks/financebench/config.py @@ -6,6 +6,27 @@ from pathlib import Path +def _parse_env_file(path: str) -> dict[str, str]: + """Parse a .env file into a dict, handling comments, blank lines, and quotes.""" + result: dict[str, str] = {} + p = Path(path) + if not p.exists(): + return result + for line in p.read_text(encoding="utf-8").splitlines(): + line = line.strip() + if not line or line.startswith("#"): + continue + if "=" not in line: + continue + k, v = line.split("=", 1) + v = v.strip() + # Strip surrounding quotes + if len(v) >= 2 and v[0] == v[-1] and v[0] in ('"', "'"): + v = v[1:-1] + result[k.strip()] = v + return result + + @dataclass class FinanceBenchConfig: """All settings for a FinanceBench evaluation run.""" @@ -41,23 +62,34 @@ class FinanceBenchConfig: max_concurrent: int = 3 request_delay: float = 0.5 + # Experiment isolation + work_path: str = "./.work" # Isolated workspace for this experiment + @classmethod def from_env(cls, env_path: str = ".env.financebench") -> "FinanceBenchConfig": - """Load config from .env file with ``os.environ`` fallback.""" - # Read .env file - env_vars: dict[str, str] = {} - p = Path(env_path) - if p.exists(): - for line in p.read_text(encoding="utf-8").splitlines(): - line = line.strip() - if not line or line.startswith("#"): - continue - if "=" in line: - k, v = line.split("=", 1) - env_vars[k.strip()] = v.strip() + """Load config with layer inheritance. + + Priority (highest to lowest): + 1. Experiment .env (.env.financebench) + 2. Platform .env (/.env, if exists) + 3. os.environ + 4. Dataclass defaults + """ + # Step 0: Pre-read experiment env to determine work_path + experiment_vars = _parse_env_file(env_path) + work_path = experiment_vars.get( + "FB_WORK_PATH", os.environ.get("FB_WORK_PATH", "./.work") + ) + + # Step 1: Load platform-level env (/.env) + platform_env_path = Path(work_path) / ".env" + platform_vars = _parse_env_file(str(platform_env_path)) + + # Step 2: Merge — experiment > platform > os.environ > defaults + merged = {**platform_vars, **experiment_vars} def _get(key: str, default: str = "") -> str: - return env_vars.get(key, os.environ.get(key, default)) + return merged.get(key, os.environ.get(key, default)) def _bool(key: str, default: bool = False) -> bool: v = _get(key, str(default)).lower() @@ -97,4 +129,5 @@ def _float(key: str, default: float = 0.0) -> float: extract_answer=_bool("FB_EXTRACT_ANSWER", True), max_concurrent=_int("FB_MAX_CONCURRENT", 3), request_delay=_float("FB_REQUEST_DELAY", 0.5), + work_path=work_path, ) diff --git a/benchmarks/financebench/run_benchmark.py b/benchmarks/financebench/run_benchmark.py index ef0df8f..c9f5b26 100644 --- a/benchmarks/financebench/run_benchmark.py +++ b/benchmarks/financebench/run_benchmark.py @@ -187,6 +187,20 @@ def main() -> None: log_path = setup_logging(cfg.output_dir) logger = logging.getLogger("financebench") + # Print config source info + work_env = Path(cfg.work_path) / ".env" + logger.info("=" * 50) + logger.info("FinanceBench Configuration") + logger.info("=" * 50) + logger.info(" Experiment env : %s", args.env) + logger.info(" Platform env : %s (%s)", work_env, "found" if work_env.exists() else "not found") + logger.info(" Work path : %s", Path(cfg.work_path).resolve()) + logger.info(" LLM : %s @ %s", cfg.llm_model, cfg.llm_base_url) + logger.info(" Eval mode : %s", cfg.eval_mode) + logger.info(" Search mode : %s, Top-K: %d", cfg.mode, cfg.top_k_files) + logger.info(" LLM Judge : %s", "enabled" if cfg.enable_llm_judge else "disabled") + logger.info("=" * 50) + # 3. Load data loader = FinanceBenchLoader(cfg.data_dir, cfg.pdf_dir) questions = loader.load_questions() diff --git a/benchmarks/financebench/runner.py b/benchmarks/financebench/runner.py index 64d709f..b95f7ca 100644 --- a/benchmarks/financebench/runner.py +++ b/benchmarks/financebench/runner.py @@ -248,7 +248,8 @@ async def run_batch( base_url=cfg.llm_base_url, model=cfg.llm_model, ) - searcher = AgenticSearch(llm=llm, reuse_knowledge=False, verbose=False) + work_path = str(Path(cfg.work_path).resolve()) + searcher = AgenticSearch(llm=llm, work_path=work_path, reuse_knowledge=False, verbose=False) loader = FinanceBenchLoader(data_dir=cfg.data_dir, pdf_dir=cfg.pdf_dir) semaphore = asyncio.Semaphore(cfg.max_concurrent) From f1f86fab5ce18c245ad46d2de817eb66bab444b7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E7=8F=AD=E6=89=AC?= Date: Fri, 17 Apr 2026 10:41:50 +0800 Subject: [PATCH 18/56] refactor financebench readme --- benchmarks/financebench/README.md | 161 +++++++++++++++++++++++++----- 1 file changed, 134 insertions(+), 27 deletions(-) diff --git a/benchmarks/financebench/README.md b/benchmarks/financebench/README.md index 4751508..d6c95b0 100644 --- a/benchmarks/financebench/README.md +++ b/benchmarks/financebench/README.md @@ -22,65 +22,172 @@ with **150 expert-annotated questions** across **40+ US public companies** (10-K ## Prerequisites -### 1. Install Sirchmunk +### Step 1: Install Sirchmunk -Make sure Sirchmunk is installed and accessible: +Install Sirchmunk from the repository root so that the `sirchmunk` CLI is available: ```bash +# From repository root pip install -e . ``` -### 2. Prepare Corpus +Verify the installation: -Download the FinanceBench dataset (PDF files and JSONL) and place them in the appropriate directory. -Update the paths in your `.env.financebench`: +```bash +sirchmunk --version +``` + +### Step 2: Prepare Dataset + +Download the [FinanceBench](https://huggingface.co/datasets/PatronusAI/financebench) +dataset and place the files under `benchmarks/financebench/data/`: -- `FB_PDF_DIR` — path to the directory containing the 10-K/10-Q PDF files -- `FB_QUESTIONS_FILE` — path to `financebench_open_source.jsonl` +``` +data/ +├── financebench_open_source.jsonl # 150 expert-annotated QA pairs +└── pdfs/ # 41 SEC-filing PDFs (10-K / 10-Q) + ├── 3M_2018_10K.pdf + ├── AMCOR_2023_10K.pdf + └── ... +``` -### 3. Initialize Workspace +Each PDF filename must match the `doc_name` field in the JSONL file. -Initialize the Sirchmunk workspace with an experiment-isolated work path: +### Step 3: Initialize Experiment Workspace + +Initialize an isolated workspace for this experiment. This keeps the knowledge base +and cache separate from the default `~/.sirchmunk`: ```bash cd benchmarks/financebench sirchmunk init --work-path ./.work ``` -This creates a `.work/` directory under the experiment folder, keeping knowledge base -and cache isolated from the default `~/.sirchmunk`. +This creates a `.work/` directory containing a **platform .env** file (`.work/.env`). + +**Configure the platform .env** (`.work/.env`): -### 4. Compile Knowledge Base +This file controls the LLM provider used by Sirchmunk's search engine. +You **must** set valid LLM credentials here before proceeding. -Compile the PDF corpus into the experiment workspace: +| Variable | Required | Description | Example | +|----------|----------|-------------|---------| +| `LLM_API_KEY` | **Yes** | API key for the LLM provider | `sk-xxx` | +| `LLM_BASE_URL` | **Yes** | LLM API endpoint | `https://dashscope.aliyuncs.com/compatible-mode/v1` | +| `LLM_MODEL_NAME` | **Yes** | Model name for search & QA | `qwen3.5-plus` | +| `LLM_TIMEOUT` | No | Request timeout in seconds | `120` | ```bash -sirchmunk compile --work-path ./.work --paths +# Edit the platform .env +vi .work/.env ``` -> **Note:** The compile step may take some time depending on the corpus size. -> For FinanceBench's ~41 PDFs (10-K/10-Q filings), expect 10-30 minutes. +### Step 4: Compile Knowledge Base -### 5. Configure Environment +Compile the PDF corpus into the experiment workspace so that Sirchmunk can search it: + +```bash +sirchmunk compile --work-path ./.work --paths ./data/pdfs +``` + +> **Note:** This step parses, chunks, and indexes all PDFs. +> For FinanceBench's ~41 PDFs (10-K/10-Q filings), expect 10–30 minutes. + +### Step 5: Configure Experiment + +Create the **experiment .env** from the template: ```bash cp .env.example .env.financebench -# Edit .env.financebench with your API keys and paths ``` -## Quick Start +**Configure the experiment .env** (`.env.financebench`): + +This file controls FinanceBench-specific evaluation parameters. + +#### Dataset Paths -### Configuration Priority +| Variable | Required | Description | Default | +|----------|----------|-------------|---------| +| `FB_WORK_PATH` | No | Isolated workspace path | `./.work` | +| `FB_DATA_DIR` | **Yes** | Directory containing `financebench_open_source.jsonl` | `./data` | +| `FB_PDF_DIR` | **Yes** | Directory containing the 41 PDF files | `./data/pdfs` | +| `FB_OUTPUT_DIR` | No | Results output directory | `./output` | -Configuration loads in this order (later overrides earlier): +#### Dataset Settings -1. **Dataclass defaults** — hard-coded in `FinanceBenchConfig` -2. **Platform .env** — `.work/.env` (created by `sirchmunk init`) -3. **Experiment .env** — `.env.financebench` -4. **Command-line** — `--limit N`, `--env ` +| Variable | Required | Description | Default | +|----------|----------|-------------|---------| +| `FB_LIMIT` | No | Number of questions to evaluate (`0` = all 150) | `0` | +| `FB_SEED` | No | Random seed for reproducibility | `42` | + +#### Search Settings + +| Variable | Required | Description | Default | +|----------|----------|-------------|---------| +| `FB_MODE` | No | Search mode: `FAST` or `DEEP` | `FAST` | +| `FB_TOP_K_FILES` | No | Max files returned per search | `5` | +| `FB_MAX_TOKEN_BUDGET` | No | Token budget for search context | `128000` | +| `FB_ENABLE_DIR_SCAN` | No | Enable directory-level scanning | `true` | + +#### Evaluation Settings + +| Variable | Required | Description | Default | +|----------|----------|-------------|---------| +| `FB_EVAL_MODE` | No | `singleDoc` (per-PDF) or `sharedCorpus` (all PDFs) | `singleDoc` | +| `FB_ENABLE_LLM_JUDGE` | No | Enable LLM Judge for semantic equivalence | `true` | +| `FB_EXTRACT_ANSWER` | No | Extract short answer from verbose response | `true` | + +#### Concurrency Settings + +| Variable | Required | Description | Default | +|----------|----------|-------------|---------| +| `FB_MAX_CONCURRENT` | No | Max concurrent evaluation requests | `3` | +| `FB_REQUEST_DELAY` | No | Delay between requests in seconds | `0.5` | + +**Optional LLM Override**: If you want this experiment to use a **different** LLM +than the platform config, uncomment the `LLM_*` lines in `.env.financebench`. +Otherwise, the experiment inherits LLM settings from `.work/.env`. + +```bash +# Edit the experiment .env +vi .env.financebench +``` + +## Configuration Architecture + +Configuration loads with layered inheritance (highest priority wins): + +``` +Priority (highest → lowest): +┌──────────────────────────────────┐ +│ Command-line args │ ← --limit N, --env +├──────────────────────────────────┤ +│ .env.financebench (experiment) │ ← FB_* params + optional LLM override +├──────────────────────────────────┤ +│ .work/.env (platform) │ ← LLM_API_KEY, LLM_MODEL_NAME, etc. +├──────────────────────────────────┤ +│ Environment variables │ ← os.environ fallback +├──────────────────────────────────┤ +│ Defaults │ ← Hard-coded in FinanceBenchConfig +└──────────────────────────────────┘ +``` -To reuse platform LLM config, leave `LLM_*` commented in `.env.financebench`. -To override, uncomment and set different values. +### What Goes Where? + +| Setting | Platform `.work/.env` | Experiment `.env.financebench` | +|---------|:---------------------:|:------------------------------:| +| LLM API Key | ✅ (required) | Only if overriding | +| LLM Model | ✅ (required) | Only if overriding | +| LLM Base URL | ✅ (required) | Only if overriding | +| LLM Timeout | Optional | Only if overriding | +| PDF directory | — | ✅ (required) | +| Data directory | — | ✅ (required) | +| Output directory | — | Optional | +| Eval mode | — | Optional | +| Search mode | — | Optional | +| LLM Judge | — | Optional | +| Concurrency | — | Optional | ### 1. Run From 0e46ef5641adb5e3857f527c04f34fe2128ef834 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E7=8F=AD=E6=89=AC?= Date: Fri, 17 Apr 2026 10:52:09 +0800 Subject: [PATCH 19/56] update readme for finbench --- benchmarks/financebench/README.md | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/benchmarks/financebench/README.md b/benchmarks/financebench/README.md index d6c95b0..e294c7b 100644 --- a/benchmarks/financebench/README.md +++ b/benchmarks/financebench/README.md @@ -60,7 +60,7 @@ and cache separate from the default `~/.sirchmunk`: ```bash cd benchmarks/financebench -sirchmunk init --work-path ./.work +sirchmunk init --work-path .work ``` This creates a `.work/` directory containing a **platform .env** file (`.work/.env`). @@ -70,12 +70,12 @@ This creates a `.work/` directory containing a **platform .env** file (`.work/.e This file controls the LLM provider used by Sirchmunk's search engine. You **must** set valid LLM credentials here before proceeding. -| Variable | Required | Description | Example | -|----------|----------|-------------|---------| -| `LLM_API_KEY` | **Yes** | API key for the LLM provider | `sk-xxx` | +| Variable | Required | Description | Example | +|----------|----------|-------------|-----------------------------------------------------| +| `LLM_API_KEY` | **Yes** | API key for the LLM provider | `sk-xxx` | | `LLM_BASE_URL` | **Yes** | LLM API endpoint | `https://dashscope.aliyuncs.com/compatible-mode/v1` | -| `LLM_MODEL_NAME` | **Yes** | Model name for search & QA | `qwen3.5-plus` | -| `LLM_TIMEOUT` | No | Request timeout in seconds | `120` | +| `LLM_MODEL_NAME` | **Yes** | Model name for search & QA | `qwen3.6-plus` | +| `LLM_TIMEOUT` | No | Request timeout in seconds | `120` | ```bash # Edit the platform .env @@ -87,7 +87,7 @@ vi .work/.env Compile the PDF corpus into the experiment workspace so that Sirchmunk can search it: ```bash -sirchmunk compile --work-path ./.work --paths ./data/pdfs +sirchmunk compile --work-path .work --paths data/pdfs ``` > **Note:** This step parses, chunks, and indexes all PDFs. From 2cf5c378cbd42d1aba82ded34db994a62f56b487 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E7=8F=AD=E6=89=AC?= Date: Fri, 17 Apr 2026 13:59:49 +0800 Subject: [PATCH 20/56] enhance tree indexes usage for search pipeline --- benchmarks/financebench/README.md | 2 +- src/sirchmunk/search.py | 291 +++++++++++++++++++++++++++++- 2 files changed, 285 insertions(+), 8 deletions(-) diff --git a/benchmarks/financebench/README.md b/benchmarks/financebench/README.md index e294c7b..95d04e7 100644 --- a/benchmarks/financebench/README.md +++ b/benchmarks/financebench/README.md @@ -82,7 +82,7 @@ You **must** set valid LLM credentials here before proceeding. vi .work/.env ``` -### Step 4: Compile Knowledge Base +### Step 4: Knowledge Compiling Compile the PDF corpus into the experiment workspace so that Sirchmunk can search it: diff --git a/src/sirchmunk/search.py b/src/sirchmunk/search.py index 9b7bf47..2e44449 100644 --- a/src/sirchmunk/search.py +++ b/src/sirchmunk/search.py @@ -1427,6 +1427,9 @@ async def _search_deep( ) _llm_usage_start = len(self.llm_usages) + # --- Adaptive compile artifact detection (shared with FAST) --- + artifacts = self._detect_compile_artifacts() + # ============================================================== # Phase 0a: Direct document analysis (intent-gated short-circuit) # ============================================================== @@ -1460,7 +1463,9 @@ async def _search_deep( self._probe_knowledge_cache(query), self._load_spec_context(paths, stale_hours=spec_stale_hours), self._probe_tree_index(query), - self._probe_compile_hints(initial_keywords if initial_keywords else [query]), + self._probe_compile_hints([query]), # query-level hints; keyword-level runs post-Phase 1 + self._probe_summary_index(query, artifacts), # GAP 2: zero-LLM BM25 + self._probe_catalog_for_deep(query, artifacts), # GAP 4: zero-LLM keyword overlap return_exceptions=True, ) @@ -1470,8 +1475,10 @@ async def _search_deep( spec_context = phase1_results[3] if not isinstance(phase1_results[3], Exception) else "" tree_hits = phase1_results[4] if not isinstance(phase1_results[4], Exception) else [] compile_hints = phase1_results[5] if not isinstance(phase1_results[5], Exception) else CompileHints([], []) + summary_index_hits = phase1_results[6] if not isinstance(phase1_results[6], Exception) else [] + catalog_deep_hits = phase1_results[7] if not isinstance(phase1_results[7], Exception) else [] - for i, label in enumerate(["keywords", "dir_scan", "knowledge", "spec_cache", "tree_index", "compile_hints"]): + for i, label in enumerate(["keywords", "dir_scan", "knowledge", "spec_cache", "tree_index", "compile_hints", "summary_index", "catalog_deep"]): if isinstance(phase1_results[i], Exception): await self._logger.warning(f"[Phase 1] {label} probe failed: {phase1_results[i]}") @@ -1508,6 +1515,8 @@ async def _search_deep( f"knowledge_files={len(knowledge_probe.file_paths)}, " f"tree_hits={len(tree_hits)}, " f"compile_hints={len(compile_hints.file_paths)}, " + f"summary_index={len(summary_index_hits)}, " + f"catalog_deep={len(catalog_deep_hits)}, " f"soft_hit={'YES' if soft_hit else 'NO'}, " f"spec_cache={'YES' if spec_context else 'NO'}" ) @@ -1583,7 +1592,7 @@ async def _search_deep( if soft_hit: extra_knowledge_files = soft_hit.file_paths + extra_knowledge_files merged_files = self._merge_file_paths( - keyword_files=list(tree_hits) + compile_hints.file_paths + keyword_files, + keyword_files=list(tree_hits) + catalog_deep_hits + compile_hints.file_paths + summary_index_hits + keyword_files, dir_scan_files=dir_scan_files, knowledge_hits=extra_knowledge_files, ) @@ -1627,6 +1636,22 @@ async def _search_deep( answer: str = "" should_save: bool = True + # Inject catalog context for wiki-enhanced answer (GAP 4) + if artifacts and artifacts.catalog_map and cluster and cluster.content: + _catalog_ctx_parts = [] + for fp in (cluster.search_results or merged_files)[:3]: + ctx = self._build_answer_context(fp, artifacts) + if ctx: + _catalog_ctx_parts.append(ctx) + if _catalog_ctx_parts: + _catalog_context = "\n".join(_catalog_ctx_parts) + if isinstance(cluster.content, list): + cluster.content = "\n".join(cluster.content) + cluster.content = f"{cluster.content}\n\n[Document Context]\n{_catalog_context}" + await self._logger.info( + f"[Phase 4] Injected catalog context for {len(_catalog_ctx_parts)} documents" + ) + if cluster and cluster.content: await self._logger.info("[Phase 4] Evidence sufficient, generating summary") answer, should_save, should_answer = await self._summarise_cluster(query, cluster) @@ -2007,6 +2032,10 @@ async def _llm_chunked_summarize(self, combined: str, query: str) -> str: """Maximum character length for a catalog keyword token.""" _CATALOG_SUMMARY_TRUNCATE = 200 """Max chars of catalog summary shown in the listing.""" + _SUMMARY_INDEX_TOP_K = 3 + """Maximum files returned by proactive summary index BM25 probe.""" + _DEEP_CATALOG_TOP_K = 3 + """Maximum files returned by catalog keyword-overlap probe in DEEP mode.""" # --- Tree-guided sampling constants --- _TREE_SAMPLE_MAX_SECTIONS = 3 @@ -2019,6 +2048,8 @@ async def _llm_chunked_summarize(self, combined: str, query: str) -> str: """Maximum number of tree roots to include in FAST Step 1 hints.""" _DEEP_PRE_NAV_MAX_FILES = 3 """Maximum number of tree files to pre-navigate in DEEP Phase 2.5.""" + _FAST_TREE_PROBE_MAX_FILES = 2 + """Maximum files returned by active tree probing in FAST mode.""" _LLM_FALLBACK_EVIDENCE = ( "[No relevant documents found]\n\n" @@ -2106,10 +2137,33 @@ async def _search_fast( if tree_hints: prompt = prompt + tree_hints - resp = await self.llm.achat( + # Step 1 LLM call + compile hints + tree probe run in parallel + # (GAP 3: hints前置化, GAP 1: 树导航主动化) + _step1_llm_task = self.llm.achat( messages=[{"role": "user", "content": prompt}], stream=False, ) + _compile_hints_task = self._probe_compile_hints([query]) + _tree_probe_task = self._probe_tree_for_fast(query, artifacts) + + _parallel_results = await asyncio.gather( + _step1_llm_task, _compile_hints_task, _tree_probe_task, + return_exceptions=True, + ) + resp = _parallel_results[0] + _early_compile_hints = _parallel_results[1] + _tree_probed_files = _parallel_results[2] + + if isinstance(resp, Exception): + await self._logger.warning(f"[FAST:Step1] LLM call failed: {resp}") + return f"Search analysis failed: {resp}", None, context + if isinstance(_early_compile_hints, Exception): + await self._logger.warning(f"[FAST:Step1] Compile hints pre-fetch failed: {_early_compile_hints}") + _early_compile_hints = CompileHints([], []) + if isinstance(_tree_probed_files, Exception): + await self._logger.warning(f"[FAST:Step1] Tree probe failed: {_tree_probed_files}") + _tree_probed_files = [] + self.llm_usages.append(resp.usage) if resp.usage and isinstance(resp.usage, dict): context.add_llm_tokens( @@ -2207,8 +2261,9 @@ async def _search_fast( all_kw_set.add(p) keyword_idfs.setdefault(p, 0.6) - # P4: compile hints from manifest + tree cache - compile_hints = await self._probe_compile_hints(primary + fallback) + # P4: compile hints — pre-fetched (query-level) + keyword-level supplement + _kw_compile_hints = await self._probe_compile_hints(primary + fallback) + compile_hints = self._merge_compile_hints(_early_compile_hints, _kw_compile_hints) for kw in compile_hints.extra_keywords: if kw not in all_kw_set: fallback.append(kw) @@ -2222,6 +2277,17 @@ async def _search_fast( if fp not in seen_hint_paths: seen_hint_paths.add(fp) compile_hint_files.append(fp) + # Active tree probe files: second priority (GAP 1) + for fp in (_tree_probed_files or []): + if fp not in seen_hint_paths: + seen_hint_paths.add(fp) + compile_hint_files.append(fp) + # Summary index BM25 files: proactive zero-LLM discovery (GAP 2) + _summary_hint_files = await self._probe_summary_index(query, artifacts) + for fp in _summary_hint_files: + if fp not in seen_hint_paths: + seen_hint_paths.add(fp) + compile_hint_files.append(fp) if soft_hit: for fp in soft_hit.file_paths: if fp not in seen_hint_paths: @@ -2235,7 +2301,10 @@ async def _search_fast( if compile_hint_files: await self._logger.info( f"[FAST:Step1.5] Compile hints: {len(compile_hint_files)} files " - f"(catalog={len(catalog_routed_files)}, soft={len(soft_hit.file_paths) if soft_hit else 0}), " + f"(catalog={len(catalog_routed_files)}, " + f"tree={len(_tree_probed_files) if _tree_probed_files else 0}, " + f"summary={len(_summary_hint_files)}, " + f"soft={len(soft_hit.file_paths) if soft_hit else 0}), " f"{len(compile_hints.extra_keywords)} extra keywords" ) @@ -4105,6 +4174,214 @@ async def _probe_compile_hints(self, keywords: List[str]) -> CompileHints: extra_keywords=extra_keywords[:10], ) + @staticmethod + def _merge_compile_hints(base: "CompileHints", supplement: "CompileHints") -> "CompileHints": + """Merge two CompileHints, deduplicating file paths and keywords.""" + seen_fps = set(base.file_paths) + merged_fps = list(base.file_paths) + for fp in supplement.file_paths: + if fp not in seen_fps: + seen_fps.add(fp) + merged_fps.append(fp) + seen_kws = set(base.extra_keywords) + merged_kws = list(base.extra_keywords) + for kw in supplement.extra_keywords: + if kw not in seen_kws: + seen_kws.add(kw) + merged_kws.append(kw) + return CompileHints(file_paths=merged_fps[:15], extra_keywords=merged_kws[:10]) + + async def _probe_summary_index( + self, + query: str, + artifacts: Optional["CompileArtifacts"] = None, + ) -> List[str]: + """Zero-LLM file discovery via compile-time summary index (BM25 only). + + Uses the pre-built summary index's BM25 channel to find files whose + summaries are lexically similar to the query. No LLM or embedding + calls — pure local computation. + + Args: + query: User query string. + artifacts: Compile artifacts (uses summary_index field). + + Returns: + File paths of top-k matching documents, or empty list. + """ + if artifacts is None or artifacts.summary_index is None: + return [] + + try: + from sirchmunk.utils.tokenizer_util import TokenizerUtil + _tokenizer = TokenizerUtil() + query_tokens = _tokenizer.segment(query) + + if not query_tokens: + return [] + + # BM25-only search: pass query_embedding=None to skip embedding channel + results = artifacts.summary_index.search( + query_embedding=None, + query_tokens=query_tokens, + top_k=self._SUMMARY_INDEX_TOP_K, + ) + + file_paths = [ + fp for fp, score in results + if score > 0.0 and Path(fp).exists() + ] + + if file_paths: + await self._logger.info( + f"[SummaryIndex:BM25] Found {len(file_paths)} files " + f"from {artifacts.summary_index.num_entries} indexed docs" + ) + return file_paths + except Exception as exc: + await self._logger.warning(f"[SummaryIndex:BM25] Probe failed: {exc}") + return [] + + async def _probe_catalog_for_deep( + self, + query: str, + artifacts: Optional["CompileArtifacts"] = None, + ) -> List[str]: + """Zero-LLM file discovery via document catalog keyword overlap. + + Scores each catalog entry by counting query token overlap with the + document summary. Returns top-k file paths sorted by overlap score. + + Args: + query: User query string. + artifacts: Compile artifacts (uses catalog field). + + Returns: + File paths of top-k matching documents, or empty list. + """ + if not artifacts or not artifacts.catalog: + return [] + + try: + query_tokens = self._tokenize_for_matching(query.lower()) + if not query_tokens: + return [] + + scored: List[Tuple[str, float]] = [] + for entry in artifacts.catalog: + fp = entry.get("path", "") + if not fp or not Path(fp).exists(): + continue + summary = (entry.get("summary", "") or "").lower() + name = (entry.get("name", "") or "").lower() + doc_tokens = self._tokenize_for_matching(f"{name} {summary}") + overlap = len(query_tokens & doc_tokens) + if overlap > 0: + # Normalize by query length to avoid bias toward long summaries + score = overlap / max(1, len(query_tokens)) + scored.append((fp, score)) + + if not scored: + return [] + + scored.sort(key=lambda x: x[1], reverse=True) + result_paths = [fp for fp, _ in scored[:self._DEEP_CATALOG_TOP_K]] + + if result_paths: + await self._logger.info( + f"[DEEP:CatalogProbe] Found {len(result_paths)} files " + f"from {len(artifacts.catalog)} catalog entries" + ) + return result_paths + except Exception as exc: + await self._logger.warning(f"[DEEP:CatalogProbe] Failed: {exc}") + return [] + + async def _probe_tree_for_fast( + self, query: str, artifacts: Optional["CompileArtifacts"] = None, + ) -> List[str]: + """Active tree-based file discovery for FAST mode (1 LLM call). + + Lightweight wrapper around tree root selection logic. When compiled + tree indices are available and cover more than 2 files, asks the LLM + to select the most relevant 1-2 documents from root summaries. + + Returns file paths of selected documents, or empty list when trees + are unavailable or cover too few files to justify an LLM call. + """ + if not artifacts or len(artifacts.tree_available_paths) <= 2: + return [] + + tree_cache = self.work_path / ".cache" / "compile" / "trees" + if not tree_cache.exists(): + return [] + + try: + from sirchmunk.learnings.tree_indexer import DocumentTree + + trees: List[DocumentTree] = [] + for tree_file in sorted(tree_cache.glob("*.json"))[:self._TREE_CACHE_SCAN_LIMIT]: + try: + t = DocumentTree.from_json( + tree_file.read_text(encoding="utf-8") + ) + if t.root and t.file_path and Path(t.file_path).exists(): + trees.append(t) + except Exception: + continue + + if not trees: + return [] + + # Few trees: return all without LLM + if len(trees) <= self._FAST_TREE_PROBE_MAX_FILES: + return [t.file_path for t in trees] + + # LLM-driven selection among tree roots + listing = "\n".join( + f"[{i}] {Path(t.file_path).name}: {(t.root.summary or '')[:200]}" + for i, t in enumerate(trees) + ) + prompt = ( + f'Given the query: "{query}"\n\n' + f"Select the 1-{self._FAST_TREE_PROBE_MAX_FILES} most relevant documents:\n" + f"{listing}\n\n" + f"Return ONLY a JSON array of index numbers, e.g. [0, 2]" + ) + resp = await self.llm.achat([{"role": "user", "content": prompt}]) + self.llm_usages.append(resp.usage) + + selected_indices: List[int] = [] + try: + raw = resp.content.strip() + m = re.search(r"\[[\d\s,]+\]", raw) + if m: + selected_indices = [ + idx for idx in json.loads(m.group()) + if isinstance(idx, int) and 0 <= idx < len(trees) + ] + except (json.JSONDecodeError, TypeError): + pass + + if not selected_indices: + selected_indices = list(range(min(self._FAST_TREE_PROBE_MAX_FILES, len(trees)))) + + result_paths = [ + trees[idx].file_path + for idx in selected_indices[:self._FAST_TREE_PROBE_MAX_FILES] + if Path(trees[idx].file_path).exists() + ] + + if result_paths: + await self._logger.info( + f"[FAST:TreeProbe] Selected {len(result_paths)} files " + f"from {len(trees)} tree indices" + ) + return result_paths + except Exception as exc: + await self._logger.warning(f"[FAST:TreeProbe] Failed: {exc}") + return [] + @staticmethod async def _async_noop(default=None): """No-op coroutine used as placeholder in gather().""" From c0b0db5b1654ce510a185c6f0c80dae18b408d3d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E7=8F=AD=E6=89=AC?= Date: Fri, 17 Apr 2026 14:11:45 +0800 Subject: [PATCH 21/56] fix issues --- src/sirchmunk/search.py | 205 ++++++++++++++++++---------------------- 1 file changed, 93 insertions(+), 112 deletions(-) diff --git a/src/sirchmunk/search.py b/src/sirchmunk/search.py index 2e44449..7506c0c 100644 --- a/src/sirchmunk/search.py +++ b/src/sirchmunk/search.py @@ -2050,6 +2050,10 @@ async def _llm_chunked_summarize(self, combined: str, query: str) -> str: """Maximum number of tree files to pre-navigate in DEEP Phase 2.5.""" _FAST_TREE_PROBE_MAX_FILES = 2 """Maximum files returned by active tree probing in FAST mode.""" + _DEEP_TREE_PROBE_MAX_FILES = 3 + """Maximum files returned by tree index probing in DEEP mode.""" + _TREE_ROOT_HINT_TRUNCATE = 150 + """Max chars of tree root summary in Step 1 structure hints.""" _LLM_FALLBACK_EVIDENCE = ( "[No relevant documents found]\n\n" @@ -3312,7 +3316,7 @@ def _build_tree_root_hints(self, artifacts: CompileArtifacts) -> str: tree = indexer.load_tree(fp) if tree and tree.root and tree.root.summary: name = Path(fp).name - hints.append(f"[{i}] {name}: {tree.root.summary[:150]}") + hints.append(f"[{i}] {name}: {tree.root.summary[:self._TREE_ROOT_HINT_TRUNCATE]}") if not hints: return "" return "\nDocument structure hints:\n" + "\n".join(hints) + "\n" @@ -4017,80 +4021,110 @@ def _collect_cluster(c: KnowledgeCluster) -> None: except Exception: return empty - async def _probe_tree_index(self, query: str) -> List[str]: - """LLM-driven file discovery via compiled tree root summaries (PageIndex). + def _load_cached_trees(self) -> list: + """Load DocumentTree objects from the tree cache directory. - Loads all cached document trees, presents their root summaries to the - LLM, and asks it to select the most relevant 1-3 documents. For - selected trees, optionally drills one level deeper into children. - - Returns file paths of the most relevant documents. + Returns a list of ``DocumentTree`` instances whose file paths exist + on disk. Returns an empty list when the tree cache is absent or + contains no valid entries. """ tree_cache = self.work_path / ".cache" / "compile" / "trees" if not tree_cache.exists(): return [] - try: from sirchmunk.learnings.tree_indexer import DocumentTree - trees: List[DocumentTree] = [] - for tree_file in sorted(tree_cache.glob("*.json"))[:50]: + trees = [] + for tree_file in sorted(tree_cache.glob("*.json"))[:self._TREE_CACHE_SCAN_LIMIT]: try: t = DocumentTree.from_json( tree_file.read_text(encoding="utf-8") ) - if t.root and t.file_path: + if t.root and t.file_path and Path(t.file_path).exists(): trees.append(t) except Exception: continue + return trees + except Exception: + return [] - if not trees: - return [] + async def _llm_select_from_trees( + self, query: str, trees: list, max_select: int, + ) -> List[str]: + """LLM-driven file selection from tree root summaries. - # If few trees, return all without LLM - if len(trees) <= 2: - return [t.file_path for t in trees if Path(t.file_path).exists()] + Presents root summaries to the LLM and returns the selected file + paths. When the number of trees is at most *max_select*, returns + all paths without an LLM call. - # LLM-driven selection among tree roots - listing = "\n".join( - f"[{i}] {Path(t.file_path).name}: {(t.root.summary or '')[:200]}" - for i, t in enumerate(trees) - ) - prompt = ( - f'Given the query: "{query}"\n\n' - f"Select the 1-3 most relevant documents (by index number):\n{listing}\n\n" - f"Return ONLY a JSON array of index numbers, e.g. [0, 2]" - ) - resp = await self.llm.achat([{"role": "user", "content": prompt}]) - self.llm_usages.append(resp.usage) + Args: + query: User query string. + trees: List of ``DocumentTree`` objects (pre-loaded). + max_select: Maximum number of files to select. - selected_indices: List[int] = [] - try: - raw = resp.content.strip() - m = re.search(r"\[[\d\s,]+\]", raw) - if m: - selected_indices = [ - idx for idx in json.loads(m.group()) - if isinstance(idx, int) and 0 <= idx < len(trees) - ] - except (json.JSONDecodeError, TypeError): - pass + Returns: + Selected file paths, or empty list. + """ + if not trees: + return [] + if len(trees) <= max_select: + return [t.file_path for t in trees] - if not selected_indices: - selected_indices = list(range(min(2, len(trees)))) + listing = "\n".join( + f"[{i}] {Path(t.file_path).name}: " + f"{(t.root.summary or '')[:self._CATALOG_SUMMARY_TRUNCATE]}" + for i, t in enumerate(trees) + ) + prompt = ( + f'Given the query: "{query}"\n\n' + f"Select the 1-{max_select} most relevant documents " + f"(by index number):\n{listing}\n\n" + f"Return ONLY a JSON array of index numbers, e.g. [0, 2]" + ) + resp = await self.llm.achat([{"role": "user", "content": prompt}]) + self.llm_usages.append(resp.usage) - result_paths: List[str] = [] - for idx in selected_indices: - fp = trees[idx].file_path - if Path(fp).exists(): - result_paths.append(fp) + selected_indices: List[int] = [] + try: + raw = resp.content.strip() + m = re.search(r"\[[\d\s,]+\]", raw) + if m: + selected_indices = [ + idx for idx in json.loads(m.group()) + if isinstance(idx, int) and 0 <= idx < len(trees) + ] + except (json.JSONDecodeError, TypeError): + pass - if result_paths: + if not selected_indices: + selected_indices = list(range(min(max_select, len(trees)))) + + return [ + trees[idx].file_path + for idx in selected_indices[:max_select] + if Path(trees[idx].file_path).exists() + ] + + async def _probe_tree_index(self, query: str) -> List[str]: + """LLM-driven file discovery via compiled tree root summaries (PageIndex). + + Loads all cached document trees, presents their root summaries to the + LLM, and asks it to select the most relevant documents. Returns file + paths of the most relevant documents. + """ + try: + trees = self._load_cached_trees() + if not trees: + return [] + result = await self._llm_select_from_trees( + query, trees, max_select=self._DEEP_TREE_PROBE_MAX_FILES, + ) + if result: await self._logger.info( - f"[Probe:TreeIndex] LLM selected {len(result_paths)} documents " + f"[Probe:TreeIndex] LLM selected {len(result)} documents " f"from {len(trees)} tree indices" ) - return result_paths + return result except Exception: return [] @@ -4302,9 +4336,9 @@ async def _probe_tree_for_fast( ) -> List[str]: """Active tree-based file discovery for FAST mode (1 LLM call). - Lightweight wrapper around tree root selection logic. When compiled - tree indices are available and cover more than 2 files, asks the LLM - to select the most relevant 1-2 documents from root summaries. + When compiled tree indices are available and cover more than 2 files, + asks the LLM to select the most relevant 1-2 documents from root + summaries. Delegates to the shared ``_llm_select_from_trees`` helper. Returns file paths of selected documents, or empty list when trees are unavailable or cover too few files to justify an LLM call. @@ -4312,72 +4346,19 @@ async def _probe_tree_for_fast( if not artifacts or len(artifacts.tree_available_paths) <= 2: return [] - tree_cache = self.work_path / ".cache" / "compile" / "trees" - if not tree_cache.exists(): - return [] - try: - from sirchmunk.learnings.tree_indexer import DocumentTree - - trees: List[DocumentTree] = [] - for tree_file in sorted(tree_cache.glob("*.json"))[:self._TREE_CACHE_SCAN_LIMIT]: - try: - t = DocumentTree.from_json( - tree_file.read_text(encoding="utf-8") - ) - if t.root and t.file_path and Path(t.file_path).exists(): - trees.append(t) - except Exception: - continue - + trees = self._load_cached_trees() if not trees: return [] - - # Few trees: return all without LLM - if len(trees) <= self._FAST_TREE_PROBE_MAX_FILES: - return [t.file_path for t in trees] - - # LLM-driven selection among tree roots - listing = "\n".join( - f"[{i}] {Path(t.file_path).name}: {(t.root.summary or '')[:200]}" - for i, t in enumerate(trees) - ) - prompt = ( - f'Given the query: "{query}"\n\n' - f"Select the 1-{self._FAST_TREE_PROBE_MAX_FILES} most relevant documents:\n" - f"{listing}\n\n" - f"Return ONLY a JSON array of index numbers, e.g. [0, 2]" + result = await self._llm_select_from_trees( + query, trees, max_select=self._FAST_TREE_PROBE_MAX_FILES, ) - resp = await self.llm.achat([{"role": "user", "content": prompt}]) - self.llm_usages.append(resp.usage) - - selected_indices: List[int] = [] - try: - raw = resp.content.strip() - m = re.search(r"\[[\d\s,]+\]", raw) - if m: - selected_indices = [ - idx for idx in json.loads(m.group()) - if isinstance(idx, int) and 0 <= idx < len(trees) - ] - except (json.JSONDecodeError, TypeError): - pass - - if not selected_indices: - selected_indices = list(range(min(self._FAST_TREE_PROBE_MAX_FILES, len(trees)))) - - result_paths = [ - trees[idx].file_path - for idx in selected_indices[:self._FAST_TREE_PROBE_MAX_FILES] - if Path(trees[idx].file_path).exists() - ] - - if result_paths: + if result: await self._logger.info( - f"[FAST:TreeProbe] Selected {len(result_paths)} files " + f"[FAST:TreeProbe] Selected {len(result)} files " f"from {len(trees)} tree indices" ) - return result_paths + return result except Exception as exc: await self._logger.warning(f"[FAST:TreeProbe] Failed: {exc}") return [] From e8184d06de2f8361fe6374fe31002d7404b039a6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E7=8F=AD=E6=89=AC?= Date: Fri, 17 Apr 2026 14:39:21 +0800 Subject: [PATCH 22/56] update tree index --- src/sirchmunk/learnings/tree_indexer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/sirchmunk/learnings/tree_indexer.py b/src/sirchmunk/learnings/tree_indexer.py index abf5459..26787eb 100644 --- a/src/sirchmunk/learnings/tree_indexer.py +++ b/src/sirchmunk/learnings/tree_indexer.py @@ -19,7 +19,7 @@ from sirchmunk.utils.file_utils import get_fast_hash # File-size threshold: skip tree indexing for small files -_TREE_MIN_CHARS = 20_000 # 20 K characters (lowered from 50K for broader coverage) +_TREE_MIN_CHARS = 10_000 # 10 K characters (lowered from 20K for broader coverage) # Adaptive depth thresholds: (min_chars, max_depth) — evaluated top-down; # **must** be sorted by min_chars descending so the first match wins. From dc27ed9eeccea93a720f2dc00d4749c238a7a440 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E7=8F=AD=E6=89=AC?= Date: Fri, 17 Apr 2026 16:18:52 +0800 Subject: [PATCH 23/56] update finbench readme --- benchmarks/financebench/README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/benchmarks/financebench/README.md b/benchmarks/financebench/README.md index 95d04e7..9c23648 100644 --- a/benchmarks/financebench/README.md +++ b/benchmarks/financebench/README.md @@ -91,7 +91,7 @@ sirchmunk compile --work-path .work --paths data/pdfs ``` > **Note:** This step parses, chunks, and indexes all PDFs. -> For FinanceBench's ~41 PDFs (10-K/10-Q filings), expect 10–30 minutes. +> For FinanceBench's all PDFs, expect hours of processing time, depending on your LLM speed and compute resources. ### Step 5: Configure Experiment From 8723b85c9394d7d1cecd7af2e7d348b3e8526757 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E7=8F=AD=E6=89=AC?= Date: Fri, 17 Apr 2026 16:59:00 +0800 Subject: [PATCH 24/56] update finbench readme --- benchmarks/financebench/README.md | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/benchmarks/financebench/README.md b/benchmarks/financebench/README.md index 9c23648..9bb0134 100644 --- a/benchmarks/financebench/README.md +++ b/benchmarks/financebench/README.md @@ -93,6 +93,18 @@ sirchmunk compile --work-path .work --paths data/pdfs > **Note:** This step parses, chunks, and indexes all PDFs. > For FinanceBench's all PDFs, expect hours of processing time, depending on your LLM speed and compute resources. +#### Shallow Compile (Recommended for First Run) + +Use `--shallow` to skip tree indexing and only generate Summary + Topics. +This reduces LLM calls dramatically and achieves **5–9× speedup**: + +```bash +sirchmunk compile --work-path .work --paths data/pdfs --shallow +``` + +> **Tip:** `--shallow` is ideal for quickly compiling a large corpus on the first pass. +> You can run a normal (full) compile later to incrementally add tree indexes. + ### Step 5: Configure Experiment Create the **experiment .env** from the template: From ca9a609d61564277c9384d9fd193d0685546a7b2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E7=8F=AD=E6=89=AC?= Date: Fri, 17 Apr 2026 19:15:48 +0800 Subject: [PATCH 25/56] update should answer thres --- src/sirchmunk/llm/prompts.py | 6 +- src/sirchmunk/search.py | 193 +++++++++++++++++++++++++++++++++-- 2 files changed, 188 insertions(+), 11 deletions(-) diff --git a/src/sirchmunk/llm/prompts.py b/src/sirchmunk/llm/prompts.py index 8df111d..27338a2 100644 --- a/src/sirchmunk/llm/prompts.py +++ b/src/sirchmunk/llm/prompts.py @@ -189,7 +189,7 @@ def generate_keyword_extraction_prompt(num_levels: int = 3) -> str: 2. Is the content meaningful and not just error messages or "no information found"? 3. Are there sufficient evidences and context to answer the user's query? -- : output "true" only if the evidence is sufficient to answer the query. +- : output "true" if the evidence contains relevant information that can help answer the query, even if it requires reasoning, computation, or interpretation. Only output "false" if the evidence is clearly irrelevant or contains no useful information for the query. - : output "true" only if the evidence is sufficient AND the result is worth caching. - If evidence is insufficient or irrelevant, both SHOULD_ANSWER and SHOULD_SAVE MUST be "false". @@ -437,7 +437,7 @@ def generate_keyword_extraction_prompt(num_levels: int = 3) -> str: 2. Is the content meaningful and not just error messages or "no information found"? 3. Are there sufficient evidences and context to answer the user's query? -- : output "true" only if the evidence is sufficient to answer the query. +- : output "true" if the evidence contains relevant information that can help answer the query, even if it requires reasoning, computation, or interpretation. Only output "false" if the evidence is clearly irrelevant or contains no useful information for the query. - : output "true" only if the evidence is sufficient AND the result is worth caching. - If evidence is insufficient or irrelevant, both SHOULD_ANSWER and SHOULD_SAVE MUST be "false". @@ -476,7 +476,7 @@ def generate_keyword_extraction_prompt(num_levels: int = 3) -> str: 2. Is the content meaningful and not just error messages or "no information found"? 3. Are there sufficient evidences and context to answer the user's query? -- : output "true" only if the evidence is sufficient to answer the query. +- : output "true" if the evidence contains relevant information that can help answer the query, even if it requires reasoning, computation, or interpretation. Only output "false" if the evidence is clearly irrelevant or contains no useful information for the query. - : output "true" only if the evidence is sufficient AND the result is worth caching. - If evidence is insufficient or irrelevant, both SHOULD_ANSWER and SHOULD_SAVE MUST be "false". diff --git a/src/sirchmunk/search.py b/src/sirchmunk/search.py index 7506c0c..a900978 100644 --- a/src/sirchmunk/search.py +++ b/src/sirchmunk/search.py @@ -86,6 +86,14 @@ # Soft-similarity threshold for gradient cluster reuse (P2) _SOFT_SIM_THRESHOLD = 0.65 +# Common English stop-words filtered out during keyword coverage computation. +_STOP_WORDS: frozenset = frozenset({ + "the", "is", "a", "an", "of", "in", "for", "to", "and", "or", + "what", "how", "which", "does", "was", "were", "has", "have", "had", + "do", "did", "are", "be", "been", "by", "with", "from", "this", + "that", "it", "its", "on", "at", "as", "not", "no", +}) + @dataclass class SoftClusterHit: @@ -951,6 +959,105 @@ def _parse_summary_response(llm_response: str) -> Tuple[str, bool, bool]: return summary, should_save, should_answer + # ------------------------------------------------------------------ + # Multi-factor evidence acceptance helpers + # ------------------------------------------------------------------ + + @staticmethod + def _compute_keyword_coverage(query: str, evidence: str) -> float: + """Compute the fraction of query keywords found in the evidence text. + + Tokenises *query* into lowercase alpha-numeric words (length >= 2), + removes common English stop-words, then checks presence in + lower-cased *evidence*. + + Returns: + Coverage ratio in [0.0, 1.0]. Returns 0.0 when no valid + keywords can be extracted from *query*. + """ + tokens = re.findall(r'\b[a-z0-9]{2,}\b', query.lower()) + keywords = [t for t in tokens if t not in _STOP_WORDS] + if not keywords: + return 0.0 + evidence_lower = evidence.lower() + matched = sum(1 for kw in keywords if kw in evidence_lower) + return matched / len(keywords) + + @staticmethod + def _detect_numeric_evidence(query: str, evidence: str) -> bool: + """Detect whether *evidence* contains structured numeric data relevant to *query*. + + Returns True when *query* implies a numeric/financial intent AND + *evidence* contains numeric patterns (currency amounts, percentages, + financial figures). + """ + query_lower = query.lower() + has_intent = any( + kw in query_lower + for kw in AgenticSearch._NUMERIC_INTENT_KEYWORDS + ) + if not has_intent: + return False + has_numeric = bool( + re.search( + r'[\$\u20ac\u00a3]\s?\d' + r'|(? Tuple[bool, str]: + """Multi-factor decision on whether to accept retrieved evidence. + + Combines the LLM's own SHOULD_ANSWER judgment with heuristic + signals (evidence length, keyword coverage, numeric-data presence) + to reduce false-negative rejections of valid evidence. + + Returns: + A tuple of (*accept*, *reason*) where *accept* is the final + boolean decision and *reason* is a human-readable string + documenting which factor(s) determined the outcome. + """ + # Factor 1: LLM direct acceptance + if llm_should_answer: + return True, "llm_accepted" + + # Factor 2: Heuristic override — length + keyword coverage + evidence_len = len(evidence) if evidence else 0 + kw_coverage = ( + AgenticSearch._compute_keyword_coverage(query, evidence) + if evidence else 0.0 + ) + + if ( + evidence_len >= AgenticSearch._EVIDENCE_MIN_ACCEPT_LENGTH + and kw_coverage >= AgenticSearch._EVIDENCE_KEYWORD_COVERAGE_THRESHOLD + ): + return True, ( + f"heuristic_override(len={evidence_len}, " + f"kw_coverage={kw_coverage:.2f})" + ) + + # Factor 3: Numeric evidence detection + if AgenticSearch._detect_numeric_evidence(query, evidence or ""): + return True, ( + f"numeric_evidence(len={evidence_len}, " + f"kw_coverage={kw_coverage:.2f})" + ) + + # All factors negative + return False, ( + f"rejected(llm=false, len={evidence_len}, " + f"kw_coverage={kw_coverage:.2f}, numeric=false)" + ) + @staticmethod def _extract_and_validate_multi_level_keywords( llm_resp: str, @@ -1655,7 +1762,17 @@ async def _search_deep( if cluster and cluster.content: await self._logger.info("[Phase 4] Evidence sufficient, generating summary") answer, should_save, should_answer = await self._summarise_cluster(query, cluster) - if not should_answer: + + # --- Multi-factor evidence acceptance --- + cluster_evidence = str(cluster.content) if cluster and cluster.content else "" + accepted, accept_reason = self._evaluate_evidence_acceptance( + query, cluster_evidence, should_answer, + ) + await self._logger.info( + f"[Phase 4] Evidence acceptance: {accepted} ({accept_reason})" + ) + + if not accepted: if llm_fallback: await self._logger.info( "[Phase 4] Summary gate rejected evidence, llm_fallback=True → LLM fallback" @@ -1703,7 +1820,17 @@ async def _search_deep( # Final DEEP decision is always made in the summary call. answer, should_save, should_answer = await self._summarise_cluster(query, cluster) - if not should_answer: + + # --- Multi-factor evidence acceptance --- + final_cluster_evidence = str(cluster.content) if cluster and cluster.content else "" + final_accepted, final_reason = self._evaluate_evidence_acceptance( + query, final_cluster_evidence, should_answer, + ) + await self._logger.info( + f"[Phase 4] Final evidence acceptance: {final_accepted} ({final_reason})" + ) + + if not final_accepted: if llm_fallback: await self._logger.info( "[Phase 4] Final summary gate rejected evidence, llm_fallback=True → LLM fallback" @@ -2055,6 +2182,25 @@ async def _llm_chunked_summarize(self, combined: str, query: str) -> str: _TREE_ROOT_HINT_TRUNCATE = 150 """Max chars of tree root summary in Step 1 structure hints.""" + # --- Self-correction expanded sampling --- + _SELF_CORRECT_EXPANDED_NAV_RESULTS: int = 6 + """Expanded tree navigation leaf count for same-file re-sampling (default nav uses 3).""" + _SELF_CORRECT_EXPANDED_SECTIONS: int = 5 + """Expanded tree sample sections for same-file re-sampling (default uses 3).""" + + # --- Evidence acceptance thresholds --- + _EVIDENCE_MIN_ACCEPT_LENGTH: int = 1500 + """Minimum evidence character length for heuristic override.""" + _EVIDENCE_KEYWORD_COVERAGE_THRESHOLD: float = 0.6 + """Minimum keyword coverage ratio for heuristic override.""" + _NUMERIC_INTENT_KEYWORDS: frozenset = frozenset({ + "revenue", "margin", "ratio", "ebitda", "income", "profit", "loss", + "cash", "debt", "equity", "eps", "dpo", "growth", "rate", + "percentage", "amount", "total", "net", "gross", "cost", "expense", + "sales", "fy", "fiscal", + }) + """Keywords indicating numeric/financial intent in a query.""" + _LLM_FALLBACK_EVIDENCE = ( "[No relevant documents found]\n\n" "The search did not find relevant content in the available documents. " @@ -2573,12 +2719,20 @@ async def _rga_evidence() -> str: answer_resp.content or "" ) + # --- Multi-factor evidence acceptance (P2+P3+P4) --- + accepted, accept_reason = self._evaluate_evidence_acceptance( + query, evidence, should_answer, + ) + await self._logger.info( + f"[FAST:Step4] Evidence acceptance: {accepted} ({accept_reason})" + ) + # ============================================================== # Step 5: Self-correction retry (conditional, ≤1 extra LLM call) # When the answer gate rejects the first attempt, try alternative # evidence sources before giving up. # ============================================================== - if not should_answer: + if not accepted: retry_evidence = await self._fast_self_correct( query, best_files, catalog_routed_files, context, ) @@ -2598,11 +2752,19 @@ async def _rga_evidence() -> str: context.add_llm_tokens( retry_resp.usage.get("total_tokens", 0), usage=retry_resp.usage, ) - answer, should_save, should_answer = self._parse_summary_response( + answer, should_save, retry_should_answer = self._parse_summary_response( retry_resp.content or "" ) + retry_accepted, retry_reason = self._evaluate_evidence_acceptance( + query, retry_evidence, retry_should_answer, + ) + await self._logger.info( + f"[FAST:Step5] Retry evidence acceptance: {retry_accepted} ({retry_reason})" + ) + if retry_accepted: + accepted = True - if not should_answer: + if not accepted: if llm_fallback: await self._logger.info( "[FAST:Step5] Retry also rejected, llm_fallback=True → LLM fallback" @@ -3637,7 +3799,7 @@ async def _tree_guided_sample( return evidence async def _navigate_tree_for_evidence( - self, file_path: str, query: str, + self, file_path: str, query: str, *, max_results: int = 3, ) -> Optional[str]: """LLM-driven tree navigation: select relevant sections and read leaf content. @@ -3653,7 +3815,7 @@ async def _navigate_tree_for_evidence( return None try: - leaves = await indexer.navigate(tree, query, max_results=3) + leaves = await indexer.navigate(tree, query, max_results=max_results) except Exception: return None @@ -3699,7 +3861,8 @@ async def _fast_self_correct( ) -> Optional[str]: """Attempt to gather alternative evidence when the first answer is rejected. - Three strategies tried in order: + Four strategies tried in order: + D) Re-sample the same primary file with expanded parameters (deeper sampling). A) Tree-navigate a 2nd catalog-routed file not yet tried. B) Retrieve the most semantically similar compiled cluster's content. C) Tree-navigate the 2nd-best rga file if available. @@ -3708,6 +3871,20 @@ async def _fast_self_correct( """ first_file = best_files[0]["path"] if best_files else "" + # Strategy D: Re-sample the SAME primary file with expanded parameters. + # The file was correct but the initial sampling may have missed key sections. + if first_file: + expanded_tree_ev = await self._navigate_tree_for_evidence( + first_file, query, + max_results=self._SELF_CORRECT_EXPANDED_NAV_RESULTS, + ) + if expanded_tree_ev and len(expanded_tree_ev.strip()) > 50: + await self._logger.info( + "[FAST:SelfCorrect] Strategy D succeeded: " + "expanded same-file tree navigation" + ) + return expanded_tree_ev + # Strategy A: 2nd catalog-routed file via tree navigation for fp in catalog_routed_files: if fp == first_file: From 34c181eaf1a01a10bf824908c6f18c1131d7078e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E7=8F=AD=E6=89=AC?= Date: Fri, 17 Apr 2026 20:22:44 +0800 Subject: [PATCH 26/56] fix eval for finbench in runner --- benchmarks/financebench/runner.py | 71 ++++++++++++++++++++++++++++--- src/sirchmunk/search.py | 6 +-- 2 files changed, 67 insertions(+), 10 deletions(-) diff --git a/benchmarks/financebench/runner.py b/benchmarks/financebench/runner.py index b95f7ca..86404f2 100644 --- a/benchmarks/financebench/runner.py +++ b/benchmarks/financebench/runner.py @@ -12,10 +12,11 @@ import asyncio import json as json_mod import logging +import re import time from datetime import datetime from pathlib import Path -from typing import Any, Dict, List +from typing import Any, Dict, List, Optional from config import FinanceBenchConfig from data_loader import FinanceBenchLoader @@ -38,7 +39,12 @@ Given the financial question and a verbose response, extract ONLY the short factoid answer. Rules: - Output ONLY the answer value/phrase (1-20 words). No explanation. -- If the response says it cannot find the answer, output: unknown +- If the response contains ANY concrete data (dollar amounts, percentages, numbers, + company names, yes/no conclusions), extract that data even if the response also + expresses uncertainty or says it could not find a "complete" answer. +- A partial answer with real data is ALWAYS better than "unknown". +- Output "unknown" ONLY when the response contains absolutely no useful factual + information (e.g., a pure apology with zero data points). - For monetary values, keep the currency format (e.g., $1,577.00) - For percentages, keep the % sign (e.g., 15.3%) - For yes/no questions, output: yes or no @@ -48,6 +54,16 @@ Short answer:""" +# Regex pattern for extracting financial numeric data as fallback +_NUMERIC_EXTRACTION_PATTERN = ( + r'\$[\d,]+(?:\.\d+)?\s*(?:million|billion|mn|bn|M|B|K)?' + r'|\d+(?:,\d{3})+(?:\.\d+)?\s*(?:million|billion|mn|bn|M|B|%)?' + r'|\d+(?:\.\d+)?\s*(?:million|billion|mn|bn|M|B|%)' +) + +# Sentinel values indicating extraction found no useful answer +_UNKNOWN_SENTINELS = frozenset({"unknown", "n/a", ""}) + # NOTE: _normalize_prediction removed — use evaluate.normalize_answer instead. @@ -57,12 +73,26 @@ # ------------------------------------------------------------------ -async def _extract_short_answer( +def _extract_numeric_fallback(text: str) -> Optional[str]: + """Extract financial figures from *text* using regex patterns. + + Looks for currency amounts ($xxx), percentages, and large numbers + with units (million, billion, etc.). + + Returns the first match or ``None``. + """ + match = re.search(_NUMERIC_EXTRACTION_PATTERN, text) + if match: + return match.group(0).strip() + return None + + +async def _llm_extract( question: str, verbose: str, llm: Any, -) -> str: - """Use *llm* to distil *verbose* into a short factoid answer.""" +) -> Optional[str]: + """Layer-1: use LLM to distil *verbose* into a short factoid answer.""" prompt = _EXTRACT_PROMPT.format(question=question, response=verbose[:4000]) try: resp = await llm.achat( @@ -71,8 +101,35 @@ async def _extract_short_answer( ) return resp.content.strip() except Exception: - logger.warning("Short-answer extraction failed; falling back to raw answer.") - return verbose + logger.warning("LLM extraction failed; will try regex fallback.") + return None + + +async def _extract_short_answer( + question: str, + verbose: str, + llm: Any, +) -> str: + """Extract a concise answer from verbose LLM analysis. + + Uses a three-layer extraction strategy: + 1. LLM-based extraction with improved prompt + 2. Regex-based numeric/financial data extraction as fallback + 3. Returns 'unknown' only when no useful data is found + """ + # Layer 1: LLM extraction + answer = await _llm_extract(question, verbose, llm) + if answer and answer.strip().lower() not in _UNKNOWN_SENTINELS: + return answer.strip() + + # Layer 2: Regex fallback for numeric/financial data + numeric_answer = _extract_numeric_fallback(verbose) + if numeric_answer: + logger.info("Regex fallback extracted: %s", numeric_answer) + return numeric_answer + + # Layer 3: No useful data found + return "unknown" # ------------------------------------------------------------------ diff --git a/src/sirchmunk/search.py b/src/sirchmunk/search.py index a900978..c2f30f4 100644 --- a/src/sirchmunk/search.py +++ b/src/sirchmunk/search.py @@ -2189,9 +2189,9 @@ async def _llm_chunked_summarize(self, combined: str, query: str) -> str: """Expanded tree sample sections for same-file re-sampling (default uses 3).""" # --- Evidence acceptance thresholds --- - _EVIDENCE_MIN_ACCEPT_LENGTH: int = 1500 + _EVIDENCE_MIN_ACCEPT_LENGTH: int = 800 """Minimum evidence character length for heuristic override.""" - _EVIDENCE_KEYWORD_COVERAGE_THRESHOLD: float = 0.6 + _EVIDENCE_KEYWORD_COVERAGE_THRESHOLD: float = 0.5 """Minimum keyword coverage ratio for heuristic override.""" _NUMERIC_INTENT_KEYWORDS: frozenset = frozenset({ "revenue", "margin", "ratio", "ebitda", "income", "profit", "loss", @@ -3231,7 +3231,7 @@ async def _fast_sample_evidence( # Diagnostic logging when falling back to snippet mode if not hit_lines and match_objects: - await self._logger.warning( + await self._logger.info( f"[FAST] No line_number in {len(match_objects)} match(es) for {fname}, " f"falling back to snippet mode" ) From 2b4714efd270376c7b53a41a131facab858f02ac Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E7=8F=AD=E6=89=AC?= Date: Fri, 17 Apr 2026 21:03:07 +0800 Subject: [PATCH 27/56] refactor metrics as LLM judge for finbench --- benchmarks/financebench/evaluate.py | 264 ++++------------------- benchmarks/financebench/judge.py | 195 +++++++++++++++-- benchmarks/financebench/run_benchmark.py | 67 ++---- benchmarks/financebench/runner.py | 240 +++++---------------- 4 files changed, 300 insertions(+), 466 deletions(-) diff --git a/benchmarks/financebench/evaluate.py b/benchmarks/financebench/evaluate.py index e22bf07..d9614f3 100644 --- a/benchmarks/financebench/evaluate.py +++ b/benchmarks/financebench/evaluate.py @@ -1,63 +1,21 @@ -"""FinanceBench evaluation metrics. +"""FinanceBench evaluation metrics — LLM Judge driven. -Implements the three-class scoring scheme from the FinanceBench paper -(Islam et al., 2023): **correct**, **hallucination**, **refusal**. +All correctness evaluation (Accuracy, Coverage) is driven by the LLM Judge. +This module aggregates per-question judge results into benchmark-level metrics. -Financial-value normalisation handles currency symbols, thousand separators, -trailing zeros, and percentage signs so that ``$1,577.00`` matches ``1577``. +The ``normalize_answer`` helper is retained for quick short-circuit checks +inside the judge (exact-match bypass before calling the LLM). """ from __future__ import annotations import re -from collections import Counter, defaultdict +from collections import defaultdict from typing import Any, Dict, List # ------------------------------------------------------------------ # Constants # ------------------------------------------------------------------ -_REFUSAL_PHRASES: list[str] = [ - "i cannot", - "i can't", - "i could not", - "i couldn't", - "no results found", - "unable to", - "not able to", - "i don't know", - "i do not know", - "information is not available", - "not enough information", - "cannot determine", - "cannot be determined", - "insufficient data", - "no relevant information", - "data not found", - "unknown", - "i'm not able to", - "i am not able to", - "the document does not contain", - "the document doesn't contain", - "this information is not disclosed", - "not disclosed", - "could not find", - "couldn't find", - "no mention of", - "no information about", - "not provided in", - "not found in the document", - "i was unable to", - "unable to determine", - "unable to find", - "unable to locate", - "there is no data", - "no data available", - "not available in", - "not specified", -] - -_F1_CORRECT_THRESHOLD: float = 0.8 - # Markdown / wrapper patterns compiled once _RE_BOLD = re.compile(r"\*\*(.+?)\*\*") _RE_ITALIC = re.compile(r"\*(.+?)\*") @@ -169,109 +127,6 @@ def _normalize_financial_value(text: str) -> str: return s -# ------------------------------------------------------------------ -# Matching helpers -# ------------------------------------------------------------------ - - -def exact_match(prediction: str, gold: str) -> bool: - """Return ``True`` when normalised strings are identical.""" - return normalize_answer(prediction) == normalize_answer(gold) - - -def f1_score(prediction: str, gold: str) -> float: - """Compute token-level F1 between *prediction* and *gold*. - - Tokenisation is simple whitespace splitting after normalisation. - Each token is further normalised as a financial value so that - ``$1577`` matches ``1577`` at the token level. - Returns 0.0 when either side is empty. - """ - pred_tokens = [_normalize_financial_value(t) for t in normalize_answer(prediction).split()] - gold_tokens = [_normalize_financial_value(t) for t in normalize_answer(gold).split()] - if not pred_tokens or not gold_tokens: - return 0.0 - - common = Counter(pred_tokens) & Counter(gold_tokens) - num_common = sum(common.values()) - if num_common == 0: - return 0.0 - - precision = num_common / len(pred_tokens) - recall = num_common / len(gold_tokens) - return 2 * precision * recall / (precision + recall) - - -# ------------------------------------------------------------------ -# Three-class classification -# ------------------------------------------------------------------ - - -def classify_answer( - prediction: str, - gold: str, - *, - is_no_result: bool = False, - f1_threshold: float = _F1_CORRECT_THRESHOLD, -) -> str: - """Classify a prediction into ``correct``, ``refusal``, or ``hallucination``. - - Classification logic (faithful to FinanceBench paper): - 1. If the system explicitly refused (``is_no_result=True``) or the - prediction contains a refusal phrase → **refusal**. - 2. If EM passes or token-level F1 ≥ *f1_threshold* → **correct**. - 3. Otherwise → **hallucination**. - """ - norm_pred = normalize_answer(prediction) - - # --- Refusal --- - if is_no_result: - return "refusal" - pred_lower = norm_pred.lower() - for phrase in _REFUSAL_PHRASES: - if phrase in pred_lower: - return "refusal" - - # --- Correct --- - if exact_match(prediction, gold): - return "correct" - if f1_score(prediction, gold) >= f1_threshold: - return "correct" - - # --- Hallucination --- - return "hallucination" - - -# ------------------------------------------------------------------ -# Evidence recall -# ------------------------------------------------------------------ - - -def evidence_recall( - retrieved_pages: List[int], - gold_evidence: List[Dict[str, Any]], -) -> float: - """Compute page-level evidence recall. - - ``gold_evidence`` entries carry ``evidence_page_num`` (0-indexed). - Returns 1.0 when there is no gold evidence (vacuously true). - """ - if not gold_evidence: - return 1.0 - - gold_pages = { - int(e["evidence_page_num"]) - for e in gold_evidence - if "evidence_page_num" in e - } - if not gold_pages: - return 1.0 - - retrieved_set = set(retrieved_pages) - hits = gold_pages & retrieved_set - return len(hits) / len(gold_pages) - - # ------------------------------------------------------------------ # Aggregate metrics # ------------------------------------------------------------------ @@ -280,100 +135,75 @@ def evidence_recall( def compute_metrics(results: List[Dict[str, Any]]) -> Dict[str, Any]: """Aggregate per-question results into benchmark-level metrics. - Expected keys per result dict: ``classification``, ``em``, ``f1``, - ``elapsed``, ``telemetry``, ``question_type``, ``question_reasoning``, - ``evidence_recall`` (optional). + All correctness evaluation is driven by LLM Judge results stored in + each result dict (``judge_correct``, ``coverage``). - Returns a dict with overall stats plus breakdowns by *question_type* - and *question_reasoning*. + Returns a dict with overall stats plus breakdown by *question_type*. """ n = len(results) if n == 0: return {"n": 0} - # --- Overall counts --- - correct = sum(1 for r in results if r.get("classification") == "correct") - halluc = sum(1 for r in results if r.get("classification") == "hallucination") - refusal = sum(1 for r in results if r.get("classification") == "refusal") + # --- Accuracy (Judge) --- + judge_correct = sum(1 for r in results if r.get("judge_correct")) - em_sum = sum(1 for r in results if r.get("em")) - f1_sum = sum(r.get("f1", 0.0) for r in results) + # --- Coverage (Judge) --- + coverage_true = sum(1 for r in results if r.get("coverage")) + # --- Latency --- latencies = [r["elapsed"] for r in results if "elapsed" in r] avg_latency = sum(latencies) / len(latencies) if latencies else 0.0 + total_time = sum(latencies) - token_counts = [ + # --- Token usage --- + search_tokens = sum( r.get("telemetry", {}).get("total_tokens", 0) for r in results - ] - avg_tokens = sum(token_counts) / len(token_counts) if token_counts else 0 - - ev_recalls = [r["evidence_recall"] for r in results if r.get("evidence_recall") is not None] - avg_ev_recall = sum(ev_recalls) / len(ev_recalls) if ev_recalls else None + ) + judge_tokens = sum(r.get("judge_tokens", 0) for r in results) + total_tokens = search_tokens + judge_tokens + avg_tokens_per_question = total_tokens / n if n else 0 - overall = { + overall: Dict[str, Any] = { "n": n, - "accuracy": round(correct / n * 100, 2), - "hallucination_rate": round(halluc / n * 100, 2), - "refusal_rate": round(refusal / n * 100, 2), - "correct": correct, - "hallucination": halluc, - "refusal": refusal, - "avg_em": em_sum / n, - "avg_f1": f1_sum / n, + "accuracy": round(judge_correct / n * 100, 2), + "coverage": round(coverage_true / n * 100, 2), "avg_latency": round(avg_latency, 2), - "avg_tokens": round(avg_tokens, 1), + "total_time_seconds": round(total_time, 2), + "token_usage": { + "total_tokens": total_tokens, + "search_tokens": search_tokens, + "judge_tokens": judge_tokens, + "avg_tokens_per_question": round(avg_tokens_per_question, 1), + }, + "judge_correct": judge_correct, + "coverage_true": coverage_true, + "by_question_type": _breakdown(results, "question_type"), } - if avg_ev_recall is not None: - overall["evidence_recall"] = round(avg_ev_recall, 4) - - # --- LLM Judge metrics (independent dimension, NOT fallback) --- - judge_results = [r for r in results if r.get("llm_judge_correct") is not None] - if judge_results: - judge_correct = sum(1 for r in judge_results if r["llm_judge_correct"]) - overall["llm_judge_accuracy"] = round(judge_correct / len(judge_results) * 100, 2) - overall["llm_judge_count"] = len(judge_results) - overall["llm_judge_correct"] = judge_correct - else: - overall["llm_judge_accuracy"] = None - overall["llm_judge_count"] = 0 - overall["llm_judge_correct"] = 0 - - # --- Breakdowns --- - overall["by_question_type"] = _breakdown(results, "question_type") - overall["by_question_reasoning"] = _breakdown(results, "question_reasoning") return overall -def _breakdown(results: List[Dict[str, Any]], key: str) -> Dict[str, Dict[str, Any]]: - """Compute per-group accuracy / hallucination / refusal breakdown.""" +def _breakdown( + results: List[Dict[str, Any]], key: str +) -> Dict[str, Dict[str, Any]]: + """Compute per-group accuracy / coverage breakdown.""" groups: dict[str, list[dict]] = defaultdict(list) for r in results: group = r.get(key) or "unknown" groups[group].append(r) out: dict[str, dict] = {} - for group, items in sorted(groups.items(), key=lambda x: (x[0] is None, x[0] or "")): + for group, items in sorted( + groups.items(), key=lambda x: (x[0] is None, x[0] or "") + ): g_n = len(items) - g_correct = sum(1 for r in items if r.get("classification") == "correct") - g_halluc = sum( - 1 for r in items if r.get("classification") == "hallucination" - ) - g_refusal = sum(1 for r in items if r.get("classification") == "refusal") - group_dict: dict[str, Any] = { + g_correct = sum(1 for r in items if r.get("judge_correct")) + g_coverage = sum(1 for r in items if r.get("coverage")) + out[group] = { "n": g_n, "accuracy": round(g_correct / g_n * 100, 2) if g_n else 0.0, - "hallucination_rate": round(g_halluc / g_n * 100, 2) if g_n else 0.0, - "refusal_rate": round(g_refusal / g_n * 100, 2) if g_n else 0.0, - "correct": g_correct, - "hallucination": g_halluc, - "refusal": g_refusal, + "coverage": round(g_coverage / g_n * 100, 2) if g_n else 0.0, + "judge_count": g_n, + "judge_correct": g_correct, } - # LLM Judge breakdown - g_judge = [r for r in items if r.get("llm_judge_correct") is not None] - if g_judge: - g_jc = sum(1 for r in g_judge if r["llm_judge_correct"]) - group_dict["llm_judge_accuracy"] = round(g_jc / len(g_judge) * 100, 2) - group_dict["llm_judge_count"] = len(g_judge) - out[group] = group_dict return out diff --git a/benchmarks/financebench/judge.py b/benchmarks/financebench/judge.py index e52b6e6..8140669 100644 --- a/benchmarks/financebench/judge.py +++ b/benchmarks/financebench/judge.py @@ -1,12 +1,11 @@ -"""LLM-based semantic equivalence judge for FinanceBench. +"""LLM-based judge for FinanceBench evaluation. -The judge evaluates whether a model's prediction is semantically -equivalent to the gold answer, operating as an **independent** -evaluation dimension alongside EM/F1 — not as a fallback. +The judge drives **all** evaluation decisions: +- **Accuracy**: whether the prediction is semantically equivalent to the gold answer. +- **Coverage**: whether the prediction contains any information relevant to the question. -This provides a more nuanced correctness signal for financial QA, -where formatting differences (e.g., $1.5B vs $1,500M) can cause -EM/F1 to undercount correct answers. +This replaces the previous EM/F1 rule-driven pipeline with a single LLM-based +evaluation authority, providing more nuanced correctness signals for financial QA. """ from __future__ import annotations @@ -155,19 +154,50 @@ class FinanceBenchLLMJudge: - """LLM-based judge for semantic equivalence in financial QA. + """LLM-based judge driving all FinanceBench evaluation. - Operates as an independent evaluation dimension — NOT as a - fallback for EM/F1. Each question gets a separate judge verdict - that is tracked in its own metrics. + Provides two evaluation axes: + - ``judge()``: semantic equivalence (Accuracy). + - ``judge_coverage()``: information relevance (Coverage). + + Token usage from every LLM call is tracked and returned. """ _CONFIDENCE_THRESHOLD: float = 0.7 _MAX_RETRIES: int = 2 + # Coverage evaluation prompt + _COVERAGE_PROMPT: str = """\ +You are evaluating whether a system's response contains ANY useful information \ +relevant to the given financial question. + +Question: {question} +System Response: {prediction} + +Task: Determine if the response contains relevant, useful information. + +═══════════════════════════════════════════════ +HAS COVERAGE (has_coverage = true) — when ANY of: +═══════════════════════════════════════════════ +1. Contains specific financial data (dollar amounts, percentages, ratios) +2. Contains relevant factual statements about the company or topic +3. Contains partial but concrete information related to the question +4. Provides a direct answer (even if potentially incorrect) + +═══════════════════════════════════════════════ +NO COVERAGE (has_coverage = false) — when ALL of: +═══════════════════════════════════════════════ +1. Response is a refusal ("I cannot", "No results found", etc.) +2. Response contains no concrete data related to the question +3. Response is empty, purely apologetic, or only contains generic filler + +Respond ONLY with a JSON object (no markdown, no extra text): +{{"has_coverage": true or false, "confidence": 0.0 to 1.0, "reasoning": "brief explanation"}}""" + def __init__(self, llm: Any) -> None: self._llm = llm self._cache: Dict[tuple, Dict[str, Any]] = {} + self._total_tokens_used: int = 0 # ------------------------------------------------------------------ # Public API @@ -192,7 +222,8 @@ async def judge( "confidence": float (0-1), "reasoning": str, "cached": bool, - "error": Optional[str] + "error": Optional[str], + "tokens_used": int, } """ # --- Refusal short-circuit (saves LLM call) --- @@ -203,6 +234,7 @@ async def judge( "reasoning": "Prediction is a refusal — skipped LLM judge.", "cached": False, "error": None, + "tokens_used": 0, } # --- Quick exact-match shortcut --- @@ -215,6 +247,7 @@ async def judge( "reasoning": "Normalized exact match", "cached": False, "error": None, + "tokens_used": 0, } # --- Check cache (key includes question for context-sensitivity) --- @@ -237,6 +270,7 @@ async def judge( result: Dict[str, Any] | None = None last_error: str | None = None + tokens_used: int = 0 for attempt in range(1, self._MAX_RETRIES + 1): try: @@ -244,6 +278,7 @@ async def judge( messages=[{"role": "user", "content": prompt}], stream=False, ) + tokens_used = self._extract_tokens(resp) raw = resp.content.strip() result = self._parse_response(raw) if result.get("error") is None: @@ -283,6 +318,8 @@ async def judge( result.setdefault("cached", False) result.setdefault("error", None) + result["tokens_used"] = tokens_used + self._total_tokens_used += tokens_used # Cache successful results only if result["error"] is None: @@ -414,6 +451,140 @@ def _is_refusal(text: str) -> bool: return True return False + async def judge_coverage( + self, + prediction: str, + question: str, + ) -> Dict[str, Any]: + """Evaluate whether *prediction* contains relevant information for *question*. + + Returns: + { + "has_coverage": bool, + "confidence": float (0-1), + "reasoning": str, + "tokens_used": int, + "error": Optional[str], + } + """ + # --- Refusal short-circuit --- + if self._is_refusal(prediction): + return { + "has_coverage": False, + "confidence": 1.0, + "reasoning": "Explicit refusal detected.", + "tokens_used": 0, + "error": None, + } + + prompt = self._COVERAGE_PROMPT.format( + question=question or "N/A", + prediction=prediction[:4000], + ) + + result: Dict[str, Any] | None = None + last_error: str | None = None + tokens_used: int = 0 + + for attempt in range(1, self._MAX_RETRIES + 1): + try: + resp = await self._llm.achat( + messages=[{"role": "user", "content": prompt}], + stream=False, + ) + tokens_used = self._extract_tokens(resp) + raw = resp.content.strip() + result = self._parse_coverage_response(raw) + if result.get("error") is None: + break + last_error = result.get("error") + except Exception as e: + last_error = str(e) + logger.warning( + "LLM Coverage judge failed (attempt %d/%d): %s", + attempt, + self._MAX_RETRIES, + e, + ) + result = None + + if result is None or result.get("error") is not None: + result = { + "has_coverage": False, + "confidence": 0.0, + "reasoning": f"Coverage judge error after {self._MAX_RETRIES} attempts: {last_error}", + "error": last_error, + } + + result.setdefault("error", None) + result["tokens_used"] = tokens_used + self._total_tokens_used += tokens_used + return result + + # ------------------------------------------------------------------ + # Coverage response parsing + # ------------------------------------------------------------------ + + def _parse_coverage_response(self, raw: str) -> Dict[str, Any]: + """Parse LLM JSON response for coverage evaluation.""" + parsed = self._try_parse_json(raw) + if parsed is not None: + has_coverage = bool(parsed.get("has_coverage", False)) + try: + confidence = float(parsed.get("confidence", 0.0)) + except (ValueError, TypeError): + confidence = 0.0 + confidence = max(0.0, min(1.0, confidence)) + reasoning = str(parsed.get("reasoning", "")) + return { + "has_coverage": has_coverage, + "confidence": confidence, + "reasoning": reasoning, + } + + # Fallback: keyword detection + lower = raw.lower() + true_match = re.search(r'"has_coverage"\s*:\s*true\b', lower) + false_match = re.search(r'"has_coverage"\s*:\s*false\b', lower) + + if false_match and not true_match: + return { + "has_coverage": False, + "confidence": 0.5, + "reasoning": f"Keyword fallback (no coverage): {raw[:200]}", + } + elif true_match and not false_match: + return { + "has_coverage": True, + "confidence": 0.5, + "reasoning": f"Keyword fallback (has coverage): {raw[:200]}", + } + + logger.warning("Cannot parse coverage response: %s", raw[:200]) + return { + "has_coverage": False, + "confidence": 0.0, + "reasoning": f"Unparseable response: {raw[:200]}", + "error": "parse_error", + } + + # ------------------------------------------------------------------ + # Token tracking + # ------------------------------------------------------------------ + + @staticmethod + def _extract_tokens(resp: Any) -> int: + """Extract total token count from an LLM response.""" + usage = getattr(resp, "usage", None) + if isinstance(usage, dict): + return int(usage.get("total_tokens", 0)) + return 0 + + @property + def total_tokens_used(self) -> int: + """Cumulative tokens consumed by all judge calls.""" + return self._total_tokens_used + @property def cache_size(self) -> int: """Return the number of cached judge results.""" diff --git a/benchmarks/financebench/run_benchmark.py b/benchmarks/financebench/run_benchmark.py index c9f5b26..cf7b30a 100644 --- a/benchmarks/financebench/run_benchmark.py +++ b/benchmarks/financebench/run_benchmark.py @@ -94,33 +94,28 @@ def _print_summary( """Print a human-readable run summary to stdout.""" n = len(results) acc = metrics.get("accuracy", 0) - hallu = metrics.get("hallucination_rate", 0) - refuse = metrics.get("refusal_rate", 0) - avg_em = metrics.get("avg_em", 0) - avg_f1 = metrics.get("avg_f1", 0) - ev_recall = metrics.get("evidence_recall") + cov = metrics.get("coverage", 0) avg_latency = metrics.get("avg_latency", 0) + token_usage = metrics.get("token_usage", {}) + total_tokens = token_usage.get("total_tokens", 0) + search_tokens = token_usage.get("search_tokens", 0) + judge_tokens = token_usage.get("judge_tokens", 0) + avg_tokens_q = token_usage.get("avg_tokens_per_question", 0) + print("\n" + "=" * 60) print(f"FinanceBench Results ({n} questions)") print("=" * 60) - print(f" Accuracy: {acc:.1f}%") - print(f" Hallucination Rate: {hallu:.1f}%") - print(f" Refusal Rate: {refuse:.1f}%") - print(f" Avg EM: {avg_em:.3f}") - print(f" Avg F1: {avg_f1:.3f}") - if ev_recall is not None: - print(f" Evidence Recall: {ev_recall:.3f}") - else: - print(f" Evidence Recall: N/A (page-level telemetry unavailable)") + print(f" Accuracy (Judge): {acc:.1f}%") + print(f" Coverage (Judge): {cov:.1f}%") print(f" Avg Latency: {avg_latency:.1f}s") print(f" Total Time: {total_time:.1f}s") - # LLM Judge independent metrics - if metrics.get("llm_judge_accuracy") is not None: - print(f"\n --- LLM Judge (Independent) ---") - print(f" Judge Accuracy: {metrics['llm_judge_accuracy']:.1f}%") - print(f" Judge Correct: {metrics['llm_judge_correct']}/{metrics['llm_judge_count']}") + print(f"\n --- Token Usage ---") + print(f" Total Tokens: {total_tokens:>,}") + print(f" Search Tokens: {search_tokens:>,}") + print(f" Judge Tokens: {judge_tokens:>,}") + print(f" Avg per Question: {avg_tokens_q:>,.0f}") print(f"\n Results: {results_path}") print(f" Metrics: {metrics_path}") @@ -129,28 +124,13 @@ def _print_summary( # Breakdown by question_type by_qt = metrics.get("by_question_type") if by_qt: - # Determine if judge data is available - has_judge = any(m.get("llm_judge_accuracy") is not None for m in by_qt.values()) - if has_judge: - print(f"\n {'Question Type':<25} {'Acc%':>6} {'Hallu%':>7} {'Refuse%':>8} {'Judge%':>7} {'N':>4}") - print(" " + "-" * 59) - for qt, m in sorted(by_qt.items()): - qt_acc = m.get("accuracy", 0) - qt_hal = m.get("hallucination_rate", 0) - qt_ref = m.get("refusal_rate", 0) - qt_n = m.get("n", 0) - qt_judge = m.get("llm_judge_accuracy") - qt_judge_str = f"{qt_judge:>6.1f}" if qt_judge is not None else " N/A" - print(f" {qt:<25} {qt_acc:>5.1f} {qt_hal:>7.1f} {qt_ref:>7.1f} {qt_judge_str} {qt_n:>4}") - else: - print(f"\n {'Question Type':<25} {'Acc%':>6} {'Hallu%':>7} {'Refuse%':>8} {'N':>4}") - print(" " + "-" * 52) - for qt, m in sorted(by_qt.items()): - qt_acc = m.get("accuracy", 0) - qt_hal = m.get("hallucination_rate", 0) - qt_ref = m.get("refusal_rate", 0) - qt_n = m.get("n", 0) - print(f" {qt:<25} {qt_acc:>5.1f} {qt_hal:>7.1f} {qt_ref:>7.1f} {qt_n:>4}") + print(f"\n {'Question Type':<28} {'Acc%':>6} {'Cover%':>7} {'N':>5}") + print(" " + "-" * 48) + for qt, m in sorted(by_qt.items()): + qt_acc = m.get("accuracy", 0) + qt_cov = m.get("coverage", 0) + qt_n = m.get("n", 0) + print(f" {qt:<28} {qt_acc:>5.1f} {qt_cov:>7.1f} {qt_n:>5}") print("=" * 60) @@ -222,11 +202,9 @@ def main() -> None: # 6. Print run config logger.info( - "Config: mode=%s, eval_mode=%s, extract_answer=%s, " - "llm_judge=%s, concurrent=%d, model=%s", + "Config: mode=%s, eval_mode=%s, llm_judge=%s, concurrent=%d, model=%s", cfg.mode, cfg.eval_mode, - cfg.extract_answer, cfg.enable_llm_judge, cfg.max_concurrent, cfg.llm_model, @@ -246,7 +224,6 @@ def main() -> None: "eval_mode": cfg.eval_mode, "model": cfg.llm_model, "top_k_files": cfg.top_k_files, - "extract_answer": cfg.extract_answer, } # 9. Save results (JSONL) + metrics (JSON) diff --git a/benchmarks/financebench/runner.py b/benchmarks/financebench/runner.py index 86404f2..7e2f115 100644 --- a/benchmarks/financebench/runner.py +++ b/benchmarks/financebench/runner.py @@ -4,157 +4,24 @@ - **singleDoc**: each question searches only its target PDF directory. - **sharedCorpus**: all questions search the full PDF corpus. -After search, an optional LLM extraction step converts the verbose -briefing into a short factoid answer suitable for EM/F1. +All evaluation (Accuracy + Coverage) is driven by LLM Judge. """ from __future__ import annotations import asyncio import json as json_mod import logging -import re import time from datetime import datetime from pathlib import Path -from typing import Any, Dict, List, Optional +from typing import Any, Dict, List from config import FinanceBenchConfig from data_loader import FinanceBenchLoader -from evaluate import ( - classify_answer, - compute_metrics, - exact_match, - evidence_recall, - f1_score, - normalize_answer, -) +from evaluate import compute_metrics logger = logging.getLogger("financebench.runner") -# ------------------------------------------------------------------ -# Answer extraction prompt (financial domain) -# ------------------------------------------------------------------ - -_EXTRACT_PROMPT = """\ -Given the financial question and a verbose response, extract ONLY the short factoid answer. -Rules: -- Output ONLY the answer value/phrase (1-20 words). No explanation. -- If the response contains ANY concrete data (dollar amounts, percentages, numbers, - company names, yes/no conclusions), extract that data even if the response also - expresses uncertainty or says it could not find a "complete" answer. -- A partial answer with real data is ALWAYS better than "unknown". -- Output "unknown" ONLY when the response contains absolutely no useful factual - information (e.g., a pure apology with zero data points). -- For monetary values, keep the currency format (e.g., $1,577.00) -- For percentages, keep the % sign (e.g., 15.3%) -- For yes/no questions, output: yes or no - -Question: {question} -Response: {response} - -Short answer:""" - -# Regex pattern for extracting financial numeric data as fallback -_NUMERIC_EXTRACTION_PATTERN = ( - r'\$[\d,]+(?:\.\d+)?\s*(?:million|billion|mn|bn|M|B|K)?' - r'|\d+(?:,\d{3})+(?:\.\d+)?\s*(?:million|billion|mn|bn|M|B|%)?' - r'|\d+(?:\.\d+)?\s*(?:million|billion|mn|bn|M|B|%)' -) - -# Sentinel values indicating extraction found no useful answer -_UNKNOWN_SENTINELS = frozenset({"unknown", "n/a", ""}) - - -# NOTE: _normalize_prediction removed — use evaluate.normalize_answer instead. - - -# ------------------------------------------------------------------ -# LLM short-answer extraction -# ------------------------------------------------------------------ - - -def _extract_numeric_fallback(text: str) -> Optional[str]: - """Extract financial figures from *text* using regex patterns. - - Looks for currency amounts ($xxx), percentages, and large numbers - with units (million, billion, etc.). - - Returns the first match or ``None``. - """ - match = re.search(_NUMERIC_EXTRACTION_PATTERN, text) - if match: - return match.group(0).strip() - return None - - -async def _llm_extract( - question: str, - verbose: str, - llm: Any, -) -> Optional[str]: - """Layer-1: use LLM to distil *verbose* into a short factoid answer.""" - prompt = _EXTRACT_PROMPT.format(question=question, response=verbose[:4000]) - try: - resp = await llm.achat( - messages=[{"role": "user", "content": prompt}], - stream=False, - ) - return resp.content.strip() - except Exception: - logger.warning("LLM extraction failed; will try regex fallback.") - return None - - -async def _extract_short_answer( - question: str, - verbose: str, - llm: Any, -) -> str: - """Extract a concise answer from verbose LLM analysis. - - Uses a three-layer extraction strategy: - 1. LLM-based extraction with improved prompt - 2. Regex-based numeric/financial data extraction as fallback - 3. Returns 'unknown' only when no useful data is found - """ - # Layer 1: LLM extraction - answer = await _llm_extract(question, verbose, llm) - if answer and answer.strip().lower() not in _UNKNOWN_SENTINELS: - return answer.strip() - - # Layer 2: Regex fallback for numeric/financial data - numeric_answer = _extract_numeric_fallback(verbose) - if numeric_answer: - logger.info("Regex fallback extracted: %s", numeric_answer) - return numeric_answer - - # Layer 3: No useful data found - return "unknown" - - -# ------------------------------------------------------------------ -# Page extraction helper -# ------------------------------------------------------------------ - - -def _try_extract_pages(telemetry: Dict[str, Any]) -> List[int]: - """Best-effort extraction of retrieved page numbers from telemetry. - - Current limitation: Sirchmunk's ``read_file_ids`` contains plain file - paths without page-level suffixes, so this function will typically - return an empty list. When empty, callers should treat evidence - recall as *unavailable* (``None``) rather than zero. - """ - pages: list[int] = [] - for fid in telemetry.get("read_file_ids", []): - # Convention: page indices may be embedded in file IDs - if isinstance(fid, str) and "_page_" in fid: - try: - pages.append(int(fid.rsplit("_page_", 1)[-1])) - except (ValueError, IndexError): - pass - return pages - # ------------------------------------------------------------------ # Single question execution @@ -174,24 +41,24 @@ async def run_single( fb_id = entry.get("financebench_id", "") question = entry["question"] gold = entry.get("answer", "") - gold_evidence = entry.get("evidence", []) async with semaphore: t0 = time.time() error: str | None = None raw_answer = "" - answer = "" telemetry: dict[str, Any] = {} - retrieved_pages: list[int] = [] try: # Determine search paths based on eval mode if cfg.eval_mode == "singleDoc": pdf_path = loader.get_pdf_path(entry.get("doc_name", "")) if pdf_path: - search_paths = [pdf_path] # pass the single PDF file directly + search_paths = [pdf_path] else: - logger.warning("PDF not found for %s, falling back to full corpus", entry.get("doc_name", "")) + logger.warning( + "PDF not found for %s, falling back to full corpus", + entry.get("doc_name", ""), + ) search_paths = [cfg.pdf_dir] else: search_paths = [cfg.pdf_dir] @@ -217,14 +84,6 @@ async def run_single( "llm_calls": len(getattr(result, "llm_usages", None) or []), "num_files_read": len(read_files), } - retrieved_pages = _try_extract_pages(telemetry) - - # Answer extraction - if cfg.extract_answer and raw_answer: - answer = await _extract_short_answer(question, raw_answer, llm) - answer = normalize_answer(answer) - else: - answer = normalize_answer(raw_answer) except Exception as exc: error = str(exc) @@ -236,39 +95,42 @@ async def run_single( if cfg.request_delay > 0: await asyncio.sleep(cfg.request_delay) - # --- Evaluation --- - is_no_result = not answer or answer.lower() in ("unknown", "") - em = exact_match(answer, gold) - f1 = f1_score(answer, gold) - classification = classify_answer(answer, gold, is_no_result=is_no_result) - if retrieved_pages: # only compute when page-level data is available - ev_recall = evidence_recall(retrieved_pages, gold_evidence) - else: - ev_recall = None # mark as unavailable, avoid false 0 - - # LLM Judge — independent evaluation dimension - # Skip judge for refusals (no point calling LLM on non-answers) - llm_judge_correct = None - llm_judge_reasoning = None - if judge is not None and classification != "refusal": + # --- LLM Judge evaluation (Accuracy + Coverage) --- + judge_correct = False + judge_reasoning = "" + judge_tokens = 0 + has_coverage = False + coverage_reasoning = "" + + if judge is not None: + # Accuracy evaluation try: judge_result = await judge.judge( - prediction=answer, + prediction=raw_answer, gold_answer=gold, question=question, ) - llm_judge_correct = judge_result.get("equivalent", False) - llm_judge_reasoning = judge_result.get("reasoning", "") + judge_correct = judge_result.get("equivalent", False) + judge_reasoning = judge_result.get("reasoning", "") + judge_tokens += judge_result.get("tokens_used", 0) + except Exception as e: + logger.warning("LLM Judge (accuracy) failed for %s: %s", fb_id, e) + + # Coverage evaluation + try: + coverage_result = await judge.judge_coverage( + prediction=raw_answer, + question=question, + ) + has_coverage = coverage_result.get("has_coverage", False) + coverage_reasoning = coverage_result.get("reasoning", "") + judge_tokens += coverage_result.get("tokens_used", 0) except Exception as e: - logger.warning("LLM Judge failed for %s: %s", fb_id, e) - elif judge is not None and classification == "refusal": - llm_judge_correct = False - llm_judge_reasoning = "Skipped: prediction classified as refusal" + logger.warning("LLM Judge (coverage) failed for %s: %s", fb_id, e) return { "financebench_id": fb_id, "question": question, - "prediction": answer, "raw_prediction": raw_answer, "gold_answer": gold, "company": entry.get("company", ""), @@ -277,12 +139,11 @@ async def run_single( "question_reasoning": entry.get("question_reasoning", ""), "elapsed": round(elapsed, 2), "telemetry": telemetry, - "classification": classification, - "em": em, - "f1": round(f1, 4), - "evidence_recall": round(ev_recall, 4) if ev_recall is not None else None, - "llm_judge_correct": llm_judge_correct, # None if judge disabled - "llm_judge_reasoning": llm_judge_reasoning, + "judge_correct": judge_correct, + "judge_reasoning": judge_reasoning, + "coverage": has_coverage, + "coverage_reasoning": coverage_reasoning, + "judge_tokens": judge_tokens, "error": error, } @@ -310,12 +171,12 @@ async def run_batch( loader = FinanceBenchLoader(data_dir=cfg.data_dir, pdf_dir=cfg.pdf_dir) semaphore = asyncio.Semaphore(cfg.max_concurrent) - # Initialise LLM Judge (uses the same test model) + # Initialise LLM Judge judge = None if cfg.enable_llm_judge: from judge import FinanceBenchLLMJudge judge = FinanceBenchLLMJudge(llm=llm) - logger.info("LLM Judge enabled (independent evaluation dimension)") + logger.info("LLM Judge enabled (drives Accuracy + Coverage)") # Prepare output directory / file out_dir = Path(cfg.output_dir) @@ -334,20 +195,16 @@ async def _run_and_record(entry: Dict[str, Any]) -> Dict[str, Any]: with open(out_path, "a", encoding="utf-8") as fp: fp.write(json_mod.dumps(res, ensure_ascii=False) + "\n") completed += 1 - status = res["classification"] - judge_tag = "" - if res.get("llm_judge_correct") is not None: - judge_tag = " [judge:\u2713]" if res["llm_judge_correct"] else " [judge:\u2717]" + acc_tag = "\u2713" if res["judge_correct"] else "\u2717" + cov_tag = "cov" if res["coverage"] else "no-cov" logger.info( - "[%d/%d] %s %s EM=%s F1=%.2f %.1fs%s", + "[%d/%d] %s [acc:%s] [%s] %.1fs", completed, total, res["financebench_id"], - status, - res["em"], - res["f1"], + acc_tag, + cov_tag, res["elapsed"], - judge_tag, ) return res @@ -361,10 +218,9 @@ async def _run_and_record(entry: Dict[str, Any]) -> Dict[str, Any]: json_mod.dump(metrics, fp, indent=2, ensure_ascii=False) logger.info("Metrics saved to %s", metrics_path) logger.info( - "Accuracy=%.2f%% Hallucination=%.2f%% Refusal=%.2f%%", + "Accuracy=%.2f%% Coverage=%.2f%%", metrics.get("accuracy", 0), - metrics.get("hallucination_rate", 0), - metrics.get("refusal_rate", 0), + metrics.get("coverage", 0), ) return list(results) From a184e862d5bab63ce1eb1ae7781c93d733703377 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E7=8F=AD=E6=89=AC?= Date: Fri, 17 Apr 2026 21:08:02 +0800 Subject: [PATCH 28/56] update config --- benchmarks/financebench/config.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/benchmarks/financebench/config.py b/benchmarks/financebench/config.py index 5c390ce..68fe2a1 100644 --- a/benchmarks/financebench/config.py +++ b/benchmarks/financebench/config.py @@ -54,9 +54,7 @@ class FinanceBenchConfig: # Evaluation eval_mode: str = "singleDoc" # singleDoc / sharedCorpus - enable_llm_judge: bool = True # Use LLM to judge semantic equivalence (independent metric) - extract_answer: bool = True - judge_f1_threshold: float = 0.8 # F1 threshold for 'correct' classification + enable_llm_judge: bool = True # LLM Judge drives Accuracy + Coverage evaluation # Concurrency max_concurrent: int = 3 @@ -126,7 +124,6 @@ def _float(key: str, default: float = 0.0) -> float: enable_dir_scan=_bool("FB_ENABLE_DIR_SCAN", True), eval_mode=_get("FB_EVAL_MODE", "singleDoc"), enable_llm_judge=_bool("FB_ENABLE_LLM_JUDGE", True), - extract_answer=_bool("FB_EXTRACT_ANSWER", True), max_concurrent=_int("FB_MAX_CONCURRENT", 3), request_delay=_float("FB_REQUEST_DELAY", 0.5), work_path=work_path, From eb43fdd6a718cf6d22b492a3b877d30f9e85ffdd Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E7=8F=AD=E6=89=AC?= Date: Fri, 17 Apr 2026 22:54:47 +0800 Subject: [PATCH 29/56] refactor doc extractor --- requirements/core.txt | 1 + src/sirchmunk/learnings/compiler.py | 13 +- src/sirchmunk/learnings/toc_extractor.py | 846 +++++++++++++++++----- src/sirchmunk/utils/document_extractor.py | 398 ++++++++++ src/sirchmunk/utils/file_utils.py | 26 +- 5 files changed, 1101 insertions(+), 183 deletions(-) create mode 100644 src/sirchmunk/utils/document_extractor.py diff --git a/requirements/core.txt b/requirements/core.txt index 1848a37..6cff25b 100644 --- a/requirements/core.txt +++ b/requirements/core.txt @@ -5,6 +5,7 @@ openai genson pillow pypdf +pdfminer.six pandas parquet numpy diff --git a/src/sirchmunk/learnings/compiler.py b/src/sirchmunk/learnings/compiler.py index 2f8983a..4e31441 100644 --- a/src/sirchmunk/learnings/compiler.py +++ b/src/sirchmunk/learnings/compiler.py @@ -32,7 +32,8 @@ ) from sirchmunk.storage.knowledge_storage import KnowledgeStorage from sirchmunk.utils import LogCallback, create_logger -from sirchmunk.utils.file_utils import fast_extract, get_fast_hash +from sirchmunk.utils.document_extractor import DocumentExtractor +from sirchmunk.utils.file_utils import get_fast_hash # Concurrency cap for LLM-heavy file processing _DEFAULT_CONCURRENCY = 3 @@ -539,7 +540,9 @@ async def _compile_single_file( try: await self._log.info(f"[Compile] Processing: {Path(entry.path).name}") - extraction = await fast_extract(file_path=entry.path) + extraction = await DocumentExtractor.extract( + entry.path, DocumentExtractor.ENHANCED, + ) content = extraction.content if not content or len(content.strip()) < 100: result.error = "Insufficient text content" @@ -550,11 +553,13 @@ async def _compile_single_file( and DocumentTreeIndexer.should_build_tree(entry.path, len(content)) ) - # Phase 0.5: TOC extraction (zero LLM calls) + # Phase 0.5: TOC extraction (layers 1-3 are zero LLM calls) toc_entries = None if use_tree: from sirchmunk.learnings.toc_extractor import TOCExtractor - toc_entries = TOCExtractor.extract(entry.path, content) + toc_entries = await TOCExtractor.extract( + entry.path, content, + ) if toc_entries: await self._log.info( f"[Compile] Extracted TOC with {len(toc_entries)} entries " diff --git a/src/sirchmunk/learnings/toc_extractor.py b/src/sirchmunk/learnings/toc_extractor.py index 85f3b8e..1516cfd 100644 --- a/src/sirchmunk/learnings/toc_extractor.py +++ b/src/sirchmunk/learnings/toc_extractor.py @@ -1,220 +1,589 @@ # Copyright (c) ModelScope Contributors. All rights reserved. """ -TOC (Table of Contents) extractor — pure local operations, zero LLM calls. +TOC (Table of Contents) extractor — multi-layer fallback strategy. Extracts hierarchical table-of-contents structures from various document -formats (PDF, Markdown, DOCX, HTML) using native format features (bookmarks, -heading styles, heading tags). The extracted TOCEntry list is consumed by -the tree indexer to accelerate tree construction. +formats (PDF, Markdown, DOCX, HTML) using a layered approach: + + Layer 1 — pypdf native outline (highest confidence, zero cost) + Layer 2 — pdfminer.six detailed parsing (fallback for pypdf) + Layer 3 — Text heading pattern detection (for documents without bookmarks) + Layer 4 — LLM-assisted inference (optional, last resort) + +The extracted TOCEntry list is consumed by the tree indexer to accelerate +tree construction. """ +import json +import logging import re from dataclasses import dataclass, field from pathlib import Path -from typing import List, Optional +from typing import Any, ClassVar, List, Optional -# Minimum number of TOC entries required to form a meaningful structure -_MIN_TOC_ENTRIES = 3 +logger = logging.getLogger(__name__) # Known heading-style prefixes across locales (English, Chinese, etc.) _HEADING_STYLE_PREFIXES = ("Heading", "heading", "\u6807\u9898") # "标题" = Chinese +# --------------------------------------------------------------------------- +# Data models +# --------------------------------------------------------------------------- + + @dataclass class TOCEntry: - """Single entry in an extracted table of contents.""" + """Single entry in an extracted table of contents. + + Attributes: + title: Section title text. + level: Heading depth (1 = top-level section, 2 = subsection, …). + char_start: Character offset in the extracted full text. + char_end: End character offset (exclusive), or None if unresolved. + page_start: 1-indexed page number, or None if unknown. + page_end: End page number (inclusive), or None. + children: Nested sub-entries forming a tree. + source: Which extraction layer produced this entry + ("pypdf", "pdfminer", "heading", "markdown", "docx", + "html", "llm"). + """ title: str - level: int # 0=root, 1=section, 2=subsection - char_start: int # Character offset in extracted text + level: int # 1=section, 2=subsection, … + char_start: int = 0 char_end: Optional[int] = None page_start: Optional[int] = None page_end: Optional[int] = None children: List["TOCEntry"] = field(default_factory=list) + source: str = "" -class TOCExtractor: - """Extract TOC structure from documents using native format features. - - All methods are static — no instance state required. Each extraction - method handles one file format and returns a flat or nested list of - ``TOCEntry`` objects. The main ``extract()`` entry point dispatches - by file extension and resolves character positions against the - extracted text content. - - Design constraints: - - Pure local operations, zero LLM calls - - Exceptions handled internally; failure returns None +@dataclass +class TocResult: + """Complete TOC extraction result with quality metadata. + + Attributes: + entries: Ordered list of TOCEntry objects. + source: Primary extraction method that produced the result. + confidence: Estimated quality score (0.0–1.0). + page_count: Total pages in the source document, if known. """ - @staticmethod - def extract(file_path: str, content: str) -> Optional[List[TOCEntry]]: - """Main entry point: extract TOC entries from a file. - - Dispatches to format-specific extractors based on file extension, - then resolves character positions in the extracted text content. - - Args: - file_path: Absolute path to the source file. - content: Extracted text content of the file. + entries: List[TOCEntry] = field(default_factory=list) + source: str = "" + confidence: float = 0.0 + page_count: Optional[int] = None - Returns: - List of TOCEntry with resolved char positions, or None if - the file format is unsupported or fewer than _MIN_TOC_ENTRIES - entries are found. - """ - ext = Path(file_path).suffix.lower() - entries: Optional[List[TOCEntry]] = None - if ext == ".pdf": - entries = TOCExtractor._extract_pdf_toc(file_path) - elif ext in (".md", ".markdown"): - entries = TOCExtractor._extract_markdown_toc(content) - elif ext in (".docx",): - entries = TOCExtractor._extract_docx_toc(file_path) - elif ext in (".html", ".htm"): - entries = TOCExtractor._extract_html_toc(content) - else: - return None +# --------------------------------------------------------------------------- +# Layer 1: pypdf native outline +# --------------------------------------------------------------------------- - if not entries: - return None - # Flatten nested children for total count check - total = TOCExtractor._count_entries(entries) - if total < _MIN_TOC_ENTRIES: - return None +class PypdfOutlineExtractor: + """Layer 1: Extract TOC from PDF native outline/bookmarks using pypdf. - # Resolve character positions in extracted text - entries = TOCExtractor._resolve_char_positions(entries, content) - return entries + Highest confidence (0.9) — relies on the PDF producer embedding + explicit bookmarks. Zero external cost. + """ @staticmethod - def _extract_pdf_toc(file_path: str) -> Optional[List[TOCEntry]]: - """Extract TOC from PDF bookmarks/outline using pypdf. - - Recursively parses the nested bookmark structure from - ``PdfReader.outline``. + def extract(file_path: str | Path) -> TocResult: + """Extract TOC from PDF outline. Args: file_path: Path to the PDF file. Returns: - List of TOCEntry with page_start populated, or None on failure. + TocResult with entries and page_count populated, + or an empty TocResult on failure. """ try: from pypdf import PdfReader - reader = PdfReader(file_path) + reader = PdfReader(str(file_path)) outline = reader.outline + page_count = len(reader.pages) + if not outline: - return None + return TocResult(source="pypdf", page_count=page_count) entries: List[TOCEntry] = [] - TOCExtractor._parse_pdf_outline(reader, outline, entries, level=1) - return entries if entries else None - except Exception: - return None + PypdfOutlineExtractor._parse_outline( + reader, outline, entries, level=1, + ) + + if not entries: + return TocResult(source="pypdf", page_count=page_count) + + return TocResult( + entries=entries, + source="pypdf", + confidence=0.9, + page_count=page_count, + ) + except Exception as exc: + logger.debug("pypdf outline extraction failed: %s", exc) + return TocResult(source="pypdf") @staticmethod - def _parse_pdf_outline( - reader: "PdfReader", - outline_items: List, + def _parse_outline( + reader: Any, + outline_items: list, entries: List[TOCEntry], level: int, ) -> None: - """Recursively parse pypdf outline items into TOCEntry list. - - Args: - reader: PdfReader instance for page number resolution. - outline_items: Nested list of outline Destination objects. - entries: Accumulator list to append entries to. - level: Current nesting level (1=top-level section). - """ + """Recursively parse pypdf outline items into TOCEntry list.""" for item in outline_items: if isinstance(item, list): - # Nested list means sub-bookmarks — attach to last entry + # Nested list → sub-bookmarks; attach to last entry if entries: - sub_entries: List[TOCEntry] = [] - TOCExtractor._parse_pdf_outline( - reader, item, sub_entries, level=level + 1, + sub: List[TOCEntry] = [] + PypdfOutlineExtractor._parse_outline( + reader, item, sub, level=level + 1, ) - entries[-1].children.extend(sub_entries) + entries[-1].children.extend(sub) else: - TOCExtractor._parse_pdf_outline( + PypdfOutlineExtractor._parse_outline( reader, item, entries, level=level, ) else: - # Single bookmark destination try: title = item.title if hasattr(item, "title") else str(item) - page_num = None + page_num: Optional[int] = None try: - page_num = reader.get_destination_page_number(item) + # get_destination_page_number returns 0-indexed + raw = reader.get_destination_page_number(item) + if raw is not None: + page_num = raw + 1 # convert to 1-indexed except Exception: pass - entry = TOCEntry( + entries.append(TOCEntry( title=title.strip(), level=level, char_start=0, page_start=page_num, - ) - entries.append(entry) + source="pypdf", + )) except Exception: continue - @staticmethod - def _extract_markdown_toc(content: str) -> Optional[List[TOCEntry]]: - """Extract TOC from Markdown heading syntax (# / ## / ###). - Matches ATX-style headings: lines beginning with 1-6 '#' characters - followed by whitespace and the heading text. +# --------------------------------------------------------------------------- +# Layer 2: pdfminer.six detailed parsing +# --------------------------------------------------------------------------- + + +class PdfminerOutlineExtractor: + """Layer 2: Extract TOC using pdfminer.six for more detailed parsing. + + Falls back here when pypdf yields insufficient entries. + Confidence 0.85 — pdfminer exposes more detail but requires + manual page-number resolution. + """ + + @staticmethod + def extract(file_path: str | Path) -> TocResult: + """Extract TOC using pdfminer's outline parser. Args: - content: Markdown text content. + file_path: Path to the PDF file. Returns: - List of TOCEntry with level derived from '#' count, or None. + TocResult with entries populated, or empty on failure. """ try: - pattern = re.compile(r"^(#{1,6})\s+(.+)$", re.MULTILINE) - matches = pattern.findall(content) - if not matches: + from pdfminer.pdfdocument import PDFDocument, PDFNoOutlines + from pdfminer.pdfpage import PDFPage + from pdfminer.pdfparser import PDFParser + from pdfminer.psparser import LIT + + fp = open(str(file_path), "rb") + try: + parser = PDFParser(fp) + document = PDFDocument(parser) + + # Build page-object-id → 1-indexed page number mapping + pages = list(PDFPage.create_pages(document)) + page_count = len(pages) + objid_to_pagenum = { + page.pageid: idx + 1 + for idx, page in enumerate(pages) + } + + entries: List[TOCEntry] = [] + try: + for level, title, dest, action, _se in document.get_outlines(): + page_num = PdfminerOutlineExtractor._resolve_page( + dest, action, objid_to_pagenum, document, + ) + entries.append(TOCEntry( + title=str(title).strip() if title else "", + level=level, + char_start=0, + page_start=page_num, + source="pdfminer", + )) + except PDFNoOutlines: + pass + + if not entries: + return TocResult(source="pdfminer", page_count=page_count) + + return TocResult( + entries=entries, + source="pdfminer", + confidence=0.85, + page_count=page_count, + ) + finally: + fp.close() + except Exception as exc: + logger.debug("pdfminer outline extraction failed: %s", exc) + return TocResult(source="pdfminer") + + @staticmethod + def _resolve_page( + dest: Any, + action: Any, + objid_to_pagenum: dict, + document: Any, + ) -> Optional[int]: + """Resolve a pdfminer destination/action to a 1-indexed page number.""" + try: + from pdfminer.pdfparser import PDFStream + from pdfminer.pdftypes import resolve1 + + # Try dest first + target = dest + if target is None and action is not None: + # GoTo action: action dict may have a 'D' key + if isinstance(action, dict): + target = action.get("D") + + if target is None: return None - entries: List[TOCEntry] = [] - for hashes, title in matches: + # Resolve indirect objects + target = resolve1(target) + + if isinstance(target, list) and len(target) > 0: + page_ref = resolve1(target[0]) + if hasattr(page_ref, "objid"): + return objid_to_pagenum.get(page_ref.objid) + elif hasattr(target, "objid"): + return objid_to_pagenum.get(target.objid) + except Exception: + pass + return None + + +# --------------------------------------------------------------------------- +# Layer 3: Text heading pattern detection +# --------------------------------------------------------------------------- + + +class HeadingTocExtractor: + """Layer 3: Infer TOC from document text structure (heading patterns). + + Handles Markdown headings, numbered sections, and common structural + keywords. Confidence 0.6 — heuristic-based, lower precision. + """ + + # Regex for Markdown ATX headings: # Title, ## Subtitle, … + _MD_HEADING_RE: ClassVar[re.Pattern] = re.compile( + r"^(#{1,6})\s+(.+)$", re.MULTILINE, + ) + + # Regex for numbered section patterns: "1.", "1.1", "1.1.1", … + _NUMBERED_RE: ClassVar[re.Pattern] = re.compile( + r"^(\d+(?:\.\d+)*)[.\s]+(.+)$", re.MULTILINE, + ) + + # Common structural keywords (case-insensitive) + _STRUCTURAL_KEYWORDS: ClassVar[tuple] = ( + "ITEM", "PART", "CHAPTER", "SECTION", "ARTICLE", + "APPENDIX", "EXHIBIT", "SCHEDULE", "ANNEX", + ) + + # Max characters for a candidate heading line + _MAX_HEADING_LINE_LEN: ClassVar[int] = 120 + + @staticmethod + def extract(content: str, mime_type: str = "") -> TocResult: + """Infer TOC from text content by detecting heading patterns. + + Tries strategies in order: + 1. Markdown ATX headings (``#`` syntax) + 2. Numbered section patterns (``1.``, ``1.1``, …) + 3. Structural keyword detection (ITEM, PART, CHAPTER, …) + + Args: + content: Full extracted text of the document. + mime_type: Optional MIME type hint (unused currently). + + Returns: + TocResult with char_position-based entries. + """ + if not content or len(content.strip()) < 50: + return TocResult(source="heading") + + # Strategy 1: Markdown headings + entries = HeadingTocExtractor._extract_markdown_headings(content) + if entries: + return TocResult( + entries=entries, + source="heading", + confidence=0.7, + ) + + # Strategy 2: Numbered sections + entries = HeadingTocExtractor._extract_numbered_sections(content) + if entries: + return TocResult( + entries=entries, + source="heading", + confidence=0.6, + ) + + # Strategy 3: Structural keywords + heuristic + entries = HeadingTocExtractor._extract_structural_headings(content) + if entries: + return TocResult( + entries=entries, + source="heading", + confidence=0.5, + ) + + return TocResult(source="heading") + + @staticmethod + def _extract_markdown_headings(content: str) -> List[TOCEntry]: + """Extract headings from Markdown ATX syntax (# / ## / ###).""" + matches = list(HeadingTocExtractor._MD_HEADING_RE.finditer(content)) + if not matches: + return [] + + entries: List[TOCEntry] = [] + for m in matches: + hashes, title = m.group(1), m.group(2).strip() + if title: entries.append(TOCEntry( - title=title.strip(), + title=title, level=len(hashes), - char_start=0, + char_start=m.start(), + source="heading", )) - return entries if entries else None - except Exception: - return None + return entries + + @staticmethod + def _extract_numbered_sections(content: str) -> List[TOCEntry]: + """Extract numbered section headings (1., 1.1, 1.1.1, …).""" + matches = list(HeadingTocExtractor._NUMBERED_RE.finditer(content)) + if not matches: + return [] + + entries: List[TOCEntry] = [] + for m in matches: + number_part = m.group(1) + title_part = m.group(2).strip() + # Line length check — skip long lines (likely not headings) + line_len = m.end() - m.start() + if line_len > HeadingTocExtractor._MAX_HEADING_LINE_LEN: + continue + if not title_part: + continue + level = number_part.count(".") + 1 + entries.append(TOCEntry( + title=f"{number_part} {title_part}", + level=level, + char_start=m.start(), + source="heading", + )) + return entries + + @staticmethod + def _extract_structural_headings(content: str) -> List[TOCEntry]: + """Detect common structural keywords as section boundaries.""" + # Build pattern: ITEM 1, PART I, CHAPTER 1, etc. + kw_pattern = "|".join(HeadingTocExtractor._STRUCTURAL_KEYWORDS) + pattern = re.compile( + rf"^({kw_pattern})\s+(\w+[\w .:\-]*)$", + re.MULTILINE | re.IGNORECASE, + ) + matches = list(pattern.finditer(content)) + if not matches: + return [] + + entries: List[TOCEntry] = [] + for m in matches: + keyword = m.group(1).upper() + rest = m.group(2).strip() + title = f"{keyword} {rest}" + # Determine level based on keyword + if keyword in ("PART", "CHAPTER"): + level = 1 + elif keyword in ("ITEM", "SECTION", "ARTICLE"): + level = 2 + else: + level = 3 + entries.append(TOCEntry( + title=title, + level=level, + char_start=m.start(), + source="heading", + )) + return entries + + +# --------------------------------------------------------------------------- +# Layer 4: LLM-assisted inference (optional) +# --------------------------------------------------------------------------- + + +class LlmTocExtractor: + """Layer 4: Use LLM to infer TOC from document content. + + This is the last-resort fallback. Requires an ``llm_caller`` that + supports ``await llm_caller.achat(messages)``. If no caller is + provided, returns an empty result immediately. + + Confidence 0.7 — LLM may hallucinate structure. + """ + + # Maximum characters sent to the LLM to stay within token limits + _MAX_CONTENT_CHARS: ClassVar[int] = 8_000 + + _PROMPT_TEMPLATE: ClassVar[str] = ( + "Analyze the following document excerpt and extract its " + "hierarchical table of contents (TOC) structure.\n\n" + "Return a JSON array where each element has:\n" + ' - "title": section title text\n' + ' - "level": integer heading depth (1=top, 2=sub, 3=subsub)\n\n' + "Only include actual section/chapter headings, not every paragraph.\n" + "Return ONLY the JSON array, no other text.\n\n" + "Document excerpt:\n---\n{content}\n---" + ) + + @staticmethod + async def extract( + content: str, + llm_caller: Any | None = None, + ) -> TocResult: + """Infer TOC using LLM analysis. + + Args: + content: Full extracted text of the document. + llm_caller: An object with ``achat(messages)`` method. + If None, returns an empty result. + + Returns: + TocResult with LLM-inferred entries. + """ + if llm_caller is None: + return TocResult(source="llm") + + if not content or len(content.strip()) < 100: + return TocResult(source="llm") + + try: + # Truncate content to fit token budget + truncated = content[:LlmTocExtractor._MAX_CONTENT_CHARS] + prompt = LlmTocExtractor._PROMPT_TEMPLATE.format(content=truncated) + + resp = await llm_caller.achat([{"role": "user", "content": prompt}]) + raw = resp.content.strip() + + entries = LlmTocExtractor._parse_response(raw, content) + if not entries: + return TocResult(source="llm") + + return TocResult( + entries=entries, + source="llm", + confidence=0.7, + ) + except Exception as exc: + logger.debug("LLM TOC extraction failed: %s", exc) + return TocResult(source="llm") @staticmethod - def _extract_docx_toc(file_path: str) -> Optional[List[TOCEntry]]: - """Extract TOC from DOCX heading styles using python-docx. + def _parse_response(raw: str, content: str) -> List[TOCEntry]: + """Parse LLM JSON response into TOCEntry list with char_positions.""" + # Strip markdown code fences if present + cleaned = raw.strip() + if cleaned.startswith("```"): + lines = cleaned.split("\n") + # Remove first and last fence lines + lines = [l for l in lines if not l.strip().startswith("```")] + cleaned = "\n".join(lines) + + try: + items = json.loads(cleaned) + except (json.JSONDecodeError, TypeError): + return [] + + if not isinstance(items, list): + return [] + + content_lower = content.lower() + search_from = 0 + entries: List[TOCEntry] = [] + + for item in items: + if not isinstance(item, dict): + continue + title = str(item.get("title", "")).strip() + level = int(item.get("level", 1)) + if not title: + continue + + # Try to locate title in content for char_position + pos = content_lower.find(title.lower(), search_from) + if pos >= 0: + char_start = pos + search_from = pos + len(title) + else: + # Fallback: try from beginning + pos = content_lower.find(title.lower()) + char_start = pos if pos >= 0 else search_from + + entries.append(TOCEntry( + title=title, + level=max(1, min(level, 6)), + char_start=char_start, + source="llm", + )) + + return entries - Reads paragraphs with heading style names (English ``Heading``, - Chinese ``\u6807\u9898``, etc.), extracting the heading level from the style - name suffix (e.g., ``Heading 1`` -> level 1). + +# --------------------------------------------------------------------------- +# Format-specific extractors (non-PDF) +# --------------------------------------------------------------------------- + + +class DocxTocExtractor: + """Extract TOC from DOCX heading styles using python-docx.""" + + @staticmethod + def extract(file_path: str | Path) -> TocResult: + """Extract TOC from DOCX heading styles. Args: file_path: Path to the DOCX file. Returns: - List of TOCEntry with level from heading style, or None. + TocResult with entries from heading styles. """ try: import docx - doc = docx.Document(file_path) + doc = docx.Document(str(file_path)) entries: List[TOCEntry] = [] for para in doc.paragraphs: style_name = para.style.name or "" - # Match heading styles across locales ("Heading 1", "标题 1", etc.) matched_prefix = "" for prefix in _HEADING_STYLE_PREFIXES: if style_name.startswith(prefix): @@ -233,47 +602,213 @@ def _extract_docx_toc(file_path: str) -> Optional[List[TOCEntry]]: title=title, level=level, char_start=0, + source="docx", )) - return entries if entries else None - except Exception: - return None - @staticmethod - def _extract_html_toc(content: str) -> Optional[List[TOCEntry]]: - """Extract TOC from HTML heading tags (

through

). + if not entries: + return TocResult(source="docx") + return TocResult(entries=entries, source="docx", confidence=0.85) + except Exception as exc: + logger.debug("DOCX TOC extraction failed: %s", exc) + return TocResult(source="docx") + - Uses regex to match heading tags and strips inner HTML tags - from the title text. +class HtmlTocExtractor: + """Extract TOC from HTML heading tags (

).""" + + _HTML_HEADING_RE: ClassVar[re.Pattern] = re.compile( + r"]*>(.*?)", + re.IGNORECASE | re.DOTALL, + ) + + @staticmethod + def extract(content: str) -> TocResult: + """Extract TOC from HTML heading tags. Args: content: HTML text content. Returns: - List of TOCEntry with level from tag number, or None. + TocResult with entries from

tags. """ try: - pattern = re.compile( - r"]*>(.*?)", - re.IGNORECASE | re.DOTALL, - ) - matches = pattern.findall(content) + matches = HtmlTocExtractor._HTML_HEADING_RE.findall(content) if not matches: - return None + return TocResult(source="html") entries: List[TOCEntry] = [] for level_str, raw_title in matches: - # Strip HTML tags from title title = re.sub(r"<[^>]+>", "", raw_title).strip() if title: entries.append(TOCEntry( title=title, level=int(level_str), char_start=0, + source="html", )) - return entries if entries else None - except Exception: + + if not entries: + return TocResult(source="html") + return TocResult(entries=entries, source="html", confidence=0.8) + except Exception as exc: + logger.debug("HTML TOC extraction failed: %s", exc) + return TocResult(source="html") + + +# --------------------------------------------------------------------------- +# Orchestrator: multi-layer fallback +# --------------------------------------------------------------------------- + + +class TOCExtractor: + """Orchestrates multi-layer TOC extraction with fallback strategy. + + All methods are static/classmethod — no instance state required. + The main ``extract()`` entry point dispatches by file extension and + applies the layered fallback for PDF files. + + Layer priority for PDFs: + 1. pypdf native outline (confidence 0.9) + 2. pdfminer.six detailed parsing (confidence 0.85) + 3. Text heading detection (confidence 0.5–0.7) + 4. LLM-assisted inference (confidence 0.7, optional) + + Design constraints: + - Layers 1–3 are pure-local, zero LLM calls + - Layer 4 is optional (requires llm_caller) + - Each layer is independently try-excepted; failure never blocks + subsequent layers + """ + + # Minimum entries to consider a TOC extraction successful + _MIN_ENTRIES_THRESHOLD: ClassVar[int] = 3 + + @classmethod + async def extract( + cls, + file_path: str, + content: str, + *, + llm_caller: Any | None = None, + ) -> Optional[List[TOCEntry]]: + """Extract TOC using layered fallback strategy. + + Tries extraction methods in order of reliability. Falls back to + the next layer when the current layer yields fewer than + ``_MIN_ENTRIES_THRESHOLD`` entries. + + Args: + file_path: Absolute path to the source file. + content: Extracted text content of the file. + llm_caller: Optional LLM caller for Layer 4. + + Returns: + List of TOCEntry with resolved char positions, or None if + no layer produced enough entries. + """ + ext = Path(file_path).suffix.lower() + + result: Optional[TocResult] = None + + if ext == ".pdf": + result = await cls._extract_pdf_layered( + file_path, content, llm_caller, + ) + elif ext in (".md", ".markdown"): + heading_result = HeadingTocExtractor.extract(content) + if cls._is_sufficient(heading_result): + result = heading_result + elif ext in (".docx",): + result = DocxTocExtractor.extract(file_path) + elif ext in (".html", ".htm"): + result = HtmlTocExtractor.extract(content) + else: + return None + + if result is None or not cls._is_sufficient(result): return None + entries = result.entries + total = cls._count_entries(entries) + if total < cls._MIN_ENTRIES_THRESHOLD: + return None + + # Resolve character positions in the extracted text + entries = cls._resolve_char_positions(entries, content) + return entries + + @classmethod + async def _extract_pdf_layered( + cls, + file_path: str, + content: str, + llm_caller: Any | None, + ) -> Optional[TocResult]: + """Apply layered extraction for PDF files. + + Args: + file_path: Path to the PDF file. + content: Extracted text content. + llm_caller: Optional LLM caller for Layer 4. + + Returns: + Best TocResult from the layer cascade, or None. + """ + # Layer 1: pypdf + result = PypdfOutlineExtractor.extract(file_path) + if cls._is_sufficient(result): + logger.info( + "TOC Layer 1 (pypdf): %d entries for %s", + len(result.entries), Path(file_path).name, + ) + return result + + # Layer 2: pdfminer.six + result = PdfminerOutlineExtractor.extract(file_path) + if cls._is_sufficient(result): + logger.info( + "TOC Layer 2 (pdfminer): %d entries for %s", + len(result.entries), Path(file_path).name, + ) + return result + + # Layer 3: heading detection from content + if content: + result = HeadingTocExtractor.extract(content) + if cls._is_sufficient(result): + logger.info( + "TOC Layer 3 (heading): %d entries for %s", + len(result.entries), Path(file_path).name, + ) + return result + + # Layer 4: LLM-assisted (optional) + if llm_caller is not None and content: + result = await LlmTocExtractor.extract(content, llm_caller) + if cls._is_sufficient(result): + logger.info( + "TOC Layer 4 (LLM): %d entries for %s", + len(result.entries), Path(file_path).name, + ) + return result + + logger.debug( + "TOC extraction: no layer produced sufficient entries for %s", + Path(file_path).name, + ) + return None + + @classmethod + def _is_sufficient(cls, result: Optional[TocResult]) -> bool: + """Check whether a TocResult has enough entries to be useful.""" + if result is None: + return False + return len(result.entries) >= cls._MIN_ENTRIES_THRESHOLD + + # ------------------------------------------------------------------ # + # Character position resolution # + # ------------------------------------------------------------------ # + @staticmethod def _resolve_char_positions( entries: List[TOCEntry], @@ -311,26 +846,21 @@ def _resolve_char_positions( if not title_lower: entry.char_start = search_from continue - # Normalise whitespace for fuzzy matching (PDF extracts may - # insert extra spaces inside headings). + # Normalise whitespace for fuzzy matching title_normalised = re.sub(r"\s+", " ", title_lower) pos = content_lower.find(title_normalised, search_from) if pos < 0: - # Retry with the original (un-normalised) title pos = content_lower.find(title_lower, search_from) if pos >= 0: entry.char_start = pos search_from = pos + len(title_lower) else: - # Title not found after search_from; try from beginning pos = content_lower.find(title_normalised) if pos < 0: pos = content_lower.find(title_lower) if pos >= 0: entry.char_start = pos - # Do NOT reset search_from to avoid breaking order else: - # Last resort: place at current search frontier entry.char_start = search_from # Pass 2: resolve char_end as start of next entry (or len(content)) @@ -346,12 +876,7 @@ def _flatten_entries( entries: List[TOCEntry], flat: List[TOCEntry], ) -> None: - """Flatten nested TOCEntry tree into document-order list. - - Args: - entries: Nested entry list. - flat: Accumulator for flattened output. - """ + """Flatten nested TOCEntry tree into document-order list.""" for entry in entries: flat.append(entry) if entry.children: @@ -359,30 +884,7 @@ def _flatten_entries( @staticmethod def _count_entries(entries: List[TOCEntry]) -> int: - """Count total entries including nested children. - - Args: - entries: Nested entry list. - - Returns: - Total number of entries in the tree. - """ - count = 0 - for entry in entries: - count += 1 - if entry.children: - count += TOCExtractor._count_entries(entry.children) - return count - @staticmethod - def _count_entries(entries: List[TOCEntry]) -> int: - """Count total entries including nested children. - - Args: - entries: Nested entry list. - - Returns: - Total number of entries in the tree. - """ + """Count total entries including nested children.""" count = 0 for entry in entries: count += 1 diff --git a/src/sirchmunk/utils/document_extractor.py b/src/sirchmunk/utils/document_extractor.py new file mode 100644 index 0000000..76e0f15 --- /dev/null +++ b/src/sirchmunk/utils/document_extractor.py @@ -0,0 +1,398 @@ +# Copyright (c) ModelScope Contributors. All rights reserved. +"""Unified document extraction facade over kreuzberg. + +Centralizes all kreuzberg interaction into a single module, providing a clean, +configurable interface for document text extraction with support for tables, +metadata, language detection, OCR, and page-range filtering. + +All other modules should import from here rather than from kreuzberg directly. +""" + +from __future__ import annotations + +import asyncio +from dataclasses import dataclass, field +from pathlib import Path +from typing import Any, ClassVar, List, Optional, Sequence, Union + +from loguru import logger + + +# --------------------------------------------------------------------------- +# Configuration profile +# --------------------------------------------------------------------------- + +@dataclass(frozen=True) +class ExtractionProfile: + """Immutable extraction configuration profile. + + Controls which kreuzberg features are enabled during document extraction. + Default values align with the legacy ``fast_extract()`` behavior + (plain text only, no extras). + """ + + output_format: str = "plain" + """Output format: ``plain`` | ``markdown`` | ``html`` | ``djot``.""" + + extract_tables: bool = False + """Whether to extract and return tables.""" + + extract_metadata: bool = False + """Whether to return document metadata.""" + + detect_language: bool = False + """Whether to detect document language.""" + + ocr_enabled: bool = False + """Whether to enable OCR fallback.""" + + ocr_backend: str = "tesseract" + """OCR engine: ``tesseract`` | ``easyocr`` | ``paddleocr``.""" + + ocr_language: str = "eng" + """OCR language code (e.g. ``eng``, ``chi_sim``).""" + + page_start: Optional[int] = None + """Page range start (0-indexed). ``None`` means first page.""" + + page_end: Optional[int] = None + """Page range end (inclusive). ``None`` means last page.""" + + pdf_extract_images: bool = False + """Extract images embedded in PDF pages.""" + + pdf_extract_metadata: bool = False + """Extract PDF-level metadata (author, title, etc.).""" + + force_ocr: bool = False + """Force OCR for all pages, bypassing native text extraction. + + Maps directly to kreuzberg's ``ExtractionConfig.force_ocr``. + Note: kreuzberg does not offer a "fallback" OCR mode — + when set, OCR is always applied regardless of text layer presence. + """ + + pdf_password: Optional[str] = None + """Password for encrypted PDFs.""" + + max_concurrent: Optional[int] = None + """Max concurrency for batch extraction.""" + + +# --------------------------------------------------------------------------- +# Extraction output +# --------------------------------------------------------------------------- + +@dataclass(frozen=True) +class ExtractionOutput: + """Structured extraction result. + + Always contains ``content``. Other fields are populated based on the + :class:`ExtractionProfile` settings used during extraction. + """ + + content: str + """Extracted text content.""" + + mime_type: str = "" + """MIME type of the source document.""" + + metadata: dict[str, Any] = field(default_factory=dict) + """Document metadata (empty when ``extract_metadata`` is disabled).""" + + tables: list[dict[str, Any]] = field(default_factory=list) + """Extracted tables (empty when ``extract_tables`` is disabled).""" + + detected_languages: dict[str, float] = field(default_factory=dict) + """Language → confidence mapping (empty when ``detect_language`` is disabled).""" + + page_count: Optional[int] = None + """Number of pages in the source document (if available).""" + + +# --------------------------------------------------------------------------- +# Document extractor facade +# --------------------------------------------------------------------------- + +class DocumentExtractor: + """Unified document extraction facade over kreuzberg. + + Provides a clean, configurable interface for document text extraction, + centralizing all kreuzberg interaction within a single module. + + Usage:: + + # Basic extraction (identical to legacy fast_extract) + result = await DocumentExtractor.extract(path) + + # Enhanced extraction with tables and metadata + result = await DocumentExtractor.extract(path, DocumentExtractor.ENHANCED) + + # Custom profile + profile = ExtractionProfile(output_format="markdown", extract_tables=True) + result = await DocumentExtractor.extract(path, profile) + """ + + # Pre-defined profiles ------------------------------------------------- + + BASIC: ClassVar[ExtractionProfile] = ExtractionProfile() + """Plain-text extraction only — equivalent to legacy ``fast_extract()``.""" + + ENHANCED: ClassVar[ExtractionProfile] = ExtractionProfile( + output_format="markdown", + extract_tables=True, + extract_metadata=True, + pdf_extract_metadata=True, + force_ocr=True, + ) + """Rich extraction with tables, metadata, and OCR fallback.""" + + # Public API ----------------------------------------------------------- + + @staticmethod + async def extract( + file_path: Union[str, Path], + profile: Optional[ExtractionProfile] = None, + ) -> ExtractionOutput: + """Extract content from a single file. + + Args: + file_path: Path to the document. + profile: Extraction profile. Defaults to :attr:`BASIC`. + + Returns: + :class:`ExtractionOutput` with at least ``content`` populated. + + Raises: + FileNotFoundError: If *file_path* does not exist. + Exception: Propagates kreuzberg extraction errors after logging. + """ + from kreuzberg import extract_file + + profile = profile or DocumentExtractor.BASIC + config = DocumentExtractor._build_config(profile) + + try: + result = await extract_file(file_path=file_path, config=config) + return DocumentExtractor._convert_result(result, profile) + except Exception as exc: + logger.error( + "Document extraction failed for {}: {}", + file_path, + exc, + ) + raise + + @staticmethod + async def extract_bytes( + data: bytes, + mime_type: str, + profile: Optional[ExtractionProfile] = None, + ) -> ExtractionOutput: + """Extract content from raw bytes. + + Args: + data: File content as bytes. + mime_type: MIME type of the data (required for format detection). + profile: Extraction profile. Defaults to :attr:`BASIC`. + + Returns: + :class:`ExtractionOutput`. + """ + from kreuzberg import extract_bytes as _extract_bytes + + profile = profile or DocumentExtractor.BASIC + config = DocumentExtractor._build_config(profile) + + try: + result = await _extract_bytes(data=data, mime_type=mime_type, config=config) + return DocumentExtractor._convert_result(result, profile) + except Exception: + logger.error("Byte extraction failed for mime_type={}", mime_type) + raise + + @staticmethod + async def batch_extract( + file_paths: Sequence[Union[str, Path]], + profile: Optional[ExtractionProfile] = None, + ) -> List[ExtractionOutput]: + """Extract content from multiple files in parallel. + + Args: + file_paths: Sequence of document paths. + profile: Extraction profile. Defaults to :attr:`BASIC`. + + Returns: + List of :class:`ExtractionOutput`, one per input path. + """ + from kreuzberg import batch_extract_files + + profile = profile or DocumentExtractor.BASIC + config = DocumentExtractor._build_config(profile) + + try: + results = await batch_extract_files(paths=list(file_paths), config=config) + return [ + DocumentExtractor._convert_result(r, profile) for r in results + ] + except Exception: + logger.error("Batch extraction failed for {} files", len(file_paths)) + raise + + # Internal helpers ----------------------------------------------------- + + @staticmethod + def _build_config(profile: ExtractionProfile): + """Build a kreuzberg ``ExtractionConfig`` from an :class:`ExtractionProfile`. + + Maps profile fields to the kreuzberg configuration objects that are + actually available in the installed version. + """ + from kreuzberg import ( + ExtractionConfig, + OcrConfig, + OutputFormat, + PageConfig, + PdfConfig, + ) + + # --- Output format --- + format_map = { + "plain": OutputFormat.PLAIN, + "markdown": OutputFormat.MARKDOWN, + "html": OutputFormat.HTML, + "djot": OutputFormat.DJOT, + } + output_format = format_map.get(profile.output_format, OutputFormat.PLAIN) + + # --- OCR config --- + ocr_config: Optional[OcrConfig] = None + if profile.ocr_enabled: + ocr_config = OcrConfig( + backend=profile.ocr_backend, + language=profile.ocr_language, + ) + + # --- Page config --- + page_config: Optional[PageConfig] = None + if profile.page_start is not None or profile.page_end is not None: + # kreuzberg PageConfig.extract_pages expects a list of page indices + pages: Optional[list[int]] = None + if profile.page_start is not None: + end = profile.page_end if profile.page_end is not None else profile.page_start + pages = list(range(profile.page_start, end + 1)) + page_config = PageConfig(extract_pages=pages) + + # --- PDF config --- + pdf_config: Optional[PdfConfig] = None + if any([ + profile.pdf_extract_images, + profile.pdf_extract_metadata, + profile.pdf_password, + ]): + passwords = [profile.pdf_password] if profile.pdf_password else None + pdf_config = PdfConfig( + extract_images=profile.pdf_extract_images, + extract_metadata=profile.pdf_extract_metadata, + passwords=passwords, + ) + + # --- Language detection --- + lang_config = None + if profile.detect_language: + from kreuzberg import LanguageDetectionConfig + lang_config = LanguageDetectionConfig(enabled=True) + + # --- Assemble ExtractionConfig --- + kwargs: dict[str, Any] = { + "output_format": output_format, + } + if ocr_config is not None: + kwargs["ocr"] = ocr_config + if profile.force_ocr: + kwargs["force_ocr"] = True + if page_config is not None: + kwargs["pages"] = page_config + if pdf_config is not None: + kwargs["pdf_options"] = pdf_config + if lang_config is not None: + kwargs["language_detection"] = lang_config + if profile.max_concurrent is not None: + kwargs["max_concurrent_extractions"] = profile.max_concurrent + + return ExtractionConfig(**kwargs) + + @staticmethod + def _convert_result( + result: "ExtractionResult", + profile: ExtractionProfile, + ) -> ExtractionOutput: + """Convert a kreuzberg ``ExtractionResult`` to :class:`ExtractionOutput`. + + Only populates optional fields when the corresponding profile flag is + enabled, keeping the output lean for basic extraction. + """ + content: str = result.content or "" + mime_type: str = getattr(result, "mime_type", "") or "" + + # Metadata + metadata: dict[str, Any] = {} + if profile.extract_metadata: + raw_meta = getattr(result, "metadata", None) + if raw_meta is not None: + if isinstance(raw_meta, dict): + metadata = dict(raw_meta) + else: + # kreuzberg may return a non-dict metadata object + try: + metadata = dict(raw_meta) + except (TypeError, ValueError): + metadata = {"raw": str(raw_meta)} + + # Tables + tables: list[dict[str, Any]] = [] + if profile.extract_tables: + raw_tables = getattr(result, "tables", None) or [] + for t in raw_tables: + if isinstance(t, dict): + tables.append(t) + else: + # kreuzberg ExtractedTable has: cells, markdown, page_number + tables.append({ + "markdown": getattr(t, "markdown", ""), + "cells": getattr(t, "cells", []), + "page_number": getattr(t, "page_number", None), + }) + + # Language detection + detected_languages: dict[str, float] = {} + if profile.detect_language: + raw_langs = getattr(result, "detected_languages", None) + if raw_langs: + for entry in raw_langs: + if isinstance(entry, dict): + lang = entry.get("language", "") + conf = entry.get("confidence", 0.0) + else: + # kreuzberg DetectedLanguage object + lang = getattr(entry, "language", "") + conf = getattr(entry, "confidence", 0.0) + if lang: + detected_languages[lang] = float(conf) + + # Page count — prefer get_page_count() over get_chunk_count() + page_count: Optional[int] = None + get_page_count = getattr(result, "get_page_count", None) + if get_page_count and callable(get_page_count): + cnt = get_page_count() + if cnt is not None and cnt > 0: + page_count = cnt + + return ExtractionOutput( + content=content, + mime_type=mime_type, + metadata=metadata, + tables=tables, + detected_languages=detected_languages, + page_count=page_count, + ) diff --git a/src/sirchmunk/utils/file_utils.py b/src/sirchmunk/utils/file_utils.py index edbbc2d..df308fd 100644 --- a/src/sirchmunk/utils/file_utils.py +++ b/src/sirchmunk/utils/file_utils.py @@ -4,17 +4,29 @@ from pathlib import Path from typing import Union -from kreuzberg import ExtractionResult, extract_file from loguru import logger +from sirchmunk.utils.document_extractor import ( + DocumentExtractor, + ExtractionOutput, +) -async def fast_extract(file_path: Union[str, Path]) -> ExtractionResult: - """ - Automatically detects and extracts text content from various file formats like docx, pptx, pdf, xlsx. - """ - result: ExtractionResult = await extract_file(file_path=file_path) - return result +async def fast_extract(file_path: Union[str, Path]) -> ExtractionOutput: + """Extract text content from a file using kreuzberg. + + This is a backward-compatible wrapper around + :meth:`DocumentExtractor.extract` with the ``BASIC`` profile + (plain text, no extras). All callers that only need ``.content`` + continue to work unchanged. + + Args: + file_path: Path to the file to extract. + + Returns: + :class:`ExtractionOutput` with ``.content`` populated. + """ + return await DocumentExtractor.extract(file_path) def get_fast_hash(file_path: Union[str, Path], sample_size: int = 8192): From b2c26bbc5321d5f7f4506cafd3431af6d74719d9 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E7=8F=AD=E6=89=AC?= Date: Sun, 19 Apr 2026 20:58:23 +0800 Subject: [PATCH 30/56] enhance compiler for tree indexing --- config/env.example | 5 + src/sirchmunk/learnings/compiler.py | 5 +- src/sirchmunk/learnings/toc_extractor.py | 71 ++++++ src/sirchmunk/learnings/tree_indexer.py | 118 ++++++++-- src/sirchmunk/search.py | 262 ++++++++++++++++------- 5 files changed, 354 insertions(+), 107 deletions(-) diff --git a/config/env.example b/config/env.example index 8272d03..4b8dcd7 100644 --- a/config/env.example +++ b/config/env.example @@ -126,3 +126,8 @@ SIRCHMUNK_DEBUG=false # Maximum concurrent WebSocket connections (default: 100) SIRCHMUNK_MAX_WS_CONNECTIONS=100 + +# ===== Ablation Experiment Settings ===== +# Pure tree search mode (ablation experiment, default: false) +# When enabled, search relies solely on tree index navigation, skipping rga keyword search +# SIRCHMUNK_PURE_TREE_SEARCH=false diff --git a/src/sirchmunk/learnings/compiler.py b/src/sirchmunk/learnings/compiler.py index 4e31441..ad2a115 100644 --- a/src/sirchmunk/learnings/compiler.py +++ b/src/sirchmunk/learnings/compiler.py @@ -559,6 +559,7 @@ async def _compile_single_file( from sirchmunk.learnings.toc_extractor import TOCExtractor toc_entries = await TOCExtractor.extract( entry.path, content, + total_pages=extraction.page_count, ) if toc_entries: await self._log.info( @@ -568,7 +569,9 @@ async def _compile_single_file( if use_tree: result.tree = await self._tree_indexer.build_tree( - entry.path, content, toc_entries=toc_entries, + entry.path, content, + toc_entries=toc_entries, + total_pages=extraction.page_count, ) # Record TOC / tree metrics on the result for manifest persistence diff --git a/src/sirchmunk/learnings/toc_extractor.py b/src/sirchmunk/learnings/toc_extractor.py index 1516cfd..0197485 100644 --- a/src/sirchmunk/learnings/toc_extractor.py +++ b/src/sirchmunk/learnings/toc_extractor.py @@ -683,6 +683,46 @@ class TOCExtractor: # Minimum entries to consider a TOC extraction successful _MIN_ENTRIES_THRESHOLD: ClassVar[int] = 3 + @staticmethod + def _build_hierarchy(flat_entries: List["TOCEntry"]) -> List["TOCEntry"]: + """Convert flat TocEntry list to nested tree using level field. + + Uses stack-based algorithm, O(n). When encountering a deeper level + entry, push it as a child of the current stack top; when same or + shallower, pop back to the corresponding level. + + Args: + flat_entries: Flat list of TOCEntry objects with ``level`` set. + + Returns: + List of top-level TOCEntry objects with ``children`` populated. + """ + if not flat_entries: + return [] + + roots: List[TOCEntry] = [] + # Stack holds (level, entry) pairs representing the current path + stack: List[TOCEntry] = [] + + for entry in flat_entries: + # Reset children to avoid stale data from prior processing + entry.children = [] + + # Pop stack until we find the parent (shallower level) + while stack and stack[-1].level >= entry.level: + stack.pop() + + if stack: + # Attach as child of the current stack top + stack[-1].children.append(entry) + else: + # No parent — this is a root-level entry + roots.append(entry) + + stack.append(entry) + + return roots + @classmethod async def extract( cls, @@ -690,6 +730,7 @@ async def extract( content: str, *, llm_caller: Any | None = None, + total_pages: Optional[int] = None, ) -> Optional[List[TOCEntry]]: """Extract TOC using layered fallback strategy. @@ -701,6 +742,8 @@ async def extract( file_path: Absolute path to the source file. content: Extracted text content of the file. llm_caller: Optional LLM caller for Layer 4. + total_pages: Total page count of the source document, if known. + Used to estimate ``page_start`` for Layer 3/4 entries. Returns: List of TOCEntry with resolved char positions, or None if @@ -709,11 +752,16 @@ async def extract( ext = Path(file_path).suffix.lower() result: Optional[TocResult] = None + # Track whether the result came from pypdf (Layer 1) which + # already produces a properly nested tree with children. + is_pypdf = False if ext == ".pdf": result = await cls._extract_pdf_layered( file_path, content, llm_caller, ) + if result is not None: + is_pypdf = result.source == "pypdf" elif ext in (".md", ".markdown"): heading_result = HeadingTocExtractor.extract(content) if cls._is_sufficient(heading_result): @@ -728,7 +776,30 @@ async def extract( if result is None or not cls._is_sufficient(result): return None + # Merge total_pages from TocResult if not explicitly provided + if total_pages is None and result.page_count: + total_pages = result.page_count + entries = result.entries + + # Post-processing for non-pypdf layers: rebuild hierarchy from + # flat level-annotated entries (Layer 2/3/4 and format extractors + # produce flat lists; pypdf already builds a nested tree). + if not is_pypdf: + entries = cls._build_hierarchy(entries) + + # Estimate page_start for Layer 3/4 entries that lack it + if total_pages and content: + flat_all: List[TOCEntry] = [] + cls._flatten_entries(entries, flat_all) + content_len = len(content) + for entry in flat_all: + if entry.page_start is None and entry.char_start is not None: + entry.page_start = min( + total_pages, + max(1, round(entry.char_start / content_len * total_pages) + 1), + ) + total = cls._count_entries(entries) if total < cls._MIN_ENTRIES_THRESHOLD: return None diff --git a/src/sirchmunk/learnings/tree_indexer.py b/src/sirchmunk/learnings/tree_indexer.py index 26787eb..8d93a2c 100644 --- a/src/sirchmunk/learnings/tree_indexer.py +++ b/src/sirchmunk/learnings/tree_indexer.py @@ -32,6 +32,12 @@ # Summary snippet length extracted from section content (chars) _TOC_NODE_SUMMARY_MAX_CHARS = 300 +# Marker substring length for fuzzy fallback matching in _resolve_positions +_MARKER_SUBSTRING_LEN = 32 + +# Maximum span ratio: filter out overly large spans (>80% of document) +_MAX_SPAN_RATIO = 0.8 + # Adaptive preview window for LLM structure analysis _TREE_PREVIEW_MIN = 12_000 # Minimum preview window (chars) _TREE_PREVIEW_MAX = 50_000 # Maximum preview window (~12K tokens) @@ -201,7 +207,9 @@ async def build_tree( # TOC-accelerated path: skip recursive LLM analysis if toc_entries: - root = await self._build_tree_from_toc(toc_entries, content) + root = await self._build_tree_from_toc( + toc_entries, content, total_pages=total_pages, + ) if root is not None: tree = DocumentTree( file_path=file_path, @@ -300,6 +308,8 @@ async def _build_tree_from_toc( self, toc_entries: List[Any], content: str, + *, + total_pages: Optional[int] = None, ) -> Optional[TreeNode]: """Build tree directly from extracted TOC entries, avoiding recursive LLM. @@ -309,25 +319,29 @@ async def _build_tree_from_toc( Args: toc_entries: List of TOCEntry from toc_extractor. content: Full extracted text of the document. + total_pages: Total page count for page_range calculation. Returns: Root TreeNode, or None if no children could be created. """ seen_ids: set = set() children = self._toc_entries_to_nodes( - toc_entries, content, len(content), seen_ids, fallback_level=1, + toc_entries, content, len(content), seen_ids, + fallback_level=1, total_pages=total_pages, ) if not children: return None root_summary = await self._synthesize_root_summary(children) + root_page_range = (1, total_pages) if total_pages and total_pages > 0 else None return TreeNode( node_id=self._unique_node_id(0, seen_ids), title="Document", summary=root_summary, char_range=(0, len(content)), level=0, + page_range=root_page_range, children=children, ) @@ -338,15 +352,25 @@ def _toc_entries_to_nodes( parent_end: int, seen_ids: set, fallback_level: int, + total_pages: Optional[int] = None, ) -> List["TreeNode"]: """Recursively convert TOCEntry trees into TreeNode trees. Handles arbitrary nesting depth and guards against invalid - char_start / char_end values. + char_start / char_end values. Computes ``page_range`` using a + look-ahead algorithm when ``page_start`` is available on entries. + + Args: + entries: List of TOCEntry objects (may have children). + content: Full extracted text. + parent_end: End offset inherited from the parent node. + seen_ids: Set for unique node-id generation. + fallback_level: Default level when entry.level is 0. + total_pages: Total page count for page_range look-ahead. """ nodes: List[TreeNode] = [] content_len = len(content) - for entry in entries: + for i, entry in enumerate(entries): start = max(0, min(entry.char_start, content_len)) end = entry.char_end if entry.char_end and entry.char_end > start else parent_end end = min(end, content_len) @@ -355,11 +379,23 @@ def _toc_entries_to_nodes( nid = DocumentTreeIndexer._unique_node_id(start, seen_ids) level = entry.level if entry.level > 0 else fallback_level + # page_range: look-ahead algorithm + page_range = None + if hasattr(entry, 'page_start') and entry.page_start is not None: + # Find next sibling with page_start to determine page_end + page_end = total_pages or entry.page_start + for j in range(i + 1, len(entries)): + if hasattr(entries[j], 'page_start') and entries[j].page_start is not None: + page_end = entries[j].page_start + break + page_range = (entry.page_start, max(entry.page_start, page_end)) + child_nodes: List[TreeNode] = [] if entry.children: child_nodes = DocumentTreeIndexer._toc_entries_to_nodes( entry.children, content, end, seen_ids, fallback_level=level + 1, + total_pages=total_pages, ) node = TreeNode( @@ -368,6 +404,7 @@ def _toc_entries_to_nodes( summary=section_text.strip(), char_range=(start, end), level=level, + page_range=page_range, children=child_nodes, ) nodes.append(node) @@ -495,40 +532,65 @@ def _parse_sections( def _resolve_positions( items: List[Dict[str, Any]], full_text: str, ) -> List[Dict[str, Any]]: - """Resolve section start/end character offsets from marker text.""" + """Resolve section start/end character offsets from marker text. + + Two-pass algorithm: + Pass 1 — determine all start positions with tiered fallback: + exact match from prev_end -> substring match -> full-text fallback. + Pass 2 — set end[i] = start[i+1]; last end = text_len. + + Filters out invalid spans and overly large spans (> ``_MAX_SPAN_RATIO`` + of the document) to prevent accumulated positioning errors. + """ + text_lower = full_text.lower() + text_len = len(full_text) resolved: List[Dict[str, Any]] = [] + + # Pass 1: determine all start positions prev_end = 0 - text_lower = full_text.lower() for item in items: title = item.get("title", "") - summary = item.get("summary", "") marker = item.get("start_marker", title) - pos = text_lower.find(marker.lower(), prev_end) if marker else -1 - start = pos if pos >= 0 else prev_end - - end_marker = item.get("end_marker", "") - if end_marker: - epos = text_lower.find(end_marker.lower(), start + 1) - end = epos if epos > start else min(start + 50000, len(full_text)) - else: - end = min(start + 50000, len(full_text)) + pos = -1 + if marker: + marker_lower = marker.lower() + # Level 1: exact match from prev_end + pos = text_lower.find(marker_lower, prev_end) + # Level 2: substring match (first N chars) from prev_end + if pos < 0 and len(marker_lower) > _MARKER_SUBSTRING_LEN: + pos = text_lower.find( + marker_lower[:_MARKER_SUBSTRING_LEN], prev_end, + ) + # Level 3: full text fallback from start + if pos < 0: + pos = text_lower.find(marker_lower, 0) + start = pos if pos >= 0 else prev_end resolved.append({ "title": title, - "summary": summary, + "summary": item.get("summary", ""), "start": start, - "end": end, + "end": text_len, # placeholder }) - prev_end = end + prev_end = ( + start + max(1, len(marker)) + if pos >= 0 + else prev_end + ) - # Fix gaps: each section ends where the next begins + # Pass 2: set end[i] = start[i+1], last end = text_len for i in range(len(resolved) - 1): resolved[i]["end"] = resolved[i + 1]["start"] if resolved: - resolved[-1]["end"] = len(full_text) + resolved[-1]["end"] = text_len - return [s for s in resolved if s["end"] > s["start"]] + # Filter out invalid spans and overly large spans + return [ + s for s in resolved + if s["end"] > s["start"] + and (s["end"] - s["start"]) / max(text_len, 1) < _MAX_SPAN_RATIO + ] async def _select_children( self, nodes: List[TreeNode], query: str, @@ -538,7 +600,7 @@ async def _select_children( return nodes listing = "\n".join( - f"[{i}] {n.title}: {n.summary[:150]}" + f"[{i}] {n.title}{self._format_page_range(n.page_range)}: {n.summary[:150]}" for i, n in enumerate(nodes) ) prompt = ( @@ -604,6 +666,16 @@ def _max_node_depth(node: TreeNode) -> int: return node.level return max(DocumentTreeIndexer._max_node_depth(c) for c in node.children) + @staticmethod + def _format_page_range( + page_range: "Optional[Tuple[int, int]]", + ) -> str: + """Format a page_range tuple into a human-readable string for prompts.""" + if not page_range: + return "" + ps, pe = page_range + return f" [pages {ps}-{pe}]" if ps != pe else f" [page {ps}]" + @staticmethod def should_build_tree(file_path: str, content_length: int) -> bool: """Determine whether a file is eligible for tree indexing.""" diff --git a/src/sirchmunk/search.py b/src/sirchmunk/search.py index c2f30f4..74b2b85 100644 --- a/src/sirchmunk/search.py +++ b/src/sirchmunk/search.py @@ -86,6 +86,10 @@ # Soft-similarity threshold for gradient cluster reuse (P2) _SOFT_SIM_THRESHOLD = 0.65 +# Pure tree search mode for ablation experiments. +# When enabled, search relies solely on tree index navigation, skipping rga keyword search. +_PURE_TREE_SEARCH: bool = os.getenv("SIRCHMUNK_PURE_TREE_SEARCH", "false").lower() == "true" + # Common English stop-words filtered out during keyword coverage computation. _STOP_WORDS: frozenset = frozenset({ "the", "is", "a", "an", "of", "in", "for", "to", "and", "or", @@ -1631,36 +1635,44 @@ async def _search_deep( # ============================================================== # Phase 2: Parallel retrieval — keyword search + dir_scan rank # ============================================================== - await self._logger.info("[Phase 2] Parallel retrieval: rga keyword search + dir_scan LLM rank") - context.increment_loop() + keyword_files: List[str] = [] + dir_scan_files: List[str] = [] - phase2_tasks = [] + if _PURE_TREE_SEARCH: + # Pure tree search mode: skip rga and dir_scan, rely solely on tree hits + await self._logger.info("[Phase 2:PureTree] Skipping rga keyword search and dir_scan") + context.increment_loop() + else: + await self._logger.info("[Phase 2] Parallel retrieval: rga keyword search + dir_scan LLM rank") + context.increment_loop() - if initial_keywords: - phase2_tasks.append( - self._retrieve_by_keywords( - initial_keywords, paths, - max_depth=max_depth, include=include, exclude=exclude, + phase2_tasks = [] + + if initial_keywords: + phase2_tasks.append( + self._retrieve_by_keywords( + initial_keywords, paths, + max_depth=max_depth, include=include, exclude=exclude, + ) ) - ) - else: - phase2_tasks.append(self._async_noop([])) + else: + phase2_tasks.append(self._async_noop([])) - if scan_result is not None and enable_dir_scan: - phase2_tasks.append( - self._rank_dir_scan_candidates(query, scan_result) - ) - else: - phase2_tasks.append(self._async_noop([])) + if scan_result is not None and enable_dir_scan: + phase2_tasks.append( + self._rank_dir_scan_candidates(query, scan_result) + ) + else: + phase2_tasks.append(self._async_noop([])) - phase2_results = await asyncio.gather(*phase2_tasks, return_exceptions=True) + phase2_results = await asyncio.gather(*phase2_tasks, return_exceptions=True) - keyword_files = phase2_results[0] if not isinstance(phase2_results[0], Exception) else [] - dir_scan_files = phase2_results[1] if not isinstance(phase2_results[1], Exception) else [] + keyword_files = phase2_results[0] if not isinstance(phase2_results[0], Exception) else [] + dir_scan_files = phase2_results[1] if not isinstance(phase2_results[1], Exception) else [] - for i, label in enumerate(["keyword_search", "dir_scan_rank"]): - if isinstance(phase2_results[i], Exception): - await self._logger.warning(f"[Phase 2] {label} failed: {phase2_results[i]}") + for i, label in enumerate(["keyword_search", "dir_scan_rank"]): + if isinstance(phase2_results[i], Exception): + await self._logger.warning(f"[Phase 2] {label} failed: {phase2_results[i]}") await self._logger.info( f"[Phase 2] Results: keyword_files={len(keyword_files)}, " @@ -1698,12 +1710,30 @@ async def _search_deep( extra_knowledge_files = knowledge_probe.file_paths if soft_hit: extra_knowledge_files = soft_hit.file_paths + extra_knowledge_files - merged_files = self._merge_file_paths( - keyword_files=list(tree_hits) + catalog_deep_hits + compile_hints.file_paths + summary_index_hits + keyword_files, - dir_scan_files=dir_scan_files, - knowledge_hits=extra_knowledge_files, - ) - await self._logger.info(f"[Phase 3] Merged {len(merged_files)} unique candidate files") + + if _PURE_TREE_SEARCH: + # Pure tree search: only use tree hits (+ soft-hit fallback if no tree hits) + pure_tree_files = list(tree_hits) + if not pure_tree_files and soft_hit: + pure_tree_files = soft_hit.file_paths + await self._logger.info( + f"[Phase 3:PureTree] No tree hits, using {len(pure_tree_files)} soft-hit files" + ) + merged_files = self._merge_file_paths( + keyword_files=pure_tree_files, + dir_scan_files=[], + knowledge_hits=[], + ) + await self._logger.info( + f"[Phase 3:PureTree] Merged {len(merged_files)} tree-only candidate files" + ) + else: + merged_files = self._merge_file_paths( + keyword_files=list(tree_hits) + catalog_deep_hits + compile_hints.file_paths + summary_index_hits + keyword_files, + dir_scan_files=dir_scan_files, + knowledge_hits=extra_knowledge_files, + ) + await self._logger.info(f"[Phase 3] Merged {len(merged_files)} unique candidate files") cluster: Optional[KnowledgeCluster] = None if merged_files: @@ -2181,6 +2211,8 @@ async def _llm_chunked_summarize(self, combined: str, query: str) -> str: """Maximum files returned by tree index probing in DEEP mode.""" _TREE_ROOT_HINT_TRUNCATE = 150 """Max chars of tree root summary in Step 1 structure hints.""" + _CHAR_RANGE_MAX_SPAN_RATIO: float = 0.8 + """char_range spanning more than this ratio of the document is treated as invalid.""" # --- Self-correction expanded sampling --- _SELF_CORRECT_EXPANDED_NAV_RESULTS: int = 6 @@ -2484,58 +2516,88 @@ async def _search_fast( evidence = "" file_path: Optional[str] = None # set when best_files found - # High-confidence catalog routing: skip rga, use catalog directly - if catalog_routed_files and catalog_confidence == "high": - used_level = "catalog_route" - await self._logger.info( - f"[FAST:Step2] High-confidence catalog routing → " - f"{[Path(p).name for p in catalog_routed_files[:top_k_files]]}" - ) - best_files = [ - {"path": p, "matches": [], "total_matches": 0, "weighted_score": 0.0} - for p in catalog_routed_files[:top_k_files] - ] - - if not best_files and primary: - best_files = await self._fast_find_best_file( - primary, top_k=top_k_files, keyword_idfs=keyword_idfs, - query=query, artifacts=artifacts, - **rga_kwargs, - ) + # --- Pure tree search mode: skip rga, use tree probe results directly --- + if _PURE_TREE_SEARCH: + if _tree_probed_files: + used_level = "pure_tree" + best_files = [ + {"path": p, "matches": [], "total_matches": 0, "weighted_score": 0.0} + for p in _tree_probed_files[:top_k_files] + ] + await self._logger.info( + f"[FAST:PureTree] Using {len(best_files)} tree-probed files: " + f"{[Path(p).name for p in _tree_probed_files[:top_k_files]]}" + ) + elif compile_hint_files: + # Tree probe returned nothing but compile hints have tree files + used_level = "pure_tree_hint" + best_files = [ + {"path": p, "matches": [], "total_matches": 0, "weighted_score": 0.0} + for p in compile_hint_files[:top_k_files] + ] + await self._logger.info( + f"[FAST:PureTree] No tree probes, falling back to " + f"{len(best_files)} compile-hint files" + ) + else: + await self._logger.warning( + "[FAST:PureTree] No tree probes available, returning empty" + ) + return _NO_RESULTS_MESSAGE, None, context + else: + # --- Original rga-based retrieval logic --- + # High-confidence catalog routing: skip rga, use catalog directly + if catalog_routed_files and catalog_confidence == "high": + used_level = "catalog_route" + await self._logger.info( + f"[FAST:Step2] High-confidence catalog routing → " + f"{[Path(p).name for p in catalog_routed_files[:top_k_files]]}" + ) + best_files = [ + {"path": p, "matches": [], "total_matches": 0, "weighted_score": 0.0} + for p in catalog_routed_files[:top_k_files] + ] - if not best_files and fallback: - used_level = "fallback" - await self._logger.info( - "[FAST:Step2] Primary miss, trying fine-grained fallback" - ) - best_files = await self._fast_find_best_file( - fallback, top_k=top_k_files, keyword_idfs=keyword_idfs, - query=query, artifacts=artifacts, - **rga_kwargs, - ) + if not best_files and primary: + best_files = await self._fast_find_best_file( + primary, top_k=top_k_files, keyword_idfs=keyword_idfs, + query=query, artifacts=artifacts, + **rga_kwargs, + ) - # --- Fallback: compile-hint files when rga misses (catalog + P2 + P4) --- - if not best_files and compile_hint_files: - used_level = "compile_hint" - await self._logger.info( - f"[FAST:Step2] rga miss — using {len(compile_hint_files)} compile-hint files" - ) - best_files = [ - {"path": p, "matches": [], "total_matches": 0, "weighted_score": 0.0} - for p in compile_hint_files[:top_k_files] - ] + if not best_files and fallback: + used_level = "fallback" + await self._logger.info( + "[FAST:Step2] Primary miss, trying fine-grained fallback" + ) + best_files = await self._fast_find_best_file( + fallback, top_k=top_k_files, keyword_idfs=keyword_idfs, + query=query, artifacts=artifacts, + **rga_kwargs, + ) - # --- Fallback: use dir_scan only when rga misses and dir scan is enabled --- - if not best_files and enable_dir_scan: - scan_result = await self._probe_dir_scan(paths, enable=True, max_files=300) - if scan_result is not None: - await self._logger.info("[FAST:Step2] rga miss — falling back to dir_scan ranking") - ranked_paths = await self._rank_dir_scan_candidates( - query, scan_result, top_k=10, include_medium=True, + # --- Fallback: compile-hint files when rga misses (catalog + P2 + P4) --- + if not best_files and compile_hint_files: + used_level = "compile_hint" + await self._logger.info( + f"[FAST:Step2] rga miss — using {len(compile_hint_files)} compile-hint files" ) - if ranked_paths: - used_level = "dir_scan" - best_files = [{"path": p, "matches": [], "total_matches": 0, "weighted_score": 0.0} for p in ranked_paths[:top_k_files]] + best_files = [ + {"path": p, "matches": [], "total_matches": 0, "weighted_score": 0.0} + for p in compile_hint_files[:top_k_files] + ] + + # --- Fallback: use dir_scan only when rga misses and dir scan is enabled --- + if not best_files and enable_dir_scan: + scan_result = await self._probe_dir_scan(paths, enable=True, max_files=300) + if scan_result is not None: + await self._logger.info("[FAST:Step2] rga miss — falling back to dir_scan ranking") + ranked_paths = await self._rank_dir_scan_candidates( + query, scan_result, top_k=10, include_medium=True, + ) + if ranked_paths: + used_level = "dir_scan" + best_files = [{"path": p, "matches": [], "total_matches": 0, "weighted_score": 0.0} for p in ranked_paths[:top_k_files]] if not best_files: if llm_fallback: @@ -3745,14 +3807,24 @@ async def _tree_guided_sample( total_chars = 0 for leaf in leaves[: self._TREE_SAMPLE_MAX_SECTIONS]: start, end = leaf.char_range - if full_text and end > start: + if self._is_valid_char_range(start, end, len(full_text)) and full_text: segment = full_text[start:end] + elif leaf.summary: + logger.debug( + f"[TreeNav] char_range degraded for '{leaf.title}' " + f"(span_ratio={(end - start) / max(len(full_text), 1):.2f}), using summary" + ) + segment = leaf.summary else: - segment = leaf.summary or "" + continue segment = segment[: self._TREE_SAMPLE_SECTION_MAX_CHARS] if not segment.strip(): continue - header = f"[{fname} \u2192 {leaf.title}]" + page_info = "" + if leaf.page_range: + ps, pe = leaf.page_range + page_info = f" (pp.{ps}-{pe})" if ps != pe else f" (p.{ps})" + header = f"[{fname} → {leaf.title}{page_info}]" chunk = f"{header}\n{segment}" if total_chars + len(chunk) > max_chars: remaining = max_chars - total_chars @@ -3798,6 +3870,20 @@ async def _tree_guided_sample( ) return evidence + def _is_valid_char_range( + self, start: int, end: int, text_len: int, + ) -> bool: + """Check whether a char_range is valid for slicing. + + A range is invalid when it covers more than + ``_CHAR_RANGE_MAX_SPAN_RATIO`` of the document (likely a + whole-document fallback) or when *end <= start*. + """ + if start < 0 or end <= start or text_len <= 0: + return False + span_ratio = (end - start) / text_len + return span_ratio < self._CHAR_RANGE_MAX_SPAN_RATIO + async def _navigate_tree_for_evidence( self, file_path: str, query: str, *, max_results: int = 3, ) -> Optional[str]: @@ -3834,12 +3920,22 @@ async def _navigate_tree_for_evidence( for leaf in leaves: start, end = leaf.char_range - if full_text and end > start: + if self._is_valid_char_range(start, end, len(full_text)) and full_text: segment = full_text[start:end] + elif leaf.summary: + logger.debug( + f"[TreeNav] char_range degraded for '{leaf.title}' " + f"(span_ratio={(end - start) / max(len(full_text), 1):.2f}), using summary" + ) + segment = leaf.summary else: - segment = leaf.summary or "" + continue if segment.strip(): - header = f"[{fname} → {leaf.title}]" + page_info = "" + if leaf.page_range: + ps, pe = leaf.page_range + page_info = f" (pp.{ps}-{pe})" if ps != pe else f" (p.{ps})" + header = f"[{fname} → {leaf.title}{page_info}]" parts.append(f"{header}\n{segment[:3000]}") if not parts: From d4e8fe3a83ec1504776be330643d13448e54d083 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E7=8F=AD=E6=89=AC?= Date: Sun, 19 Apr 2026 22:14:17 +0800 Subject: [PATCH 31/56] fix table extraction --- src/sirchmunk/learnings/compiler.py | 104 ++++++++++++++++++ src/sirchmunk/learnings/tree_indexer.py | 7 +- src/sirchmunk/search.py | 122 +++++++++++++++++++++- src/sirchmunk/utils/document_extractor.py | 13 +++ 4 files changed, 243 insertions(+), 3 deletions(-) diff --git a/src/sirchmunk/learnings/compiler.py b/src/sirchmunk/learnings/compiler.py index ad2a115..037070b 100644 --- a/src/sirchmunk/learnings/compiler.py +++ b/src/sirchmunk/learnings/compiler.py @@ -76,6 +76,8 @@ class FileManifestEntry: has_explicit_toc: bool = False # Whether a native TOC was extracted from the file tree_node_count: int = 0 # Number of nodes in the tree index (quality metric) has_xlsx_digest: bool = False # Whether a pre-compiled Excel evidence digest exists + has_table_digest: bool = False # Whether PDF tables were extracted and stored + table_count: int = 0 # Number of tables in this file def to_dict(self) -> Dict[str, Any]: return { @@ -88,6 +90,8 @@ def to_dict(self) -> Dict[str, Any]: "has_explicit_toc": self.has_explicit_toc, "tree_node_count": self.tree_node_count, "has_xlsx_digest": self.has_xlsx_digest, + "has_table_digest": self.has_table_digest, + "table_count": self.table_count, } @classmethod @@ -102,6 +106,8 @@ def from_dict(cls, data: Dict[str, Any]) -> "FileManifestEntry": has_explicit_toc=data.get("has_explicit_toc", False), tree_node_count=data.get("tree_node_count", 0), has_xlsx_digest=data.get("has_xlsx_digest", False), + has_table_digest=data.get("has_table_digest", False), + table_count=data.get("table_count", 0), ) @@ -167,6 +173,8 @@ class FileCompileResult: has_explicit_toc: bool = False # Whether TOC was extracted from native structure tree_node_count: int = 0 # Number of nodes in the tree index has_xlsx_digest: bool = False # Whether a pre-compiled Excel evidence digest exists + has_table_digest: bool = False # Whether a pre-compiled table digest exists + table_count: int = 0 # Number of tables extracted @dataclass @@ -402,6 +410,8 @@ async def _bounded(entry: FileEntry) -> FileCompileResult: has_explicit_toc=result.has_explicit_toc, tree_node_count=result.tree_node_count, has_xlsx_digest=result.has_xlsx_digest, + has_table_digest=result.has_table_digest, + table_count=result.table_count, ) # Phase 3: aggregate results into knowledge network @@ -609,6 +619,29 @@ async def _compile_single_file( except Exception: pass + # Persist table digest for documents with extracted tables + if extraction.tables: + try: + table_digest = self._build_table_digest(extraction.tables) + if table_digest: + digest_dir = self._compile_dir / "table_digests" + digest_dir.mkdir(parents=True, exist_ok=True) + file_hash = get_fast_hash(entry.path) or "" + if file_hash: + digest_path = digest_dir / f"{file_hash}.json" + digest_path.write_text( + json.dumps(table_digest, ensure_ascii=False), + encoding="utf-8", + ) + result.has_table_digest = True + result.table_count = len(extraction.tables) + except Exception: + pass + + # Annotate tree nodes with table counts for navigation hints + if result.tree and result.tree.root and extraction.tables: + self._annotate_tree_with_table_counts(result.tree.root, extraction.tables) + except Exception as exc: result.error = str(exc) await self._log.warning(f"[Compile] Failed: {entry.path}: {exc}") @@ -1130,6 +1163,77 @@ def _add_edge( WeakSemanticEdge(target_cluster_id=target_id, weight=weight, source=source) ) + def _build_table_digest( + self, tables: List[Dict[str, Any]], + ) -> Optional[Dict[str, Any]]: + """Build a structured table digest from extraction output. + + Returns a versioned JSON-serializable dict containing all tables + with their page numbers, markdown representation, and cell data. + Tables are indexed for page-range-based retrieval at search time. + """ + if not tables: + return None + + digest_tables = [] + for idx, table in enumerate(tables): + markdown = table.get("markdown", "") + cells = table.get("cells", []) + if not markdown and not cells: + continue + + # Compute row/col counts from cells (kreuzberg returns List[List[str]]) + row_count = 0 + col_count = 0 + if cells: + row_count = len(cells) + col_count = max((len(row) for row in cells if isinstance(row, (list, tuple))), default=0) + elif markdown: + # Estimate from markdown lines + lines = [l for l in markdown.strip().split("\n") if l.strip().startswith("|")] + row_count = max(0, len(lines) - 1) # exclude separator + col_count = lines[0].count("|") - 1 if lines else 0 + + digest_tables.append({ + "index": idx, + "page_number": table.get("page_number"), + "markdown": markdown, + "row_count": row_count, + "col_count": col_count, + "cells": cells, + }) + + if not digest_tables: + return None + + return { + "version": 1, + "table_count": len(digest_tables), + "tables": digest_tables, + } + + def _annotate_tree_with_table_counts( + self, + node: "TreeNode", + tables: List[Dict[str, Any]], + ) -> None: + """Annotate tree nodes with table count based on page_range overlap. + + For each node with a valid page_range, counts how many extracted + tables fall within that range and sets node.table_count accordingly. + """ + if node is None: + return + if node.page_range: + ps, pe = node.page_range + count = sum( + 1 for t in tables + if t.get("page_number") is not None and ps <= t["page_number"] <= pe + ) + node.table_count = count + for child in node.children: + self._annotate_tree_with_table_counts(child, tables) + @staticmethod def _count_tree_nodes(tree: Optional[DocumentTree]) -> int: """Count total nodes in a DocumentTree (recursive). diff --git a/src/sirchmunk/learnings/tree_indexer.py b/src/sirchmunk/learnings/tree_indexer.py index 8d93a2c..6895745 100644 --- a/src/sirchmunk/learnings/tree_indexer.py +++ b/src/sirchmunk/learnings/tree_indexer.py @@ -65,6 +65,7 @@ class TreeNode: level: int = 0 page_range: Optional[Tuple[int, int]] = None children: List["TreeNode"] = field(default_factory=list) + table_count: int = 0 # Number of tables associated with this node's page range def to_dict(self) -> Dict[str, Any]: return { @@ -75,6 +76,7 @@ def to_dict(self) -> Dict[str, Any]: "level": self.level, "page_range": list(self.page_range) if self.page_range else None, "children": [c.to_dict() for c in self.children], + "table_count": self.table_count, } @classmethod @@ -89,6 +91,7 @@ def from_dict(cls, data: Dict[str, Any]) -> "TreeNode": level=data.get("level", 0), page_range=tuple(pr) if pr else None, children=children, + table_count=data.get("table_count", 0), ) @property @@ -600,7 +603,9 @@ async def _select_children( return nodes listing = "\n".join( - f"[{i}] {n.title}{self._format_page_range(n.page_range)}: {n.summary[:150]}" + f"[{i}] {n.title}{self._format_page_range(n.page_range)}" + f"{' [' + str(n.table_count) + ' tables]' if n.table_count > 0 else ''}" + f": {n.summary[:150]}" for i, n in enumerate(nodes) ) prompt = ( diff --git a/src/sirchmunk/search.py b/src/sirchmunk/search.py index 74b2b85..9374e27 100644 --- a/src/sirchmunk/search.py +++ b/src/sirchmunk/search.py @@ -2671,6 +2671,18 @@ async def _rga_evidence() -> str: except Exception: pass + # 0.5 Table digest priority (pre-compiled PDF table evidence) + if ev is None and artifacts and artifacts.manifest_map: + _me = artifacts.manifest_map.get(fp) + if _me and getattr(_me, 'has_table_digest', False): + _all_tables = self._load_table_digest( + self.work_path, _me.file_hash, + ) + if _all_tables: + _table_ev = self._format_table_evidence(_all_tables) + if _table_ev: + ev = f"[{fn} - Table Evidence]\n{_table_ev}" + # 1. Tree-guided sampling FIRST for tree-indexed files if ( artifacts @@ -3810,7 +3822,7 @@ async def _tree_guided_sample( if self._is_valid_char_range(start, end, len(full_text)) and full_text: segment = full_text[start:end] elif leaf.summary: - logger.debug( + _loguru_logger.debug( f"[TreeNav] char_range degraded for '{leaf.title}' " f"(span_ratio={(end - start) / max(len(full_text), 1):.2f}), using summary" ) @@ -3884,6 +3896,81 @@ def _is_valid_char_range( span_ratio = (end - start) / text_len return span_ratio < self._CHAR_RANGE_MAX_SPAN_RATIO + @staticmethod + def _load_table_digest( + work_path: Path, file_hash: str, + ) -> Optional[List[Dict[str, Any]]]: + """Load pre-compiled table digest for a file. + + Returns the list of table entries from the digest JSON, or None + if no digest exists or loading fails. + """ + digest_path = ( + work_path / ".cache" / "compile" / "table_digests" / f"{file_hash}.json" + ) + if not digest_path.exists(): + return None + try: + data = json.loads(digest_path.read_text(encoding="utf-8")) + return data.get("tables", []) + except Exception: + return None + + @staticmethod + def _filter_tables_by_page_range( + tables: List[Dict[str, Any]], + page_start: int, + page_end: int, + ) -> List[Dict[str, Any]]: + """Filter tables whose page_number falls within the given range (inclusive).""" + return [ + t for t in tables + if t.get("page_number") is not None + and page_start <= t["page_number"] <= page_end + ] + + @staticmethod + def _format_table_evidence( + tables: List[Dict[str, Any]], + max_chars: int = 3000, + ) -> str: + """Format table digest entries as LLM-friendly evidence text. + + Strategy: + - Small tables (<1000 chars): preserve full Markdown + - Large tables: truncate to max_chars with "(truncated)" note + - Each table prefixed with "[Table from page N]" + + Returns concatenated formatted table evidence string. + """ + if not tables: + return "" + + parts: List[str] = [] + remaining = max_chars + + for table in tables: + if remaining <= 0: + break + + page = table.get("page_number", "?") + markdown = table.get("markdown", "") + + if not markdown: + continue + + header = f"[Table from page {page}]" + + if len(markdown) <= remaining: + parts.append(f"{header}\n{markdown}") + remaining -= len(markdown) + len(header) + 2 + else: + truncated = markdown[:remaining] + parts.append(f"{header}\n{truncated}\n(truncated)") + remaining = 0 + + return "\n\n".join(parts) + async def _navigate_tree_for_evidence( self, file_path: str, query: str, *, max_results: int = 3, ) -> Optional[str]: @@ -3923,7 +4010,7 @@ async def _navigate_tree_for_evidence( if self._is_valid_char_range(start, end, len(full_text)) and full_text: segment = full_text[start:end] elif leaf.summary: - logger.debug( + _loguru_logger.debug( f"[TreeNav] char_range degraded for '{leaf.title}' " f"(span_ratio={(end - start) / max(len(full_text), 1):.2f}), using summary" ) @@ -3941,6 +4028,37 @@ async def _navigate_tree_for_evidence( if not parts: return None + # Supplement with table evidence if available + try: + from sirchmunk.utils.file_utils import get_fast_hash + _file_hash = get_fast_hash(file_path) + if _file_hash: + _all_tables = self._load_table_digest( + self.work_path, _file_hash, + ) + if _all_tables and leaves: + _seen_pages: set = set() + for leaf in leaves: + if leaf.page_range: + ps, pe = leaf.page_range + page_key = (ps, pe) + if page_key in _seen_pages: + continue + _seen_pages.add(page_key) + leaf_tables = self._filter_tables_by_page_range( + _all_tables, ps, pe, + ) + if leaf_tables: + table_text = self._format_table_evidence( + leaf_tables, max_chars=2000, + ) + if table_text: + parts.append( + f"[Tables pp.{ps}-{pe}]\n{table_text}" + ) + except Exception: + pass + evidence = "\n\n".join(parts) await self._logger.info( f"[FAST:TreeNav] Extracted {len(parts)} sections, " diff --git a/src/sirchmunk/utils/document_extractor.py b/src/sirchmunk/utils/document_extractor.py index 76e0f15..d72d397 100644 --- a/src/sirchmunk/utils/document_extractor.py +++ b/src/sirchmunk/utils/document_extractor.py @@ -303,6 +303,17 @@ def _build_config(profile: ExtractionProfile): from kreuzberg import LanguageDetectionConfig lang_config = LanguageDetectionConfig(enabled=True) + # --- Layout detection for table extraction --- + layout_config = None + if profile.extract_tables: + try: + from kreuzberg import LayoutDetectionConfig + layout_config = LayoutDetectionConfig() + except ImportError: + # kreuzberg <= 4.2.x extracts tables by default; + # filtering is handled in _convert_result(). + pass + # --- Assemble ExtractionConfig --- kwargs: dict[str, Any] = { "output_format": output_format, @@ -319,6 +330,8 @@ def _build_config(profile: ExtractionProfile): kwargs["language_detection"] = lang_config if profile.max_concurrent is not None: kwargs["max_concurrent_extractions"] = profile.max_concurrent + if layout_config is not None: + kwargs["layout"] = layout_config return ExtractionConfig(**kwargs) From 1d550bf9d73e5bc043f8a7ce91b8742c8859ef04 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E7=8F=AD=E6=89=AC?= Date: Mon, 20 Apr 2026 00:53:53 +0800 Subject: [PATCH 32/56] fix warning --- src/sirchmunk/learnings/compiler.py | 4 ++ src/sirchmunk/utils/document_extractor.py | 72 +++++++++++++++++++++-- 2 files changed, 71 insertions(+), 5 deletions(-) diff --git a/src/sirchmunk/learnings/compiler.py b/src/sirchmunk/learnings/compiler.py index 037070b..6f65e12 100644 --- a/src/sirchmunk/learnings/compiler.py +++ b/src/sirchmunk/learnings/compiler.py @@ -1194,6 +1194,10 @@ def _build_table_digest( row_count = max(0, len(lines) - 1) # exclude separator col_count = lines[0].count("|") - 1 if lines else 0 + # Skip pseudo-tables: single-column or insufficient structure + if col_count <= 1: + continue + digest_tables.append({ "index": idx, "page_number": table.get("page_number"), diff --git a/src/sirchmunk/utils/document_extractor.py b/src/sirchmunk/utils/document_extractor.py index d72d397..d114b7d 100644 --- a/src/sirchmunk/utils/document_extractor.py +++ b/src/sirchmunk/utils/document_extractor.py @@ -143,9 +143,16 @@ class DocumentExtractor: extract_tables=True, extract_metadata=True, pdf_extract_metadata=True, - force_ocr=True, + force_ocr=False, ) - """Rich extraction with tables, metadata, and OCR fallback.""" + """Rich extraction with tables, metadata, and layout-based table detection. + + ``force_ocr`` is disabled because: + - Most documents (e.g. 10-K, 10-Q PDFs) already contain a native text layer. + - kreuzberg automatically falls back to OCR for scanned / image-only pages. + - Forcing OCR triggers Tesseract ObjectCache leak warnings in concurrent use + and significantly slows down compilation with no quality benefit. + """ # Public API ----------------------------------------------------------- @@ -174,7 +181,21 @@ async def extract( try: result = await extract_file(file_path=file_path, config=config) - return DocumentExtractor._convert_result(result, profile) + output = DocumentExtractor._convert_result(result, profile) + # Fallback: kreuzberg 4.9.1 returns page_count=0 when force_ocr=True; + # use pypdf to get the real page count when missing. + if output.page_count is None: + fallback = DocumentExtractor._fallback_page_count(file_path) + if fallback is not None: + output = ExtractionOutput( + content=output.content, + mime_type=output.mime_type, + metadata=output.metadata, + tables=output.tables, + detected_languages=output.detected_languages, + page_count=fallback, + ) + return output except Exception as exc: logger.error( "Document extraction failed for {}: {}", @@ -232,15 +253,54 @@ async def batch_extract( try: results = await batch_extract_files(paths=list(file_paths), config=config) - return [ + outputs = [ DocumentExtractor._convert_result(r, profile) for r in results ] + # Apply page_count fallback for each output + fixed: List[ExtractionOutput] = [] + for output, fp in zip(outputs, file_paths): + if output.page_count is None: + fallback = DocumentExtractor._fallback_page_count(fp) + if fallback is not None: + output = ExtractionOutput( + content=output.content, + mime_type=output.mime_type, + metadata=output.metadata, + tables=output.tables, + detected_languages=output.detected_languages, + page_count=fallback, + ) + fixed.append(output) + return fixed except Exception: logger.error("Batch extraction failed for {} files", len(file_paths)) raise # Internal helpers ----------------------------------------------------- + @staticmethod + def _fallback_page_count( + file_path: Union[str, Path], + ) -> Optional[int]: + """Get page count via pypdf when kreuzberg fails to report it. + + kreuzberg >= 4.9.1 returns ``get_page_count() == 0`` when + ``force_ocr=True`` is set. This fallback uses pypdf (already a + transitive dependency) for a lightweight page-count-only read. + + Returns: + Page count, or None for non-PDF files or on error. + """ + if Path(file_path).suffix.lower() != ".pdf": + return None + try: + from pypdf import PdfReader + reader = PdfReader(str(file_path)) + count = len(reader.pages) + return count if count > 0 else None + except Exception: + return None + @staticmethod def _build_config(profile: ExtractionProfile): """Build a kreuzberg ``ExtractionConfig`` from an :class:`ExtractionProfile`. @@ -308,9 +368,11 @@ def _build_config(profile: ExtractionProfile): if profile.extract_tables: try: from kreuzberg import LayoutDetectionConfig + # kreuzberg >= 4.5.0: model-based table detection (RT-DETR v2) + # Default: table_model="tatr", apply_heuristics=True layout_config = LayoutDetectionConfig() except ImportError: - # kreuzberg <= 4.2.x extracts tables by default; + # kreuzberg < 4.5.0: tables extracted via heuristics only; # filtering is handled in _convert_result(). pass From 384d345d8c7f21e1ca435d3143e620f46b71ad79 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E7=8F=AD=E6=89=AC?= Date: Mon, 20 Apr 2026 02:02:52 +0800 Subject: [PATCH 33/56] enhance compiler --- src/sirchmunk/learnings/compiler.py | 175 ++++++++++++++- src/sirchmunk/learnings/tree_indexer.py | 282 ++++++++++++++++++++++-- src/sirchmunk/search.py | 50 +++-- 3 files changed, 458 insertions(+), 49 deletions(-) diff --git a/src/sirchmunk/learnings/compiler.py b/src/sirchmunk/learnings/compiler.py index 6f65e12..531cb28 100644 --- a/src/sirchmunk/learnings/compiler.py +++ b/src/sirchmunk/learnings/compiler.py @@ -638,9 +638,12 @@ async def _compile_single_file( except Exception: pass - # Annotate tree nodes with table counts for navigation hints + # Integrate tables into tree: annotate counts + create table child nodes if result.tree and result.tree.root and extraction.tables: - self._annotate_tree_with_table_counts(result.tree.root, extraction.tables) + self._integrate_tables_into_tree( + result.tree.root, extraction.tables, + content=content, total_pages=extraction.page_count, + ) except Exception as exc: result.error = str(exc) @@ -1216,27 +1219,175 @@ def _build_table_digest( "tables": digest_tables, } - def _annotate_tree_with_table_counts( + def _integrate_tables_into_tree( self, node: "TreeNode", tables: List[Dict[str, Any]], + content: str, + *, + total_pages: Optional[int] = None, + _counter: Optional[List[int]] = None, ) -> None: - """Annotate tree nodes with table count based on page_range overlap. + """Integrate tables into tree: annotate counts AND create table child nodes for leaf nodes. - For each node with a valid page_range, counts how many extracted - tables fall within that range and sets node.table_count accordingly. + For each node with a valid page_range, counts how many valid extracted + tables fall within that range (excluding pseudo-tables with col_count <= 1). + For leaf nodes with matching tables, creates dedicated TreeNode children + with ``content_type="table"``. """ + from sirchmunk.learnings.tree_indexer import TreeNode + if node is None: return + + if _counter is None: + _counter = [0] + + # Depth-first: process existing children first + for child in list(node.children): + self._integrate_tables_into_tree( + child, tables, content, + total_pages=total_pages, _counter=_counter, + ) + + # Match valid tables to this node's page_range + matched_tables: List[Dict[str, Any]] = [] if node.page_range: ps, pe = node.page_range - count = sum( - 1 for t in tables - if t.get("page_number") is not None and ps <= t["page_number"] <= pe + for t in tables: + pn = t.get("page_number") + if pn is None or not (ps <= pn <= pe): + continue + # Skip pseudo-tables + if self._is_pseudo_table(t): + continue + matched_tables.append(t) + + node.table_count = len(matched_tables) + + # Create table child nodes only for leaf nodes with matched tables + if not node.children and matched_tables: + try: + self._spawn_table_children( + node, matched_tables, content, _counter, + ) + except Exception: + pass # Never break compile for table node creation + + @staticmethod + def _is_pseudo_table(table: Dict[str, Any]) -> bool: + """Return True if the table lacks meaningful structure (col_count <= 1).""" + markdown = table.get("markdown", "") + cells = table.get("cells", []) + if not markdown and not cells: + return True + col_count = 0 + if cells: + col_count = max( + (len(row) for row in cells if isinstance(row, (list, tuple))), + default=0, ) - node.table_count = count - for child in node.children: - self._annotate_tree_with_table_counts(child, tables) + elif markdown: + lines = [l for l in markdown.strip().split("\n") if l.strip().startswith("|")] + col_count = (lines[0].count("|") - 1) if lines else 0 + return col_count <= 1 + + def _spawn_table_children( + self, + node: "TreeNode", + matched_tables: List[Dict[str, Any]], + content: str, + counter: List[int], + ) -> None: + """Create TreeNode children for each matched table under a leaf node. + + Also inserts a text-content sibling preserving the original leaf content. + """ + from sirchmunk.learnings.tree_indexer import TreeNode + + child_level = node.level + 1 + + # Preserve original text content as first child + text_child_id = f"T{counter[0]:06d}" + counter[0] += 1 + node.children.append( + TreeNode( + node_id=text_child_id, + title=node.title, + summary=node.summary[:300] if node.summary else "", + char_range=node.char_range, + level=child_level, + page_range=node.page_range, + children=[], + table_count=0, + content_type="text", + ) + ) + + # Create one child per table + for table in matched_tables: + tid = f"T{counter[0]:06d}" + counter[0] += 1 + + markdown = table.get("markdown", "") + title = self._extract_table_title(table) + page_number = table.get("page_number") + + # Attempt to locate table markdown in content + char_range = node.char_range + if markdown and content: + pos = content.find(markdown[:120]) + if pos >= 0: + char_range = (pos, pos + len(markdown)) + + page_range = ( + (page_number, page_number) if page_number is not None + else node.page_range + ) + + node.children.append( + TreeNode( + node_id=tid, + title=title, + summary=markdown[:300] if markdown else "", + char_range=char_range, + level=child_level, + page_range=page_range, + children=[], + table_count=0, + content_type="table", + ) + ) + + @staticmethod + def _extract_table_title(table: Dict[str, Any]) -> str: + """Extract a concise title from table markdown header row. + + Parses the first meaningful line of the markdown table (skipping + separator rows like ``|---|---|``), strips ``|`` delimiters, and + returns the first 80 characters as the title. + """ + markdown = table.get("markdown", "") + if not markdown: + pn = table.get("page_number", "?") + return f"Table (p.{pn})" + + for line in markdown.strip().split("\n"): + stripped = line.strip() + if not stripped: + continue + # Skip separator rows (e.g. |---|---| or +---+---+) + content_chars = stripped.replace("|", "").replace("-", "").replace(":", "").replace("+", "").strip() + if not content_chars: + continue + # Extract cell contents + title = " | ".join( + seg.strip() for seg in stripped.split("|") if seg.strip() + ) + return title[:80] if title else f"Table (p.{table.get('page_number', '?')})" + + pn = table.get("page_number", "?") + return f"Table (p.{pn})" @staticmethod def _count_tree_nodes(tree: Optional[DocumentTree]) -> int: diff --git a/src/sirchmunk/learnings/tree_indexer.py b/src/sirchmunk/learnings/tree_indexer.py index 6895745..060c9b8 100644 --- a/src/sirchmunk/learnings/tree_indexer.py +++ b/src/sirchmunk/learnings/tree_indexer.py @@ -66,6 +66,7 @@ class TreeNode: page_range: Optional[Tuple[int, int]] = None children: List["TreeNode"] = field(default_factory=list) table_count: int = 0 # Number of tables associated with this node's page range + content_type: str = "text" # "text" | "table" def to_dict(self) -> Dict[str, Any]: return { @@ -77,6 +78,7 @@ def to_dict(self) -> Dict[str, Any]: "page_range": list(self.page_range) if self.page_range else None, "children": [c.to_dict() for c in self.children], "table_count": self.table_count, + "content_type": self.content_type, } @classmethod @@ -92,6 +94,7 @@ def from_dict(cls, data: Dict[str, Any]) -> "TreeNode": page_range=tuple(pr) if pr else None, children=children, table_count=data.get("table_count", 0), + content_type=data.get("content_type", "text"), ) @property @@ -214,6 +217,8 @@ async def build_tree( toc_entries, content, total_pages=total_pages, ) if root is not None: + await self._deepen_large_leaves(root, content, max_depth=effective_depth) + await self._enrich_node_summaries(root, content) tree = DocumentTree( file_path=file_path, file_hash=file_hash, @@ -233,6 +238,9 @@ async def build_tree( if root is None: return None + await self._deepen_large_leaves(root, content, max_depth=effective_depth) + await self._enrich_node_summaries(root, content) + tree = DocumentTree( file_path=file_path, file_hash=file_hash, @@ -254,36 +262,48 @@ async def navigate( query: str, *, max_results: int = 3, + max_nav_depth: int = 4, ) -> List[TreeNode]: """Reasoning-based tree navigation: LLM selects the most relevant branches. + Iteratively descends through the tree until leaf nodes are reached or + *max_nav_depth* selection rounds are exhausted. + Returns up to *max_results* leaf nodes with their char_range for precise evidence extraction. """ if tree.root is None: return [] - candidates = tree.root.children if tree.root.children else [tree.root] - if not candidates: + current = tree.root.children if tree.root.children else [tree.root] + if not current: return [tree.root] - selected = await self._select_children(candidates, query) - if not selected: - return [] - - result_leaves: List[TreeNode] = [] - for node in selected: - if node.leaf: - result_leaves.append(node) - else: - deeper = await self._select_children(node.children, query) - for d in (deeper or node.children[:1]): - result_leaves.extend(d.all_leaves()[:max_results]) + selected: List[TreeNode] = current + for _ in range(max_nav_depth): + selected = await self._select_children(current, query) + if not selected: + break + # All leaves — stop descending + if all(n.leaf for n in selected): + break + # Expand non-leaf children, keep leaves as-is + next_level: List[TreeNode] = [] + for n in selected: + if n.leaf: + next_level.append(n) + else: + next_level.extend(n.children) + if not next_level: + break + current = next_level + else: + selected = current # Deduplicate and cap - seen_ids = set() + seen_ids: set = set() unique: List[TreeNode] = [] - for n in result_leaves: + for n in (selected or current): if n.node_id not in seen_ids: seen_ids.add(n.node_id) unique.append(n) @@ -604,12 +624,17 @@ async def _select_children( listing = "\n".join( f"[{i}] {n.title}{self._format_page_range(n.page_range)}" + f" [{n.content_type.upper()}]" f"{' [' + str(n.table_count) + ' tables]' if n.table_count > 0 else ''}" - f": {n.summary[:150]}" + f": {n.summary[:200]}" for i, n in enumerate(nodes) ) prompt = ( f"Given the query: \"{query}\"\n\n" + "Guidelines:\n" + "- For numerical/financial data queries, prefer TABLE nodes and consolidated statements\n" + "- Prefer company-wide/consolidated data over segment-level unless query specifies a segment\n" + "- When multiple tables exist, select the one most directly answering the query\n\n" f"Select the 1-2 most relevant sections (by index number):\n{listing}\n\n" f"Return ONLY a JSON array of index numbers, e.g. [0, 2]" ) @@ -681,6 +706,229 @@ def _format_page_range( ps, pe = page_range return f" [pages {ps}-{pe}]" if ps != pe else f" [page {ps}]" + # ------------------------------------------------------------------ # + # Leaf deepening & summary enrichment # + # ------------------------------------------------------------------ # + + async def _deepen_large_leaves( + self, + node: TreeNode, + content: str, + *, + max_leaf_chars: int = 5000, + max_depth: int = 4, + _seen_ids: Optional[set] = None, + ) -> None: + """Recursively deepen leaf nodes whose char_range exceeds *max_leaf_chars* using LLM decomposition.""" + if _seen_ids is None: + _seen_ids = self._collect_node_ids(node) + + if not node.leaf: + for child in node.children: + await self._deepen_large_leaves( + child, content, + max_leaf_chars=max_leaf_chars, + max_depth=max_depth, + _seen_ids=_seen_ids, + ) + return + + start, end = node.char_range + span = end - start + if span <= max_leaf_chars or node.level >= max_depth: + return + + snippet = self._truncate_snippet(content[start:end]) + + prompt = ( + "Analyze this document section and identify 3-8 logical sub-sections.\n" + "For each sub-section, provide:\n" + '- "title": descriptive heading (concise)\n' + '- "start_text": the first 8-15 words that mark where this sub-section ' + "begins (must be exact text from the content)\n" + '- "content_type": "text" or "table"\n\n' + f'Section: "{node.title}"\n---\n{snippet}\n---\n\n' + 'Return ONLY a JSON array, e.g. ' + '[{"title": "...", "start_text": "...", "content_type": "text"}, ...]' + ) + + try: + resp = await self._llm.achat([{"role": "user", "content": prompt}]) + sub_sections = self._parse_json_array(resp.content) + if not sub_sections or len(sub_sections) < 2: + return + except Exception: + return + + sub_nodes = self._build_sub_nodes_from_llm( + sub_sections, node, content, _seen_ids, + ) + if not sub_nodes: + return + + node.children = sub_nodes + await self._log.info( + f"[TreeIndexer] Deepened '{node.title}' into {len(sub_nodes)} sub-nodes" + ) + + # Recurse into newly created children + for child in node.children: + await self._deepen_large_leaves( + child, content, + max_leaf_chars=max_leaf_chars, + max_depth=max_depth, + _seen_ids=_seen_ids, + ) + + def _build_sub_nodes_from_llm( + self, + sub_sections: List[Dict[str, Any]], + parent: TreeNode, + content: str, + seen_ids: set, + ) -> List[TreeNode]: + """Create child TreeNodes from LLM-decomposed sub-sections.""" + parent_start, parent_end = parent.char_range + parent_span = max(parent_end - parent_start, 1) + parent_ps, parent_pe = parent.page_range if parent.page_range else (0, 0) + page_span = parent_pe - parent_ps + child_level = parent.level + 1 + + # Resolve char_start for each sub-section + positions: List[int] = [] + search_from = parent_start + for sec in sub_sections: + start_text = sec.get("start_text", "") + pos = content.find(start_text, search_from) if start_text else -1 + if pos < 0 or pos >= parent_end: + pos = search_from + positions.append(pos) + search_from = pos + 1 + + nodes: List[TreeNode] = [] + for i, sec in enumerate(sub_sections): + char_start = positions[i] + char_end = positions[i + 1] if i + 1 < len(positions) else parent_end + + # Estimate page_range proportionally from parent + page_range = None + if parent.page_range and parent_span > 0: + p_start = parent_ps + (char_start - parent_start) / parent_span * page_span + p_end = parent_ps + (char_end - parent_start) / parent_span * page_span + page_range = (int(p_start), max(int(p_start), int(p_end))) + + content_type = sec.get("content_type", "text") + if content_type not in ("text", "table"): + content_type = "text" + + nodes.append(TreeNode( + node_id=self._unique_node_id(char_start, seen_ids), + title=sec.get("title", f"Sub-section {i + 1}"), + summary="", + char_range=(char_start, char_end), + level=child_level, + page_range=page_range, + content_type=content_type, + )) + return nodes + + async def _enrich_node_summaries( + self, + node: TreeNode, + content: str, + *, + max_summary_len: int = 200, + ) -> None: + """Post-order traversal to enrich empty summaries: leaf from content, non-leaf via LLM.""" + # Post-order: process children first + for child in node.children: + await self._enrich_node_summaries( + child, content, max_summary_len=max_summary_len, + ) + + if self._summary_needs_enrichment(node.summary): + if node.leaf: + node.summary = self._extract_leaf_summary( + content, node.char_range, max_summary_len, + ) + else: + node.summary = await self._generate_nonleaf_summary( + node, max_summary_len, + ) + + @staticmethod + def _summary_needs_enrichment(summary: str) -> bool: + """Check whether a summary is empty or too short to be useful.""" + return not summary or len(summary.strip()) < 10 + + @staticmethod + def _extract_leaf_summary( + content: str, + char_range: Tuple[int, int], + max_len: int, + ) -> str: + """Extract a concise summary for a leaf node from its content slice.""" + start, end = char_range + raw = content[start:end][:500] + # Clean to single line + return " ".join(raw.split())[:max_len] + + async def _generate_nonleaf_summary( + self, + node: TreeNode, + max_summary_len: int, + ) -> str: + """Generate a summary for a non-leaf node via LLM, with fallback.""" + children_listing = "\n".join( + f"- {c.title}: {c.summary[:100]}" for c in node.children + ) + prompt = ( + "Summarize this document section in 1-2 concise sentences.\n" + f'Section: "{node.title}"\n' + f"Sub-sections:\n{children_listing}\n\n" + "Return ONLY the summary text." + ) + try: + resp = await self._llm.achat([{"role": "user", "content": prompt}]) + return resp.content.strip()[:max_summary_len] + except Exception: + # Fallback: concatenate children titles + return ", ".join(c.title for c in node.children)[:max_summary_len] + + # ------------------------------------------------------------------ # + # Parsing / snippet helpers # + # ------------------------------------------------------------------ # + + @staticmethod + def _truncate_snippet( + text: str, + *, + head_chars: int = 3000, + tail_chars: int = 1000, + ) -> str: + """Truncate a long text snippet keeping head and tail with an ellipsis marker.""" + if len(text) <= head_chars + tail_chars: + return text + return text[:head_chars] + "\n...[truncated]...\n" + text[-tail_chars:] + + @staticmethod + def _parse_json_array(raw: str) -> List[Dict[str, Any]]: + """Extract and parse a JSON array from LLM output.""" + cleaned = re.sub(r"^```(?:json)?\s*", "", raw, flags=re.MULTILINE) + cleaned = re.sub(r"```\s*$", "", cleaned, flags=re.MULTILINE).strip() + m = re.search(r"\[.*\]", cleaned, re.DOTALL) + if m: + return json.loads(m.group()) + return [] + + @staticmethod + def _collect_node_ids(node: TreeNode) -> set: + """Collect all existing node_ids in the subtree.""" + ids = {node.node_id} + for c in node.children: + ids.update(DocumentTreeIndexer._collect_node_ids(c)) + return ids + @staticmethod def should_build_tree(file_path: str, content_length: int) -> bool: """Determine whether a file is eligible for tree indexing.""" diff --git a/src/sirchmunk/search.py b/src/sirchmunk/search.py index 9374e27..42424ab 100644 --- a/src/sirchmunk/search.py +++ b/src/sirchmunk/search.py @@ -3818,17 +3818,21 @@ async def _tree_guided_sample( parts: List[str] = [] total_chars = 0 for leaf in leaves[: self._TREE_SAMPLE_MAX_SECTIONS]: - start, end = leaf.char_range - if self._is_valid_char_range(start, end, len(full_text)) and full_text: - segment = full_text[start:end] - elif leaf.summary: - _loguru_logger.debug( - f"[TreeNav] char_range degraded for '{leaf.title}' " - f"(span_ratio={(end - start) / max(len(full_text), 1):.2f}), using summary" - ) + # Table nodes: prefer summary (contains table markdown) + if getattr(leaf, 'content_type', 'text') == 'table' and leaf.summary: segment = leaf.summary else: - continue + start, end = leaf.char_range + if self._is_valid_char_range(start, end, len(full_text)) and full_text: + segment = full_text[start:end] + elif leaf.summary: + _loguru_logger.debug( + f"[TreeNav] char_range degraded for '{leaf.title}' " + f"(span_ratio={(end - start) / max(len(full_text), 1):.2f}), using summary" + ) + segment = leaf.summary + else: + continue segment = segment[: self._TREE_SAMPLE_SECTION_MAX_CHARS] if not segment.strip(): continue @@ -3836,7 +3840,8 @@ async def _tree_guided_sample( if leaf.page_range: ps, pe = leaf.page_range page_info = f" (pp.{ps}-{pe})" if ps != pe else f" (p.{ps})" - header = f"[{fname} → {leaf.title}{page_info}]" + type_tag = " [TABLE]" if getattr(leaf, 'content_type', 'text') == 'table' else "" + header = f"[{fname} → {leaf.title}{page_info}{type_tag}]" chunk = f"{header}\n{segment}" if total_chars + len(chunk) > max_chars: remaining = max_chars - total_chars @@ -4006,23 +4011,28 @@ async def _navigate_tree_for_evidence( full_text = "" for leaf in leaves: - start, end = leaf.char_range - if self._is_valid_char_range(start, end, len(full_text)) and full_text: - segment = full_text[start:end] - elif leaf.summary: - _loguru_logger.debug( - f"[TreeNav] char_range degraded for '{leaf.title}' " - f"(span_ratio={(end - start) / max(len(full_text), 1):.2f}), using summary" - ) + # Table nodes: prefer summary (contains table markdown) + if getattr(leaf, 'content_type', 'text') == 'table' and leaf.summary: segment = leaf.summary else: - continue + start, end = leaf.char_range + if self._is_valid_char_range(start, end, len(full_text)) and full_text: + segment = full_text[start:end] + elif leaf.summary: + _loguru_logger.debug( + f"[TreeNav] char_range degraded for '{leaf.title}' " + f"(span_ratio={(end - start) / max(len(full_text), 1):.2f}), using summary" + ) + segment = leaf.summary + else: + continue if segment.strip(): page_info = "" if leaf.page_range: ps, pe = leaf.page_range page_info = f" (pp.{ps}-{pe})" if ps != pe else f" (p.{ps})" - header = f"[{fname} → {leaf.title}{page_info}]" + type_tag = " [TABLE]" if getattr(leaf, 'content_type', 'text') == 'table' else "" + header = f"[{fname} → {leaf.title}{page_info}{type_tag}]" parts.append(f"{header}\n{segment[:3000]}") if not parts: From 579f8d6c64514d6f0b21d29a1ac304299111c275 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E7=8F=AD=E6=89=AC?= Date: Mon, 20 Apr 2026 02:33:44 +0800 Subject: [PATCH 34/56] fix robust issue --- src/sirchmunk/learnings/compiler.py | 18 ++++++++++-------- src/sirchmunk/learnings/tree_indexer.py | 10 ++++++++-- 2 files changed, 18 insertions(+), 10 deletions(-) diff --git a/src/sirchmunk/learnings/compiler.py b/src/sirchmunk/learnings/compiler.py index 531cb28..df2ca12 100644 --- a/src/sirchmunk/learnings/compiler.py +++ b/src/sirchmunk/learnings/compiler.py @@ -1265,14 +1265,16 @@ def _integrate_tables_into_tree( node.table_count = len(matched_tables) - # Create table child nodes only for leaf nodes with matched tables - if not node.children and matched_tables: - try: - self._spawn_table_children( - node, matched_tables, content, _counter, - ) - except Exception: - pass # Never break compile for table node creation + # NOTE: _spawn_table_children disabled - converting leaf to non-leaf breaks + # search navigation which expects leaves for char_range extraction. + # TODO: Re-enable when search can properly handle mixed text+table children. + # if not node.children and matched_tables: + # try: + # self._spawn_table_children( + # node, matched_tables, content, _counter, + # ) + # except Exception: + # pass @staticmethod def _is_pseudo_table(table: Dict[str, Any]) -> bool: diff --git a/src/sirchmunk/learnings/tree_indexer.py b/src/sirchmunk/learnings/tree_indexer.py index 060c9b8..eb56f6e 100644 --- a/src/sirchmunk/learnings/tree_indexer.py +++ b/src/sirchmunk/learnings/tree_indexer.py @@ -217,7 +217,10 @@ async def build_tree( toc_entries, content, total_pages=total_pages, ) if root is not None: - await self._deepen_large_leaves(root, content, max_depth=effective_depth) + # NOTE: _deepen_large_leaves disabled - char_range anchoring via LLM start_text + # is unreliable, causing overlapping ranges and search failures. + # TODO: Re-enable when robust char_range calculation is implemented. + # await self._deepen_large_leaves(root, content, max_depth=effective_depth) await self._enrich_node_summaries(root, content) tree = DocumentTree( file_path=file_path, @@ -238,7 +241,10 @@ async def build_tree( if root is None: return None - await self._deepen_large_leaves(root, content, max_depth=effective_depth) + # NOTE: _deepen_large_leaves disabled - char_range anchoring via LLM start_text + # is unreliable, causing overlapping ranges and search failures. + # TODO: Re-enable when robust char_range calculation is implemented. + # await self._deepen_large_leaves(root, content, max_depth=effective_depth) await self._enrich_node_summaries(root, content) tree = DocumentTree( From 78c11170ba6ed96b361529baba1f89495905aa13 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E7=8F=AD=E6=89=AC?= Date: Mon, 20 Apr 2026 15:56:50 +0800 Subject: [PATCH 35/56] fix pure tree search env --- benchmarks/financebench/run_benchmark.py | 5 ++ src/sirchmunk/learnings/tree_indexer.py | 64 +++++++++--------------- src/sirchmunk/search.py | 2 + 3 files changed, 31 insertions(+), 40 deletions(-) diff --git a/benchmarks/financebench/run_benchmark.py b/benchmarks/financebench/run_benchmark.py index cf7b30a..65af87d 100644 --- a/benchmarks/financebench/run_benchmark.py +++ b/benchmarks/financebench/run_benchmark.py @@ -27,6 +27,8 @@ from pathlib import Path from typing import List +from dotenv import load_dotenv + from config import FinanceBenchConfig from data_loader import FinanceBenchLoader from evaluate import compute_metrics @@ -158,6 +160,9 @@ def main() -> None: ) args = parser.parse_args() + # Load .env into os.environ so SIRCHMUNK_* variables are visible globally + load_dotenv(args.env, override=True) + # 1. Load config cfg = FinanceBenchConfig.from_env(args.env) if args.limit is not None: diff --git a/src/sirchmunk/learnings/tree_indexer.py b/src/sirchmunk/learnings/tree_indexer.py index eb56f6e..e1a652f 100644 --- a/src/sirchmunk/learnings/tree_indexer.py +++ b/src/sirchmunk/learnings/tree_indexer.py @@ -221,7 +221,9 @@ async def build_tree( # is unreliable, causing overlapping ranges and search failures. # TODO: Re-enable when robust char_range calculation is implemented. # await self._deepen_large_leaves(root, content, max_depth=effective_depth) - await self._enrich_node_summaries(root, content) + # NOTE: _enrich_node_summaries disabled temporarily to isolate its impact. + # The summaries may inadvertently bias _select_children() navigation. + # await self._enrich_node_summaries(root, content) tree = DocumentTree( file_path=file_path, file_hash=file_hash, @@ -245,7 +247,9 @@ async def build_tree( # is unreliable, causing overlapping ranges and search failures. # TODO: Re-enable when robust char_range calculation is implemented. # await self._deepen_large_leaves(root, content, max_depth=effective_depth) - await self._enrich_node_summaries(root, content) + # NOTE: _enrich_node_summaries disabled temporarily to isolate its impact. + # The summaries may inadvertently bias _select_children() navigation. + # await self._enrich_node_summaries(root, content) tree = DocumentTree( file_path=file_path, @@ -268,48 +272,32 @@ async def navigate( query: str, *, max_results: int = 3, - max_nav_depth: int = 4, ) -> List[TreeNode]: - """Reasoning-based tree navigation: LLM selects the most relevant branches. - - Iteratively descends through the tree until leaf nodes are reached or - *max_nav_depth* selection rounds are exhausted. - - Returns up to *max_results* leaf nodes with their char_range for - precise evidence extraction. - """ + """LLM-driven branch selection using _select_children().""" if tree.root is None: return [] - current = tree.root.children if tree.root.children else [tree.root] - if not current: + candidates = tree.root.children if tree.root.children else [tree.root] + if not candidates: return [tree.root] - selected: List[TreeNode] = current - for _ in range(max_nav_depth): - selected = await self._select_children(current, query) - if not selected: - break - # All leaves — stop descending - if all(n.leaf for n in selected): - break - # Expand non-leaf children, keep leaves as-is - next_level: List[TreeNode] = [] - for n in selected: - if n.leaf: - next_level.append(n) - else: - next_level.extend(n.children) - if not next_level: - break - current = next_level - else: - selected = current + selected = await self._select_children(candidates, query) + if not selected: + return [] + + result_leaves: List[TreeNode] = [] + for node in selected: + if node.leaf: + result_leaves.append(node) + else: + deeper = await self._select_children(node.children, query) + for d in (deeper or node.children[:1]): + result_leaves.extend(d.all_leaves()[:max_results]) # Deduplicate and cap seen_ids: set = set() unique: List[TreeNode] = [] - for n in (selected or current): + for n in result_leaves: if n.node_id not in seen_ids: seen_ids.add(n.node_id) unique.append(n) @@ -630,17 +618,13 @@ async def _select_children( listing = "\n".join( f"[{i}] {n.title}{self._format_page_range(n.page_range)}" - f" [{n.content_type.upper()}]" f"{' [' + str(n.table_count) + ' tables]' if n.table_count > 0 else ''}" - f": {n.summary[:200]}" + f": {n.summary[:150]}" for i, n in enumerate(nodes) ) + prompt = ( f"Given the query: \"{query}\"\n\n" - "Guidelines:\n" - "- For numerical/financial data queries, prefer TABLE nodes and consolidated statements\n" - "- Prefer company-wide/consolidated data over segment-level unless query specifies a segment\n" - "- When multiple tables exist, select the one most directly answering the query\n\n" f"Select the 1-2 most relevant sections (by index number):\n{listing}\n\n" f"Return ONLY a JSON array of index numbers, e.g. [0, 2]" ) diff --git a/src/sirchmunk/search.py b/src/sirchmunk/search.py index 42424ab..5e497f2 100644 --- a/src/sirchmunk/search.py +++ b/src/sirchmunk/search.py @@ -1453,6 +1453,8 @@ async def search( return ctx return msg + await self._logger.info(f"[SearchConfig] PURE_TREE_SEARCH={'enabled' if _PURE_TREE_SEARCH else 'disabled'}") + # ---- Chat intent short-circuit (rule-based, no LLM cost) ---- if mode != "FILENAME_ONLY" and self._is_chat_query(query): answer, cluster, ctx = await self._respond_chat(query, chat_history=chat_history) From 86d528ed208290501d730d3059c6a0ea4df94768 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E7=8F=AD=E6=89=AC?= Date: Mon, 20 Apr 2026 20:32:02 +0800 Subject: [PATCH 36/56] improve tree index --- src/sirchmunk/learnings/compiler.py | 290 +++++++++++++ src/sirchmunk/learnings/tree_indexer.py | 480 +++++++++++++++++++++- src/sirchmunk/search.py | 248 +++++++++-- src/sirchmunk/utils/document_extractor.py | 91 ++++ 4 files changed, 1048 insertions(+), 61 deletions(-) diff --git a/src/sirchmunk/learnings/compiler.py b/src/sirchmunk/learnings/compiler.py index df2ca12..92dba7f 100644 --- a/src/sirchmunk/learnings/compiler.py +++ b/src/sirchmunk/learnings/compiler.py @@ -12,6 +12,7 @@ import math import os import random +import re import hashlib from dataclasses import dataclass, field from datetime import datetime, timezone @@ -51,6 +52,17 @@ _SUMMARY_SAMPLE_SECTIONS = 3 # Number of sections to sample for large docs _SUMMARY_SAMPLE_SECTION_CHARS = 5_000 # Chars per sampled section +# Targeted table extraction: max chars per table region +_TARGETED_TABLE_MAX_CHARS = 5000 + +# Targeted table extraction: only process nodes spanning <= N pages +_TABLE_PAGE_SPAN_LIMIT = 5 + +# Numeric density threshold – fraction of numeric/symbol chars ($, %, digits, +# parenthesised numbers) relative to total non-whitespace chars. Pages below +# this threshold are skipped during targeted extraction. +_TABLE_NUMERIC_DENSITY_THRESHOLD = 0.15 + # Excel table-level adaptive sampling constants _XLSX_TOTAL_ROW_BUDGET = 100 # Total sampled rows budget across all sheets _XLSX_MIN_ROWS_PER_SHEET = 3 # Minimum sampled rows per sheet @@ -645,6 +657,51 @@ async def _compile_single_file( content=content, total_pages=extraction.page_count, ) + # Phase 2.5: Targeted table extraction via generic structural signals + if result.tree and result.tree.root and ext == ".pdf": + targeted_tables = await self._targeted_table_extraction( + entry.path, result.tree, + ) + if targeted_tables: + # Load existing table digest (if any) and merge + digest_dir = self._compile_dir / "table_digests" + file_hash = get_fast_hash(entry.path) or "" + existing_digest: list[dict] = [] + if file_hash and result.has_table_digest: + digest_path = digest_dir / f"{file_hash}.json" + if digest_path.exists(): + try: + raw = json.loads( + digest_path.read_text(encoding="utf-8") + ) + existing_digest = raw.get("tables", []) + except Exception: + pass + merged = self._merge_table_digests( + existing_digest, targeted_tables, + ) + if merged and file_hash: + digest_dir.mkdir(parents=True, exist_ok=True) + digest_path = digest_dir / f"{file_hash}.json" + digest_path.write_text( + json.dumps( + { + "version": 1, + "table_count": len(merged), + "tables": merged, + }, + ensure_ascii=False, + ), + encoding="utf-8", + ) + result.has_table_digest = True + result.table_count = len(merged) + await self._log.info( + f"[Compile] Targeted table extraction added " + f"{len(targeted_tables)} tables for " + f"{Path(entry.path).name}" + ) + except Exception as exc: result.error = str(exc) await self._log.warning(f"[Compile] Failed: {entry.path}: {exc}") @@ -1409,6 +1466,239 @@ def _count(node: Any) -> int: return _count(tree.root) + # ------------------------------------------------------------------ # + # Targeted table extraction # + # ------------------------------------------------------------------ # + + async def _targeted_table_extraction( + self, file_path: str, tree: DocumentTree, + ) -> list[dict]: + """Extract tables from tree nodes likely containing tabular data. + + Uses generic structural signals (metadata, page span, numeric + density) instead of domain-specific title keywords. For each + candidate with a valid ``page_range``, extracts per-page text + via :meth:`DocumentExtractor.extract_page_range` and applies + heuristic table-region detection. Pages whose numeric density + falls below ``_TABLE_NUMERIC_DENSITY_THRESHOLD`` are skipped. + + Returns: + List of table dicts compatible with the table-digest format:: + + {"page": int, "content": str, "source": str} + """ + if tree is None or tree.root is None: + return [] + + candidates = self._find_table_candidate_nodes(tree.root) + if not candidates: + return [] + + await self._log.info( + f"[Compile] Targeted extraction: {len(candidates)} candidate " + f"nodes in {Path(file_path).name}" + ) + + results: list[dict] = [] + seen_pages: set[int] = set() + + for node in candidates: + if node.page_range is None: + continue + start_page, end_page = node.page_range + # Skip pages already processed by another candidate + page_nums = [p for p in range(start_page, end_page + 1) + if p not in seen_pages] + if not page_nums: + continue + + try: + pages = DocumentExtractor.extract_page_range( + file_path, start_page, end_page, + ) + except Exception as exc: + await self._log.warning( + f"[Compile] Targeted extraction page read failed " + f"({start_page}-{end_page}): {exc}" + ) + continue + + for pc in pages: + if pc.page_number in seen_pages: + continue + seen_pages.add(pc.page_number) + # Numeric density gate – skip pages unlikely to contain tables + if not self._page_has_table_density(pc.content): + continue + regions = self._identify_table_regions(pc.content) + for region in regions: + truncated = region[:_TARGETED_TABLE_MAX_CHARS] + results.append({ + "page": pc.page_number, + "content": truncated, + "source": f"targeted:{node.title[:80]}", + }) + + return results + + def _find_table_candidate_nodes( + self, root: "TreeNode", + ) -> list["TreeNode"]: + """Collect leaf nodes that likely contain tables. + + Uses generic, domain-agnostic structural signals (any match + suffices): + + - ``node.content_type == "table"`` – already tagged during compile. + - ``node.table_count > 0`` – known to contain tables. + - Has a valid ``page_range`` with span ≤ ``_TABLE_PAGE_SPAN_LIMIT``. + """ + candidates: list = [] + + def _walk(node: "TreeNode") -> None: + if node.leaf: + # Signal 1: content_type marked as table + if getattr(node, "content_type", None) == "table": + candidates.append(node) + return + # Signal 2: known to contain tables + if getattr(node, "table_count", 0) > 0: + candidates.append(node) + return + # Signal 3: moderate page span (tables rarely span many pages) + page_range = getattr(node, "page_range", None) + if page_range and len(page_range) == 2: + span = page_range[1] - page_range[0] + 1 + if 1 <= span <= _TABLE_PAGE_SPAN_LIMIT: + candidates.append(node) + else: + for child in node.children: + _walk(child) + + _walk(root) + return candidates + + @staticmethod + def _page_has_table_density(page_text: str) -> bool: + """Return True if *page_text* has numeric density above the threshold. + + Counts digits and common table symbols (``$``, ``%``, ``(``, ``)``) + relative to total non-whitespace characters. + """ + if not page_text: + return False + non_ws = sum(1 for ch in page_text if not ch.isspace()) + if non_ws == 0: + return False + numeric_chars = sum( + 1 for ch in page_text + if ch.isdigit() or ch in "$%(),.+-" + ) + return (numeric_chars / non_ws) >= _TABLE_NUMERIC_DENSITY_THRESHOLD + + @staticmethod + def _identify_table_regions(page_text: str) -> list[str]: + """Identify contiguous table-like regions in *page_text*. + + Heuristic rules: + - Lines containing multiple numeric tokens (dollar amounts, %, + parenthesised negatives) are considered *numeric rows*. + - A run of >= 3 consecutive numeric rows forms a table region. + - Leading/trailing whitespace rows are trimmed. + + Returns: + List of extracted region strings (may be empty). + """ + if not page_text: + return [] + + # Pattern: line has at least 2 numeric-looking tokens + _NUM_TOKEN = re.compile( + r"(?:" + r"[\$€£¥]\s*[\d,.]+|" + r"\([\d,.]+\)|" + r"[\d,.]+%|" + r"[\d]+\.[\d]+(?:[eE][+-]?\d+)?|" + r"[\d,]{2,}" + r")" + ) + _MIN_NUMS_PER_LINE = 2 + _MIN_CONSECUTIVE = 3 + + lines = page_text.split("\n") + is_numeric = [ + len(_NUM_TOKEN.findall(line)) >= _MIN_NUMS_PER_LINE + for line in lines + ] + + regions: list[str] = [] + run_start: int | None = None + + for i, flag in enumerate(is_numeric): + if flag: + if run_start is None: + run_start = i + else: + if run_start is not None: + run_len = i - run_start + if run_len >= _MIN_CONSECUTIVE: + # Include one context line above/below + start = max(0, run_start - 1) + end = min(len(lines), i + 1) + regions.append( + "\n".join(lines[start:end]).strip() + ) + run_start = None + + # Flush trailing run + if run_start is not None: + run_len = len(lines) - run_start + if run_len >= _MIN_CONSECUTIVE: + start = max(0, run_start - 1) + regions.append( + "\n".join(lines[start:]).strip() + ) + + return regions + + @staticmethod + def _get_table_page(entry: dict) -> int | None: + """统一获取表格条目的页码,兼容 page_number 和 page 两种字段名。""" + p = entry.get("page_number") or entry.get("page") + return int(p) if p is not None else None + + @classmethod + def _merge_table_digests( + cls, existing: list[dict], new_tables: list[dict], + ) -> list[dict]: + """Merge *new_tables* into *existing* digest, deduplicating by page. + + If an existing entry and a new entry share the same page number, + the new entry is skipped (existing kreuzberg-detected table takes + precedence because it has richer structure like cells/markdown). + + Returns: + Merged list suitable for storage in the table-digest JSON. + """ + existing_pages = {cls._get_table_page(e) for e in existing} + existing_pages.discard(None) + + merged = list(existing) + for tbl in new_tables: + page = cls._get_table_page(tbl) + if page is not None and page in existing_pages: + continue + # Normalise to digest table format for consistency + merged.append({ + "page_number": page, + "markdown": tbl.get("content", ""), + "row_count": None, + "col_count": None, + "cells": [], + "source": tbl.get("source", "targeted"), + }) + return merged + # ------------------------------------------------------------------ # # Summary index for embedding + BM25 fallback # # ------------------------------------------------------------------ # diff --git a/src/sirchmunk/learnings/tree_indexer.py b/src/sirchmunk/learnings/tree_indexer.py index e1a652f..96c44b9 100644 --- a/src/sirchmunk/learnings/tree_indexer.py +++ b/src/sirchmunk/learnings/tree_indexer.py @@ -8,7 +8,9 @@ """ import json +import math import re +from collections import Counter from dataclasses import dataclass, field from datetime import datetime, timezone from pathlib import Path @@ -153,6 +155,13 @@ def from_json(cls, json_str: str) -> "DocumentTree": class DocumentTreeIndexer: """Build and cache PageIndex-style hierarchical tree indices for documents.""" + # Maximum child nodes before switching to paginated LLM selection. + # Balance: lower = more LLM calls, higher = more tokens per call. + _PAGE_SIZE_THRESHOLD: int = 15 + + # Number of nodes per group in paginated selection. + _GROUP_PAGE_SIZE: int = 15 + def __init__( self, llm: OpenAIChat, @@ -272,8 +281,23 @@ async def navigate( query: str, *, max_results: int = 3, + max_depth: int = 4, ) -> List[TreeNode]: - """LLM-driven branch selection using _select_children().""" + """Adaptive-depth LLM-driven tree navigation. + + Iteratively descends the tree using _select_children() at each level, + collecting leaf nodes until *max_results* are found or *max_depth* is + reached. + + Args: + tree: DocumentTree with a root node. + query: Search query for relevance selection. + max_results: Maximum number of leaf nodes to return. + max_depth: Maximum descent depth (default 4). + + Returns: + List of the most relevant leaf TreeNodes. + """ if tree.root is None: return [] @@ -281,18 +305,41 @@ async def navigate( if not candidates: return [tree.root] - selected = await self._select_children(candidates, query) - if not selected: - return [] - result_leaves: List[TreeNode] = [] - for node in selected: - if node.leaf: - result_leaves.append(node) - else: - deeper = await self._select_children(node.children, query) - for d in (deeper or node.children[:1]): - result_leaves.extend(d.all_leaves()[:max_results]) + visited: set = set() # prevent cycles + frontier = candidates + selected: List[TreeNode] = [] + + depth = 0 + while depth < max_depth and frontier: + selected = await self._select_children( + frontier, query, max_selections=max_results, + ) + if not selected: + break + + next_frontier: List[TreeNode] = [] + for node in selected: + node_id = id(node) + if node_id in visited: + continue + visited.add(node_id) + + if node.leaf or not node.children: + result_leaves.append(node) + else: + next_frontier.extend(node.children) + + if len(result_leaves) >= max_results: + break + + frontier = next_frontier + depth += 1 + + # Fallback: if no leaves found, expand last selected nodes + if not result_leaves and selected: + for node in selected: + result_leaves.extend(node.all_leaves()[:max_results]) # Deduplicate and cap seen_ids: set = set() @@ -341,6 +388,9 @@ async def _build_tree_from_toc( Returns: Root TreeNode, or None if no children could be created. """ + # Infer hierarchy when TOC entries are flat (all same level) + toc_entries = self._infer_hierarchy(toc_entries) + seen_ids: set = set() children = self._toc_entries_to_nodes( toc_entries, content, len(content), seen_ids, @@ -610,12 +660,21 @@ def _resolve_positions( ] async def _select_children( - self, nodes: List[TreeNode], query: str, + self, nodes: List[TreeNode], query: str, *, max_selections: int = 3, ) -> List[TreeNode]: - """LLM-driven branch selection: pick the most relevant children.""" + """LLM-driven branch selection: pick the most relevant children. + + Dispatches to paginated selection when *nodes* exceeds + ``_PAGE_SIZE_THRESHOLD`` to avoid overwhelming the LLM. + """ if len(nodes) <= 2: return nodes + if len(nodes) > self._PAGE_SIZE_THRESHOLD: + return await self._select_children_paginated( + nodes, query, max_selections=max_selections, + ) + listing = "\n".join( f"[{i}] {n.title}{self._format_page_range(n.page_range)}" f"{' [' + str(n.table_count) + ' tables]' if n.table_count > 0 else ''}" @@ -634,10 +693,120 @@ async def _select_children( m = re.search(r"\[[\d\s,]+\]", raw) if m: indices = json.loads(m.group()) - return [nodes[i] for i in indices if 0 <= i < len(nodes)] + selected = [nodes[i] for i in indices if 0 <= i < len(nodes)] + return selected if selected else nodes[:max_selections] except (json.JSONDecodeError, IndexError, TypeError): pass - return nodes[:2] + return nodes[:max_selections] + + async def _select_children_paginated( + self, + nodes: List[TreeNode], + query: str, + *, + page_size: int = 15, + max_selections: int = 3, + ) -> List[TreeNode]: + """Two-phase paginated selection for large node sets. + + Phase 1: partition *nodes* into sequential groups of *page_size*, + present group summaries to LLM, and select 1-2 groups. + Phase 2: run fine-grained selection within each chosen group. + + Falls back to the first *max_selections* nodes on any LLM failure. + """ + page_size = max(page_size, self._GROUP_PAGE_SIZE) + + # --- Phase 0: build groups --- + groups: List[List[TreeNode]] = [] + for start in range(0, len(nodes), page_size): + groups.append(nodes[start:start + page_size]) + + if len(groups) <= 1: + # Only one group — skip directly to fine-grained selection + return await self._select_from_group(nodes, query, max_selections) + + # --- Phase 1: group-level selection --- + group_listing = "\n".join( + f"[{i}] {g[0].title} ... {g[-1].title} ({len(g)} sections)" + for i, g in enumerate(groups) + ) + group_prompt = ( + f"Given the query: \"{query}\"\n\n" + f"The document has {len(nodes)} sections organized into " + f"{len(groups)} groups.\n" + f"Select the 1-2 most relevant groups (by index number):\n" + f"{group_listing}\n\n" + f"Return ONLY a JSON array of group index numbers, e.g. [0, 2]" + ) + + selected_groups: List[List[TreeNode]] = [] + try: + resp = await self._llm.achat( + [{"role": "user", "content": group_prompt}], + ) + raw = resp.content.strip() + m = re.search(r"\[[\d\s,]+\]", raw) + if m: + g_indices = json.loads(m.group()) + selected_groups = [ + groups[i] for i in g_indices if 0 <= i < len(groups) + ] + except (json.JSONDecodeError, IndexError, TypeError): + pass + + if not selected_groups: + # Fallback: take the first group + selected_groups = [groups[0]] + + # --- Phase 2: fine-grained selection within chosen groups --- + results: List[TreeNode] = [] + for group in selected_groups: + picked = await self._select_from_group(group, query, max_selections) + results.extend(picked) + + # Deduplicate by node_id and cap + seen: set = set() + unique: List[TreeNode] = [] + for n in results: + if n.node_id not in seen: + seen.add(n.node_id) + unique.append(n) + return unique[:max_selections] if unique else nodes[:max_selections] + + async def _select_from_group( + self, + group: List[TreeNode], + query: str, + max_selections: int, + ) -> List[TreeNode]: + """Select the most relevant nodes within a single group via LLM.""" + if len(group) <= 2: + return group + + listing = "\n".join( + f"[{i}] {n.title}{self._format_page_range(n.page_range)}" + f"{' [' + str(n.table_count) + ' tables]' if n.table_count > 0 else ''}" + f": {n.summary[:150]}" + for i, n in enumerate(group) + ) + prompt = ( + f"Given the query: \"{query}\"\n\n" + f"Select the 1-2 most relevant sections (by index number):\n{listing}\n\n" + f"Return ONLY a JSON array of index numbers, e.g. [0, 2]" + ) + try: + resp = await self._llm.achat([{"role": "user", "content": prompt}]) + raw = resp.content.strip() + m = re.search(r"\[[\d\s,]+\]", raw) + if m: + indices = json.loads(m.group()) + selected = [group[i] for i in indices if 0 <= i < len(group)] + if selected: + return selected[:max_selections] + except (json.JSONDecodeError, IndexError, TypeError): + pass + return group[:max_selections] # ------------------------------------------------------------------ # # Cache I/O # @@ -924,3 +1093,282 @@ def should_build_tree(file_path: str, content_length: int) -> bool: """Determine whether a file is eligible for tree indexing.""" ext = Path(file_path).suffix.lower() return ext in _TREE_EXTENSIONS and content_length >= _TREE_MIN_CHARS + + # ------------------------------------------------------------------ # + # Hierarchy inference for flat TOC entries # + # ------------------------------------------------------------------ # + + # Minimum number of TOC entries to trigger hierarchy inference. + # Documents with fewer entries are typically already well-structured. + _FLAT_ENTRY_THRESHOLD = 20 + + # If this fraction of entries share the same level, consider it "flat" + # and apply hierarchy inference. Real hierarchies typically have + # varied level distribution. + _FLAT_LEVEL_RATIO = 0.9 + + # Number of entries per virtual group when using uniform grouping fallback. + _GROUP_SIZE = 15 + + @staticmethod + def _infer_hierarchy(entries: List[Any]) -> List[Any]: + """When all entries share the same level, infer hierarchy from title patterns. + + Applies three strategies in priority order: + A. Keyword groups — detect repeated structural prefixes (generic) + B. Generic numbering patterns (1., 1.1, I., A., etc.) + C. Uniform grouping fallback (virtual parent nodes) + + Only activates when >90% of entries share the same level and + the total count exceeds ``_FLAT_ENTRY_THRESHOLD``. + + Args: + entries: List of TOCEntry (may be nested). + + Returns: + Possibly restructured list of TOCEntry with updated levels + and rebuilt hierarchy. + """ + if not entries: + return entries or [] + + try: + from sirchmunk.learnings.toc_extractor import TOCExtractor + flat: List[Any] = [] + TOCExtractor._flatten_entries(entries, flat) + except Exception: + return entries # Cannot flatten; return original entries + + if not flat: + return entries + + if len(flat) <= DocumentTreeIndexer._FLAT_ENTRY_THRESHOLD: + return entries + + # Validate level field: skip entries with invalid levels + valid_flat = [e for e in flat if hasattr(e, 'level') and isinstance(e.level, (int, float))] + if not valid_flat: + return entries + + # Check if >90% share the same level + level_counts = Counter(e.level for e in valid_flat) + dominant_level, dominant_count = level_counts.most_common(1)[0] + if dominant_count / len(flat) <= DocumentTreeIndexer._FLAT_LEVEL_RATIO: + return entries # Already has meaningful hierarchy + + # Try strategies in priority order + modified = DocumentTreeIndexer._strategy_keyword_groups(flat, dominant_level) + if modified is None: + modified = DocumentTreeIndexer._strategy_numbering(flat, dominant_level) + if modified is None: + modified = DocumentTreeIndexer._strategy_uniform_grouping( + flat, dominant_level, + ) + if modified is None: + return entries + + # Rebuild hierarchy from the re-leveled flat list + return TOCExtractor._build_hierarchy(modified) + + # -- Strategy A: keyword groups (generic structural prefix detection) # + + # Pattern: title starts with a capitalized word optionally followed by + # a Roman numeral or Arabic number (e.g. "PART IV", "Item 1A", + # "Section 3", "Chapter 12", "Article II"). + _RE_STRUCTURAL_PREFIX = re.compile( + r'^([A-Z][A-Za-z]*(?:\s+[IVXLCDM\d]+[A-Za-z]?)?)\b', + ) + + @staticmethod + def _extract_structural_prefix(title: str) -> Optional[str]: + """Extract a structural prefix from a title. + + Matches leading capitalized words optionally followed by a number + or Roman numeral (e.g. "PART IV", "Item 1A", "Section 3"). + Returns the normalized (uppercased) prefix, or None. + """ + if not title or not title.strip(): + return None + m = DocumentTreeIndexer._RE_STRUCTURAL_PREFIX.match(title.strip()) + if m: + prefix = m.group(1).strip() + # Prefix must not be too long (avoid capturing entire title) + if len(prefix) <= 20: + return prefix.upper() + return None + + @staticmethod + def _strategy_keyword_groups( + flat: List[Any], + dominant_level: int, + ) -> Optional[List[Any]]: + """Strategy A — detect repeated structural prefixes and infer levels. + + Works for any document with repetitive heading patterns (SEC filings, + legal contracts, technical specs, etc.). Automatically discovers + prefix groups and assigns hierarchical levels based on frequency: + lower-frequency prefixes become higher-level parents. + + Returns re-leveled flat list, or None if coverage is insufficient. + """ + # 1. Extract prefix for each entry + prefix_map: Dict[str, List[int]] = {} # prefix -> [entry indices] + for i, e in enumerate(flat): + prefix = DocumentTreeIndexer._extract_structural_prefix(e.title) + if prefix: + prefix_map.setdefault(prefix, []).append(i) + + # 2. Keep only prefixes appearing >= 2 times + repeated_prefixes = {k: v for k, v in prefix_map.items() if len(v) >= 2} + if not repeated_prefixes: + return None + + # 3. Check coverage: at least 30% of entries must be covered + covered = sum(len(indices) for indices in repeated_prefixes.values()) + if covered < len(flat) * 0.3: + return None + + # 4. Sort prefixes by frequency (ascending) then by first appearance + # Low frequency = higher level (parent), high frequency = lower level + sorted_prefixes = sorted( + repeated_prefixes.items(), + key=lambda x: (len(x[1]), min(x[1])), + ) + + # 5. Assign level per prefix group + prefix_to_level: Dict[str, int] = {} + for level_idx, (prefix, _) in enumerate(sorted_prefixes): + prefix_to_level[prefix] = level_idx + 1 + + # 6. Determine the "other" level for entries without a known prefix + max_level = max(prefix_to_level.values()) + 1 + + # 7. Apply levels + for i, e in enumerate(flat): + prefix = DocumentTreeIndexer._extract_structural_prefix(e.title) + if prefix and prefix in prefix_to_level: + e.level = prefix_to_level[prefix] + else: + e.level = max_level + e.children = [] + + return flat + + # -- Strategy B: generic numbering --------------------------------- # + + # Three-level numbering: 1.1.1, (a), (i), (1) + _RE_NUM_LEVEL3 = re.compile( + r"^\s*(?:\d+\.\d+\.\d+|\([a-z]\)|\([ivx]+\)|\(\d+\))\s", + re.IGNORECASE, + ) + # Two-level numbering: 1.1, A., B., a., b. + _RE_NUM_LEVEL2 = re.compile( + r"^\s*(?:\d+\.\d+(?!\.)\b|[A-Z]\.\s|[a-z]\.\s)", + ) + # Top-level numbering: 1., 2., I., II. + _RE_NUM_LEVEL1 = re.compile( + r"^\s*(?:\d+\.\s|[IVXLC]+\.\s)", + ) + + @staticmethod + def _strategy_numbering( + flat: List[Any], + dominant_level: int, + ) -> Optional[List[Any]]: + """Strategy B — detect generic numbering patterns. + + Returns re-leveled flat list, or None if fewer than 30% of + entries match any numbering pattern. + """ + matched = 0 + assignments: List[Optional[int]] = [] + + for e in flat: + title = e.title + if DocumentTreeIndexer._RE_NUM_LEVEL3.match(title): + assignments.append(3) + matched += 1 + elif DocumentTreeIndexer._RE_NUM_LEVEL2.match(title): + assignments.append(2) + matched += 1 + elif DocumentTreeIndexer._RE_NUM_LEVEL1.match(title): + assignments.append(1) + matched += 1 + else: + assignments.append(None) + + if matched < len(flat) * 0.3: + return None + + # Apply assignments; entries without a pattern get the level of + # the previous entry + 1 (capped at 3) + prev_level = 1 + for i, e in enumerate(flat): + if assignments[i] is not None: + e.level = assignments[i] + else: + e.level = min(prev_level + 1, 3) + prev_level = e.level + e.children = [] + return flat + + # -- Strategy C: uniform grouping fallback ------------------------- # + + @staticmethod + def _strategy_uniform_grouping( + flat: List[Any], + dominant_level: int, + ) -> Optional[List[Any]]: + """Strategy C — group entries into fixed-size buckets with virtual parents. + + Creates synthetic parent TOCEntry nodes whose char_start/char_end + and page_start/page_end are derived from the first and last child + in each group. + + Returns the re-leveled flat list including virtual parents, or None + on error. + """ + from sirchmunk.learnings.toc_extractor import TOCEntry + + group_size = DocumentTreeIndexer._GROUP_SIZE + num_groups = math.ceil(len(flat) / group_size) + if num_groups <= 1: + return None # Grouping would not improve anything + + parent_level = max(1, dominant_level - 1) if dominant_level > 1 else 1 + child_level = parent_level + 1 + + result: List[Any] = [] + for g in range(num_groups): + start_idx = g * group_size + end_idx = min((g + 1) * group_size, len(flat)) + group = flat[start_idx:end_idx] + + first = group[0] + last = group[-1] + + # Derive positions from children + char_start = first.char_start + char_end = last.char_end if last.char_end else None + page_start = first.page_start + page_end = last.page_start # Best available estimate + + virtual_parent = TOCEntry( + title=f"{first.title} \u2013 {last.title}", + level=parent_level, + char_start=char_start, + char_end=char_end, + page_start=page_start, + page_end=page_end, + children=[], + source="inferred", + ) + result.append(virtual_parent) + + # Set child level + for e in group: + e.level = child_level + e.children = [] + result.extend(group) + + return result diff --git a/src/sirchmunk/search.py b/src/sirchmunk/search.py index 5e497f2..52c0db3 100644 --- a/src/sirchmunk/search.py +++ b/src/sirchmunk/search.py @@ -15,6 +15,7 @@ from sirchmunk.base import BaseSearch from sirchmunk.learnings.knowledge_base import KnowledgeBase +from sirchmunk.utils.document_extractor import DocumentExtractor from sirchmunk.llm.openai_chat import OpenAIChat from sirchmunk.llm.prompts import ( KEYWORD_QUERY_PLACEHOLDER, @@ -3808,42 +3809,100 @@ async def _tree_guided_sample( if not leaves: return None - # --- Read full text once for char_range slicing --- - try: - from sirchmunk.utils.file_utils import fast_extract - extraction = await fast_extract(file_path=file_path) - full_text = extraction.content or "" - except Exception: - full_text = "" + # --- Classify leaves by extraction method --- + trimmed = leaves[: self._TREE_SAMPLE_MAX_SECTIONS] + page_leaves, char_leaves, table_and_summary = self._classify_leaves(trimmed) - # --- Extract tree sections --- - parts: List[str] = [] - total_chars = 0 - for leaf in leaves[: self._TREE_SAMPLE_MAX_SECTIONS]: - # Table nodes: prefer summary (contains table markdown) - if getattr(leaf, 'content_type', 'text') == 'table' and leaf.summary: - segment = leaf.summary + # Collect (leaf, segment) pairs preserving original leaf order + leaf_segments: List[tuple] = [] # (leaf, segment_text) + + # -- Phase A: table / summary-only leaves -- + for leaf in table_and_summary: + leaf_segments.append((leaf, leaf.summary)) + + # -- Phase B: batch page-level extraction (single IO) -- + page_segment_map: dict = {} # id(leaf) -> segment + if page_leaves: + all_pages: set = set() + for _leaf, (sp, ep) in page_leaves: + all_pages.update(range(sp, ep + 1)) + try: + page_contents = DocumentExtractor.extract_pages( + file_path, sorted(all_pages), + ) + page_map = {pc.page_number: pc.content for pc in page_contents} + + for leaf, (sp, ep) in page_leaves: + seg_parts = [] + for p in range(sp, ep + 1): + text = page_map.get(p, "") + if text.strip(): + seg_parts.append(text) + if seg_parts: + page_segment_map[id(leaf)] = "\n".join(seg_parts) + elif getattr(leaf, 'summary', None): + page_segment_map[id(leaf)] = leaf.summary + except (FileNotFoundError, PermissionError): + raise # 文件系统错误应传播 + except Exception as e: + _loguru_logger.warning( + f"[TreeSample] Page extraction failed for {fname}: {e}, " + f"falling back to char_range for {len(page_leaves)} leaves" + ) + # Demote page_leaves → char_leaves + for leaf, _ in page_leaves: + if hasattr(leaf, 'char_range') and leaf.char_range: + char_leaves.append(leaf) + elif getattr(leaf, 'summary', None): + leaf_segments.append((leaf, leaf.summary)) + page_leaves_ok = False else: + page_leaves_ok = True + + if page_leaves_ok: + for leaf, _ in page_leaves: + seg = page_segment_map.get(id(leaf)) + if seg: + leaf_segments.append((leaf, seg)) + # If page extraction failed, demoted leaves are now in char_leaves + + # -- Phase C: char_range fallback (lazy full-text extraction) -- + if char_leaves: + try: + from sirchmunk.utils.file_utils import fast_extract + extraction = await fast_extract(file_path=file_path) + full_text = extraction.content or "" + except Exception: + full_text = "" + + for leaf in char_leaves: start, end = leaf.char_range if self._is_valid_char_range(start, end, len(full_text)) and full_text: segment = full_text[start:end] - elif leaf.summary: + if segment.strip(): + leaf_segments.append((leaf, segment)) + elif getattr(leaf, 'summary', None): + leaf_segments.append((leaf, leaf.summary)) + elif getattr(leaf, 'summary', None): _loguru_logger.debug( - f"[TreeNav] char_range degraded for '{leaf.title}' " + f"[TreeSample] char_range degraded for '{leaf.title}' " f"(span_ratio={(end - start) / max(len(full_text), 1):.2f}), using summary" ) - segment = leaf.summary - else: - continue + leaf_segments.append((leaf, leaf.summary)) + + # --- Build parts with budget control --- + parts: List[str] = [] + total_chars = 0 + for leaf, segment in leaf_segments: segment = segment[: self._TREE_SAMPLE_SECTION_MAX_CHARS] if not segment.strip(): continue page_info = "" - if leaf.page_range: + if getattr(leaf, 'page_range', None): ps, pe = leaf.page_range page_info = f" (pp.{ps}-{pe})" if ps != pe else f" (p.{ps})" type_tag = " [TABLE]" if getattr(leaf, 'content_type', 'text') == 'table' else "" - header = f"[{fname} → {leaf.title}{page_info}{type_tag}]" + header = f"[{fname} \u2192 {leaf.title}{page_info}{type_tag}]" chunk = f"{header}\n{segment}" if total_chars + len(chunk) > max_chars: remaining = max_chars - total_chars @@ -3889,6 +3948,36 @@ async def _tree_guided_sample( ) return evidence + @staticmethod + def _classify_leaves(leaves: list) -> Tuple[List[tuple], List, List]: + """将叶节点按提取策略分类。 + + Returns: + (page_leaves, char_leaves, summary_leaves) 三元组: + - page_leaves: list of (leaf, page_range) tuples — 有有效 page_range 的 + - char_leaves: list of leaf — 需要 char_range fallback 的 + - summary_leaves: list of leaf — 只有 summary 可用的 + """ + page_leaves: List[tuple] = [] + char_leaves: List = [] + summary_leaves: List = [] + + for leaf in leaves: + # 表格类型节点优先使用 summary(结构化摘要) + if getattr(leaf, 'content_type', 'text') == 'table' and getattr(leaf, 'summary', None): + summary_leaves.append(leaf) + continue + + page_range = getattr(leaf, 'page_range', None) + if page_range and len(page_range) == 2 and page_range[0] is not None and page_range[0] > 0: + page_leaves.append((leaf, page_range)) + elif hasattr(leaf, 'char_range') and leaf.char_range: + char_leaves.append(leaf) + elif getattr(leaf, 'summary', None): + summary_leaves.append(leaf) + + return page_leaves, char_leaves, summary_leaves + def _is_valid_char_range( self, start: int, end: int, text_len: int, ) -> bool: @@ -3978,6 +4067,23 @@ def _format_table_evidence( return "\n\n".join(parts) + @staticmethod + def _append_evidence_part( + parts: List[str], fname: str, leaf, segment: str, + *, max_chars: int = 3000, + ) -> None: + """Format and append one leaf's evidence to *parts* (in-place).""" + text = segment[:max_chars] + if not text.strip(): + return + page_info = "" + if getattr(leaf, 'page_range', None): + ps, pe = leaf.page_range + page_info = f" (pp.{ps}-{pe})" if ps != pe else f" (p.{ps})" + type_tag = " [TABLE]" if getattr(leaf, 'content_type', 'text') == 'table' else "" + header = f"[{fname} \u2192 {leaf.title}{page_info}{type_tag}]" + parts.append(f"{header}\n{text}") + async def _navigate_tree_for_evidence( self, file_path: str, query: str, *, max_results: int = 3, ) -> Optional[str]: @@ -3986,6 +4092,11 @@ async def _navigate_tree_for_evidence( Uses 1 LLM call to drill into the compiled tree index for *file_path*, returning concatenated leaf content as evidence. Returns None when no tree cache is available. + + Extraction priority (highest first): + 1. page_range – page-level extraction via DocumentExtractor + 2. char_range – full-text extraction + slice (fallback) + 3. leaf.summary – last resort """ indexer = self._get_tree_indexer() if indexer is None: @@ -4003,39 +4114,86 @@ async def _navigate_tree_for_evidence( return None fname = Path(file_path).name - # Read leaf content from the original document via char_range parts: List[str] = [] - try: - from sirchmunk.utils.file_utils import fast_extract - extraction = await fast_extract(file_path=file_path) - full_text = extraction.content or "" - except Exception: - full_text = "" - for leaf in leaves: - # Table nodes: prefer summary (contains table markdown) - if getattr(leaf, 'content_type', 'text') == 'table' and leaf.summary: - segment = leaf.summary - else: + # ── Phase 1: classify leaves by available extraction method ── + page_leaves, char_leaves, summary_only = self._classify_leaves(leaves) + + for leaf in summary_only: + self._append_evidence_part( + parts, fname, leaf, leaf.summary, + ) + + # ── Phase 2: batch page-level extraction (single IO) ── + if page_leaves: + all_pages: set = set() + for _leaf, (sp, ep) in page_leaves: + all_pages.update(range(sp, ep + 1)) + try: + page_contents = DocumentExtractor.extract_pages( + file_path, sorted(all_pages), + ) + page_map = {pc.page_number: pc.content for pc in page_contents} + + for leaf, (sp, ep) in page_leaves: + segment_parts = [] + for p in range(sp, ep + 1): + text = page_map.get(p, "") + if text.strip(): + segment_parts.append(text) + if segment_parts: + self._append_evidence_part( + parts, fname, leaf, "\n".join(segment_parts), + ) + elif getattr(leaf, 'summary', None): + self._append_evidence_part( + parts, fname, leaf, leaf.summary, + ) + except (FileNotFoundError, PermissionError): + raise # 文件系统错误应传播 + except Exception as e: + _loguru_logger.warning( + f"[TreeNav] Page extraction failed for {fname}: {e}, " + f"falling back to char_range for {len(page_leaves)} leaves" + ) + # Demote page_leaves → char_leaves for char_range fallback + for leaf, _ in page_leaves: + if hasattr(leaf, 'char_range') and leaf.char_range: + char_leaves.append(leaf) + elif getattr(leaf, 'summary', None): + self._append_evidence_part( + parts, fname, leaf, leaf.summary, + ) + + # ── Phase 3: char_range fallback (lazy full-text extraction) ── + if char_leaves: + try: + from sirchmunk.utils.file_utils import fast_extract + extraction = await fast_extract(file_path=file_path) + full_text = extraction.content or "" + except Exception: + full_text = "" + + for leaf in char_leaves: start, end = leaf.char_range if self._is_valid_char_range(start, end, len(full_text)) and full_text: segment = full_text[start:end] - elif leaf.summary: + if segment.strip(): + self._append_evidence_part( + parts, fname, leaf, segment, + ) + elif getattr(leaf, 'summary', None): + self._append_evidence_part( + parts, fname, leaf, leaf.summary, + ) + elif getattr(leaf, 'summary', None): _loguru_logger.debug( f"[TreeNav] char_range degraded for '{leaf.title}' " f"(span_ratio={(end - start) / max(len(full_text), 1):.2f}), using summary" ) - segment = leaf.summary - else: - continue - if segment.strip(): - page_info = "" - if leaf.page_range: - ps, pe = leaf.page_range - page_info = f" (pp.{ps}-{pe})" if ps != pe else f" (p.{ps})" - type_tag = " [TABLE]" if getattr(leaf, 'content_type', 'text') == 'table' else "" - header = f"[{fname} → {leaf.title}{page_info}{type_tag}]" - parts.append(f"{header}\n{segment[:3000]}") + self._append_evidence_part( + parts, fname, leaf, leaf.summary, + ) if not parts: return None diff --git a/src/sirchmunk/utils/document_extractor.py b/src/sirchmunk/utils/document_extractor.py index d114b7d..a022a8d 100644 --- a/src/sirchmunk/utils/document_extractor.py +++ b/src/sirchmunk/utils/document_extractor.py @@ -110,6 +110,25 @@ class ExtractionOutput: """Number of pages in the source document (if available).""" +# --------------------------------------------------------------------------- +# Page-level extraction output +# --------------------------------------------------------------------------- + +@dataclass(frozen=True) +class PageContent: + """Single page extraction result. + + Returned by :meth:`DocumentExtractor.extract_pages` to represent the + text content of one PDF page. + """ + + page_number: int + """1-indexed page number.""" + + content: str + """Extracted text content (may be empty string).""" + + # --------------------------------------------------------------------------- # Document extractor facade # --------------------------------------------------------------------------- @@ -276,6 +295,78 @@ async def batch_extract( logger.error("Batch extraction failed for {} files", len(file_paths)) raise + # Page-level extraction ------------------------------------------------- + + @staticmethod + def extract_pages( + file_path: Union[str, Path], + pages: list[int], + ) -> list[PageContent]: + """Extract text content from specific PDF pages. + + Uses pypdf to read individual pages by 1-indexed page number. + Invalid page numbers (< 1 or > total pages) are silently skipped. + + Args: + file_path: Path to a PDF file. + pages: List of 1-indexed page numbers to extract. + + Returns: + List of :class:`PageContent` for each valid requested page, + in the order given by *pages*. + + Raises: + FileNotFoundError: If *file_path* does not exist. + Exception: On PDF parsing failure (logged before re-raise). + """ + path = Path(file_path) + if not path.exists(): + raise FileNotFoundError(f"PDF file not found: {path}") + + try: + from pypdf import PdfReader + + reader = PdfReader(str(path)) + total = len(reader.pages) + valid_pages = [p for p in pages if 1 <= p <= total] + return [ + PageContent( + page_number=p, + content=reader.pages[p - 1].extract_text() or "", + ) + for p in valid_pages + ] + except FileNotFoundError: + raise + except Exception as exc: + logger.error( + "Page-level extraction failed for {}: {}", + file_path, + exc, + ) + raise + + @staticmethod + def extract_page_range( + file_path: Union[str, Path], + start_page: int, + end_page: int, + ) -> list[PageContent]: + """Extract text content from a contiguous range of PDF pages. + + Convenience wrapper around :meth:`extract_pages`. + + Args: + file_path: Path to a PDF file. + start_page: First page (1-indexed, inclusive). + end_page: Last page (1-indexed, inclusive). + + Returns: + List of :class:`PageContent` for the requested range. + """ + pages = list(range(start_page, end_page + 1)) + return DocumentExtractor.extract_pages(file_path, pages) + # Internal helpers ----------------------------------------------------- @staticmethod From 2f3a25777f212e5d43c2f2a3648f508d75545c0b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E7=8F=AD=E6=89=AC?= Date: Sun, 26 Apr 2026 17:53:10 +0800 Subject: [PATCH 37/56] improve search tree index --- src/sirchmunk/learnings/tree_indexer.py | 15 ++++--- src/sirchmunk/search.py | 59 ++++++++++++++++++------- 2 files changed, 52 insertions(+), 22 deletions(-) diff --git a/src/sirchmunk/learnings/tree_indexer.py b/src/sirchmunk/learnings/tree_indexer.py index 96c44b9..a720bf3 100644 --- a/src/sirchmunk/learnings/tree_indexer.py +++ b/src/sirchmunk/learnings/tree_indexer.py @@ -9,6 +9,7 @@ import json import math +import os import re from collections import Counter from dataclasses import dataclass, field @@ -230,9 +231,10 @@ async def build_tree( # is unreliable, causing overlapping ranges and search failures. # TODO: Re-enable when robust char_range calculation is implemented. # await self._deepen_large_leaves(root, content, max_depth=effective_depth) - # NOTE: _enrich_node_summaries disabled temporarily to isolate its impact. - # The summaries may inadvertently bias _select_children() navigation. - # await self._enrich_node_summaries(root, content) + # Node summary enrichment: controlled by SIRCHMUNK_SKIP_NODE_SUMMARIES env var. + # Set to "true" to skip during debugging / performance testing. + if os.getenv("SIRCHMUNK_SKIP_NODE_SUMMARIES", "").lower() not in ("true", "1", "yes"): + await self._enrich_node_summaries(root, content) tree = DocumentTree( file_path=file_path, file_hash=file_hash, @@ -256,9 +258,10 @@ async def build_tree( # is unreliable, causing overlapping ranges and search failures. # TODO: Re-enable when robust char_range calculation is implemented. # await self._deepen_large_leaves(root, content, max_depth=effective_depth) - # NOTE: _enrich_node_summaries disabled temporarily to isolate its impact. - # The summaries may inadvertently bias _select_children() navigation. - # await self._enrich_node_summaries(root, content) + # Node summary enrichment: controlled by SIRCHMUNK_SKIP_NODE_SUMMARIES env var. + # Set to "true" to skip during debugging / performance testing. + if os.getenv("SIRCHMUNK_SKIP_NODE_SUMMARIES", "").lower() not in ("true", "1", "yes"): + await self._enrich_node_summaries(root, content) tree = DocumentTree( file_path=file_path, diff --git a/src/sirchmunk/search.py b/src/sirchmunk/search.py index 52c0db3..1c47a77 100644 --- a/src/sirchmunk/search.py +++ b/src/sirchmunk/search.py @@ -2543,10 +2543,21 @@ async def _search_fast( f"{len(best_files)} compile-hint files" ) else: - await self._logger.warning( - "[FAST:PureTree] No tree probes available, returning empty" + # Graceful degradation: fall back to keyword search when no tree is available + await self._logger.info( + "[FAST:PureTree] No tree probes available, falling back to keyword search" ) - return _NO_RESULTS_MESSAGE, None, context + best_files = await self._fast_find_best_file( + primary, top_k=top_k_files, keyword_idfs=keyword_idfs, + query=query, artifacts=artifacts, **rga_kwargs, + ) + if not best_files and fallback: + best_files = await self._fast_find_best_file( + fallback, top_k=top_k_files, keyword_idfs=keyword_idfs, + query=query, artifacts=artifacts, **rga_kwargs, + ) + if not best_files: + return _NO_RESULTS_MESSAGE, None, context else: # --- Original rga-based retrieval logic --- # High-confidence catalog routing: skip rga, use catalog directly @@ -2675,16 +2686,32 @@ async def _rga_evidence() -> str: pass # 0.5 Table digest priority (pre-compiled PDF table evidence) - if ev is None and artifacts and artifacts.manifest_map: - _me = artifacts.manifest_map.get(fp) - if _me and getattr(_me, 'has_table_digest', False): - _all_tables = self._load_table_digest( - self.work_path, _me.file_hash, - ) - if _all_tables: - _table_ev = self._format_table_evidence(_all_tables) - if _table_ev: - ev = f"[{fn} - Table Evidence]\n{_table_ev}" + _all_tables = None + if ev is None and artifacts: + # Primary: manifest-based lookup + if artifacts.manifest_map: + _me = artifacts.manifest_map.get(fp) + if _me and getattr(_me, 'has_table_digest', False): + _all_tables = self._load_table_digest( + self.work_path, _me.file_hash, + ) + + # Fallback: direct hash-based lookup when manifest misses + if not _all_tables: + try: + from sirchmunk.utils.file_utils import get_fast_hash + _file_hash = get_fast_hash(fp) + if _file_hash: + _all_tables = self._load_table_digest( + self.work_path, _file_hash, + ) + except Exception: + pass + + if _all_tables: + _table_ev = self._format_table_evidence(_all_tables) + if _table_ev: + ev = f"[{fn} - Table Evidence]\n{_table_ev}" # 1. Tree-guided sampling FIRST for tree-indexed files if ( @@ -3492,8 +3519,8 @@ def _detect_compile_artifacts(self) -> CompileArtifacts: # Prefer manifest-based detection (fast, O(1) per file) if manifest_map: tree_paths = {fp for fp, entry in manifest_map.items() if entry.has_tree} - # Fallback: scan tree cache directory (legacy path) - elif indexer is not None: + # Always try directory fallback if manifest-based detection found nothing + if not tree_paths and indexer is not None: tree_cache = self.work_path / ".cache" / "compile" / "trees" if tree_cache.exists(): try: @@ -4904,7 +4931,7 @@ async def _probe_tree_for_fast( Returns file paths of selected documents, or empty list when trees are unavailable or cover too few files to justify an LLM call. """ - if not artifacts or len(artifacts.tree_available_paths) <= 2: + if not artifacts or not artifacts.tree_available_paths: return [] try: From 9dd47bed8cb12ff96b49ad0526cdc77e1e7dd88c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E7=8F=AD=E6=89=AC?= Date: Sun, 26 Apr 2026 19:53:14 +0800 Subject: [PATCH 38/56] update log --- benchmarks/financebench/run_benchmark.py | 75 ++++++++++++++++++++++-- src/sirchmunk/learnings/compiler.py | 6 ++ src/sirchmunk/learnings/tree_indexer.py | 17 +++++- src/sirchmunk/search.py | 45 ++++++++++++++ 4 files changed, 134 insertions(+), 9 deletions(-) diff --git a/benchmarks/financebench/run_benchmark.py b/benchmarks/financebench/run_benchmark.py index 65af87d..183a6d3 100644 --- a/benchmarks/financebench/run_benchmark.py +++ b/benchmarks/financebench/run_benchmark.py @@ -34,24 +34,61 @@ from evaluate import compute_metrics from runner import run_batch +# --------------------------------------------------------------------------- +# Tee stdout to log file +# --------------------------------------------------------------------------- + + +class _TeeWriter: + """Duplicate stdout to both terminal and a log file.""" + + def __init__(self, log_path: str) -> None: + self._terminal = sys.stdout + self._log = open(log_path, "w", encoding="utf-8") # noqa: SIM115 + + def write(self, msg: str) -> int: + self._terminal.write(msg) + self._log.write(msg) + return len(msg) + + def flush(self) -> None: + self._terminal.flush() + self._log.flush() + + def close(self) -> None: + self._log.close() + + # Let logging / other code check the stream capabilities + @property + def encoding(self) -> str: + return getattr(self._terminal, "encoding", "utf-8") + + def isatty(self) -> bool: + return False + + def fileno(self) -> int: + return self._terminal.fileno() + + # --------------------------------------------------------------------------- # Logging # --------------------------------------------------------------------------- -def setup_logging(output_dir: str) -> str: +def setup_logging(output_dir: str, ts: str | None = None) -> tuple[str, str]: """Configure logging to file + console. Creates a timestamped log file under ``logs/`` (relative to *output_dir*'s parent, i.e. the benchmark root directory). Returns: - Absolute path to the log file. + Tuple of (absolute path to the log file, timestamp string). """ log_dir = Path("logs") log_dir.mkdir(parents=True, exist_ok=True) - ts = datetime.now().strftime("%Y%m%d_%H%M%S") + if ts is None: + ts = datetime.now().strftime("%Y%m%d_%H%M%S") log_path = log_dir / f"benchmark_{ts}.log" root_logger = logging.getLogger("financebench") @@ -77,7 +114,7 @@ def setup_logging(output_dir: str) -> str: root_logger.addHandler(fh) root_logger.addHandler(ch) - return str(log_path.resolve()) + return str(log_path.resolve()), ts # --------------------------------------------------------------------------- @@ -169,9 +206,17 @@ def main() -> None: cfg.limit = args.limit # 2. Setup logging - log_path = setup_logging(cfg.output_dir) + ts = datetime.now().strftime("%Y%m%d_%H%M%S") + log_path, ts = setup_logging(cfg.output_dir, ts=ts) logger = logging.getLogger("financebench") + # 2b. Tee stdout → debug log so SEARCH_WIKI_DEBUG prints are captured + log_dir = Path("logs") + log_dir.mkdir(parents=True, exist_ok=True) + debug_log_path = log_dir / f"benchmark_{ts}_debug.log" + tee = _TeeWriter(str(debug_log_path)) + sys.stdout = tee + # Print config source info work_env = Path(cfg.work_path) / ".env" logger.info("=" * 50) @@ -250,7 +295,25 @@ def main() -> None: # 10. Print summary _print_summary(results, metrics, total_time, results_path, metrics_path, log_path) + print(f" Debug log: {debug_log_path.resolve()}") + + # 11. Restore stdout + sys.stdout = tee._terminal + tee.close() + + +def _main_safe() -> None: + """Wrapper that guarantees stdout is restored even on exceptions.""" + try: + main() + except (KeyboardInterrupt, Exception): + # Restore stdout if tee was installed + if hasattr(sys.stdout, "_terminal"): + terminal = sys.stdout._terminal + sys.stdout.close() + sys.stdout = terminal + raise if __name__ == "__main__": - main() + _main_safe() diff --git a/src/sirchmunk/learnings/compiler.py b/src/sirchmunk/learnings/compiler.py index 92dba7f..62f3e19 100644 --- a/src/sirchmunk/learnings/compiler.py +++ b/src/sirchmunk/learnings/compiler.py @@ -425,6 +425,8 @@ async def _bounded(entry: FileEntry) -> FileCompileResult: has_table_digest=result.has_table_digest, table_count=result.table_count, ) + _mentry = manifest.files[result.path] + print(f"SEARCH_WIKI_DEBUG [C4] manifest_entry: has_tree={_mentry.has_tree}, has_table_digest={_mentry.has_table_digest}, file_hash={_mentry.file_hash}", flush=True) # Phase 3: aggregate results into knowledge network await self._log.info("[Compile] Phase 3: Knowledge aggregation") @@ -559,6 +561,7 @@ async def _compile_single_file( the pipeline skips tree building and summarises via a direct LLM call. """ result = FileCompileResult(path=entry.path) + print(f"SEARCH_WIKI_DEBUG [C1] _compile_single_file: file_path={entry.path}, file_hash={entry.file_hash}", flush=True) try: await self._log.info(f"[Compile] Processing: {Path(entry.path).name}") @@ -599,6 +602,7 @@ async def _compile_single_file( # Record TOC / tree metrics on the result for manifest persistence result.has_explicit_toc = toc_entries is not None and len(toc_entries) > 0 result.tree_node_count = self._count_tree_nodes(result.tree) + print(f"SEARCH_WIKI_DEBUG [C2] tree_build: success={result.tree is not None}, nodes={result.tree_node_count}, tree.file_path={result.tree.file_path if result.tree else 'N/A'}", flush=True) # Enrich content with structural metadata for non-text types ext = Path(entry.path).suffix.lower() @@ -650,6 +654,8 @@ async def _compile_single_file( except Exception: pass + print(f"SEARCH_WIKI_DEBUG [C3] table_digest: generated={result.has_table_digest}, count={result.table_count}", flush=True) + # Integrate tables into tree: annotate counts + create table child nodes if result.tree and result.tree.root and extraction.tables: self._integrate_tables_into_tree( diff --git a/src/sirchmunk/learnings/tree_indexer.py b/src/sirchmunk/learnings/tree_indexer.py index a720bf3..2e2f909 100644 --- a/src/sirchmunk/learnings/tree_indexer.py +++ b/src/sirchmunk/learnings/tree_indexer.py @@ -233,7 +233,9 @@ async def build_tree( # await self._deepen_large_leaves(root, content, max_depth=effective_depth) # Node summary enrichment: controlled by SIRCHMUNK_SKIP_NODE_SUMMARIES env var. # Set to "true" to skip during debugging / performance testing. - if os.getenv("SIRCHMUNK_SKIP_NODE_SUMMARIES", "").lower() not in ("true", "1", "yes"): + _skip_summaries = os.getenv("SIRCHMUNK_SKIP_NODE_SUMMARIES", "").lower() in ("true", "1", "yes") + print(f"SEARCH_WIKI_DEBUG [T1] enrich_node_summaries (TOC path): skip={_skip_summaries}, env={os.getenv('SIRCHMUNK_SKIP_NODE_SUMMARIES', '')}", flush=True) + if not _skip_summaries: await self._enrich_node_summaries(root, content) tree = DocumentTree( file_path=file_path, @@ -260,7 +262,9 @@ async def build_tree( # await self._deepen_large_leaves(root, content, max_depth=effective_depth) # Node summary enrichment: controlled by SIRCHMUNK_SKIP_NODE_SUMMARIES env var. # Set to "true" to skip during debugging / performance testing. - if os.getenv("SIRCHMUNK_SKIP_NODE_SUMMARIES", "").lower() not in ("true", "1", "yes"): + _skip_summaries = os.getenv("SIRCHMUNK_SKIP_NODE_SUMMARIES", "").lower() in ("true", "1", "yes") + print(f"SEARCH_WIKI_DEBUG [T1] enrich_node_summaries (recursive path): skip={_skip_summaries}, env={os.getenv('SIRCHMUNK_SKIP_NODE_SUMMARIES', '')}", flush=True) + if not _skip_summaries: await self._enrich_node_summaries(root, content) tree = DocumentTree( @@ -304,6 +308,8 @@ async def navigate( if tree.root is None: return [] + print(f"SEARCH_WIKI_DEBUG [T2] navigate: query={query[:80]}, total_nodes={self._count_nodes(tree.root)}", flush=True) + candidates = tree.root.children if tree.root.children else [tree.root] if not candidates: return [tree.root] @@ -318,6 +324,7 @@ async def navigate( selected = await self._select_children( frontier, query, max_selections=max_results, ) + print(f"SEARCH_WIKI_DEBUG [T3] navigate layer: depth={depth}, selected={len(selected)}, names={[n.title[:30] for n in selected][:5]}", flush=True) if not selected: break @@ -351,7 +358,10 @@ async def navigate( if n.node_id not in seen_ids: seen_ids.add(n.node_id) unique.append(n) - return unique[:max_results] + leaves = unique[:max_results] + _page_valid = sum(1 for l in leaves if getattr(l, 'page_range', None) and len(l.page_range) == 2 and l.page_range[0]) + print(f"SEARCH_WIKI_DEBUG [T4] navigate result: leaves={len(leaves)}, page_range_valid={_page_valid}", flush=True) + return leaves def load_tree(self, file_path: str) -> Optional[DocumentTree]: """Load a cached tree index for the given file (sync).""" @@ -821,6 +831,7 @@ def _cache_path(self, file_hash: str) -> Path: def _save_cache(self, file_hash: str, tree: DocumentTree) -> None: path = self._cache_path(file_hash) path.write_text(tree.to_json(), encoding="utf-8") + print(f"SEARCH_WIKI_DEBUG [C5] tree_json_saved: path={path}", flush=True) def _load_cache(self, file_hash: str) -> Optional[DocumentTree]: path = self._cache_path(file_hash) diff --git a/src/sirchmunk/search.py b/src/sirchmunk/search.py index 1c47a77..738db15 100644 --- a/src/sirchmunk/search.py +++ b/src/sirchmunk/search.py @@ -2527,6 +2527,8 @@ async def _search_fast( {"path": p, "matches": [], "total_matches": 0, "weighted_score": 0.0} for p in _tree_probed_files[:top_k_files] ] + print(f"SEARCH_WIKI_DEBUG [D7] _tree_probed_files={_tree_probed_files}", flush=True) + print(f"SEARCH_WIKI_DEBUG [D8] best_files={[bf['path'] for bf in best_files]}", flush=True) await self._logger.info( f"[FAST:PureTree] Using {len(best_files)} tree-probed files: " f"{[Path(p).name for p in _tree_probed_files[:top_k_files]]}" @@ -2651,6 +2653,11 @@ async def _search_fast( tree_nav_done: Set[str] = set() tree_nav_target = best_files[0]["path"] + print(f"SEARCH_WIKI_DEBUG [D9] tree_nav_target={tree_nav_target}", flush=True) + print(f"SEARCH_WIKI_DEBUG [D10] tree_nav_match={tree_nav_target in (artifacts.tree_available_paths if artifacts else set())}", flush=True) + if artifacts and tree_nav_target not in artifacts.tree_available_paths: + print(f"SEARCH_WIKI_DEBUG [D11] MISMATCH! tree_available_paths={artifacts.tree_available_paths}", flush=True) + if artifacts and tree_nav_target in artifacts.tree_available_paths: tree_task = self._navigate_tree_for_evidence(tree_nav_target, query) tree_nav_done.add(tree_nav_target) @@ -2669,6 +2676,8 @@ async def _rga_evidence() -> str: ext = Path(fp).suffix.lower() ev = None + print(f"SEARCH_WIKI_DEBUG [D12] _rga_evidence: fp={fp}", flush=True) + # 0. Excel digest priority (pre-compiled evidence) if artifacts and artifacts.manifest_map: manifest_entry = artifacts.manifest_map.get(fp) @@ -2708,12 +2717,16 @@ async def _rga_evidence() -> str: except Exception: pass + print(f"SEARCH_WIKI_DEBUG [D13] table_digest: manifest_lookup={'found' if artifacts.manifest_map and artifacts.manifest_map.get(fp) else 'miss'}, has_table_digest={getattr(artifacts.manifest_map.get(fp), 'has_table_digest', False) if artifacts.manifest_map else 'N/A'}, hash_fallback={'tried' if not _all_tables else 'skipped'}, tables_count={len(_all_tables) if _all_tables else 0}", flush=True) + if _all_tables: _table_ev = self._format_table_evidence(_all_tables) if _table_ev: ev = f"[{fn} - Table Evidence]\n{_table_ev}" # 1. Tree-guided sampling FIRST for tree-indexed files + _tree_cond = artifacts and fp in artifacts.tree_available_paths and fp not in tree_nav_done + print(f"SEARCH_WIKI_DEBUG [D14] tree_sample: cond={_tree_cond}, in_tree_paths={fp in (artifacts.tree_available_paths if artifacts else set())}, in_nav_done={fp in tree_nav_done}", flush=True) if ( artifacts and fp in artifacts.tree_available_paths @@ -2755,6 +2768,14 @@ async def _rga_evidence() -> str: parts.append(ev[:remaining]) chars += len(parts[-1]) context.mark_file_read(fp) + + _ev_source = "none" + if ev: + if "Table Evidence" in ev: _ev_source = "table_digest" + elif "Pre-compiled" in ev: _ev_source = "excel_digest" + elif "TreeSample" in str(ev)[:50] or "TreeNav" in str(ev)[:50]: _ev_source = "tree" + else: _ev_source = "rga_or_other" + print(f"SEARCH_WIKI_DEBUG [D15] ev_source={_ev_source}, ev_len={len(ev) if ev else 0}", flush=True) return "\n\n---\n\n".join(parts) # Launch tree navigation for the primary file alongside rga @@ -2770,6 +2791,10 @@ async def _rga_evidence() -> str: evidence_parts_final.append(rga_ev) evidence = "\n\n---\n\n".join(evidence_parts_final) + print(f"SEARCH_WIKI_DEBUG [D16] tree_ev: {'yes' if tree_ev else 'no'}, len={len(tree_ev) if tree_ev else 0}", flush=True) + print(f"SEARCH_WIKI_DEBUG [D17] rga_ev: {'yes' if rga_ev else 'no'}, len={len(rga_ev) if rga_ev else 0}", flush=True) + print(f"SEARCH_WIKI_DEBUG [D18] final_evidence_len={len(evidence)}", flush=True) + if not evidence or len(evidence.strip()) < 20: if llm_fallback: await self._logger.info( @@ -3549,6 +3574,9 @@ def _detect_compile_artifacts(self) -> CompileArtifacts: except Exception: pass + print(f"SEARCH_WIKI_DEBUG [D1] manifest_map: {len(manifest_map)} entries, keys={list(manifest_map.keys())[:3]}", flush=True) + print(f"SEARCH_WIKI_DEBUG [D2] tree_available_paths: {tree_paths}", flush=True) + print(f"SEARCH_WIKI_DEBUG [D3] manifest_fallback_executed: {manifest_map and not tree_paths}", flush=True) return CompileArtifacts( catalog=catalog, catalog_map=catalog_map, @@ -3804,6 +3832,8 @@ async def _tree_guided_sample( if max_chars <= 0: max_chars = self._FAST_MAX_EVIDENCE_CHARS + print(f"SEARCH_WIKI_DEBUG [S1] _tree_guided_sample: file_path={file_path}", flush=True) + # --- Guard: tree availability --- if artifacts is not None: if file_path not in artifacts.tree_available_paths: @@ -3839,6 +3869,7 @@ async def _tree_guided_sample( # --- Classify leaves by extraction method --- trimmed = leaves[: self._TREE_SAMPLE_MAX_SECTIONS] page_leaves, char_leaves, table_and_summary = self._classify_leaves(trimmed) + print(f"SEARCH_WIKI_DEBUG [S2] classify_leaves: page={len(page_leaves)}, char={len(char_leaves)}, table_summary={len(table_and_summary)}", flush=True) # Collect (leaf, segment) pairs preserving original leaf order leaf_segments: List[tuple] = [] # (leaf, segment_text) @@ -3968,6 +3999,7 @@ async def _tree_guided_sample( return None evidence = "\n\n".join(parts) + print(f"SEARCH_WIKI_DEBUG [S3] _tree_guided_sample result: len={len(evidence) if evidence else 0}", flush=True) await self._logger.info( f"[TreeSample] {fname}: " f"{len(parts)} sections, {total_chars} chars " @@ -4126,6 +4158,7 @@ async def _navigate_tree_for_evidence( 3. leaf.summary – last resort """ indexer = self._get_tree_indexer() + print(f"SEARCH_WIKI_DEBUG [N1] _navigate_tree_for_evidence: file_path={file_path}", flush=True) if indexer is None: return None tree = indexer.load_tree(file_path) @@ -4137,6 +4170,8 @@ async def _navigate_tree_for_evidence( except Exception: return None + print(f"SEARCH_WIKI_DEBUG [N2] navigate_result: {len(leaves) if leaves else 0} leaves", flush=True) + if not leaves: return None @@ -4145,6 +4180,7 @@ async def _navigate_tree_for_evidence( # ── Phase 1: classify leaves by available extraction method ── page_leaves, char_leaves, summary_only = self._classify_leaves(leaves) + print(f"SEARCH_WIKI_DEBUG [N3] classify_leaves: page={len(page_leaves)}, char={len(char_leaves)}, summary={len(summary_only)}", flush=True) for leaf in summary_only: self._append_evidence_part( @@ -4191,6 +4227,9 @@ async def _navigate_tree_for_evidence( self._append_evidence_part( parts, fname, leaf, leaf.summary, ) + print(f"SEARCH_WIKI_DEBUG [N4] page_extraction: page_leaves_ok=False", flush=True) + else: + print(f"SEARCH_WIKI_DEBUG [N4] page_extraction: page_leaves_ok=True", flush=True) # ── Phase 3: char_range fallback (lazy full-text extraction) ── if char_leaves: @@ -4256,7 +4295,10 @@ async def _navigate_tree_for_evidence( except Exception: pass + print(f"SEARCH_WIKI_DEBUG [N5] table_supplement: tables_loaded={len(_all_tables) if '_all_tables' in dir() and _all_tables else 0}", flush=True) + evidence = "\n\n".join(parts) + print(f"SEARCH_WIKI_DEBUG [N6] _navigate_tree_for_evidence result: len={len(evidence) if evidence else 0}", flush=True) await self._logger.info( f"[FAST:TreeNav] Extracted {len(parts)} sections, " f"{len(evidence)} chars from {fname}" @@ -4931,16 +4973,19 @@ async def _probe_tree_for_fast( Returns file paths of selected documents, or empty list when trees are unavailable or cover too few files to justify an LLM call. """ + print(f"SEARCH_WIKI_DEBUG [D4] _probe_tree_for_fast: tree_available_paths={len(artifacts.tree_available_paths) if artifacts else 0}", flush=True) if not artifacts or not artifacts.tree_available_paths: return [] try: trees = self._load_cached_trees() + print(f"SEARCH_WIKI_DEBUG [D5] loaded_trees: {len(trees)} trees, paths={[t.file_path for t in trees][:3]}", flush=True) if not trees: return [] result = await self._llm_select_from_trees( query, trees, max_select=self._FAST_TREE_PROBE_MAX_FILES, ) + print(f"SEARCH_WIKI_DEBUG [D6] llm_select_result: {result}", flush=True) if result: await self._logger.info( f"[FAST:TreeProbe] Selected {len(result)} files " From 8ff1f98c207e3c998bfaf10155cd72c9b0bfb871 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E7=8F=AD=E6=89=AC?= Date: Sun, 26 Apr 2026 21:32:01 +0800 Subject: [PATCH 39/56] enhance search fast for compile --- src/sirchmunk/learnings/tree_indexer.py | 283 +++++++++++++++++++++--- src/sirchmunk/search.py | 17 +- 2 files changed, 269 insertions(+), 31 deletions(-) diff --git a/src/sirchmunk/learnings/tree_indexer.py b/src/sirchmunk/learnings/tree_indexer.py index 2e2f909..2d7e277 100644 --- a/src/sirchmunk/learnings/tree_indexer.py +++ b/src/sirchmunk/learnings/tree_indexer.py @@ -163,6 +163,9 @@ class DocumentTreeIndexer: # Number of nodes per group in paginated selection. _GROUP_PAGE_SIZE: int = 15 + # Minimum navigation depth before allowing early termination. + _NAV_MIN_DEPTH: int = 2 + def __init__( self, llm: OpenAIChat, @@ -289,18 +292,21 @@ async def navigate( *, max_results: int = 3, max_depth: int = 4, + min_depth: int = 2, ) -> List[TreeNode]: """Adaptive-depth LLM-driven tree navigation. Iteratively descends the tree using _select_children() at each level, collecting leaf nodes until *max_results* are found or *max_depth* is - reached. + reached. Enforces *min_depth* descent before allowing early + termination to avoid shallow results. Args: tree: DocumentTree with a root node. query: Search query for relevance selection. max_results: Maximum number of leaf nodes to return. max_depth: Maximum descent depth (default 4). + min_depth: Minimum depth before early termination (default 2). Returns: List of the most relevant leaf TreeNodes. @@ -314,6 +320,10 @@ async def navigate( if not candidates: return [tree.root] + # Adaptive min-depth: clamp to tree's actual depth + tree_max_depth = self._max_node_depth(tree.root) + effective_min_depth = min(min_depth, max(tree_max_depth - 1, 1)) + result_leaves: List[TreeNode] = [] visited: set = set() # prevent cycles frontier = candidates @@ -325,7 +335,21 @@ async def navigate( frontier, query, max_selections=max_results, ) print(f"SEARCH_WIKI_DEBUG [T3] navigate layer: depth={depth}, selected={len(selected)}, names={[n.title[:30] for n in selected][:5]}", flush=True) + if not selected: + # Fix A.1: when depth < effective_min_depth, expand all frontier children + if depth < effective_min_depth: + next_frontier: List[TreeNode] = [] + for node in frontier: + if node.children: + next_frontier.extend(node.children) + else: + result_leaves.append(node) + if not next_frontier: + break + frontier = next_frontier + depth += 1 + continue break next_frontier: List[TreeNode] = [] @@ -335,14 +359,25 @@ async def navigate( continue visited.add(node_id) + # Fix A.2: leaf determination with depth constraint if node.leaf or not node.children: - result_leaves.append(node) + if depth >= effective_min_depth: + result_leaves.append(node) + elif node.children: + next_frontier.extend(node.children) + else: + # True leaf (no children), cannot descend further + result_leaves.append(node) else: next_frontier.extend(node.children) - if len(result_leaves) >= max_results: + # Fix A.3: early termination requires depth >= effective_min_depth + if len(result_leaves) >= max_results and depth >= effective_min_depth: break + # Fix A.4: check for empty next_frontier + if not next_frontier: + break frontier = next_frontier depth += 1 @@ -404,6 +439,9 @@ async def _build_tree_from_toc( # Infer hierarchy when TOC entries are flat (all same level) toc_entries = self._infer_hierarchy(toc_entries) + # Merge consecutive fragment entries into virtual parents + toc_entries = self._merge_fragment_entries(toc_entries) + seen_ids: set = set() children = self._toc_entries_to_nodes( toc_entries, content, len(content), seen_ids, @@ -425,6 +463,78 @@ async def _build_tree_from_toc( children=children, ) + @staticmethod + def _merge_fragment_entries(entries: List[Any]) -> List[Any]: + """Merge consecutive fragment TOC entries into virtual parent nodes. + + Detects runs of >=3 consecutive entries that have tiny char_range + spans (<500) and no children, then collapses them into a single + virtual 'Preamble' entry. Uses only structural signals (char spans, + children counts) — no domain-specific keywords. + + Safety valve: returns original *entries* if result has < 2 entries. + """ + if len(entries) <= 5: + return entries + + # Phase 1: Detect fragment runs + def _is_fragment(e: Any) -> bool: + span = 0 + if hasattr(e, 'char_start') and hasattr(e, 'char_end'): + if e.char_end and e.char_start is not None: + span = e.char_end - e.char_start + has_children = bool(getattr(e, 'children', None)) + return span < 500 and not has_children + + # Find runs of consecutive fragments + runs: List[List[int]] = [] # list of [start_idx, end_idx] inclusive + i = 0 + while i < len(entries): + if _is_fragment(entries[i]): + run_start = i + while i < len(entries) and _is_fragment(entries[i]): + i += 1 + if (i - run_start) >= 3: # Only merge runs of 3+ + runs.append([run_start, i - 1]) + else: + i += 1 + + if not runs: + return entries + + # Phase 2: Merge each run into a virtual parent + from copy import deepcopy + + result: List[Any] = [] + prev_end = -1 + for run_start, run_end in runs: + # Add non-fragment entries before this run + for j in range(prev_end + 1, run_start): + result.append(entries[j]) + + # Create virtual parent from the run + first_entry = entries[run_start] + last_entry = entries[run_end] + + merged = deepcopy(first_entry) + merged.title = f"Preamble ({run_end - run_start + 1} sections)" + if hasattr(last_entry, 'char_end') and last_entry.char_end: + merged.char_end = last_entry.char_end + # Set children to the original entries + merged.children = list(entries[run_start:run_end + 1]) + result.append(merged) + prev_end = run_end + + # Add remaining entries after last run + for j in range(prev_end + 1, len(entries)): + result.append(entries[j]) + + # Safety valve + if len(result) < 2: + return entries + + return result + @staticmethod def _toc_entries_to_nodes( entries: List[Any], @@ -672,34 +782,161 @@ def _resolve_positions( and (s["end"] - s["start"]) / max(text_len, 1) < _MAX_SPAN_RATIO ] + @staticmethod + def _filter_low_value_nodes( + nodes: List["TreeNode"], + *, + min_remaining: int = 3, + ) -> List["TreeNode"]: + """Filter out low-value fragment nodes using structural signals. + + Applies three generic heuristics (no domain-specific keywords): + 1. Short-page leaf: page_range spans <= 2 pages AND no children AND + summary length < 100 chars. + 2. Tiny fragment: title < 10 chars AND no children AND + char_range span < 200 chars. + 3. Duplicate page_range: among nodes sharing the same page_range, + keep only the one with the largest char_range span. + + Safety valve: returns original *nodes* if fewer than *min_remaining* + survive filtering. + """ + if len(nodes) <= min_remaining: + return nodes + + # Pass 1: identify fragment nodes + keep: List[bool] = [True] * len(nodes) + + for i, n in enumerate(nodes): + pr = getattr(n, 'page_range', None) + has_children = bool(n.children) + summary_len = len(n.summary) if n.summary else 0 + title_len = len(n.title.strip()) if n.title else 0 + cr = getattr(n, 'char_range', (0, 0)) + span = (cr[1] - cr[0]) if cr and len(cr) == 2 else 0 + + # Heuristic 1: short-page leaf + if ( + pr and len(pr) == 2 + and pr[0] is not None and pr[1] is not None + and (pr[1] - pr[0]) <= 1 + and not has_children + and summary_len < 100 + ): + keep[i] = False + continue + + # Heuristic 2: tiny fragment + if title_len < 10 and not has_children and span < 200: + keep[i] = False + continue + + # Pass 2: deduplicate by page_range + page_range_groups: dict = {} # page_range -> list of (index, span) + for i, n in enumerate(nodes): + if not keep[i]: + continue + pr = getattr(n, 'page_range', None) + if pr and len(pr) == 2: + key = (pr[0], pr[1]) + cr = getattr(n, 'char_range', (0, 0)) + span = (cr[1] - cr[0]) if cr and len(cr) == 2 else 0 + page_range_groups.setdefault(key, []).append((i, span)) + + for key, group in page_range_groups.items(): + if len(group) > 1: + # Keep only the node with largest char_range span + best_idx = max(group, key=lambda x: x[1])[0] + for idx, _ in group: + if idx != best_idx: + keep[idx] = False + + filtered = [n for i, n in enumerate(nodes) if keep[i]] + return filtered if len(filtered) >= min_remaining else nodes + + @staticmethod + def _build_node_descriptor(node: "TreeNode", index: int) -> str: + """Build a rich descriptor string for a single tree node. + + Includes structural signals: page span, table count, subsection + count, and depth information to help LLM make informed selections. + """ + parts = [f"[{index}] {node.title}"] + + # Page range with span + pr = getattr(node, 'page_range', None) + if pr and len(pr) == 2 and pr[0] is not None: + span_pages = pr[1] - pr[0] + 1 if pr[1] else 1 + parts.append(f"[pages {pr[0]}-{pr[1]}, {span_pages}p]") + + # Table count + if node.table_count > 0: + parts.append(f"[{node.table_count} tables]") + + # Subsections + child_count = len(node.children) + if child_count > 0: + parts.append(f"[{child_count} subsections]") + + # Summary + summary = (node.summary or "")[:200] + if summary: + parts.append(f": {summary}") + + return " ".join(parts) + + @staticmethod + def _build_selection_prompt( + nodes: List["TreeNode"], + query: str, + max_selections: int, + ) -> str: + """Build unified LLM prompt for branch selection. + + Uses structural signals to guide LLM toward high-value sections: + tables, subsection depth, page span. No domain-specific keywords. + """ + listing = "\n".join( + DocumentTreeIndexer._build_node_descriptor(n, i) + for i, n in enumerate(nodes) + ) + + sel_hint = f"1-{min(max_selections, len(nodes))}" + + return ( + f"Given the query: \"{query}\"\n\n" + f"Select the {sel_hint} most relevant sections (by index number):\n" + f"{listing}\n\n" + f"Selection criteria:\n" + f"- Prioritize sections containing tables and data\n" + f"- Prefer sections with many subsections over small leaf fragments\n" + f"- Avoid sections covering only 1-2 pages with no subsections\n" + f"- When uncertain, prefer larger sections that can be narrowed later\n\n" + f"Return ONLY a JSON array of index numbers, e.g. [0, 2]" + ) + async def _select_children( self, nodes: List[TreeNode], query: str, *, max_selections: int = 3, ) -> List[TreeNode]: """LLM-driven branch selection: pick the most relevant children. - Dispatches to paginated selection when *nodes* exceeds - ``_PAGE_SIZE_THRESHOLD`` to avoid overwhelming the LLM. + Pre-filters low-value fragments, then dispatches to paginated + selection when *nodes* exceeds ``_PAGE_SIZE_THRESHOLD``. """ if len(nodes) <= 2: return nodes + # Pre-filter low-value fragment nodes + nodes = self._filter_low_value_nodes(nodes) + if len(nodes) <= 2: + return nodes + if len(nodes) > self._PAGE_SIZE_THRESHOLD: return await self._select_children_paginated( nodes, query, max_selections=max_selections, ) - listing = "\n".join( - f"[{i}] {n.title}{self._format_page_range(n.page_range)}" - f"{' [' + str(n.table_count) + ' tables]' if n.table_count > 0 else ''}" - f": {n.summary[:150]}" - for i, n in enumerate(nodes) - ) - - prompt = ( - f"Given the query: \"{query}\"\n\n" - f"Select the 1-2 most relevant sections (by index number):\n{listing}\n\n" - f"Return ONLY a JSON array of index numbers, e.g. [0, 2]" - ) + prompt = self._build_selection_prompt(nodes, query, max_selections) resp = await self._llm.achat([{"role": "user", "content": prompt}]) try: raw = resp.content.strip() @@ -797,17 +1034,7 @@ async def _select_from_group( if len(group) <= 2: return group - listing = "\n".join( - f"[{i}] {n.title}{self._format_page_range(n.page_range)}" - f"{' [' + str(n.table_count) + ' tables]' if n.table_count > 0 else ''}" - f": {n.summary[:150]}" - for i, n in enumerate(group) - ) - prompt = ( - f"Given the query: \"{query}\"\n\n" - f"Select the 1-2 most relevant sections (by index number):\n{listing}\n\n" - f"Return ONLY a JSON array of index numbers, e.g. [0, 2]" - ) + prompt = self._build_selection_prompt(group, query, max_selections) try: resp = await self._llm.achat([{"role": "user", "content": prompt}]) raw = resp.content.strip() diff --git a/src/sirchmunk/search.py b/src/sirchmunk/search.py index 738db15..e12fd39 100644 --- a/src/sirchmunk/search.py +++ b/src/sirchmunk/search.py @@ -4022,9 +4022,20 @@ def _classify_leaves(leaves: list) -> Tuple[List[tuple], List, List]: summary_leaves: List = [] for leaf in leaves: - # 表格类型节点优先使用 summary(结构化摘要) - if getattr(leaf, 'content_type', 'text') == 'table' and getattr(leaf, 'summary', None): - summary_leaves.append(leaf) + # 表格类型节点:优先 page-level 提取获取完整原始内容 + if getattr(leaf, 'content_type', 'text') == 'table': + page_range = getattr(leaf, 'page_range', None) + if ( + page_range + and len(page_range) == 2 + and page_range[0] is not None + and page_range[0] > 0 + ): + page_leaves.append((leaf, page_range)) + elif getattr(leaf, 'summary', None): + summary_leaves.append(leaf) + else: + char_leaves.append(leaf) continue page_range = getattr(leaf, 'page_range', None) From 464d8d511093f07d388313549df7f1fab8dfbb43 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E7=8F=AD=E6=89=AC?= Date: Sun, 26 Apr 2026 22:49:59 +0800 Subject: [PATCH 40/56] enhance tree index --- src/sirchmunk/learnings/tree_indexer.py | 130 +++++++++++------------- src/sirchmunk/llm/prompts.py | 12 ++- src/sirchmunk/search.py | 114 ++++++++++++++++----- 3 files changed, 158 insertions(+), 98 deletions(-) diff --git a/src/sirchmunk/learnings/tree_indexer.py b/src/sirchmunk/learnings/tree_indexer.py index 2d7e277..9cf450e 100644 --- a/src/sirchmunk/learnings/tree_indexer.py +++ b/src/sirchmunk/learnings/tree_indexer.py @@ -292,21 +292,21 @@ async def navigate( *, max_results: int = 3, max_depth: int = 4, - min_depth: int = 2, + min_depth: int = 1, ) -> List[TreeNode]: """Adaptive-depth LLM-driven tree navigation. Iteratively descends the tree using _select_children() at each level, collecting leaf nodes until *max_results* are found or *max_depth* is reached. Enforces *min_depth* descent before allowing early - termination to avoid shallow results. + termination to avoid overly shallow results. Args: tree: DocumentTree with a root node. query: Search query for relevance selection. max_results: Maximum number of leaf nodes to return. max_depth: Maximum descent depth (default 4). - min_depth: Minimum depth before early termination (default 2). + min_depth: Minimum depth before early termination (default 1). Returns: List of the most relevant leaf TreeNodes. @@ -320,6 +320,16 @@ async def navigate( if not candidates: return [tree.root] + # Skip single-child container chains (e.g. SEC boilerplate wrappers + # like "UNITED STATES SECURITIES AND EXCHANGE COMMISSION" → "FORM 10-K") + # to avoid wasting navigation depth on structural-only nodes. + while ( + len(candidates) == 1 + and candidates[0].children + and not candidates[0].leaf + ): + candidates = candidates[0].children + # Adaptive min-depth: clamp to tree's actual depth tree_max_depth = self._max_node_depth(tree.root) effective_min_depth = min(min_depth, max(tree_max_depth - 1, 1)) @@ -359,17 +369,10 @@ async def navigate( continue visited.add(node_id) - # Fix A.2: leaf determination with depth constraint - if node.leaf or not node.children: - if depth >= effective_min_depth: - result_leaves.append(node) - elif node.children: - next_frontier.extend(node.children) - else: - # True leaf (no children), cannot descend further - result_leaves.append(node) - else: + if node.children: next_frontier.extend(node.children) + else: + result_leaves.append(node) # Fix A.3: early termination requires depth >= effective_min_depth if len(result_leaves) >= max_results and depth >= effective_min_depth: @@ -788,68 +791,58 @@ def _filter_low_value_nodes( *, min_remaining: int = 3, ) -> List["TreeNode"]: - """Filter out low-value fragment nodes using structural signals. - - Applies three generic heuristics (no domain-specific keywords): - 1. Short-page leaf: page_range spans <= 2 pages AND no children AND - summary length < 100 chars. - 2. Tiny fragment: title < 10 chars AND no children AND - char_range span < 200 chars. - 3. Duplicate page_range: among nodes sharing the same page_range, - keep only the one with the largest char_range span. - - Safety valve: returns original *nodes* if fewer than *min_remaining* - survive filtering. + """Remove only structurally empty or exact-duplicate nodes. + + Intentionally conservative: the LLM selection step receives rich + structural descriptors (page span, table count, subsection count) + and is trusted to judge relevance. This filter removes only + definitive noise that would waste LLM context: + + 1. Empty placeholders — no title, no children, zero char span, + and no summary. + 2. Exact duplicates — identical (title, page_range) pairs; among + duplicates the node with the richest structure is kept. + + Safety: returns original *nodes* when fewer than *min_remaining* + would survive. """ if len(nodes) <= min_remaining: return nodes - # Pass 1: identify fragment nodes keep: List[bool] = [True] * len(nodes) - for i, n in enumerate(nodes): - pr = getattr(n, 'page_range', None) - has_children = bool(n.children) - summary_len = len(n.summary) if n.summary else 0 - title_len = len(n.title.strip()) if n.title else 0 - cr = getattr(n, 'char_range', (0, 0)) - span = (cr[1] - cr[0]) if cr and len(cr) == 2 else 0 - - # Heuristic 1: short-page leaf - if ( - pr and len(pr) == 2 - and pr[0] is not None and pr[1] is not None - and (pr[1] - pr[0]) <= 1 - and not has_children - and summary_len < 100 - ): - keep[i] = False - continue + def _char_span(n: "TreeNode") -> int: + cr = getattr(n, "char_range", (0, 0)) + return (cr[1] - cr[0]) if cr and len(cr) == 2 else 0 - # Heuristic 2: tiny fragment - if title_len < 10 and not has_children and span < 200: + # Pass 1: remove structurally empty placeholder nodes + for i, n in enumerate(nodes): + title = (n.title or "").strip() + if not title and not n.children and _char_span(n) == 0 and not n.summary: keep[i] = False - continue - # Pass 2: deduplicate by page_range - page_range_groups: dict = {} # page_range -> list of (index, span) + # Pass 2: deduplicate exact (title, page_range) pairs — + # keep the node with more structural richness. + seen: dict = {} # (title, page_range_key) → index for i, n in enumerate(nodes): if not keep[i]: continue - pr = getattr(n, 'page_range', None) - if pr and len(pr) == 2: - key = (pr[0], pr[1]) - cr = getattr(n, 'char_range', (0, 0)) - span = (cr[1] - cr[0]) if cr and len(cr) == 2 else 0 - page_range_groups.setdefault(key, []).append((i, span)) - - for key, group in page_range_groups.items(): - if len(group) > 1: - # Keep only the node with largest char_range span - best_idx = max(group, key=lambda x: x[1])[0] - for idx, _ in group: - if idx != best_idx: - keep[idx] = False + title = (n.title or "").strip() + pr = getattr(n, "page_range", None) + pr_key = (pr[0], pr[1]) if pr and len(pr) == 2 else None + dup_key = (title, pr_key) + if dup_key in seen: + prev_i = seen[dup_key] + prev = nodes[prev_i] + richness = (len(n.children), getattr(n, "table_count", 0), _char_span(n)) + prev_richness = (len(prev.children), getattr(prev, "table_count", 0), _char_span(prev)) + if richness > prev_richness: + keep[prev_i] = False + seen[dup_key] = i + else: + keep[i] = False + else: + seen[dup_key] = i filtered = [n for i, n in enumerate(nodes) if keep[i]] return filtered if len(filtered) >= min_remaining else nodes @@ -908,9 +901,9 @@ def _build_selection_prompt( f"Select the {sel_hint} most relevant sections (by index number):\n" f"{listing}\n\n" f"Selection criteria:\n" - f"- Prioritize sections containing tables and data\n" - f"- Prefer sections with many subsections over small leaf fragments\n" - f"- Avoid sections covering only 1-2 pages with no subsections\n" + f"- Prioritize sections most likely to answer the query\n" + f"- Sections with tables, data, or subsections are often high-value\n" + f"- Short sections containing relevant data should not be dismissed\n" f"- When uncertain, prefer larger sections that can be narrowed later\n\n" f"Return ONLY a JSON array of index numbers, e.g. [0, 2]" ) @@ -920,8 +913,9 @@ async def _select_children( ) -> List[TreeNode]: """LLM-driven branch selection: pick the most relevant children. - Pre-filters low-value fragments, then dispatches to paginated - selection when *nodes* exceeds ``_PAGE_SIZE_THRESHOLD``. + Removes only definitive noise (empty / duplicate nodes), then + dispatches to paginated selection when *nodes* exceeds + ``_PAGE_SIZE_THRESHOLD``. Relevance judgment is delegated to the LLM. """ if len(nodes) <= 2: return nodes diff --git a/src/sirchmunk/llm/prompts.py b/src/sirchmunk/llm/prompts.py index 27338a2..074847e 100644 --- a/src/sirchmunk/llm/prompts.py +++ b/src/sirchmunk/llm/prompts.py @@ -422,6 +422,7 @@ def generate_keyword_extraction_prompt(num_levels: int = 3) -> str: 1. **Language Continuity**: The output must be in the SAME language as the User Input. 2. **Format**: Use Markdown (headings, bullet points, and bold text) for high readability. 3. **Style**: Keep it professional, objective, and clear. Avoid fluff. +4. **Precision**: When the query asks for a specific value, ratio, number, percentage, or yes/no determination, you MUST compute it and state the precise result. Show key calculation steps when applicable. ### Input Data - **User Input**: {user_input} @@ -442,8 +443,11 @@ def generate_keyword_extraction_prompt(num_levels: int = 3) -> str: - If evidence is insufficient or irrelevant, both SHOULD_ANSWER and SHOULD_SAVE MUST be "false". ### Output Format + +[If the query asks for a specific value, ratio, number, or factual answer, state ONLY the direct answer here (e.g. "0.83", "$1,832 million", "Yes", "Increased from 0.67 to 0.69"). If the query is open-ended, write a one-sentence conclusion.] + -[Generate the Markdown Briefing here] +[Generate the Markdown Briefing here with detailed analysis and supporting evidence] true/false true/false @@ -458,6 +462,7 @@ def generate_keyword_extraction_prompt(num_levels: int = 3) -> str: 1. **Language Continuity**: The output must be in the SAME language as the User Input. 2. **Format**: Use Markdown (headings, bullet points, and bold text) for high readability. 3. **Style**: Keep it professional, objective, and clear. Avoid fluff. +4. **Precision**: When the query asks for a specific value, ratio, number, percentage, or yes/no determination, you MUST compute it and state the precise result. Show key calculation steps when applicable. ### Document Context {document_context} @@ -481,8 +486,11 @@ def generate_keyword_extraction_prompt(num_levels: int = 3) -> str: - If evidence is insufficient or irrelevant, both SHOULD_ANSWER and SHOULD_SAVE MUST be "false". ### Output Format + +[If the query asks for a specific value, ratio, number, or factual answer, state ONLY the direct answer here (e.g. "0.83", "$1,832 million", "Yes", "Increased from 0.67 to 0.69"). If the query is open-ended, write a one-sentence conclusion.] + -[Generate the Markdown Briefing here] +[Generate the Markdown Briefing here with detailed analysis and supporting evidence] true/false true/false diff --git a/src/sirchmunk/search.py b/src/sirchmunk/search.py index e12fd39..921c383 100644 --- a/src/sirchmunk/search.py +++ b/src/sirchmunk/search.py @@ -934,20 +934,21 @@ async def _search_by_filename( @staticmethod def _parse_summary_response(llm_response: str) -> Tuple[str, bool, bool]: - """ - Parse LLM response to extract summary and quality decisions. + """Parse LLM response to extract summary, precise answer, and quality decisions. - Args: - llm_response: Raw LLM response containing SUMMARY, SHOULD_ANSWER and SHOULD_SAVE tags + When a ```` tag is present, its content is prepended to + the summary so downstream consumers (evaluation judges, UIs) see the + direct answer prominently without needing separate tag awareness. Returns: Tuple of (summary_text, should_save_flag, should_answer_flag) """ summary_fields = extract_fields( content=llm_response, - tags=["SUMMARY", "SHOULD_ANSWER", "SHOULD_SAVE"], + tags=["PRECISE_ANSWER", "SUMMARY", "SHOULD_ANSWER", "SHOULD_SAVE"], ) + precise = str(summary_fields.get("precise_answer") or "").strip() summary = str(summary_fields.get("summary") or "").strip() should_answer_str = str(summary_fields.get("should_answer") or "false").strip().lower() should_save_str = str(summary_fields.get("should_save") or "false").strip().lower() @@ -955,8 +956,11 @@ def _parse_summary_response(llm_response: str) -> Tuple[str, bool, bool]: should_answer = should_answer_str in ["true", "yes", "1"] should_save = should_save_str in ["true", "yes", "1"] - # If extraction failed, use entire response as summary and default to conservative: - # not answerable and not saveable. + if precise and summary: + summary = f"**Answer: {precise}**\n\n{summary}" + elif precise: + summary = precise + if not summary: summary = llm_response.strip() should_answer = False @@ -2198,7 +2202,7 @@ async def _llm_chunked_summarize(self, combined: str, query: str) -> str: """Maximum files returned by catalog keyword-overlap probe in DEEP mode.""" # --- Tree-guided sampling constants --- - _TREE_SAMPLE_MAX_SECTIONS = 3 + _TREE_SAMPLE_MAX_SECTIONS = 5 """Max tree sections to include per file in tree-guided sampling.""" _TREE_SAMPLE_SECTION_MAX_CHARS = 3000 """Max chars per tree section.""" @@ -2218,10 +2222,10 @@ async def _llm_chunked_summarize(self, combined: str, query: str) -> str: """char_range spanning more than this ratio of the document is treated as invalid.""" # --- Self-correction expanded sampling --- - _SELF_CORRECT_EXPANDED_NAV_RESULTS: int = 6 - """Expanded tree navigation leaf count for same-file re-sampling (default nav uses 3).""" - _SELF_CORRECT_EXPANDED_SECTIONS: int = 5 - """Expanded tree sample sections for same-file re-sampling (default uses 3).""" + _SELF_CORRECT_EXPANDED_NAV_RESULTS: int = 10 + """Expanded tree navigation leaf count for same-file re-sampling (default nav uses 5).""" + _SELF_CORRECT_EXPANDED_SECTIONS: int = 8 + """Expanded tree sample sections for same-file re-sampling (default uses 5).""" # --- Evidence acceptance thresholds --- _EVIDENCE_MIN_ACCEPT_LENGTH: int = 800 @@ -2720,7 +2724,9 @@ async def _rga_evidence() -> str: print(f"SEARCH_WIKI_DEBUG [D13] table_digest: manifest_lookup={'found' if artifacts.manifest_map and artifacts.manifest_map.get(fp) else 'miss'}, has_table_digest={getattr(artifacts.manifest_map.get(fp), 'has_table_digest', False) if artifacts.manifest_map else 'N/A'}, hash_fallback={'tried' if not _all_tables else 'skipped'}, tables_count={len(_all_tables) if _all_tables else 0}", flush=True) if _all_tables: - _table_ev = self._format_table_evidence(_all_tables) + _table_ev = self._format_table_evidence( + _all_tables, query=query, + ) if _table_ev: ev = f"[{fn} - Table Evidence]\n{_table_ev}" @@ -4009,20 +4015,26 @@ async def _tree_guided_sample( @staticmethod def _classify_leaves(leaves: list) -> Tuple[List[tuple], List, List]: - """将叶节点按提取策略分类。 + """Classify leaf nodes by preferred extraction strategy. + + For non-table leaves, **char_range** (kreuzberg markdown) is preferred + over page_range (pypdf raw text) because compile-time extraction + preserves table layout and column structure far better than pypdf's + ``extract_text()``. page_range remains available on each leaf for + table-supplement filtering even when the leaf is routed to char_leaves. Returns: - (page_leaves, char_leaves, summary_leaves) 三元组: - - page_leaves: list of (leaf, page_range) tuples — 有有效 page_range 的 - - char_leaves: list of leaf — 需要 char_range fallback 的 - - summary_leaves: list of leaf — 只有 summary 可用的 + (page_leaves, char_leaves, summary_leaves) triple: + - page_leaves: list of (leaf, page_range) — page-level extraction + - char_leaves: list of leaf — kreuzberg char_range extraction + - summary_leaves: list of leaf — only summary available """ page_leaves: List[tuple] = [] char_leaves: List = [] summary_leaves: List = [] for leaf in leaves: - # 表格类型节点:优先 page-level 提取获取完整原始内容 + # Table nodes: prefer page-level extraction for raw original content if getattr(leaf, 'content_type', 'text') == 'table': page_range = getattr(leaf, 'page_range', None) if ( @@ -4038,11 +4050,21 @@ def _classify_leaves(leaves: list) -> Tuple[List[tuple], List, List]: char_leaves.append(leaf) continue + # Non-table leaves: prefer char_range (kreuzberg markdown) over + # page_range (pypdf raw text) for higher-fidelity table rendering. + has_char = hasattr(leaf, 'char_range') and leaf.char_range page_range = getattr(leaf, 'page_range', None) - if page_range and len(page_range) == 2 and page_range[0] is not None and page_range[0] > 0: - page_leaves.append((leaf, page_range)) - elif hasattr(leaf, 'char_range') and leaf.char_range: + has_page = ( + page_range + and len(page_range) == 2 + and page_range[0] is not None + and page_range[0] > 0 + ) + + if has_char: char_leaves.append(leaf) + elif has_page: + page_leaves.append((leaf, page_range)) elif getattr(leaf, 'summary', None): summary_leaves.append(leaf) @@ -4095,27 +4117,62 @@ def _filter_tables_by_page_range( and page_start <= t["page_number"] <= page_end ] + @staticmethod + def _score_table_relevance( + markdown: str, query_tokens: frozenset, + ) -> float: + """Score a table's relevance to the query via token overlap. + + Returns a value in [0, 1] representing the fraction of *query_tokens* + found in the table's markdown text (case-insensitive). + """ + if not markdown or not query_tokens: + return 0.0 + md_lower = markdown.lower() + hits = sum(1 for tok in query_tokens if tok in md_lower) + return hits / len(query_tokens) + @staticmethod def _format_table_evidence( tables: List[Dict[str, Any]], - max_chars: int = 3000, + max_chars: int = 6000, + query: str = "", ) -> str: """Format table digest entries as LLM-friendly evidence text. + When *query* is provided, tables are **sorted by relevance** to the + query before budget truncation, ensuring critical tables are included + even when they appear late in page order. + Strategy: - - Small tables (<1000 chars): preserve full Markdown - - Large tables: truncate to max_chars with "(truncated)" note + - Query-relevant tables are prioritised via keyword overlap scoring - Each table prefixed with "[Table from page N]" + - Large tables truncated with "(truncated)" note Returns concatenated formatted table evidence string. """ if not tables: return "" + ordered = tables + if query: + query_tokens = frozenset( + tok for tok in query.lower().split() if len(tok) > 2 + ) + if query_tokens: + scored = [ + (AgenticSearch._score_table_relevance( + t.get("markdown", ""), query_tokens, + ), idx, t) + for idx, t in enumerate(tables) + ] + scored.sort(key=lambda x: (-x[0], x[1])) + ordered = [t for _, _, t in scored] + parts: List[str] = [] remaining = max_chars - for table in tables: + for table in ordered: if remaining <= 0: break @@ -4155,7 +4212,7 @@ def _append_evidence_part( parts.append(f"{header}\n{text}") async def _navigate_tree_for_evidence( - self, file_path: str, query: str, *, max_results: int = 3, + self, file_path: str, query: str, *, max_results: int = 5, ) -> Optional[str]: """LLM-driven tree navigation: select relevant sections and read leaf content. @@ -4297,7 +4354,8 @@ async def _navigate_tree_for_evidence( ) if leaf_tables: table_text = self._format_table_evidence( - leaf_tables, max_chars=2000, + leaf_tables, max_chars=4000, + query=query, ) if table_text: parts.append( From bdd8bdc666628de6120075bcadfffd3d4f37b9a2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E7=8F=AD=E6=89=AC?= Date: Sun, 26 Apr 2026 23:05:03 +0800 Subject: [PATCH 41/56] fix review --- src/sirchmunk/learnings/compiler.py | 13 +++ src/sirchmunk/search.py | 124 ++++++++++++++++++++++------ 2 files changed, 114 insertions(+), 23 deletions(-) diff --git a/src/sirchmunk/learnings/compiler.py b/src/sirchmunk/learnings/compiler.py index 62f3e19..ef901aa 100644 --- a/src/sirchmunk/learnings/compiler.py +++ b/src/sirchmunk/learnings/compiler.py @@ -635,6 +635,19 @@ async def _compile_single_file( except Exception: pass + # Cache compile-time ENHANCED content so search can slice + # char_range from the same text the tree was built from. + try: + file_hash_content = get_fast_hash(entry.path) or "" + if file_hash_content and content: + content_dir = self._compile_dir / "content" + content_dir.mkdir(parents=True, exist_ok=True) + (content_dir / f"{file_hash_content}.txt").write_text( + content, encoding="utf-8", + ) + except Exception: + pass + # Persist table digest for documents with extracted tables if extraction.tables: try: diff --git a/src/sirchmunk/search.py b/src/sirchmunk/search.py index 921c383..c38dfee 100644 --- a/src/sirchmunk/search.py +++ b/src/sirchmunk/search.py @@ -3930,14 +3930,16 @@ async def _tree_guided_sample( leaf_segments.append((leaf, seg)) # If page extraction failed, demoted leaves are now in char_leaves - # -- Phase C: char_range fallback (lazy full-text extraction) -- + # -- Phase C: char_range extraction (compile-consistent content) -- if char_leaves: - try: - from sirchmunk.utils.file_utils import fast_extract - extraction = await fast_extract(file_path=file_path) - full_text = extraction.content or "" - except Exception: - full_text = "" + full_text = self._load_compile_content(self.work_path, file_path) + if not full_text: + try: + from sirchmunk.utils.file_utils import fast_extract + extraction = await fast_extract(file_path=file_path) + full_text = extraction.content or "" + except Exception: + full_text = "" for leaf in char_leaves: start, end = leaf.char_range @@ -4084,6 +4086,31 @@ def _is_valid_char_range( span_ratio = (end - start) / text_len return span_ratio < self._CHAR_RANGE_MAX_SPAN_RATIO + @staticmethod + def _load_compile_content( + work_path: Path, file_path: str, + ) -> Optional[str]: + """Load the ENHANCED content cached at compile time. + + Compile stores the kreuzberg ENHANCED-profile content alongside the + tree index so that search-time ``char_range`` slicing operates on + the *same* text the ranges were computed from. Returns ``None`` + when the cache file is missing (e.g. pre-cache compile run). + """ + try: + from sirchmunk.utils.file_utils import get_fast_hash + file_hash = get_fast_hash(file_path) + if not file_hash: + return None + cache_path = ( + work_path / ".cache" / "compile" / "content" / f"{file_hash}.txt" + ) + if cache_path.exists(): + return cache_path.read_text(encoding="utf-8") + except Exception: + pass + return None + @staticmethod def _load_table_digest( work_path: Path, file_hash: str, @@ -4221,8 +4248,8 @@ async def _navigate_tree_for_evidence( Returns None when no tree cache is available. Extraction priority (highest first): - 1. page_range – page-level extraction via DocumentExtractor - 2. char_range – full-text extraction + slice (fallback) + 1. char_range – compile-time ENHANCED content slice (preserves tables) + 2. page_range – page-level extraction via DocumentExtractor (fallback) 3. leaf.summary – last resort """ indexer = self._get_tree_indexer() @@ -4299,14 +4326,22 @@ async def _navigate_tree_for_evidence( else: print(f"SEARCH_WIKI_DEBUG [N4] page_extraction: page_leaves_ok=True", flush=True) - # ── Phase 3: char_range fallback (lazy full-text extraction) ── + # ── Phase 3: char_range extraction (compile-consistent content) ── if char_leaves: - try: - from sirchmunk.utils.file_utils import fast_extract - extraction = await fast_extract(file_path=file_path) - full_text = extraction.content or "" - except Exception: - full_text = "" + # Prefer compile-time ENHANCED content (matches char_range offsets + # exactly). Fall back to fast_extract only when cache is absent. + full_text = self._load_compile_content(self.work_path, file_path) + if not full_text: + try: + from sirchmunk.utils.file_utils import fast_extract + extraction = await fast_extract(file_path=file_path) + full_text = extraction.content or "" + except Exception: + full_text = "" + + # Leaves whose char_range is invalid but have a valid page_range + # are demoted to page extraction instead of discarding to summary. + page_fallback_leaves: List[tuple] = [] for leaf in char_leaves: start, end = leaf.char_range @@ -4320,14 +4355,57 @@ async def _navigate_tree_for_evidence( self._append_evidence_part( parts, fname, leaf, leaf.summary, ) - elif getattr(leaf, 'summary', None): - _loguru_logger.debug( - f"[TreeNav] char_range degraded for '{leaf.title}' " - f"(span_ratio={(end - start) / max(len(full_text), 1):.2f}), using summary" - ) - self._append_evidence_part( - parts, fname, leaf, leaf.summary, + else: + # char_range covers too much of the document (or text is + # empty). Try page_range extraction before falling back + # to summary. + pr = getattr(leaf, 'page_range', None) + if ( + pr + and len(pr) == 2 + and pr[0] is not None + and pr[0] > 0 + ): + page_fallback_leaves.append((leaf, pr)) + elif getattr(leaf, 'summary', None): + _loguru_logger.debug( + f"[TreeNav] char_range degraded for '{leaf.title}' " + f"(span_ratio={(end - start) / max(len(full_text), 1):.2f}), " + f"using summary" + ) + self._append_evidence_part( + parts, fname, leaf, leaf.summary, + ) + + # Batch page extraction for demoted leaves (same pattern as Phase 2) + if page_fallback_leaves: + all_fb_pages: set = set() + for _lf, (sp, ep) in page_fallback_leaves: + all_fb_pages.update(range(sp, ep + 1)) + try: + fb_contents = DocumentExtractor.extract_pages( + file_path, sorted(all_fb_pages), ) + fb_map = {pc.page_number: pc.content for pc in fb_contents} + for lf, (sp, ep) in page_fallback_leaves: + seg_parts = [ + fb_map[p] for p in range(sp, ep + 1) + if fb_map.get(p, "").strip() + ] + if seg_parts: + self._append_evidence_part( + parts, fname, lf, "\n".join(seg_parts), + ) + elif getattr(lf, 'summary', None): + self._append_evidence_part( + parts, fname, lf, lf.summary, + ) + except Exception: + for lf, _ in page_fallback_leaves: + if getattr(lf, 'summary', None): + self._append_evidence_part( + parts, fname, lf, lf.summary, + ) if not parts: return None From d3b91d679e17ae2e697ae5913cec702eaa41ce52 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E7=8F=AD=E6=89=AC?= Date: Mon, 27 Apr 2026 02:09:00 +0800 Subject: [PATCH 42/56] improve kreuzberg table extraction --- src/sirchmunk/learnings/compiler.py | 295 ++++++++++++++++++---- src/sirchmunk/utils/document_extractor.py | 20 +- 2 files changed, 265 insertions(+), 50 deletions(-) diff --git a/src/sirchmunk/learnings/compiler.py b/src/sirchmunk/learnings/compiler.py index ef901aa..ddd509b 100644 --- a/src/sirchmunk/learnings/compiler.py +++ b/src/sirchmunk/learnings/compiler.py @@ -63,6 +63,9 @@ # this threshold are skipped during targeted extraction. _TABLE_NUMERIC_DENSITY_THRESHOLD = 0.15 +# Selective force-OCR: max pages to re-extract with forced OCR per document +_FORCE_OCR_MAX_PAGES = 30 + # Excel table-level adaptive sampling constants _XLSX_TOTAL_ROW_BUDGET = 100 # Total sampled rows budget across all sheets _XLSX_MIN_ROWS_PER_SHEET = 3 # Minimum sampled rows per sheet @@ -676,50 +679,43 @@ async def _compile_single_file( content=content, total_pages=extraction.page_count, ) - # Phase 2.5: Targeted table extraction via generic structural signals + # Phase 2.5: Targeted table extraction via tree-node structural signals if result.tree and result.tree.root and ext == ".pdf": targeted_tables = await self._targeted_table_extraction( entry.path, result.tree, ) - if targeted_tables: - # Load existing table digest (if any) and merge - digest_dir = self._compile_dir / "table_digests" - file_hash = get_fast_hash(entry.path) or "" - existing_digest: list[dict] = [] - if file_hash and result.has_table_digest: - digest_path = digest_dir / f"{file_hash}.json" - if digest_path.exists(): - try: - raw = json.loads( - digest_path.read_text(encoding="utf-8") - ) - existing_digest = raw.get("tables", []) - except Exception: - pass - merged = self._merge_table_digests( - existing_digest, targeted_tables, + await self._supplement_table_digest( + entry.path, targeted_tables, result, + source_label="Targeted extraction", + ) + + # Phase 2.6: Content-based full-page table scan (tree-independent) + if ext == ".pdf" and extraction.page_count: + covered_pages = self._get_covered_table_pages(entry.path) + content_tables = await self._content_based_table_scan( + entry.path, + extraction.page_count, + covered_pages, + ) + await self._supplement_table_digest( + entry.path, content_tables, result, + source_label="Content-based scan", + ) + + # Phase 2.7: Selective force-OCR for high-density gap pages + if ext == ".pdf" and extraction.page_count: + covered_after_scan = self._get_covered_table_pages(entry.path) + gap_pages = self._find_force_ocr_candidates( + entry.path, extraction.page_count, covered_after_scan, + ) + if gap_pages: + ocr_tables = await self._selective_force_ocr_tables( + entry.path, gap_pages, + ) + await self._supplement_table_digest( + entry.path, ocr_tables, result, + source_label="Selective force-OCR", ) - if merged and file_hash: - digest_dir.mkdir(parents=True, exist_ok=True) - digest_path = digest_dir / f"{file_hash}.json" - digest_path.write_text( - json.dumps( - { - "version": 1, - "table_count": len(merged), - "tables": merged, - }, - ensure_ascii=False, - ), - encoding="utf-8", - ) - result.has_table_digest = True - result.table_count = len(merged) - await self._log.info( - f"[Compile] Targeted table extraction added " - f"{len(targeted_tables)} tables for " - f"{Path(entry.path).name}" - ) except Exception as exc: result.error = str(exc) @@ -1707,17 +1703,226 @@ def _merge_table_digests( page = cls._get_table_page(tbl) if page is not None and page in existing_pages: continue - # Normalise to digest table format for consistency merged.append({ "page_number": page, - "markdown": tbl.get("content", ""), - "row_count": None, - "col_count": None, - "cells": [], - "source": tbl.get("source", "targeted"), + "markdown": tbl.get("markdown", "") or tbl.get("content", ""), + "row_count": tbl.get("row_count"), + "col_count": tbl.get("col_count"), + "cells": tbl.get("cells", []), + "source": tbl.get("source", "supplementary"), }) return merged + async def _supplement_table_digest( + self, + file_path: str, + new_tables: list[dict], + result: "FileCompileResult", + *, + source_label: str, + ) -> None: + """Merge supplementary tables into the persisted table digest. + + Loads the existing digest (if any), merges *new_tables* with + page-level deduplication, and writes the updated digest back. + Updates *result* metadata in place. + """ + if not new_tables: + return + + file_hash = get_fast_hash(file_path) or "" + if not file_hash: + return + + digest_dir = self._compile_dir / "table_digests" + digest_path = digest_dir / f"{file_hash}.json" + + existing: list[dict] = [] + if result.has_table_digest and digest_path.exists(): + try: + raw = json.loads(digest_path.read_text(encoding="utf-8")) + existing = raw.get("tables", []) + except Exception: + pass + + merged = self._merge_table_digests(existing, new_tables) + if not merged: + return + + digest_dir.mkdir(parents=True, exist_ok=True) + digest_path.write_text( + json.dumps( + {"version": 1, "table_count": len(merged), "tables": merged}, + ensure_ascii=False, + ), + encoding="utf-8", + ) + result.has_table_digest = True + result.table_count = len(merged) + await self._log.info( + f"[Compile] {source_label}: +{len(new_tables)} tables for " + f"{Path(file_path).name} (total={len(merged)})" + ) + + def _get_covered_table_pages(self, file_path: str) -> Set[int]: + """Return the set of page numbers already present in the table digest.""" + file_hash = get_fast_hash(file_path) or "" + if not file_hash: + return set() + + digest_path = ( + self._compile_dir / "table_digests" / f"{file_hash}.json" + ) + if not digest_path.exists(): + return set() + + try: + raw = json.loads(digest_path.read_text(encoding="utf-8")) + pages: Set[int] = set() + for t in raw.get("tables", []): + p = self._get_table_page(t) + if p is not None: + pages.add(p) + return pages + except Exception: + return set() + + # ------------------------------------------------------------------ # + # Tree-independent content-based table scanning (P1) # + # ------------------------------------------------------------------ # + + async def _content_based_table_scan( + self, + file_path: str, + total_pages: Optional[int], + kreuzberg_table_pages: Set[int], + ) -> list[dict]: + """Scan *all* PDF pages for table-like regions via numeric density. + + Unlike :meth:`_targeted_table_extraction` this method does **not** + depend on tree node metadata (``page_range``, ``table_count``). + It reads every page through pypdf and applies the same density + + region-detection heuristics, skipping pages that already have a + kreuzberg-detected table. + + Args: + file_path: Path to the PDF file. + total_pages: Total page count (from extraction metadata). + kreuzberg_table_pages: Page numbers already covered by kreuzberg + layout-detected tables. + + Returns: + List of table dicts compatible with the digest format:: + + {"page": int, "content": str, "source": "content_scan"} + """ + if not total_pages or total_pages <= 0: + return [] + + all_page_nums = list(range(1, total_pages + 1)) + try: + pages = DocumentExtractor.extract_pages(file_path, all_page_nums) + except Exception as exc: + await self._log.warning( + f"[Compile] Content-based scan: page read failed for " + f"{Path(file_path).name}: {exc}" + ) + return [] + + results: list[dict] = [] + for pc in pages: + if pc.page_number in kreuzberg_table_pages: + continue + if not self._page_has_table_density(pc.content): + continue + regions = self._identify_table_regions(pc.content) + for region in regions: + results.append({ + "page": pc.page_number, + "content": region[:_TARGETED_TABLE_MAX_CHARS], + "source": "content_scan", + }) + return results + + def _find_force_ocr_candidates( + self, + file_path: str, + total_pages: Optional[int], + covered_pages: Set[int], + ) -> List[int]: + """Identify pages worth re-extracting with forced OCR. + + Returns 0-indexed page numbers for pages that have high numeric + density (suggesting tabular content) but are NOT already covered + by any table in the digest. The result is capped at + :data:`_FORCE_OCR_MAX_PAGES`. + """ + if not total_pages or total_pages <= 0: + return [] + + all_page_nums = list(range(1, total_pages + 1)) + try: + pages = DocumentExtractor.extract_pages(file_path, all_page_nums) + except Exception: + return [] + + candidates: List[int] = [] + for pc in pages: + if pc.page_number in covered_pages: + continue + if self._page_has_table_density(pc.content): + candidates.append(pc.page_number - 1) # 0-indexed for kreuzberg + + return sorted(candidates)[:_FORCE_OCR_MAX_PAGES] + + # ------------------------------------------------------------------ # + # Selective force-OCR re-extraction (P2) # + # ------------------------------------------------------------------ # + + async def _selective_force_ocr_tables( + self, + file_path: str, + gap_pages: List[int], + ) -> list[dict[str, Any]]: + """Re-extract specific pages with forced OCR + layout detection. + + For pages where the native text layer was not recognized as tables + by kreuzberg's RT-DETR model, re-rendering as images may yield + better layout detection results. Uses ``force_ocr_pages`` so only + the targeted pages are OCR'd (no full-document penalty). + + Args: + file_path: Path to the PDF. + gap_pages: 0-indexed page numbers to force OCR on. Capped at + :data:`_FORCE_OCR_MAX_PAGES` to bound compile time. + + Returns: + List of kreuzberg-format table dicts (with ``markdown``, + ``cells``, ``page_number``). + """ + from sirchmunk.utils.document_extractor import ExtractionProfile + + if not gap_pages: + return [] + + capped = sorted(gap_pages)[:_FORCE_OCR_MAX_PAGES] + + profile = ExtractionProfile( + output_format="markdown", + extract_tables=True, + force_ocr_pages=tuple(capped), + ) + try: + extraction = await DocumentExtractor.extract(file_path, profile) + except Exception as exc: + await self._log.warning( + f"[Compile] Selective force-OCR failed for " + f"{Path(file_path).name}: {exc}" + ) + return [] + + return extraction.tables + # ------------------------------------------------------------------ # # Summary index for embedding + BM25 fallback # # ------------------------------------------------------------------ # diff --git a/src/sirchmunk/utils/document_extractor.py b/src/sirchmunk/utils/document_extractor.py index a022a8d..b2835f5 100644 --- a/src/sirchmunk/utils/document_extractor.py +++ b/src/sirchmunk/utils/document_extractor.py @@ -72,6 +72,14 @@ class ExtractionProfile: when set, OCR is always applied regardless of text layer presence. """ + force_ocr_pages: Optional[tuple[int, ...]] = None + """Force OCR on specific pages only (0-indexed). + + Maps to kreuzberg's ``ExtractionConfig.force_ocr_pages``. + Mutually exclusive with :attr:`force_ocr` — when both are set, + ``force_ocr`` takes precedence. + """ + pdf_password: Optional[str] = None """Password for encrypted PDFs.""" @@ -459,12 +467,12 @@ def _build_config(profile: ExtractionProfile): if profile.extract_tables: try: from kreuzberg import LayoutDetectionConfig - # kreuzberg >= 4.5.0: model-based table detection (RT-DETR v2) - # Default: table_model="tatr", apply_heuristics=True - layout_config = LayoutDetectionConfig() + layout_config = LayoutDetectionConfig( + confidence_threshold=0.3, + apply_heuristics=True, + table_model="slanet_auto", + ) except ImportError: - # kreuzberg < 4.5.0: tables extracted via heuristics only; - # filtering is handled in _convert_result(). pass # --- Assemble ExtractionConfig --- @@ -475,6 +483,8 @@ def _build_config(profile: ExtractionProfile): kwargs["ocr"] = ocr_config if profile.force_ocr: kwargs["force_ocr"] = True + elif profile.force_ocr_pages: + kwargs["force_ocr_pages"] = list(profile.force_ocr_pages) if page_config is not None: kwargs["pages"] = page_config if pdf_config is not None: From 63ed04795a704da5e6214b652871d0882df776ae Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E7=8F=AD=E6=89=AC?= Date: Mon, 27 Apr 2026 02:49:02 +0800 Subject: [PATCH 43/56] enhance compiler table extraction --- src/sirchmunk/learnings/compiler.py | 252 ++++++++++++++++++++++------ 1 file changed, 198 insertions(+), 54 deletions(-) diff --git a/src/sirchmunk/learnings/compiler.py b/src/sirchmunk/learnings/compiler.py index ddd509b..3316e54 100644 --- a/src/sirchmunk/learnings/compiler.py +++ b/src/sirchmunk/learnings/compiler.py @@ -8,6 +8,7 @@ """ import asyncio +import bisect import json import math import os @@ -66,6 +67,23 @@ # Selective force-OCR: max pages to re-extract with forced OCR per document _FORCE_OCR_MAX_PAGES = 30 +# Shared numeric-token regex for table detection heuristics. +# Matches: $1,234 (1,234) 12.5% 3.14e-5 1,000 +_NUM_TOKEN_RE = re.compile( + r"(?:" + r"[\$€£¥]\s*[\d,.]+|" + r"\([\d,.]+\)|" + r"[\d,.]+%|" + r"[\d]+\.[\d]+(?:[eE][+-]?\d+)?|" + r"[\d,]{2,}" + r")" +) + +# A single line with >= this many numeric tokens is treated as a dense +# table row (or multiple rows concatenated), enabling detection even when +# pypdf flattens the entire page to one or two lines. +_DENSE_LINE_MIN_TOKENS = 15 + # Excel table-level adaptive sampling constants _XLSX_TOTAL_ROW_BUDGET = 100 # Total sampled rows budget across all sheets _XLSX_MIN_ROWS_PER_SHEET = 3 # Minimum sampled rows per sheet @@ -692,10 +710,16 @@ async def _compile_single_file( # Phase 2.6: Content-based full-page table scan (tree-independent) if ext == ".pdf" and extraction.page_count: covered_pages = self._get_covered_table_pages(entry.path) + tree_root = ( + result.tree.root + if result.tree and result.tree.root else None + ) content_tables = await self._content_based_table_scan( entry.path, extraction.page_count, covered_pages, + enhanced_content=content, + tree_root=tree_root, ) await self._supplement_table_digest( entry.path, content_tables, result, @@ -1595,10 +1619,15 @@ def _walk(node: "TreeNode") -> None: @staticmethod def _page_has_table_density(page_text: str) -> bool: - """Return True if *page_text* has numeric density above the threshold. + """Return True if *page_text* likely contains tabular numeric data. - Counts digits and common table symbols (``$``, ``%``, ``(``, ``)``) - relative to total non-whitespace characters. + Two independent signals (either suffices): + + 1. **Character-level density** — fraction of digit/symbol chars + relative to total non-whitespace exceeds the threshold. + 2. **Token-dense line** — any single line contains + ``_DENSE_LINE_MIN_TOKENS`` or more numeric tokens, which + catches pages where pypdf flattens all content into ≤ 2 lines. """ if not page_text: return False @@ -1609,17 +1638,26 @@ def _page_has_table_density(page_text: str) -> bool: 1 for ch in page_text if ch.isdigit() or ch in "$%(),.+-" ) - return (numeric_chars / non_ws) >= _TABLE_NUMERIC_DENSITY_THRESHOLD + if (numeric_chars / non_ws) >= _TABLE_NUMERIC_DENSITY_THRESHOLD: + return True + return any( + len(_NUM_TOKEN_RE.findall(line)) >= _DENSE_LINE_MIN_TOKENS + for line in page_text.split("\n") + ) @staticmethod def _identify_table_regions(page_text: str) -> list[str]: """Identify contiguous table-like regions in *page_text*. - Heuristic rules: - - Lines containing multiple numeric tokens (dollar amounts, %, - parenthesised negatives) are considered *numeric rows*. - - A run of >= 3 consecutive numeric rows forms a table region. - - Leading/trailing whitespace rows are trimmed. + Two complementary strategies: + + 1. **Consecutive-line detection** — a run of ≥ 3 lines each + containing ≥ 2 numeric tokens forms a table region. Works + well when pypdf preserves per-row line breaks. + 2. **Dense-line detection** — a *single* line with ≥ + ``_DENSE_LINE_MIN_TOKENS`` numeric tokens is treated as a + table region. This handles PDFs where pypdf collapses + the entire page into one or two very long lines. Returns: List of extracted region strings (may be empty). @@ -1627,52 +1665,44 @@ def _identify_table_regions(page_text: str) -> list[str]: if not page_text: return [] - # Pattern: line has at least 2 numeric-looking tokens - _NUM_TOKEN = re.compile( - r"(?:" - r"[\$€£¥]\s*[\d,.]+|" - r"\([\d,.]+\)|" - r"[\d,.]+%|" - r"[\d]+\.[\d]+(?:[eE][+-]?\d+)?|" - r"[\d,]{2,}" - r")" - ) _MIN_NUMS_PER_LINE = 2 _MIN_CONSECUTIVE = 3 lines = page_text.split("\n") - is_numeric = [ - len(_NUM_TOKEN.findall(line)) >= _MIN_NUMS_PER_LINE - for line in lines + token_counts = [ + len(_NUM_TOKEN_RE.findall(line)) for line in lines ] regions: list[str] = [] - run_start: int | None = None + captured_lines: set[int] = set() - for i, flag in enumerate(is_numeric): - if flag: + # --- Strategy 1: consecutive-line runs --- + run_start: int | None = None + for i, cnt in enumerate(token_counts): + if cnt >= _MIN_NUMS_PER_LINE: if run_start is None: run_start = i else: if run_start is not None: - run_len = i - run_start - if run_len >= _MIN_CONSECUTIVE: - # Include one context line above/below + if i - run_start >= _MIN_CONSECUTIVE: start = max(0, run_start - 1) end = min(len(lines), i + 1) regions.append( "\n".join(lines[start:end]).strip() ) + captured_lines.update(range(start, end)) run_start = None - - # Flush trailing run - if run_start is not None: - run_len = len(lines) - run_start - if run_len >= _MIN_CONSECUTIVE: - start = max(0, run_start - 1) - regions.append( - "\n".join(lines[start:]).strip() - ) + if run_start is not None and len(lines) - run_start >= _MIN_CONSECUTIVE: + start = max(0, run_start - 1) + regions.append("\n".join(lines[start:]).strip()) + captured_lines.update(range(start, len(lines))) + + # --- Strategy 2: dense-line detection --- + for i, cnt in enumerate(token_counts): + if cnt >= _DENSE_LINE_MIN_TOKENS and i not in captured_lines: + start = max(0, i - 1) + end = min(len(lines), i + 2) + regions.append("\n".join(lines[start:end]).strip()) return regions @@ -1795,30 +1825,54 @@ async def _content_based_table_scan( self, file_path: str, total_pages: Optional[int], - kreuzberg_table_pages: Set[int], + covered_pages: Set[int], + *, + enhanced_content: Optional[str] = None, + tree_root: Optional[Any] = None, ) -> list[dict]: - """Scan *all* PDF pages for table-like regions via numeric density. + """Scan PDF pages for table-like regions via numeric density. - Unlike :meth:`_targeted_table_extraction` this method does **not** - depend on tree node metadata (``page_range``, ``table_count``). - It reads every page through pypdf and applies the same density + - region-detection heuristics, skipping pages that already have a - kreuzberg-detected table. + Uses a two-tier strategy: + + 1. **pypdf page scan** — reads every page individually. Works well + when pypdf preserves per-row line breaks. + 2. **ENHANCED content fallback** — if pypdf yields poor line + structure (> 50 % of pages have ≤ 3 lines), falls back to + scanning the kreuzberg ENHANCED markdown content, which often + has better formatting. Page numbers are recovered via the + tree's ``char_range → page_range`` mapping. Args: - file_path: Path to the PDF file. - total_pages: Total page count (from extraction metadata). - kreuzberg_table_pages: Page numbers already covered by kreuzberg - layout-detected tables. + file_path: Path to the PDF file. + total_pages: Total page count. + covered_pages: Page numbers already in the table digest. + enhanced_content: Cached kreuzberg ENHANCED text (optional). + tree_root: Tree root node for char → page mapping (optional). Returns: - List of table dicts compatible with the digest format:: - - {"page": int, "content": str, "source": "content_scan"} + List of table dicts compatible with the digest format. """ if not total_pages or total_pages <= 0: return [] + results = await self._pypdf_page_scan( + file_path, total_pages, covered_pages, + ) + + if results or not enhanced_content or not tree_root: + return results + + return self._enhanced_content_scan( + enhanced_content, total_pages, covered_pages, tree_root, + ) + + async def _pypdf_page_scan( + self, + file_path: str, + total_pages: int, + covered_pages: Set[int], + ) -> list[dict]: + """Primary scan: per-page pypdf extraction with density heuristics.""" all_page_nums = list(range(1, total_pages + 1)) try: pages = DocumentExtractor.extract_pages(file_path, all_page_nums) @@ -1830,20 +1884,110 @@ async def _content_based_table_scan( return [] results: list[dict] = [] + poor_line_count = 0 for pc in pages: - if pc.page_number in kreuzberg_table_pages: + if len(pc.content.split("\n")) <= 3: + poor_line_count += 1 + if pc.page_number in covered_pages: continue if not self._page_has_table_density(pc.content): continue - regions = self._identify_table_regions(pc.content) - for region in regions: + for region in self._identify_table_regions(pc.content): results.append({ "page": pc.page_number, "content": region[:_TARGETED_TABLE_MAX_CHARS], "source": "content_scan", }) + + if results: + return results + + # Signal that pypdf line quality is poor — caller should try fallback + if poor_line_count > total_pages * 0.5: + return [] + + return results + + @staticmethod + def _enhanced_content_scan( + enhanced_content: str, + total_pages: int, + covered_pages: Set[int], + tree_root: Any, + ) -> list[dict]: + """Fallback scan: use ENHANCED (kreuzberg markdown) content. + + Scans the full ENHANCED text line-by-line for dense-token lines, + then maps each detected region back to a page number using the + tree's ``char_range → page_range`` mapping. + """ + char_page_map = KnowledgeCompiler._build_char_to_page_map( + tree_root, total_pages, + ) + if not char_page_map: + return [] + + breakpoints = [cp[0] for cp in char_page_map] + + results: list[dict] = [] + offset = 0 + for line in enhanced_content.split("\n"): + token_count = len(_NUM_TOKEN_RE.findall(line)) + if token_count >= _DENSE_LINE_MIN_TOKENS: + idx = bisect.bisect_right(breakpoints, offset) - 1 + page = char_page_map[max(0, idx)][1] if idx >= 0 else 1 + if page not in covered_pages: + results.append({ + "page": page, + "content": line[:_TARGETED_TABLE_MAX_CHARS], + "source": "content_scan:enhanced", + }) + covered_pages.add(page) + offset += len(line) + 1 # +1 for '\n' + return results + @staticmethod + def _build_char_to_page_map( + tree_root: Any, + total_pages: int, + ) -> list[tuple[int, int]]: + """Build a sorted (char_start, page_number) list from tree leaves. + + Enables efficient binary-search lookup from any character offset + in the ENHANCED content to the corresponding page number. + """ + entries: list[tuple[int, int]] = [] + + def _collect(node: Any) -> None: + children = getattr(node, "children", None) or [] + if isinstance(node, dict): + children = node.get("children", []) + pr = ( + getattr(node, "page_range", None) + if not isinstance(node, dict) + else node.get("page_range") + ) + cr = ( + getattr(node, "char_range", None) + if not isinstance(node, dict) + else node.get("char_range") + ) + if not children and cr and pr: + page = pr[0] if isinstance(pr, (list, tuple)) else pr + char_start = cr[0] if isinstance(cr, (list, tuple)) else cr + if page and char_start is not None: + entries.append((int(char_start), int(page))) + for ch in children: + _collect(ch) + + _collect(tree_root) + + if not entries: + return [(0, 1)] + entries.sort() + return entries + def _find_force_ocr_candidates( self, file_path: str, From b760119624bc22c7d4eb6e3b03cff588002e12b8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E7=8F=AD=E6=89=AC?= Date: Mon, 27 Apr 2026 14:37:55 +0800 Subject: [PATCH 44/56] fix table extraction --- src/sirchmunk/learnings/compiler.py | 379 +++++++++++++++++++++++++++- src/sirchmunk/llm/prompts.py | 18 ++ src/sirchmunk/search.py | 39 ++- 3 files changed, 432 insertions(+), 4 deletions(-) diff --git a/src/sirchmunk/learnings/compiler.py b/src/sirchmunk/learnings/compiler.py index 3316e54..8b08357 100644 --- a/src/sirchmunk/learnings/compiler.py +++ b/src/sirchmunk/learnings/compiler.py @@ -84,6 +84,28 @@ # pypdf flattens the entire page to one or two lines. _DENSE_LINE_MIN_TOKENS = 15 +# --------------------------------------------------------------------------- +# Heading normalisation: candidate extraction patterns +# --------------------------------------------------------------------------- +# kreuzberg sometimes renders section titles as ``**bold text**`` or bare +# short standalone lines instead of ``## heading``. The tree indexer can +# only split on markdown headings, so these "invisible" titles get absorbed +# into parent nodes. +# +# We extract *candidates* via lightweight regexes and let the LLM classify +# which ones are genuine section headings (language/domain-agnostic). + +_BOLD_LINE_RE = re.compile( + r"^\*\*((?:(?!\*\*).)+)\*\*\s*$", + re.MULTILINE, +) + +_STANDALONE_LINE_RE = re.compile( + r"(?:^|\n\n)([^\n]{5,100})\n\n", +) + +_HEADING_CANDIDATE_CAP = 40 + # Excel table-level adaptive sampling constants _XLSX_TOTAL_ROW_BUDGET = 100 # Total sampled rows budget across all sheets _XLSX_MIN_ROWS_PER_SHEET = 3 # Minimum sampled rows per sheet @@ -590,6 +612,7 @@ async def _compile_single_file( entry.path, DocumentExtractor.ENHANCED, ) content = extraction.content + content = await self._normalize_bold_headings(content) if not content or len(content.strip()) < 100: result.error = "Insufficient text content" return result @@ -741,6 +764,12 @@ async def _compile_single_file( source_label="Selective force-OCR", ) + # Phase 2.8: Enrich targeted-extraction tables with ENHANCED content + if ext == ".pdf" and result.has_table_digest: + self._enrich_table_digest_content( + entry.path, content, tree_root=None, + ) + except Exception as exc: result.error = str(exc) await self._log.warning(f"[Compile] Failed: {entry.path}: {exc}") @@ -1617,6 +1646,177 @@ def _walk(node: "TreeNode") -> None: _walk(root) return candidates + # ------------------------------------------------------------------ # + # LLM-based heading normalisation # + # ------------------------------------------------------------------ # + + @staticmethod + def _extract_heading_candidates( + content: str, + ) -> list[tuple[re.Match, str, str]]: + """Extract candidate lines that *might* be section headings. + + Returns a list of ``(match, title_text, source_tag)`` triples + where *source_tag* is ``"bold"`` or ``"standalone"``. + + Bold lines (``**Title**``) are always candidates. Short + standalone lines (surrounded by blank lines, 10-100 chars) are + included only when they pass structural heuristics that filter + out data rows, sentences, and existing headings. + """ + occupied: list[tuple[int, int]] = [] + candidates: list[tuple[re.Match, str, str]] = [] + + def _overlaps(start: int, end: int) -> bool: + return any(s < end and start < e for s, e in occupied) + + for m in _BOLD_LINE_RE.finditer(content): + title = m.group(1).strip() + if title and not _overlaps(m.start(), m.end()): + occupied.append((m.start(), m.end())) + candidates.append((m, title, "bold")) + + for m in _STANDALONE_LINE_RE.finditer(content): + text = m.group(1).strip() + if len(text) < 10: + continue + text_offset = m.start() + m.group(0).index(m.group(1)) + if _overlaps(text_offset, text_offset + len(m.group(1))): + continue + if text.startswith(("#", "**")): + continue + if _NUM_TOKEN_RE.search(text): + continue + if text.endswith((".", "。", "!", "?", "!", "?")): + continue + if len(text.split()) > 12: + continue + occupied.append((text_offset, text_offset + len(m.group(1)))) + candidates.append((m, text, "standalone")) + + candidates.sort(key=lambda t: t[0].start()) + return candidates[:_HEADING_CANDIDATE_CAP] + + async def _normalize_bold_headings(self, content: str) -> str: + """Detect and promote bold/standalone section titles to headings. + + Three-phase pipeline: + 1. **Extract** candidate lines via regex (deterministic). + 2. **Classify** candidates with a single LLM call — the LLM + returns which indices are section headings and their level. + 3. **Replace** confirmed headings deterministically. + + Short-circuits when no candidates are found (zero LLM calls). + On any LLM / parse failure, returns the original content unchanged + (graceful degradation — equivalent to no-op). + + The transformation is idempotent: existing ``#`` headings never + enter the candidate set. + """ + if not content: + return content + + candidates = self._extract_heading_candidates(content) + if not candidates: + return content + + listing = "\n".join( + f"{i}: \"{title}\"" for i, (_, title, _tag) in enumerate(candidates) + ) + + from sirchmunk.llm.prompts import COMPILE_CLASSIFY_HEADINGS + prompt = COMPILE_CLASSIFY_HEADINGS.format(candidates=listing) + + try: + resp = await self._llm.achat( + [{"role": "user", "content": prompt}], + ) + raw = resp.content.strip() + headings = self._parse_heading_classifications(raw, len(candidates)) + except Exception: + return content + + if not headings: + return content + + return self._apply_heading_promotions(content, candidates, headings) + + @staticmethod + def _parse_heading_classifications( + raw: str, + num_candidates: int, + ) -> list[tuple[int, int]]: + """Parse LLM JSON response into a list of ``(idx, level)`` pairs. + + Robustly handles markdown code fences, trailing commas, and + out-of-range indices. Returns an empty list on any parse failure. + """ + cleaned = raw.strip() + if cleaned.startswith("```"): + lines = cleaned.splitlines() + lines = [ln for ln in lines if not ln.strip().startswith("```")] + cleaned = "\n".join(lines).strip() + + try: + items = json.loads(cleaned) + except json.JSONDecodeError: + m = re.search(r"\[.*\]", cleaned, re.DOTALL) + if not m: + return [] + try: + items = json.loads(m.group()) + except json.JSONDecodeError: + return [] + + if not isinstance(items, list): + return [] + + result: list[tuple[int, int]] = [] + for item in items: + if isinstance(item, dict): + idx = item.get("idx") + level = item.get("level", 2) + elif isinstance(item, int): + idx, level = item, 2 + else: + continue + if not isinstance(idx, int) or not (0 <= idx < num_candidates): + continue + level = max(2, min(4, int(level))) + result.append((idx, level)) + return result + + @staticmethod + def _apply_heading_promotions( + content: str, + candidates: list[tuple[re.Match, str, str]], + headings: list[tuple[int, int]], + ) -> str: + """Apply heading promotions to *content* in reverse-offset order. + + Processes replacements from end-to-start so that earlier offsets + remain valid after each substitution. + """ + heading_map: dict[int, int] = dict(headings) + + replacements: list[tuple[int, int, str]] = [] + for idx, (match, title, tag) in enumerate(candidates): + if idx not in heading_map: + continue + level = heading_map[idx] + prefix = "#" * level + if tag == "bold": + replacements.append((match.start(), match.end(), f"{prefix} {title}")) + else: + text_start = match.start() + match.group(0).index(match.group(1)) + text_end = text_start + len(match.group(1)) + replacements.append((text_start, text_end, f"{prefix} {title}")) + + replacements.sort(key=lambda r: r[0], reverse=True) + for start, end, replacement in replacements: + content = content[:start] + replacement + content[end:] + return content + @staticmethod def _page_has_table_density(page_text: str) -> bool: """Return True if *page_text* likely contains tabular numeric data. @@ -1818,7 +2018,184 @@ def _get_covered_table_pages(self, file_path: str) -> Set[int]: return set() # ------------------------------------------------------------------ # - # Tree-independent content-based table scanning (P1) # + # P1: Enrich table digest with ENHANCED content # + # ------------------------------------------------------------------ # + + @staticmethod + def _build_page_char_map( + tree_root: Any, + max_page_span: int = _TABLE_PAGE_SPAN_LIMIT, + ) -> Dict[int, Tuple[int, int]]: + """Map page numbers to ``(start_char, end_char)`` in ENHANCED content. + + Aggregates ``char_range`` bounds from leaf nodes whose + ``page_range`` intersects a given page. To avoid inflated + ranges from wide-spanning nodes (e.g. a cover-page node + spanning pages 1–85), only nodes with a page span ≤ + *max_page_span* are used when available; wider nodes serve + as a fallback. + """ + # (char_start, char_end, page_span) per page + entries: Dict[int, List[Tuple[int, int, int]]] = {} + + def _walk(node: Any) -> None: + children = getattr(node, "children", None) or [] + if isinstance(node, dict): + children = node.get("children", []) + if not children: + pr = ( + getattr(node, "page_range", None) + if not isinstance(node, dict) + else node.get("page_range") + ) + cr = ( + getattr(node, "char_range", None) + if not isinstance(node, dict) + else node.get("char_range") + ) + if ( + pr + and cr + and len(pr) >= 2 + and len(cr) >= 2 + ): + span = int(pr[1]) - int(pr[0]) + 1 + for p in range(int(pr[0]), int(pr[1]) + 1): + entries.setdefault(p, []).append( + (int(cr[0]), int(cr[1]), span) + ) + for ch in children: + _walk(ch) + + _walk(tree_root) + + result: Dict[int, Tuple[int, int]] = {} + for page, elist in entries.items(): + narrow = [e for e in elist if e[2] <= max_page_span] + chosen = narrow if narrow else elist + result[page] = ( + min(e[0] for e in chosen), + max(e[1] for e in chosen), + ) + return result + + @staticmethod + def _find_enhanced_region( + enhanced_content: str, + pypdf_text: str, + budget: int = _TARGETED_TABLE_MAX_CHARS, + ) -> Optional[str]: + """Locate the ENHANCED content region matching *pypdf_text*. + + Uses progressively shorter text anchors extracted from the + pypdf content to find the corresponding position in the + ENHANCED (kreuzberg markdown) text. Whitespace is normalised + in the anchor to handle formatting differences (pypdf line + breaks vs kreuzberg markdown spacing). This avoids reliance + on page-number alignment, which may differ between the two + extractors. + + Returns the ENHANCED slice (up to *budget* chars) or ``None``. + """ + text = pypdf_text.strip() + for prefix in ("Table of Contents\n", "Table of Contents "): + if text.startswith(prefix): + text = text[len(prefix):] + text = text.strip() + + for anchor_len in (80, 50, 30): + raw = text[:anchor_len].strip() + if len(raw) < 15: + continue + anchor = " ".join(raw.split()) + pos = enhanced_content.find(anchor) + if pos < 0: + continue + start = max( + 0, + enhanced_content.rfind("\n", max(0, pos - 300), pos) + 1, + ) + end = min(len(enhanced_content), start + budget) + return enhanced_content[start:end].strip() + + return None + + def _enrich_table_digest_content( + self, + file_path: str, + enhanced_content: str, + tree_root: Optional[Any], + ) -> None: + """Replace pypdf-sourced table text with ENHANCED content slices. + + Targeted extraction tables use pypdf, which often produces dense + single-line text (the "2-line page" problem). This method + locates each table's content in the ENHANCED (kreuzberg markdown) + text via anchor matching and replaces the ``markdown`` field when + the ENHANCED version has substantially better structure. + + Only tables whose ``source`` indicates pypdf origin are + candidates; kreuzberg-detected tables already have high-quality + markdown and are left untouched. + """ + if not enhanced_content: + return + + file_hash = get_fast_hash(file_path) or "" + if not file_hash: + return + + digest_path = ( + self._compile_dir / "table_digests" / f"{file_hash}.json" + ) + if not digest_path.exists(): + return + + try: + raw = json.loads(digest_path.read_text(encoding="utf-8")) + tables = raw.get("tables", []) + except Exception: + return + + if not tables: + return + + modified = False + for table in tables: + source = table.get("source", "") + if not ( + source.startswith("targeted:") + or source == "content_scan" + ): + continue + + current = table.get("markdown", "") + if not current: + continue + + enhanced_region = self._find_enhanced_region( + enhanced_content, current, + ) + if not enhanced_region: + continue + + current_lines = len(current.strip().split("\n")) + enhanced_lines = len(enhanced_region.split("\n")) + + if enhanced_lines > max(current_lines, 3): + table["markdown"] = enhanced_region[ + :_TARGETED_TABLE_MAX_CHARS + ] + modified = True + + if modified: + digest_path.write_text( + json.dumps(raw, ensure_ascii=False), + encoding="utf-8", + ) + + # ------------------------------------------------------------------ # + # Tree-independent content-based table scanning # # ------------------------------------------------------------------ # async def _content_based_table_scan( diff --git a/src/sirchmunk/llm/prompts.py b/src/sirchmunk/llm/prompts.py index 074847e..71d5836 100644 --- a/src/sirchmunk/llm/prompts.py +++ b/src/sirchmunk/llm/prompts.py @@ -565,6 +565,24 @@ def generate_keyword_extraction_prompt(num_levels: int = 3) -> str: - Use the same language as the summary""" +COMPILE_CLASSIFY_HEADINGS = """Classify each bold text line as either a **section heading** or **non-heading**. + +A line is a *section heading* if it serves as the title of a major structural division of the document (chapter, section, subsection, exhibit, schedule, financial statement, note, etc.). +A line is *non-heading* if it is emphasis text, a label, a caption, a total/subtotal row, or any inline bold phrase that does not introduce a new document section. + +For each heading, also assign a Markdown heading level (2–4): +- Level 2: top-level sections (e.g. financial statements, major chapters) +- Level 3: sub-sections (e.g. notes to financial statements, sub-chapters) +- Level 4: sub-sub-sections + +Return ONLY a JSON array of objects for the lines that ARE headings. +Each object: {{"idx": <0-based index>, "level": <2|3|4>}} +If none are headings, return an empty array: [] + +Bold lines: +{candidates}""" + + COMPILE_MERGE_KNOWLEDGE = """You are merging new information into an existing knowledge cluster. ### Existing Knowledge diff --git a/src/sirchmunk/search.py b/src/sirchmunk/search.py index c38dfee..2c1c9b5 100644 --- a/src/sirchmunk/search.py +++ b/src/sirchmunk/search.py @@ -4144,19 +4144,52 @@ def _filter_tables_by_page_range( and page_start <= t["page_number"] <= page_end ] + _TABLE_RELEVANCE_MIN_PREFIX = 5 + @staticmethod def _score_table_relevance( markdown: str, query_tokens: frozenset, ) -> float: """Score a table's relevance to the query via token overlap. - Returns a value in [0, 1] representing the fraction of *query_tokens* - found in the table's markdown text (case-insensitive). + Uses two matching strategies per token: + + 1. **Exact substring** — fast check whether the token appears + anywhere in the table text (original behaviour). + 2. **Prefix match** — handles morphological variation such as + plural/singular (*inventory* ↔ *inventories*) by comparing + word prefixes of at least ``_TABLE_RELEVANCE_MIN_PREFIX`` + characters. Only attempted when the exact match misses. + + Returns a value in [0, 1] representing the fraction of + *query_tokens* matched. """ if not markdown or not query_tokens: return 0.0 + + min_pfx = AgenticSearch._TABLE_RELEVANCE_MIN_PREFIX md_lower = markdown.lower() - hits = sum(1 for tok in query_tokens if tok in md_lower) + md_words = None # lazily built on first prefix-match attempt + + hits = 0 + for tok in query_tokens: + if tok in md_lower: + hits += 1 + continue + # Prefix-match fallback + pfx_len = min(len(tok), min_pfx) + if pfx_len < 4: + continue + if md_words is None: + md_words = frozenset(md_lower.split()) + prefix = tok[:pfx_len] + if any( + w[:pfx_len] == prefix + for w in md_words + if len(w) >= pfx_len + ): + hits += 1 + return hits / len(query_tokens) @staticmethod From e55ada78b709420a36951d89e34e967b39cfbecb Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E7=8F=AD=E6=89=AC?= Date: Sat, 9 May 2026 15:23:51 +0800 Subject: [PATCH 45/56] improve compile for summary and table --- src/sirchmunk/learnings/compiler.py | 28 +++- src/sirchmunk/learnings/tree_indexer.py | 189 +++++++++++++++++++++++- src/sirchmunk/search.py | 110 +++++++++++++- 3 files changed, 317 insertions(+), 10 deletions(-) diff --git a/src/sirchmunk/learnings/compiler.py b/src/sirchmunk/learnings/compiler.py index 8b08357..25868d6 100644 --- a/src/sirchmunk/learnings/compiler.py +++ b/src/sirchmunk/learnings/compiler.py @@ -776,6 +776,24 @@ async def _compile_single_file( return result + @staticmethod + def _is_generic_summary(summary: str, min_specificity_len: int = 80) -> bool: + """Check whether a summary is too generic to be useful for retrieval. + + A generic summary typically contains only structural descriptions + (e.g., "This document contains several sections") without specific + content indicators. Detection uses summary length and information + density as domain-agnostic proxies. + """ + if not summary: + return True + stripped = summary.strip() + if len(stripped) < min_specificity_len: + return True + # Count unique substantive words (>4 chars) as a proxy for specificity + words = set(w.lower() for w in stripped.split() if len(w) > 4) + return len(words) < 8 + async def _extract_summary( self, file_path: str, @@ -786,13 +804,19 @@ async def _extract_summary( When a tree is available its root already contains an LLM-synthesized summary (produced by ``_synthesize_root_summary`` during tree build), - so we reuse it directly — no redundant LLM call. + so we reuse it directly — unless the summary is too generic (Plan 2), + in which case we fall back to multi-section LLM summarization. For large documents without a tree, uses multi-section sampling (beginning, middle, end) to capture the full scope of the document. """ if tree and tree.root and tree.root.summary: - return tree.root.summary + if not self._is_generic_summary(tree.root.summary): + return tree.root.summary + await self._log.info( + f"[Compile] Root summary too generic for {Path(file_path).name}, " + f"falling back to LLM summarization" + ) preview = self._build_summary_preview(content) from sirchmunk.llm.prompts import COMPILE_DOC_SUMMARY diff --git a/src/sirchmunk/learnings/tree_indexer.py b/src/sirchmunk/learnings/tree_indexer.py index 9cf450e..10cab2b 100644 --- a/src/sirchmunk/learnings/tree_indexer.py +++ b/src/sirchmunk/learnings/tree_indexer.py @@ -46,6 +46,10 @@ _TREE_PREVIEW_MAX = 50_000 # Maximum preview window (~12K tokens) _TREE_PREVIEW_RATIO = 0.15 # Fraction of document to preview +# Structured content detection thresholds (Plan 1: generic table recognition) +_STRUCT_MD_TABLE_MIN_ROWS = 3 # Min markdown table rows to classify as structured +_STRUCT_NUMERIC_DENSITY_THRESHOLD = 0.20 # Fraction of numeric tokens in a text segment + # Extensions eligible for tree indexing _TREE_EXTENSIONS = { ".pdf", ".docx", ".doc", ".md", ".markdown", @@ -445,6 +449,9 @@ async def _build_tree_from_toc( # Merge consecutive fragment entries into virtual parents toc_entries = self._merge_fragment_entries(toc_entries) + # Plan 4: Group disproportionately large tail entries (exhibits/appendices) + toc_entries = self._merge_supplementary_entries(toc_entries) + seen_ids: set = set() children = self._toc_entries_to_nodes( toc_entries, content, len(content), seen_ids, @@ -466,6 +473,74 @@ async def _build_tree_from_toc( children=children, ) + @staticmethod + def _merge_supplementary_entries(entries: List[Any]) -> List[Any]: + """Merge tail entries with disproportionately large spans into a virtual parent. + + Detects when the last few entries collectively span much more content + than the preceding entries — a generic structural signal for exhibits, + appendices, or attachment sections. Groups them under a single + navigable node to prevent them from dominating tree navigation. + + Uses only structural signals (char span ratios, position in document) + — no domain-specific keywords. Returns original entries when the + structural pattern is not detected or when too few entries remain. + """ + if len(entries) < 4: + return entries + + def _span(e: Any) -> int: + if hasattr(e, 'char_start') and hasattr(e, 'char_end'): + if e.char_end and e.char_start is not None: + return max(0, e.char_end - e.char_start) + return 0 + + spans = [_span(e) for e in entries] + total_span = sum(spans) + if total_span == 0: + return entries + + # Scan backwards to find tail entries whose cumulative span is + # disproportionately large while individually being much larger + # than the body-section baseline. Uses 25th percentile instead of + # median so that many large tail entries cannot inflate the baseline. + non_zero_spans = [s for s in spans if s > 0] + if len(non_zero_spans) < 4: + return entries + sorted_spans = sorted(non_zero_spans) + q25_idx = max(0, len(sorted_spans) // 4) + baseline_span = sorted_spans[q25_idx] + + tail_start = len(entries) + cumulative = 0 + for i in range(len(entries) - 1, 0, -1): + if spans[i] > baseline_span * 3: + cumulative += spans[i] + tail_start = i + else: + break + + tail_count = len(entries) - tail_start + # Require at least 2 tail entries spanning > 40% of total content + if tail_count < 2 or cumulative / total_span < 0.40: + return entries + + # Also ensure enough primary entries remain + if tail_start < 2: + return entries + + from copy import deepcopy + first_tail = entries[tail_start] + last_tail = entries[-1] + merged = deepcopy(first_tail) + merged.title = f"Supplementary Material ({tail_count} sections)" + if hasattr(last_tail, 'char_end') and last_tail.char_end: + merged.char_end = last_tail.char_end + merged.children = list(entries[tail_start:]) + + result = list(entries[:tail_start]) + [merged] + return result if len(result) >= 2 else entries + @staticmethod def _merge_fragment_entries(entries: List[Any]) -> List[Any]: """Merge consecutive fragment TOC entries into virtual parent nodes. @@ -591,10 +666,19 @@ def _toc_entries_to_nodes( total_pages=total_pages, ) + # Plan 1: Detect structured/tabular content and add navigation hint + # to help LLM-driven navigation prioritize data-rich sections. + # Deliberately keeps content_type="text" so _classify_leaves + # routes to kreuzberg char_range (higher fidelity than pypdf). + summary_text = section_text.strip() + section_sample = content[start:min(start + 2000, end)] + if DocumentTreeIndexer._detect_structured_content(section_sample): + summary_text = f"[Data/Tables] {summary_text}" + node = TreeNode( node_id=nid, title=entry.title, - summary=section_text.strip(), + summary=summary_text, char_range=(start, end), level=level, page_range=page_range, @@ -636,6 +720,44 @@ def _compute_adaptive_depth(content_length: int) -> int: return depth return 2 # minimum depth + @staticmethod + def _detect_structured_content(text: str, sample_size: int = 2000) -> bool: + """Detect whether text contains structured/tabular data using generic signals. + + Uses two high-precision, domain-agnostic heuristics (any triggers True): + 1. Markdown table syntax (pipe-delimited rows with separator line) + 2. High numeric token density (currency, percentages, large numbers) + + Intentionally omits lower-precision signals (multi-space alignment, + tab counts) because PDF-extracted text frequently has irregular + spacing that causes false positives. + + Args: + text: Content segment to analyze. + sample_size: Max chars to analyze (avoids scanning huge sections). + """ + sample = text[:sample_size] + if not sample.strip(): + return False + + # Signal 1: Markdown table syntax — pipe-separated rows with header separator + pipe_lines = [ln for ln in sample.split("\n") if ln.strip().startswith("|")] + separator_lines = [ln for ln in pipe_lines if re.match(r"\|\s*[-:]+", ln)] + data_rows = len(pipe_lines) - len(separator_lines) + if data_rows >= _STRUCT_MD_TABLE_MIN_ROWS and separator_lines: + return True + + # Signal 2: Numeric token density — high ratio of numeric-pattern tokens + non_ws = re.sub(r"\s+", "", sample) + if len(non_ws) > 50: + from sirchmunk.learnings.compiler import _NUM_TOKEN_RE + num_tokens = _NUM_TOKEN_RE.findall(sample) + total_chars = sum(len(t) for t in num_tokens) + if total_chars / len(non_ws) >= _STRUCT_NUMERIC_DENSITY_THRESHOLD: + return True + + return False + async def _build_node( self, text: str, level: int, max_depth: int, offset: int = 0, @@ -691,13 +813,74 @@ async def _build_node( children=children, ) + @staticmethod + def _collect_representative_nodes( + children: List[TreeNode], + max_nodes: int = 15, + ) -> List[TreeNode]: + """Collect representative nodes from multiple tree depths. + + Gathers direct children plus a sample of deeper descendants to + ensure the summary captures actual content topics — not just + top-level structural wrappers that may be uninformative. + + Strategy: + - Layer 1: all direct children (structural overview). + - Layer 2+: BFS preferring **leaf nodes** (actual content topics) + over intermediate nodes (whose summaries overlap children). + """ + reps: List[TreeNode] = [] + seen: set = set() + + # Layer 1: all direct children (even wrappers — they provide structure) + for c in children: + if c.node_id not in seen and len(reps) < max_nodes: + reps.append(c) + seen.add(c.node_id) + + # Layer 2+: BFS collecting leaf nodes with substantive summaries. + # Leaf nodes represent actual content sections; intermediate nodes + # often have summaries that redundantly overlap their children. + queue = [] + for c in children: + for gc in c.children: + queue.append(gc) + + while queue and len(reps) < max_nodes: + node = queue.pop(0) + if node.node_id in seen: + continue + + is_leaf = not node.children + has_substance = ( + (node.summary and len(node.summary.strip()) > 20) + or node.table_count > 0 + ) + + if is_leaf and has_substance: + reps.append(node) + seen.add(node.node_id) + elif not is_leaf: + # Expand intermediate nodes without adding them — + # their content is represented by their leaf descendants. + for ch in node.children: + queue.append(ch) + + return reps + async def _synthesize_root_summary(self, children: List[TreeNode]) -> str: - """Synthesize a document-level summary from children's section summaries.""" + """Synthesize a document-level summary from multi-depth section info. + + Gathers representative nodes from multiple tree depths to produce + a summary that reflects actual document content, not just top-level + wrapper headings like "SEC Filing" or "Table of Contents". + """ if not children: return "" from sirchmunk.llm.prompts import COMPILE_SYNTHESIZE_SUMMARY + representatives = self._collect_representative_nodes(children) sections_text = "\n".join( - f"- {c.title}: {c.summary}" for c in children + f"- {n.title}: {n.summary}" for n in representatives ) prompt = COMPILE_SYNTHESIZE_SUMMARY.format(sections=sections_text) resp = await self._llm.achat([{"role": "user", "content": prompt}]) diff --git a/src/sirchmunk/search.py b/src/sirchmunk/search.py index 2c1c9b5..b296b8f 100644 --- a/src/sirchmunk/search.py +++ b/src/sirchmunk/search.py @@ -2166,7 +2166,7 @@ async def _llm_chunked_summarize(self, combined: str, query: str) -> str: ".css", ".bash", ".java", ".c", ".cpp", ".h", ".go", ".rs", } _FAST_CONTEXT_WINDOW = 30 # ± lines around each grep hit - _FAST_MAX_EVIDENCE_CHARS = 15_000 + _FAST_MAX_EVIDENCE_CHARS = 20_000 # Plan 5: expanded from 15K to accommodate richer table evidence _FAST_SMALL_FILE_THRESHOLD = 100_000 # 100K chars - read full file instead of grep sampling # --- Wiki-enhanced ranking constants --- @@ -2221,6 +2221,20 @@ async def _llm_chunked_summarize(self, combined: str, query: str) -> str: _CHAR_RANGE_MAX_SPAN_RATIO: float = 0.8 """char_range spanning more than this ratio of the document is treated as invalid.""" + # --- Tree navigation retry (Plan 3) --- + _NAV_RETRY_MIN_EVIDENCE_CHARS: int = 200 + """Evidence below this length triggers a retry with expanded results.""" + _NAV_RETRY_EXPANDED_RESULTS: int = 8 + """Expanded max_results for retry navigation pass.""" + + # --- Table evidence budgets (Plan 5) --- + _TABLE_EVIDENCE_DEFAULT_CHARS: int = 10_000 + """Default max_chars for _format_table_evidence (was 6000).""" + _TABLE_EVIDENCE_PER_RANGE_CHARS: int = 8_000 + """Max chars for per-page-range table supplement in tree nav (was 4000).""" + _TABLE_EVIDENCE_STANDALONE_CHARS: int = 12_000 + """Max chars for standalone table digest fallback when tree nav evidence is thin.""" + # --- Self-correction expanded sampling --- _SELF_CORRECT_EXPANDED_NAV_RESULTS: int = 10 """Expanded tree navigation leaf count for same-file re-sampling (default nav uses 5).""" @@ -4086,6 +4100,19 @@ def _is_valid_char_range( span_ratio = (end - start) / text_len return span_ratio < self._CHAR_RANGE_MAX_SPAN_RATIO + @staticmethod + def _is_evidence_sufficient(evidence: str, min_chars: int = 0) -> bool: + """Check whether collected evidence has enough substance to answer a query. + + Uses a length threshold as a lightweight, domain-agnostic proxy. + Empty or near-empty evidence (e.g., only headers with no data) + fails the check, triggering a retry with expanded parameters. + """ + if not evidence: + return False + stripped = evidence.strip() + return len(stripped) >= min_chars + @staticmethod def _load_compile_content( work_path: Path, file_path: str, @@ -4195,7 +4222,7 @@ def _score_table_relevance( @staticmethod def _format_table_evidence( tables: List[Dict[str, Any]], - max_chars: int = 6000, + max_chars: int = 10_000, query: str = "", ) -> str: """Format table digest entries as LLM-friendly evidence text. @@ -4440,10 +4467,63 @@ async def _navigate_tree_for_evidence( parts, fname, lf, lf.summary, ) + # ── Plan 3: Retry with expanded results if evidence is insufficient ── + # Triggers on: (a) zero evidence parts, OR (b) evidence too thin. + _current_ev_text = "\n\n".join(parts) + _needs_retry = ( + max_results < self._NAV_RETRY_EXPANDED_RESULTS + and not self._is_evidence_sufficient( + _current_ev_text, self._NAV_RETRY_MIN_EVIDENCE_CHARS, + ) + ) + if _needs_retry: + try: + retry_leaves = await indexer.navigate( + tree, query, + max_results=self._NAV_RETRY_EXPANDED_RESULTS, + ) + if retry_leaves: + r_page, r_char, r_summary = self._classify_leaves(retry_leaves) + for rl in r_summary: + self._append_evidence_part(parts, fname, rl, rl.summary) + + # Page-level extraction for retry (mirrors Phase 2) + if r_page: + r_all_pages: set = set() + for _rl, (rsp, rep) in r_page: + r_all_pages.update(range(rsp, rep + 1)) + try: + r_page_contents = DocumentExtractor.extract_pages( + file_path, sorted(r_all_pages), + ) + r_page_map = {pc.page_number: pc.content for pc in r_page_contents} + for rl, (rsp, rep) in r_page: + r_seg = [r_page_map[p] for p in range(rsp, rep + 1) if r_page_map.get(p, "").strip()] + if r_seg: + self._append_evidence_part(parts, fname, rl, "\n".join(r_seg)) + except Exception: + pass + + # Char-range extraction for retry (mirrors Phase 3) + if r_char: + r_text = self._load_compile_content(self.work_path, file_path) or "" + for rl in r_char: + s, e = rl.char_range + if self._is_valid_char_range(s, e, len(r_text)) and r_text: + seg = r_text[s:e] + if seg.strip(): + self._append_evidence_part(parts, fname, rl, seg) + + leaves = retry_leaves + print(f"SEARCH_WIKI_DEBUG [N3.1] retry_nav: {len(retry_leaves)} leaves", flush=True) + except Exception: + pass + if not parts: return None # Supplement with table evidence if available + _all_tables = None try: from sirchmunk.utils.file_utils import get_fast_hash _file_hash = get_fast_hash(file_path) @@ -4465,7 +4545,8 @@ async def _navigate_tree_for_evidence( ) if leaf_tables: table_text = self._format_table_evidence( - leaf_tables, max_chars=4000, + leaf_tables, + max_chars=self._TABLE_EVIDENCE_PER_RANGE_CHARS, query=query, ) if table_text: @@ -4475,9 +4556,28 @@ async def _navigate_tree_for_evidence( except Exception: pass - print(f"SEARCH_WIKI_DEBUG [N5] table_supplement: tables_loaded={len(_all_tables) if '_all_tables' in dir() and _all_tables else 0}", flush=True) - + # Plan 3: If evidence is still too thin, add full table digest as standalone evidence = "\n\n".join(parts) + if ( + not self._is_evidence_sufficient( + evidence, self._NAV_RETRY_MIN_EVIDENCE_CHARS, + ) + and _all_tables + ): + standalone_table_ev = self._format_table_evidence( + _all_tables, + max_chars=self._TABLE_EVIDENCE_STANDALONE_CHARS, + query=query, + ) + if standalone_table_ev: + parts.append( + f"[{fname} - Standalone Table Evidence]\n{standalone_table_ev}" + ) + evidence = "\n\n".join(parts) + print(f"SEARCH_WIKI_DEBUG [N5.1] standalone_table_fallback: len={len(standalone_table_ev)}", flush=True) + + print(f"SEARCH_WIKI_DEBUG [N5] table_supplement: tables_loaded={len(_all_tables) if _all_tables else 0}", flush=True) + print(f"SEARCH_WIKI_DEBUG [N6] _navigate_tree_for_evidence result: len={len(evidence) if evidence else 0}", flush=True) await self._logger.info( f"[FAST:TreeNav] Extracted {len(parts)} sections, " From fe351a1fa3d422568625ba549e09ce245ab80e7a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E7=8F=AD=E6=89=AC?= Date: Sat, 9 May 2026 16:19:28 +0800 Subject: [PATCH 46/56] fix tree index --- src/sirchmunk/llm/prompts.py | 18 +++-- src/sirchmunk/search.py | 152 +++++++++++++++++++++++++++++++++-- 2 files changed, 157 insertions(+), 13 deletions(-) diff --git a/src/sirchmunk/llm/prompts.py b/src/sirchmunk/llm/prompts.py index 71d5836..909402d 100644 --- a/src/sirchmunk/llm/prompts.py +++ b/src/sirchmunk/llm/prompts.py @@ -423,6 +423,7 @@ def generate_keyword_extraction_prompt(num_levels: int = 3) -> str: 2. **Format**: Use Markdown (headings, bullet points, and bold text) for high readability. 3. **Style**: Keep it professional, objective, and clear. Avoid fluff. 4. **Precision**: When the query asks for a specific value, ratio, number, percentage, or yes/no determination, you MUST compute it and state the precise result. Show key calculation steps when applicable. +5. **Verify before answering**: For numerical calculations, complete ALL computation steps in the SUMMARY section FIRST. Only write the PRECISE_ANSWER tag AFTER you have verified the final result. If you discover an error during computation, use the corrected value in PRECISE_ANSWER. ### Input Data - **User Input**: {user_input} @@ -443,12 +444,12 @@ def generate_keyword_extraction_prompt(num_levels: int = 3) -> str: - If evidence is insufficient or irrelevant, both SHOULD_ANSWER and SHOULD_SAVE MUST be "false". ### Output Format - -[If the query asks for a specific value, ratio, number, or factual answer, state ONLY the direct answer here (e.g. "0.83", "$1,832 million", "Yes", "Increased from 0.67 to 0.69"). If the query is open-ended, write a one-sentence conclusion.] - -[Generate the Markdown Briefing here with detailed analysis and supporting evidence] +[Generate the Markdown Briefing here with detailed analysis, supporting evidence, and full calculation steps. Complete all reasoning BEFORE the PRECISE_ANSWER tag.] + +[State ONLY the final verified answer here (e.g. "0.83", "$1,832 million", "Yes", "Increased from 0.67 to 0.69"). For calculations, this MUST reflect the result from your completed computation above. If the query is open-ended, write a one-sentence conclusion.] + true/false true/false """ @@ -463,6 +464,7 @@ def generate_keyword_extraction_prompt(num_levels: int = 3) -> str: 2. **Format**: Use Markdown (headings, bullet points, and bold text) for high readability. 3. **Style**: Keep it professional, objective, and clear. Avoid fluff. 4. **Precision**: When the query asks for a specific value, ratio, number, percentage, or yes/no determination, you MUST compute it and state the precise result. Show key calculation steps when applicable. +5. **Verify before answering**: For numerical calculations, complete ALL computation steps in the SUMMARY section FIRST. Only write the PRECISE_ANSWER tag AFTER you have verified the final result. If you discover an error during computation, use the corrected value in PRECISE_ANSWER. ### Document Context {document_context} @@ -486,12 +488,12 @@ def generate_keyword_extraction_prompt(num_levels: int = 3) -> str: - If evidence is insufficient or irrelevant, both SHOULD_ANSWER and SHOULD_SAVE MUST be "false". ### Output Format - -[If the query asks for a specific value, ratio, number, or factual answer, state ONLY the direct answer here (e.g. "0.83", "$1,832 million", "Yes", "Increased from 0.67 to 0.69"). If the query is open-ended, write a one-sentence conclusion.] - -[Generate the Markdown Briefing here with detailed analysis and supporting evidence] +[Generate the Markdown Briefing here with detailed analysis, supporting evidence, and full calculation steps. Complete all reasoning BEFORE the PRECISE_ANSWER tag.] + +[State ONLY the final verified answer here (e.g. "0.83", "$1,832 million", "Yes", "Increased from 0.67 to 0.69"). For calculations, this MUST reflect the result from your completed computation above. If the query is open-ended, write a one-sentence conclusion.] + true/false true/false """ diff --git a/src/sirchmunk/search.py b/src/sirchmunk/search.py index b296b8f..4b11a09 100644 --- a/src/sirchmunk/search.py +++ b/src/sirchmunk/search.py @@ -932,14 +932,24 @@ async def _search_by_filename( await self._logger.error(f"Traceback: {traceback.format_exc()}") return [] - @staticmethod - def _parse_summary_response(llm_response: str) -> Tuple[str, bool, bool]: + _SELF_CORRECTION_PATTERN = re.compile( + r'(?:correction|re-?verif|wait,?\s|let me re|actually|self-correction|recalcul)', + re.IGNORECASE, + ) + + @classmethod + def _parse_summary_response(cls, llm_response: str) -> Tuple[str, bool, bool]: """Parse LLM response to extract summary, precise answer, and quality decisions. When a ```` tag is present, its content is prepended to the summary so downstream consumers (evaluation judges, UIs) see the direct answer prominently without needing separate tag awareness. + The method also detects self-correction patterns in the summary text: + when the LLM revised its calculation mid-stream, the last numeric + conclusion is used if PRECISE_ANSWER is absent or matches the + pre-correction value. + Returns: Tuple of (summary_text, should_save_flag, should_answer_flag) """ @@ -2227,6 +2237,17 @@ async def _llm_chunked_summarize(self, combined: str, query: str) -> str: _NAV_RETRY_EXPANDED_RESULTS: int = 8 """Expanded max_results for retry navigation pass.""" + _CHAR_RANGE_MIN_SPAN: int = 200 + """Minimum char_range span to trust as substantive content. + + Nodes whose char_range covers fewer characters than this threshold + (e.g. a TOC entry that only records the section title) are demoted + to page-level extraction when a valid page_range is available. + """ + + _NAV_COMPLEMENT_MIN_COMPONENTS: int = 2 + """Minimum query decomposition components to trigger complementary navigation.""" + # --- Table evidence budgets (Plan 5) --- _TABLE_EVIDENCE_DEFAULT_CHARS: int = 10_000 """Default max_chars for _format_table_evidence (was 6000).""" @@ -4029,8 +4050,8 @@ async def _tree_guided_sample( ) return evidence - @staticmethod - def _classify_leaves(leaves: list) -> Tuple[List[tuple], List, List]: + @classmethod + def _classify_leaves(cls, leaves: list) -> Tuple[List[tuple], List, List]: """Classify leaf nodes by preferred extraction strategy. For non-table leaves, **char_range** (kreuzberg markdown) is preferred @@ -4039,6 +4060,11 @@ def _classify_leaves(leaves: list) -> Tuple[List[tuple], List, List]: ``extract_text()``. page_range remains available on each leaf for table-supplement filtering even when the leaf is routed to char_leaves. + Thin char_range nodes (span < ``_CHAR_RANGE_MIN_SPAN``) are demoted + to page-level extraction when a valid page_range exists, as they + typically represent TOC entries whose char offsets only cover the + section title rather than the actual content. + Returns: (page_leaves, char_leaves, summary_leaves) triple: - page_leaves: list of (leaf, page_range) — page-level extraction @@ -4048,6 +4074,7 @@ def _classify_leaves(leaves: list) -> Tuple[List[tuple], List, List]: page_leaves: List[tuple] = [] char_leaves: List = [] summary_leaves: List = [] + min_span = cls._CHAR_RANGE_MIN_SPAN for leaf in leaves: # Table nodes: prefer page-level extraction for raw original content @@ -4078,7 +4105,12 @@ def _classify_leaves(leaves: list) -> Tuple[List[tuple], List, List]: ) if has_char: - char_leaves.append(leaf) + start, end = leaf.char_range + span = end - start if end > start else 0 + if span < min_span and has_page: + page_leaves.append((leaf, page_range)) + else: + char_leaves.append(leaf) elif has_page: page_leaves.append((leaf, page_range)) elif getattr(leaf, 'summary', None): @@ -4113,6 +4145,64 @@ def _is_evidence_sufficient(evidence: str, min_chars: int = 0) -> bool: stripped = evidence.strip() return len(stripped) >= min_chars + _MULTI_COMPONENT_PATTERNS: Tuple[Tuple[str, ...], ...] = ( + ("balance sheet", "income statement"), + ("balance sheet", "cash flow"), + ("income statement", "cash flow"), + ("accounts payable", "cost of"), + ("accounts payable", "inventory"), + ("current assets", "current liabilities"), + ("revenue", "net income", "earnings"), + ("operating income", "depreciation"), + ) + + @staticmethod + def _decompose_query_components(query: str) -> List[str]: + """Extract distinct data-source components from a multi-part query. + + Scans for known multi-component patterns (e.g. a ratio needing data + from both Balance Sheet and Income Statement) and returns a list of + component phrases that the evidence should cover. + """ + q = query.lower() + components: List[str] = [] + for group in AgenticSearch._MULTI_COMPONENT_PATTERNS: + hits = [phrase for phrase in group if phrase in q] + if len(hits) >= 2: + components.extend(hits) + if not components: + financial_keywords = [ + "balance sheet", "income statement", "cash flow", + "accounts payable", "accounts receivable", "inventory", + "current liabilities", "current assets", "total assets", + "revenue", "cost of", "cogs", "depreciation", "amortization", + "operating income", "net income", "earnings", + ] + for kw in financial_keywords: + if kw in q: + components.append(kw) + seen: set = set() + return [c for c in components if not (c in seen or seen.add(c))] + + @staticmethod + def _check_leaf_coverage( + leaves: list, components: List[str], + ) -> Tuple[List[str], List[str]]: + """Check which query components are covered by the navigated leaves. + + Returns: + (covered, missing) — lists of component phrases. + """ + if not leaves or not components: + return [], list(components) + leaf_text = " ".join( + (getattr(l, 'title', '') or '') + " " + (getattr(l, 'summary', '') or '') + for l in leaves + ).lower() + covered = [c for c in components if c in leaf_text] + missing = [c for c in components if c not in leaf_text] + return covered, missing + @staticmethod def _load_compile_content( work_path: Path, file_path: str, @@ -4467,6 +4557,58 @@ async def _navigate_tree_for_evidence( parts, fname, lf, lf.summary, ) + # ── Phase 4: Complementary navigation for multi-component queries ── + # When a query requires data from multiple document sections (e.g. + # Balance Sheet + Income Statement for a ratio), the initial navigate + # may only reach one component. Detect missing components and run a + # focused second navigate pass with a refined query. + _query_components = self._decompose_query_components(query) + if len(_query_components) >= self._NAV_COMPLEMENT_MIN_COMPONENTS: + _covered, _missing = self._check_leaf_coverage(leaves, _query_components) + if _missing: + _complement_query = f"{query} — focus on: {', '.join(_missing)}" + try: + _existing_ids = {id(l) for l in leaves} + comp_leaves = await indexer.navigate( + tree, _complement_query, max_results=max_results, + ) + comp_new = [l for l in (comp_leaves or []) if id(l) not in _existing_ids] + if comp_new: + c_page, c_char, c_summary = self._classify_leaves(comp_new) + for cl in c_summary: + self._append_evidence_part(parts, fname, cl, cl.summary) + if c_page: + c_all_pages: set = set() + for _cl, (csp, cep) in c_page: + c_all_pages.update(range(csp, cep + 1)) + try: + c_contents = DocumentExtractor.extract_pages( + file_path, sorted(c_all_pages), + ) + c_map = {pc.page_number: pc.content for pc in c_contents} + for cl, (csp, cep) in c_page: + c_seg = [c_map[p] for p in range(csp, cep + 1) if c_map.get(p, "").strip()] + if c_seg: + self._append_evidence_part(parts, fname, cl, "\n".join(c_seg)) + except Exception: + pass + if c_char: + c_text = self._load_compile_content(self.work_path, file_path) or "" + for cl in c_char: + s, e = cl.char_range + if self._is_valid_char_range(s, e, len(c_text)) and c_text: + seg = c_text[s:e] + if seg.strip(): + self._append_evidence_part(parts, fname, cl, seg) + leaves = list(leaves) + comp_new + print( + f"SEARCH_WIKI_DEBUG [N3.2] complement_nav: " + f"missing={_missing}, new_leaves={len(comp_new)}", + flush=True, + ) + except Exception: + pass + # ── Plan 3: Retry with expanded results if evidence is insufficient ── # Triggers on: (a) zero evidence parts, OR (b) evidence too thin. _current_ev_text = "\n\n".join(parts) From cb1ba9659c7f2f099d762ee9f714e61fe047f494 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E7=8F=AD=E6=89=AC?= Date: Mon, 11 May 2026 16:03:22 +0800 Subject: [PATCH 47/56] update compiler --- src/sirchmunk/learnings/compiler.py | 31 ++++++++-- src/sirchmunk/search.py | 90 ++++++++++++++++++++++++----- 2 files changed, 100 insertions(+), 21 deletions(-) diff --git a/src/sirchmunk/learnings/compiler.py b/src/sirchmunk/learnings/compiler.py index 25868d6..a2a193c 100644 --- a/src/sirchmunk/learnings/compiler.py +++ b/src/sirchmunk/learnings/compiler.py @@ -67,6 +67,10 @@ # Selective force-OCR: max pages to re-extract with forced OCR per document _FORCE_OCR_MAX_PAGES = 30 +# Incremental manifest flush: persist manifest every N completed files +# to survive interrupted compiles without excessive I/O overhead. +_MANIFEST_FLUSH_INTERVAL = 10 + # Shared numeric-token regex for table detection heuristics. # Matches: $1,234 (1,234) 12.5% 3.14e-5 1,000 _NUM_TOKEN_RE = re.compile( @@ -440,6 +444,7 @@ async def compile( # Phase 2: compile files with bounded concurrency semaphore = asyncio.Semaphore(concurrency) results: List[FileCompileResult] = [] + _files_since_flush = 0 async def _bounded(entry: FileEntry) -> FileCompileResult: async with semaphore: @@ -454,7 +459,6 @@ async def _bounded(entry: FileEntry) -> FileCompileResult: else: if result.tree: report.trees_built += 1 - # Update manifest manifest.files[result.path] = FileManifestEntry( file_hash=get_fast_hash(result.path) or "", compiled_at=datetime.now(timezone.utc).isoformat(), @@ -471,6 +475,17 @@ async def _bounded(entry: FileEntry) -> FileCompileResult: _mentry = manifest.files[result.path] print(f"SEARCH_WIKI_DEBUG [C4] manifest_entry: has_tree={_mentry.has_tree}, has_table_digest={_mentry.has_table_digest}, file_hash={_mentry.file_hash}", flush=True) + # Incremental manifest flush to survive interrupted compiles + _files_since_flush += 1 + if _files_since_flush >= _MANIFEST_FLUSH_INTERVAL: + manifest.last_compile_at = datetime.now(timezone.utc).isoformat() + self._save_manifest(manifest) + _files_since_flush = 0 + + # Phase 2 checkpoint: persist manifest before knowledge aggregation + manifest.last_compile_at = datetime.now(timezone.utc).isoformat() + self._save_manifest(manifest) + # Phase 3: aggregate results into knowledge network await self._log.info("[Compile] Phase 3: Knowledge aggregation") for r in results: @@ -484,15 +499,15 @@ async def _bounded(entry: FileEntry) -> FileCompileResult: await self._log.info("[Compile] Phase 4: Building cross-references") report.cross_refs_built = await self._build_cross_references(results) - # Phase 5: persist manifest + document catalog + # Phase 5: persist final manifest + derived indices + # Catalog and summary index are rebuilt from the manifest, so even + # partial compiles produce usable search-time metadata. manifest.last_compile_at = datetime.now(timezone.utc).isoformat() self._save_manifest(manifest) self._storage.force_sync() - # Generate document catalog for search-time routing self._build_document_catalog(manifest) - # Phase: Build summary index for embedding+BM25 fallback (optional, non-blocking) await self._build_summary_index(manifest) report.elapsed_seconds = time.monotonic() - t0 @@ -2553,7 +2568,13 @@ def _load_manifest(self) -> CompileManifest: return CompileManifest() def _save_manifest(self, manifest: CompileManifest) -> None: - self._manifest_path.write_text(manifest.to_json(), encoding="utf-8") + """Atomically persist the manifest via write-to-tmp + rename. + + This prevents partial JSON on disk if the process is killed mid-write. + """ + tmp_path = self._manifest_path.with_suffix(".json.tmp") + tmp_path.write_text(manifest.to_json(), encoding="utf-8") + tmp_path.replace(self._manifest_path) # ------------------------------------------------------------------ # # Document catalog for search-time routing # diff --git a/src/sirchmunk/search.py b/src/sirchmunk/search.py index 4b11a09..207b837 100644 --- a/src/sirchmunk/search.py +++ b/src/sirchmunk/search.py @@ -2231,6 +2231,14 @@ async def _llm_chunked_summarize(self, combined: str, query: str) -> str: _CHAR_RANGE_MAX_SPAN_RATIO: float = 0.8 """char_range spanning more than this ratio of the document is treated as invalid.""" + # --- Hierarchical file selection for large tree pools --- + _TREE_PREFILTER_THRESHOLD: int = 15 + """Tree pool size above which rule-based pre-filtering is applied.""" + _TREE_PREFILTER_MAX_CANDIDATES: int = 10 + """Maximum candidate trees forwarded to the LLM after pre-filtering.""" + _TREE_PREFILTER_MIN_SCORE: float = 0.5 + """Minimum relevance score for a tree to survive pre-filtering.""" + # --- Tree navigation retry (Plan 3) --- _NAV_RETRY_MIN_EVIDENCE_CHARS: int = 200 """Evidence below this length triggers a retry with expanded results.""" @@ -5100,32 +5108,82 @@ def _load_cached_trees(self) -> list: except Exception: return [] + @staticmethod + def _prefilter_trees_by_query( + query: str, trees: list, max_candidates: int, min_score: float, + ) -> list: + """Rule-based pre-filter: score trees by query-token overlap with filenames. + + Extracts meaningful tokens from the query (alphanumeric words, 4-digit + years, multi-word entity fragments) and scores each tree's filename by + weighted token overlap. Returns the top-scoring candidates, or the + full list if fewer than *max_candidates* pass the threshold. + + This avoids sending hundreds of root summaries to the LLM. + """ + raw_tokens = re.findall(r"[A-Za-z0-9]+", query.lower()) + tokens = [t for t in raw_tokens if len(t) >= 2 and t not in _STOP_WORDS] + if not tokens: + return trees + + year_tokens = {t for t in tokens if re.fullmatch(r"(?:19|20)\d{2}", t)} + entity_tokens = {t for t in tokens if len(t) >= 3 and t not in year_tokens} + + scored: List[Tuple[float, int]] = [] + for idx, tree in enumerate(trees): + name_lower = Path(tree.file_path).stem.lower() + name_parts = set(re.findall(r"[a-z0-9]+", name_lower)) + + score = 0.0 + for tok in entity_tokens: + if tok in name_lower: + score += 2.0 + elif any(tok[:4] in part for part in name_parts if len(tok) >= 4): + score += 0.5 + for yr in year_tokens: + if yr in name_lower: + score += 3.0 + + scored.append((score, idx)) + + scored.sort(key=lambda x: -x[0]) + + candidates = [trees[idx] for sc, idx in scored if sc >= min_score] + if not candidates: + return [trees[idx] for _, idx in scored[:max_candidates]] + return candidates[:max_candidates] + async def _llm_select_from_trees( self, query: str, trees: list, max_select: int, ) -> List[str]: - """LLM-driven file selection from tree root summaries. - - Presents root summaries to the LLM and returns the selected file - paths. When the number of trees is at most *max_select*, returns - all paths without an LLM call. + """Two-stage LLM-driven file selection from tree root summaries. - Args: - query: User query string. - trees: List of ``DocumentTree`` objects (pre-loaded). - max_select: Maximum number of files to select. + Stage 1 (rule-based): when the pool exceeds ``_TREE_PREFILTER_THRESHOLD``, + narrow candidates by query-token / filename overlap. + Stage 2 (LLM): present root summaries of the narrowed set for precise selection. - Returns: - Selected file paths, or empty list. + When the number of trees is at most *max_select*, returns all paths + without an LLM call. """ if not trees: return [] if len(trees) <= max_select: return [t.file_path for t in trees] + pool = trees + if len(pool) > self._TREE_PREFILTER_THRESHOLD: + pool = self._prefilter_trees_by_query( + query, pool, + max_candidates=self._TREE_PREFILTER_MAX_CANDIDATES, + min_score=self._TREE_PREFILTER_MIN_SCORE, + ) + if len(pool) <= max_select: + return [t.file_path for t in pool] + listing = "\n".join( f"[{i}] {Path(t.file_path).name}: " f"{(t.root.summary or '')[:self._CATALOG_SUMMARY_TRUNCATE]}" - for i, t in enumerate(trees) + for i, t in enumerate(pool) ) prompt = ( f'Given the query: "{query}"\n\n' @@ -5143,18 +5201,18 @@ async def _llm_select_from_trees( if m: selected_indices = [ idx for idx in json.loads(m.group()) - if isinstance(idx, int) and 0 <= idx < len(trees) + if isinstance(idx, int) and 0 <= idx < len(pool) ] except (json.JSONDecodeError, TypeError): pass if not selected_indices: - selected_indices = list(range(min(max_select, len(trees)))) + selected_indices = list(range(min(max_select, len(pool)))) return [ - trees[idx].file_path + pool[idx].file_path for idx in selected_indices[:max_select] - if Path(trees[idx].file_path).exists() + if Path(pool[idx].file_path).exists() ] async def _probe_tree_index(self, query: str) -> List[str]: From 93d4a1fc0c3a1b40978317cf727235c29ea96fa6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E7=8F=AD=E6=89=AC?= Date: Mon, 11 May 2026 16:13:11 +0800 Subject: [PATCH 48/56] improve compile efficiency --- src/sirchmunk/learnings/compiler.py | 157 ++++++++++++++++++---------- 1 file changed, 99 insertions(+), 58 deletions(-) diff --git a/src/sirchmunk/learnings/compiler.py b/src/sirchmunk/learnings/compiler.py index a2a193c..f21939a 100644 --- a/src/sirchmunk/learnings/compiler.py +++ b/src/sirchmunk/learnings/compiler.py @@ -71,6 +71,10 @@ # to survive interrupted compiles without excessive I/O overhead. _MANIFEST_FLUSH_INTERVAL = 10 +# Page-level extraction: max pages to load into memory per batch. +# Prevents loading all 200-400 pages of a large PDF at once. +_PAGE_SCAN_BATCH_SIZE = 50 + # Shared numeric-token regex for table detection heuristics. # Matches: $1,234 (1,234) 12.5% 3.14e-5 1,000 _NUM_TOKEN_RE = re.compile( @@ -374,6 +378,29 @@ def __init__( self._compile_dir.mkdir(parents=True, exist_ok=True) self._manifest_path = self._compile_dir / "manifest.json" + # ------------------------------------------------------------------ # + # Resource management # + # ------------------------------------------------------------------ # + + @staticmethod + def _configure_thread_limits() -> None: + """Cap PyTorch / OpenMP / MKL thread count to avoid runaway CPU and memory. + + Only sets defaults when the user has not already configured them via + environment variables, so explicit overrides are always respected. + The cap is half the available CPU cores, clamped to [1, 4]. + """ + cpu_count = os.cpu_count() or 4 + cap = str(max(1, min(cpu_count // 2, 4))) + for var in ("OMP_NUM_THREADS", "MKL_NUM_THREADS", "OPENBLAS_NUM_THREADS"): + if var not in os.environ: + os.environ[var] = cap + try: + import torch + torch.set_num_threads(int(cap)) + except ImportError: + pass + # ------------------------------------------------------------------ # # Public API # # ------------------------------------------------------------------ # @@ -398,6 +425,9 @@ async def compile( concurrency: Max parallel file compilations. """ import time + + self._configure_thread_limits() + t0 = time.monotonic() report = CompileReport() @@ -441,9 +471,11 @@ async def compile( f"(concurrency={concurrency})" ) - # Phase 2: compile files with bounded concurrency + # Phase 2 + 3 (fused): compile files, aggregate inline, release heavy objects + # Fusing Phase 3 into the completion loop avoids retaining all + # DocumentTree / EvidenceUnit objects until the end of the pipeline. semaphore = asyncio.Semaphore(concurrency) - results: List[FileCompileResult] = [] + _xref_pairs: List[Tuple[str, List[str]]] = [] # lightweight (path, cluster_ids) for Phase 4 _files_since_flush = 0 async def _bounded(entry: FileEntry) -> FileCompileResult: @@ -453,7 +485,6 @@ async def _bounded(entry: FileEntry) -> FileCompileResult: tasks = [_bounded(f) for f in to_compile] for coro in asyncio.as_completed(tasks): result = await coro - results.append(result) if result.error: report.errors.append(f"{result.path}: {result.error}") else: @@ -475,6 +506,16 @@ async def _bounded(entry: FileEntry) -> FileCompileResult: _mentry = manifest.files[result.path] print(f"SEARCH_WIKI_DEBUG [C4] manifest_entry: has_tree={_mentry.has_tree}, has_table_digest={_mentry.has_table_digest}, file_hash={_mentry.file_hash}", flush=True) + # Phase 3 inline: aggregate while the result is still alive + if not result.error and result.summary: + created, merged = await self._aggregate_to_knowledge_network(result) + report.clusters_created += created + report.clusters_merged += merged + + # Retain only lightweight cross-ref data, then drop the heavy result + _xref_pairs.append((result.path, list(result.cluster_ids))) + del result + # Incremental manifest flush to survive interrupted compiles _files_since_flush += 1 if _files_since_flush >= _MANIFEST_FLUSH_INTERVAL: @@ -482,22 +523,15 @@ async def _bounded(entry: FileEntry) -> FileCompileResult: self._save_manifest(manifest) _files_since_flush = 0 - # Phase 2 checkpoint: persist manifest before knowledge aggregation + # Phase 2 checkpoint: persist manifest before cross-references manifest.last_compile_at = datetime.now(timezone.utc).isoformat() self._save_manifest(manifest) - # Phase 3: aggregate results into knowledge network - await self._log.info("[Compile] Phase 3: Knowledge aggregation") - for r in results: - if r.error or not r.summary: - continue - created, merged = await self._aggregate_to_knowledge_network(r) - report.clusters_created += created - report.clusters_merged += merged - - # Phase 4: cross-references + # Phase 4: cross-references (uses only lightweight path+cluster_ids pairs) await self._log.info("[Compile] Phase 4: Building cross-references") - report.cross_refs_built = await self._build_cross_references(results) + report.cross_refs_built = await self._build_cross_references_from_pairs( + _xref_pairs, manifest, + ) # Phase 5: persist final manifest + derived indices # Catalog and summary index are rebuilt from the manifest, so even @@ -1265,25 +1299,23 @@ async def _create_cluster( # Cross-references # # ------------------------------------------------------------------ # - async def _build_cross_references( - self, results: List[FileCompileResult], + async def _build_cross_references_from_pairs( + self, + pairs: List[Tuple[str, List[str]]], + manifest: CompileManifest, ) -> int: """Build co-occurrence edges between clusters that share source files. - Two clusters are co-occurring when the same source file contributed - evidence to both (e.g., different sections compiled into different - clusters). Includes historical data from the manifest. + Accepts lightweight ``(path, cluster_ids)`` pairs instead of full + ``FileCompileResult`` objects to avoid retaining heavy compile results. + Includes historical data from the manifest. """ - # Build a complete map: cluster_id -> set of source file paths cluster_to_files: Dict[str, Set[str]] = {} - # From current compile results - for r in results: - for cid in r.cluster_ids: - cluster_to_files.setdefault(cid, set()).add(r.path) + for path, cluster_ids in pairs: + for cid in cluster_ids: + cluster_to_files.setdefault(cid, set()).add(path) - # From manifest (historical data) - manifest = self._load_manifest() for fp, entry in manifest.files.items(): for cid in entry.cluster_ids: cluster_to_files.setdefault(cid, set()).add(fp) @@ -2288,37 +2320,44 @@ async def _pypdf_page_scan( total_pages: int, covered_pages: Set[int], ) -> list[dict]: - """Primary scan: per-page pypdf extraction with density heuristics.""" - all_page_nums = list(range(1, total_pages + 1)) - try: - pages = DocumentExtractor.extract_pages(file_path, all_page_nums) - except Exception as exc: - await self._log.warning( - f"[Compile] Content-based scan: page read failed for " - f"{Path(file_path).name}: {exc}" - ) - return [] + """Primary scan: per-page pypdf extraction with density heuristics. + Pages are loaded in batches of ``_PAGE_SCAN_BATCH_SIZE`` to bound + peak memory when processing large PDFs (200-400+ pages). + """ results: list[dict] = [] poor_line_count = 0 - for pc in pages: - if len(pc.content.split("\n")) <= 3: - poor_line_count += 1 - if pc.page_number in covered_pages: - continue - if not self._page_has_table_density(pc.content): - continue - for region in self._identify_table_regions(pc.content): - results.append({ - "page": pc.page_number, - "content": region[:_TARGETED_TABLE_MAX_CHARS], - "source": "content_scan", - }) + + for batch_start in range(1, total_pages + 1, _PAGE_SCAN_BATCH_SIZE): + batch_end = min(batch_start + _PAGE_SCAN_BATCH_SIZE, total_pages + 1) + batch_pages = list(range(batch_start, batch_end)) + try: + pages = DocumentExtractor.extract_pages(file_path, batch_pages) + except Exception as exc: + await self._log.warning( + f"[Compile] Content-based scan: page read failed for " + f"{Path(file_path).name}: {exc}" + ) + return [] + + for pc in pages: + if len(pc.content.split("\n")) <= 3: + poor_line_count += 1 + if pc.page_number in covered_pages: + continue + if not self._page_has_table_density(pc.content): + continue + for region in self._identify_table_regions(pc.content): + results.append({ + "page": pc.page_number, + "content": region[:_TARGETED_TABLE_MAX_CHARS], + "source": "content_scan", + }) + del pages if results: return results - # Signal that pypdf line quality is poor — caller should try fallback if poor_line_count > total_pages * 0.5: return [] @@ -2498,7 +2537,8 @@ async def _build_summary_index(self, manifest: CompileManifest) -> None: The index is saved to .cache/compile/summary_index.json and consumed by search.py as a last-resort fallback when rga keyword search fails. - Skips gracefully if dependencies (EmbeddingUtil/TokenizerUtil) are unavailable. + Reuses ``self._embedding`` when available to avoid loading a duplicate + model into memory. Falls back to a fresh instance otherwise. """ try: from sirchmunk.utils.tokenizer_util import TokenizerUtil @@ -2518,7 +2558,6 @@ async def _build_summary_index(self, manifest: CompileManifest) -> None: if not entries: return - # Tokenize summaries + compute TF (always available) tokenizer = TokenizerUtil() for idx, entry in enumerate(entries): tokens = tokenizer.segment(entry.summary) @@ -2527,12 +2566,14 @@ async def _build_summary_index(self, manifest: CompileManifest) -> None: for t in tokens: entry.token_freqs[t] = entry.token_freqs.get(t, 0) + 1 - # Compute embeddings (optional — requires EmbeddingUtil) + # Reuse the compiler's embedding client to avoid duplicate model load try: - from sirchmunk.utils.embedding_util import EmbeddingUtil - embedding_util = EmbeddingUtil() - embedding_util.start_loading() - # Wait up to 60 seconds for model load + embedding_util = self._embedding + if embedding_util is None: + from sirchmunk.utils.embedding_util import EmbeddingUtil + embedding_util = EmbeddingUtil() + embedding_util.start_loading() + await embedding_util._ensure_model_async(timeout=60) if embedding_util.is_ready(): From 929fbc523250c276990d2cc53ab3aa6334836129 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E7=8F=AD=E6=89=AC?= Date: Mon, 11 May 2026 16:45:36 +0800 Subject: [PATCH 49/56] improve compile mem usage --- src/sirchmunk/cli/cli.py | 20 +++ src/sirchmunk/learnings/compiler.py | 182 +++++++++++++++------------- 2 files changed, 118 insertions(+), 84 deletions(-) diff --git a/src/sirchmunk/cli/cli.py b/src/sirchmunk/cli/cli.py index 99d6843..4aec43f 100644 --- a/src/sirchmunk/cli/cli.py +++ b/src/sirchmunk/cli/cli.py @@ -1242,6 +1242,22 @@ def cmd_mcp_version(args: argparse.Namespace) -> int: # sirchmunk compile # ------------------------------------------------------------------ + +def _configure_compile_threads() -> None: + """Set sensible thread-count defaults for CPU-bound ML workloads. + + Must be called early — before PyTorch, OpenMP, or kreuzberg's Rust + core are imported — so the environment variables are read at library + init time. User-provided overrides are always respected. + """ + cpu_count = os.cpu_count() or 4 + cap = str(max(1, min(cpu_count // 2, 4))) + for var in ("OMP_NUM_THREADS", "MKL_NUM_THREADS", "OPENBLAS_NUM_THREADS", + "RAYON_NUM_THREADS"): + if var not in os.environ: + os.environ[var] = cap + + def cmd_compile(args: argparse.Namespace) -> int: """Compile document collections into structured knowledge indices. @@ -1254,6 +1270,10 @@ def cmd_compile(args: argparse.Namespace) -> int: Returns: Exit code (0 for success, non-zero for failure) """ + # Cap thread counts BEFORE heavy libraries are imported, so OpenMP/MKL + # read the correct values at init time. User-set vars are respected. + _configure_compile_threads() + try: work_path = Path( getattr(args, "work_path", None) or str(_get_default_work_path()) diff --git a/src/sirchmunk/learnings/compiler.py b/src/sirchmunk/learnings/compiler.py index f21939a..5bd5bef 100644 --- a/src/sirchmunk/learnings/compiler.py +++ b/src/sirchmunk/learnings/compiler.py @@ -384,21 +384,19 @@ def __init__( @staticmethod def _configure_thread_limits() -> None: - """Cap PyTorch / OpenMP / MKL thread count to avoid runaway CPU and memory. + """Cap PyTorch thread count to reduce per-thread memory allocation. - Only sets defaults when the user has not already configured them via - environment variables, so explicit overrides are always respected. - The cap is half the available CPU cores, clamped to [1, 4]. + Environment variables (OMP_NUM_THREADS, etc.) are set in the CLI + entry point before libraries are imported. This method handles the + PyTorch-specific runtime API that works retroactively. """ cpu_count = os.cpu_count() or 4 - cap = str(max(1, min(cpu_count // 2, 4))) - for var in ("OMP_NUM_THREADS", "MKL_NUM_THREADS", "OPENBLAS_NUM_THREADS"): - if var not in os.environ: - os.environ[var] = cap + cap = max(1, min(cpu_count // 2, 4)) try: import torch - torch.set_num_threads(int(cap)) - except ImportError: + torch.set_num_threads(cap) + torch.set_num_interop_threads(max(1, cap // 2)) + except (ImportError, RuntimeError): pass # ------------------------------------------------------------------ # @@ -651,6 +649,10 @@ async def _compile_single_file( When *shallow* is True (or file is ineligible for tree indexing), the pipeline skips tree building and summarises via a direct LLM call. + + Large intermediate objects (extraction output, enriched content, + raw tables) are explicitly released after their last use to keep + per-file peak memory bounded. """ result = FileCompileResult(path=entry.path) print(f"SEARCH_WIKI_DEBUG [C1] _compile_single_file: file_path={entry.path}, file_hash={entry.file_hash}", flush=True) @@ -666,6 +668,11 @@ async def _compile_single_file( result.error = "Insufficient text content" return result + # Extract scalar metadata from extraction before releasing it + page_count = extraction.page_count + raw_tables = extraction.tables + del extraction + use_tree = ( not shallow and DocumentTreeIndexer.should_build_tree(entry.path, len(content)) @@ -677,7 +684,7 @@ async def _compile_single_file( from sirchmunk.learnings.toc_extractor import TOCExtractor toc_entries = await TOCExtractor.extract( entry.path, content, - total_pages=extraction.page_count, + total_pages=page_count, ) if toc_entries: await self._log.info( @@ -689,47 +696,53 @@ async def _compile_single_file( result.tree = await self._tree_indexer.build_tree( entry.path, content, toc_entries=toc_entries, - total_pages=extraction.page_count, + total_pages=page_count, ) - # Record TOC / tree metrics on the result for manifest persistence - result.has_explicit_toc = toc_entries is not None and len(toc_entries) > 0 + result.has_explicit_toc = bool(toc_entries) + del toc_entries result.tree_node_count = self._count_tree_nodes(result.tree) print(f"SEARCH_WIKI_DEBUG [C2] tree_build: success={result.tree is not None}, nodes={result.tree_node_count}, tree.file_path={result.tree.file_path if result.tree else 'N/A'}", flush=True) - # Enrich content with structural metadata for non-text types + # --- Summary + topics + evidence (needs content) --- ext = Path(entry.path).suffix.lower() evidence_digest = "" if ext in (".xlsx", ".xls"): - # Excel: use adaptive sampling for both metadata and evidence metadata_prefix, evidence_digest = self._extract_xlsx_sampling(entry.path) - enriched_content = metadata_prefix + content if metadata_prefix else content else: metadata_prefix = self._extract_structured_metadata(entry.path, content) - enriched_content = metadata_prefix + content if metadata_prefix else content - result.summary = await self._extract_summary( - entry.path, enriched_content, result.tree, - ) + # Build enriched_content only for the summary LLM call, then release + if metadata_prefix: + result.summary = await self._extract_summary( + entry.path, metadata_prefix + content, result.tree, + ) + else: + result.summary = await self._extract_summary( + entry.path, content, result.tree, + ) + del metadata_prefix + result.topics = await self._extract_topics(result.summary) result.evidence = self._build_evidence(entry, content, result) - # Persist Excel evidence digest for search-time consumption + # Persist Excel evidence digest if evidence_digest.strip(): try: digest_dir = self._compile_dir / "xlsx_digests" digest_dir.mkdir(parents=True, exist_ok=True) file_hash = get_fast_hash(entry.path) or "" if file_hash: - digest_path = digest_dir / f"{file_hash}.txt" - digest_path.write_text(evidence_digest, encoding="utf-8") + (digest_dir / f"{file_hash}.txt").write_text( + evidence_digest, encoding="utf-8", + ) result.has_xlsx_digest = True except Exception: pass + del evidence_digest - # Cache compile-time ENHANCED content so search can slice - # char_range from the same text the tree was built from. + # Cache ENHANCED content to disk try: file_hash_content = get_fast_hash(entry.path) or "" if file_hash_content and content: @@ -741,83 +754,84 @@ async def _compile_single_file( except Exception: pass - # Persist table digest for documents with extracted tables - if extraction.tables: + # --- Table digest + integration (needs raw_tables, then release) --- + if raw_tables: try: - table_digest = self._build_table_digest(extraction.tables) + table_digest = self._build_table_digest(raw_tables) if table_digest: digest_dir = self._compile_dir / "table_digests" digest_dir.mkdir(parents=True, exist_ok=True) file_hash = get_fast_hash(entry.path) or "" if file_hash: - digest_path = digest_dir / f"{file_hash}.json" - digest_path.write_text( + (digest_dir / f"{file_hash}.json").write_text( json.dumps(table_digest, ensure_ascii=False), encoding="utf-8", ) result.has_table_digest = True - result.table_count = len(extraction.tables) + result.table_count = len(raw_tables) except Exception: pass - print(f"SEARCH_WIKI_DEBUG [C3] table_digest: generated={result.has_table_digest}, count={result.table_count}", flush=True) - - # Integrate tables into tree: annotate counts + create table child nodes - if result.tree and result.tree.root and extraction.tables: - self._integrate_tables_into_tree( - result.tree.root, extraction.tables, - content=content, total_pages=extraction.page_count, - ) - - # Phase 2.5: Targeted table extraction via tree-node structural signals - if result.tree and result.tree.root and ext == ".pdf": - targeted_tables = await self._targeted_table_extraction( - entry.path, result.tree, - ) - await self._supplement_table_digest( - entry.path, targeted_tables, result, - source_label="Targeted extraction", - ) + if result.tree and result.tree.root: + self._integrate_tables_into_tree( + result.tree.root, raw_tables, + content=content, total_pages=page_count, + ) - # Phase 2.6: Content-based full-page table scan (tree-independent) - if ext == ".pdf" and extraction.page_count: - covered_pages = self._get_covered_table_pages(entry.path) - tree_root = ( - result.tree.root - if result.tree and result.tree.root else None - ) - content_tables = await self._content_based_table_scan( - entry.path, - extraction.page_count, - covered_pages, - enhanced_content=content, - tree_root=tree_root, - ) - await self._supplement_table_digest( - entry.path, content_tables, result, - source_label="Content-based scan", - ) + print(f"SEARCH_WIKI_DEBUG [C3] table_digest: generated={result.has_table_digest}, count={result.table_count}", flush=True) + del raw_tables + + # --- Phases 2.5-2.8: secondary table extraction (PDF only) --- + # These phases re-read from the PDF file; `content` is only + # needed for Phase 2.6 fallback and Phase 2.8 enrichment. + if ext == ".pdf": + if result.tree and result.tree.root: + targeted_tables = await self._targeted_table_extraction( + entry.path, result.tree, + ) + await self._supplement_table_digest( + entry.path, targeted_tables, result, + source_label="Targeted extraction", + ) + del targeted_tables - # Phase 2.7: Selective force-OCR for high-density gap pages - if ext == ".pdf" and extraction.page_count: - covered_after_scan = self._get_covered_table_pages(entry.path) - gap_pages = self._find_force_ocr_candidates( - entry.path, extraction.page_count, covered_after_scan, - ) - if gap_pages: - ocr_tables = await self._selective_force_ocr_tables( - entry.path, gap_pages, + if page_count: + covered_pages = self._get_covered_table_pages(entry.path) + tree_root = ( + result.tree.root + if result.tree and result.tree.root else None + ) + content_tables = await self._content_based_table_scan( + entry.path, page_count, covered_pages, + enhanced_content=content, tree_root=tree_root, ) await self._supplement_table_digest( - entry.path, ocr_tables, result, - source_label="Selective force-OCR", + entry.path, content_tables, result, + source_label="Content-based scan", ) + del content_tables - # Phase 2.8: Enrich targeted-extraction tables with ENHANCED content - if ext == ".pdf" and result.has_table_digest: - self._enrich_table_digest_content( - entry.path, content, tree_root=None, - ) + covered_after_scan = self._get_covered_table_pages(entry.path) + gap_pages = self._find_force_ocr_candidates( + entry.path, page_count, covered_after_scan, + ) + if gap_pages: + ocr_tables = await self._selective_force_ocr_tables( + entry.path, gap_pages, + ) + await self._supplement_table_digest( + entry.path, ocr_tables, result, + source_label="Selective force-OCR", + ) + del ocr_tables + + if result.has_table_digest: + self._enrich_table_digest_content( + entry.path, content, tree_root=None, + ) + + # Content is no longer needed — release before returning + del content except Exception as exc: result.error = str(exc) From 207fe59fcac81f3fe004fd15c29027bae9de45a4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E7=8F=AD=E6=89=AC?= Date: Mon, 11 May 2026 20:42:23 +0800 Subject: [PATCH 50/56] improve extractor multi-processing --- src/sirchmunk/learnings/compiler.py | 74 ++++++++++------- src/sirchmunk/utils/document_extractor.py | 96 +++++++++++++++++++++++ 2 files changed, 143 insertions(+), 27 deletions(-) diff --git a/src/sirchmunk/learnings/compiler.py b/src/sirchmunk/learnings/compiler.py index 5bd5bef..ee44f8e 100644 --- a/src/sirchmunk/learnings/compiler.py +++ b/src/sirchmunk/learnings/compiler.py @@ -9,9 +9,12 @@ import asyncio import bisect +import ctypes +import gc import json import math import os +import platform import random import re import hashlib @@ -75,6 +78,20 @@ # Prevents loading all 200-400 pages of a large PDF at once. _PAGE_SCAN_BATCH_SIZE = 50 +# How often to run gc.collect() inside the compile loop (every N files). +_GC_INTERVAL = 5 + + +def _force_gc() -> None: + """Aggressively reclaim Python-managed memory and nudge the C allocator.""" + gc.collect() + if platform.system() == "Linux": + try: + ctypes.CDLL("libc.so.6").malloc_trim(0) + except (OSError, AttributeError): + pass + + # Shared numeric-token regex for table detection heuristics. # Matches: $1,234 (1,234) 12.5% 3.14e-5 1,000 _NUM_TOKEN_RE = re.compile( @@ -475,6 +492,7 @@ async def compile( semaphore = asyncio.Semaphore(concurrency) _xref_pairs: List[Tuple[str, List[str]]] = [] # lightweight (path, cluster_ids) for Phase 4 _files_since_flush = 0 + _files_since_gc = 0 async def _bounded(entry: FileEntry) -> FileCompileResult: async with semaphore: @@ -521,6 +539,11 @@ async def _bounded(entry: FileEntry) -> FileCompileResult: self._save_manifest(manifest) _files_since_flush = 0 + _files_since_gc += 1 + if _files_since_gc >= _GC_INTERVAL: + _force_gc() + _files_since_gc = 0 + # Phase 2 checkpoint: persist manifest before cross-references manifest.last_compile_at = datetime.now(timezone.utc).isoformat() self._save_manifest(manifest) @@ -659,7 +682,7 @@ async def _compile_single_file( try: await self._log.info(f"[Compile] Processing: {Path(entry.path).name}") - extraction = await DocumentExtractor.extract( + extraction = await DocumentExtractor.extract_isolated( entry.path, DocumentExtractor.ENHANCED, ) content = extraction.content @@ -1218,10 +1241,11 @@ async def _aggregate_to_knowledge_network( def _encode_text(self, text: str) -> Optional[Any]: """Encode text to embedding vector, returns None on failure.""" - if not self._embedding: + if not self._embedding or not self._embedding.is_ready(): return None try: - return self._embedding.encode(text) + vectors = self._embedding._encode_sync([text]) + return vectors[0] if len(vectors) > 0 else None except Exception: return None @@ -2497,44 +2521,40 @@ async def _selective_force_ocr_tables( file_path: str, gap_pages: List[int], ) -> list[dict[str, Any]]: - """Re-extract specific pages with forced OCR + layout detection. + """Extract text from gap pages using pypdf (no kreuzberg re-call). - For pages where the native text layer was not recognized as tables - by kreuzberg's RT-DETR model, re-rendering as images may yield - better layout detection results. Uses ``force_ocr_pages`` so only - the targeted pages are OCR'd (no full-document penalty). + Earlier versions spawned a second kreuzberg extraction with + ``force_ocr_pages``, which doubled native memory pressure. + Using pypdf instead avoids Rust/native allocations entirely + while still capturing page text for the table digest. Args: file_path: Path to the PDF. - gap_pages: 0-indexed page numbers to force OCR on. Capped at - :data:`_FORCE_OCR_MAX_PAGES` to bound compile time. + gap_pages: 0-indexed page numbers. Returns: - List of kreuzberg-format table dicts (with ``markdown``, - ``cells``, ``page_number``). + List of table-compatible dicts (``markdown``, ``page_number``). """ - from sirchmunk.utils.document_extractor import ExtractionProfile - if not gap_pages: return [] capped = sorted(gap_pages)[:_FORCE_OCR_MAX_PAGES] - - profile = ExtractionProfile( - output_format="markdown", - extract_tables=True, - force_ocr_pages=tuple(capped), - ) + one_indexed = [p + 1 for p in capped] try: - extraction = await DocumentExtractor.extract(file_path, profile) - except Exception as exc: - await self._log.warning( - f"[Compile] Selective force-OCR failed for " - f"{Path(file_path).name}: {exc}" - ) + pages = DocumentExtractor.extract_pages(file_path, one_indexed) + except Exception: return [] - return extraction.tables + tables: list[dict[str, Any]] = [] + for pc in pages: + text = (pc.content or "").strip() + if text and self._page_has_table_density(text): + tables.append({ + "markdown": text, + "cells": [], + "page_number": pc.page_number, + }) + return tables # ------------------------------------------------------------------ # # Summary index for embedding + BM25 fallback # diff --git a/src/sirchmunk/utils/document_extractor.py b/src/sirchmunk/utils/document_extractor.py index b2835f5..68670a3 100644 --- a/src/sirchmunk/utils/document_extractor.py +++ b/src/sirchmunk/utils/document_extractor.py @@ -11,6 +11,9 @@ from __future__ import annotations import asyncio +import concurrent.futures +import dataclasses +import os from dataclasses import dataclass, field from pathlib import Path from typing import Any, ClassVar, List, Optional, Sequence, Union @@ -18,6 +21,41 @@ from loguru import logger +# --------------------------------------------------------------------------- +# Top-level helper for subprocess-based extraction (must be picklable) +# --------------------------------------------------------------------------- + +def _extract_in_worker( + file_path: str, + profile_dict: dict[str, Any], +) -> dict[str, Any]: + """Run kreuzberg extraction inside a worker process. + + Returns a plain dict so the result crosses the process boundary + without dragging native kreuzberg objects (and their Rust allocations) + back into the parent process. + """ + import asyncio as _aio + + async def _run() -> dict[str, Any]: + from sirchmunk.utils.document_extractor import ( + DocumentExtractor, + ExtractionProfile, + ) + profile = ExtractionProfile(**profile_dict) + output = await DocumentExtractor.extract(file_path, profile) + return { + "content": output.content, + "mime_type": output.mime_type, + "metadata": output.metadata, + "tables": output.tables, + "detected_languages": output.detected_languages, + "page_count": output.page_count, + } + + return _aio.run(_run()) + + # --------------------------------------------------------------------------- # Configuration profile # --------------------------------------------------------------------------- @@ -231,6 +269,64 @@ async def extract( ) raise + # Shared process pool — lazily created, workers exit after every task + # so the OS reclaims all native memory (Rust arenas, layout-model caches). + _process_pool: ClassVar[Optional[concurrent.futures.ProcessPoolExecutor]] = None + _POOL_WORKERS: ClassVar[int] = max(1, min(os.cpu_count() or 4, 3)) + + @classmethod + def _get_process_pool(cls) -> concurrent.futures.ProcessPoolExecutor: + if cls._process_pool is None: + cls._process_pool = concurrent.futures.ProcessPoolExecutor( + max_workers=cls._POOL_WORKERS, + max_tasks_per_child=1, + ) + return cls._process_pool + + @staticmethod + async def extract_isolated( + file_path: Union[str, Path], + profile: Optional[ExtractionProfile] = None, + ) -> ExtractionOutput: + """Extract content in an isolated subprocess. + + Identical to :meth:`extract` but runs kreuzberg inside a child + process. ``max_tasks_per_child=1`` ensures each worker exits + after one extraction, allowing the OS to reclaim all native + memory (Rust arenas, layout-model buffers, image caches). + + Falls back to in-process extraction on subprocess failure. + """ + profile = profile or DocumentExtractor.BASIC + profile_dict = { + f.name: getattr(profile, f.name) + for f in dataclasses.fields(profile) + } + + loop = asyncio.get_event_loop() + pool = DocumentExtractor._get_process_pool() + try: + raw = await loop.run_in_executor( + pool, + _extract_in_worker, + str(file_path), + profile_dict, + ) + return ExtractionOutput( + content=raw["content"], + mime_type=raw.get("mime_type", ""), + metadata=raw.get("metadata", {}), + tables=raw.get("tables", []), + detected_languages=raw.get("detected_languages", {}), + page_count=raw.get("page_count"), + ) + except Exception as exc: + logger.warning( + "Subprocess extraction failed for {}, falling back to in-process: {}", + file_path, exc, + ) + return await DocumentExtractor.extract(file_path, profile) + @staticmethod async def extract_bytes( data: bytes, From bbc2bbd6d3d2bd955b60d5fe7510c63c2a0daa6d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E7=8F=AD=E6=89=AC?= Date: Mon, 11 May 2026 21:02:57 +0800 Subject: [PATCH 51/56] fix ProcessPoolExecutor --- src/sirchmunk/utils/document_extractor.py | 135 ++++++++++++++-------- 1 file changed, 89 insertions(+), 46 deletions(-) diff --git a/src/sirchmunk/utils/document_extractor.py b/src/sirchmunk/utils/document_extractor.py index 68670a3..f115687 100644 --- a/src/sirchmunk/utils/document_extractor.py +++ b/src/sirchmunk/utils/document_extractor.py @@ -11,8 +11,8 @@ from __future__ import annotations import asyncio -import concurrent.futures import dataclasses +import multiprocessing as mp import os from dataclasses import dataclass, field from pathlib import Path @@ -22,38 +22,93 @@ # --------------------------------------------------------------------------- -# Top-level helper for subprocess-based extraction (must be picklable) +# Subprocess extraction helpers (module-level for picklability) # --------------------------------------------------------------------------- -def _extract_in_worker( +_EXTRACT_TIMEOUT_S = 600 + + +def _extraction_worker( file_path: str, profile_dict: dict[str, Any], -) -> dict[str, Any]: - """Run kreuzberg extraction inside a worker process. + pipe_w: mp.connection.Connection, +) -> None: + """Child process entry point: run kreuzberg, send result via pipe, exit. - Returns a plain dict so the result crosses the process boundary - without dragging native kreuzberg objects (and their Rust allocations) - back into the parent process. + Sends a plain dict so no native kreuzberg/Rust objects cross the + process boundary. On failure sends ``{"_error": ""}``. """ - import asyncio as _aio + try: + import asyncio as _aio + + async def _run() -> dict[str, Any]: + from sirchmunk.utils.document_extractor import ( + DocumentExtractor, + ExtractionProfile, + ) + profile = ExtractionProfile(**profile_dict) + output = await DocumentExtractor.extract(file_path, profile) + return { + "content": output.content, + "mime_type": output.mime_type, + "metadata": output.metadata, + "tables": output.tables, + "detected_languages": output.detected_languages, + "page_count": output.page_count, + } + + pipe_w.send(_aio.run(_run())) + except BaseException as exc: + try: + pipe_w.send({"_error": str(exc)}) + except Exception: + pass + finally: + pipe_w.close() + + +def _run_extraction_in_child( + file_path: str, + profile_dict: dict[str, Any], +) -> dict[str, Any]: + """Spawn an isolated child process, wait for its result. - async def _run() -> dict[str, Any]: - from sirchmunk.utils.document_extractor import ( - DocumentExtractor, - ExtractionProfile, + Unlike ``ProcessPoolExecutor``, a crash in one child never + poisons future extractions — each call spawns a fresh process. + """ + pipe_r, pipe_w = mp.Pipe(duplex=False) + proc = mp.Process( + target=_extraction_worker, + args=(file_path, profile_dict, pipe_w), + daemon=True, + ) + proc.start() + pipe_w.close() + + try: + if not pipe_r.poll(timeout=_EXTRACT_TIMEOUT_S): + proc.kill() + proc.join(timeout=10) + raise RuntimeError( + f"Extraction timed out after {_EXTRACT_TIMEOUT_S}s" + ) + result = pipe_r.recv() + except EOFError: + proc.join(timeout=10) + raise RuntimeError( + f"Worker crashed (exit code {proc.exitcode})" ) - profile = ExtractionProfile(**profile_dict) - output = await DocumentExtractor.extract(file_path, profile) - return { - "content": output.content, - "mime_type": output.mime_type, - "metadata": output.metadata, - "tables": output.tables, - "detected_languages": output.detected_languages, - "page_count": output.page_count, - } + finally: + pipe_r.close() + + proc.join(timeout=30) + if proc.is_alive(): + proc.kill() + proc.join() - return _aio.run(_run()) + if isinstance(result, dict) and "_error" in result: + raise RuntimeError(result["_error"]) + return result # --------------------------------------------------------------------------- @@ -269,31 +324,20 @@ async def extract( ) raise - # Shared process pool — lazily created, workers exit after every task - # so the OS reclaims all native memory (Rust arenas, layout-model caches). - _process_pool: ClassVar[Optional[concurrent.futures.ProcessPoolExecutor]] = None - _POOL_WORKERS: ClassVar[int] = max(1, min(os.cpu_count() or 4, 3)) - - @classmethod - def _get_process_pool(cls) -> concurrent.futures.ProcessPoolExecutor: - if cls._process_pool is None: - cls._process_pool = concurrent.futures.ProcessPoolExecutor( - max_workers=cls._POOL_WORKERS, - max_tasks_per_child=1, - ) - return cls._process_pool - @staticmethod async def extract_isolated( file_path: Union[str, Path], profile: Optional[ExtractionProfile] = None, ) -> ExtractionOutput: - """Extract content in an isolated subprocess. + """Extract content in a fully isolated child process. + + Each call spawns a fresh ``multiprocessing.Process``. When the + child exits (normally or via crash), the OS reclaims **all** of + its native memory — Rust arenas, layout-model buffers, image + caches — guaranteeing zero accumulation in the parent. - Identical to :meth:`extract` but runs kreuzberg inside a child - process. ``max_tasks_per_child=1`` ensures each worker exits - after one extraction, allowing the OS to reclaim all native - memory (Rust arenas, layout-model buffers, image caches). + Unlike ``ProcessPoolExecutor``, a crash in one extraction never + poisons future calls. Falls back to in-process extraction on subprocess failure. """ @@ -304,11 +348,10 @@ async def extract_isolated( } loop = asyncio.get_event_loop() - pool = DocumentExtractor._get_process_pool() try: raw = await loop.run_in_executor( - pool, - _extract_in_worker, + None, + _run_extraction_in_child, str(file_path), profile_dict, ) From 5af51df12884aa5798ed23fc618ef0f11e579029 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E7=8F=AD=E6=89=AC?= Date: Mon, 11 May 2026 21:19:39 +0800 Subject: [PATCH 52/56] clean methods for compiler --- src/sirchmunk/learnings/compiler.py | 56 ++++++++++++++++++++++++++--- 1 file changed, 51 insertions(+), 5 deletions(-) diff --git a/src/sirchmunk/learnings/compiler.py b/src/sirchmunk/learnings/compiler.py index ee44f8e..e812bef 100644 --- a/src/sirchmunk/learnings/compiler.py +++ b/src/sirchmunk/learnings/compiler.py @@ -458,8 +458,10 @@ async def compile( to_compile = changes.added + changes.modified report.files_skipped = len(changes.unchanged) report.files_deleted = len(changes.deleted) - for deleted_path in changes.deleted: - manifest.files.pop(deleted_path, None) + + stale_paths = changes.deleted + [e.path for e in changes.modified] + if stale_paths: + await self._purge_stale_artifacts(stale_paths, manifest) else: to_compile = discovered report.files_skipped = 0 @@ -519,8 +521,6 @@ async def _bounded(entry: FileEntry) -> FileCompileResult: has_table_digest=result.has_table_digest, table_count=result.table_count, ) - _mentry = manifest.files[result.path] - print(f"SEARCH_WIKI_DEBUG [C4] manifest_entry: has_tree={_mentry.has_tree}, has_table_digest={_mentry.has_table_digest}, file_hash={_mentry.file_hash}", flush=True) # Phase 3 inline: aggregate while the result is still alive if not result.error and result.summary: @@ -658,6 +658,53 @@ def _detect_changes( return changes + # ------------------------------------------------------------------ # + # Stale artifact cleanup # + # ------------------------------------------------------------------ # + + async def _purge_stale_artifacts( + self, + file_paths: List[str], + manifest: CompileManifest, + ) -> None: + """Remove disk artifacts and DuckDB clusters for deleted/modified files. + + Called before recompilation so that modified files start with a + clean slate and deleted files leave no residue. + """ + artifact_dirs = { + "trees": ".json", + "content": ".txt", + "table_digests": ".json", + "xlsx_digests": ".txt", + } + + for file_path in file_paths: + entry = manifest.files.get(file_path) + if entry is None: + continue + + file_hash = entry.file_hash + + # 1. Remove disk artifacts keyed by file_hash + if file_hash: + for subdir, ext in artifact_dirs.items(): + artifact = self._compile_dir / subdir / f"{file_hash}{ext}" + try: + artifact.unlink(missing_ok=True) + except OSError: + pass + + # 2. Remove associated knowledge clusters from DuckDB + for cluster_id in entry.cluster_ids: + try: + await self._storage.remove(cluster_id) + except Exception: + pass + + # 3. Drop the manifest entry + manifest.files.pop(file_path, None) + # ------------------------------------------------------------------ # # Single-file compilation # # ------------------------------------------------------------------ # @@ -678,7 +725,6 @@ async def _compile_single_file( per-file peak memory bounded. """ result = FileCompileResult(path=entry.path) - print(f"SEARCH_WIKI_DEBUG [C1] _compile_single_file: file_path={entry.path}, file_hash={entry.file_hash}", flush=True) try: await self._log.info(f"[Compile] Processing: {Path(entry.path).name}") From af5f7e16fd10cfb226c2e9799a8aa156ad8b5e84 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E7=8F=AD=E6=89=AC?= Date: Tue, 12 May 2026 14:21:04 +0800 Subject: [PATCH 53/56] improve all corpus --- src/sirchmunk/search.py | 126 +++++++++++++++++++++++++++++++++++----- 1 file changed, 113 insertions(+), 13 deletions(-) diff --git a/src/sirchmunk/search.py b/src/sirchmunk/search.py index 207b837..64a702d 100644 --- a/src/sirchmunk/search.py +++ b/src/sirchmunk/search.py @@ -87,6 +87,54 @@ # Soft-similarity threshold for gradient cluster reuse (P2) _SOFT_SIM_THRESHOLD = 0.65 + +class _PathScope: + """Immutable search-path scope for filtering compile artifacts. + + Resolves the provided search paths into absolute file paths and + directory prefixes, then offers ``contains()`` to test whether a + given artifact path falls within this scope. + + When the scope is empty (no paths provided), ``contains()`` always + returns True — i.e. *no filtering* is applied. + """ + + __slots__ = ("_files", "_dirs", "_empty") + + def __init__(self, search_paths: Optional[List[str]] = None) -> None: + files: Set[str] = set() + dirs: List[str] = [] + if search_paths: + for p in search_paths: + resolved = str(Path(p).expanduser().resolve()) + if Path(resolved).is_file(): + files.add(resolved) + elif Path(resolved).is_dir(): + dirs.append( + resolved if resolved.endswith(os.sep) + else resolved + os.sep + ) + else: + files.add(resolved) + self._files = frozenset(files) + self._dirs = tuple(dirs) + self._empty = not files and not dirs + + def contains(self, file_path: str) -> bool: + """Return True when *file_path* falls within the search scope.""" + if self._empty: + return True + if not file_path: + return False + resolved = str(Path(file_path).expanduser().resolve()) + if resolved in self._files: + return True + return any(resolved.startswith(d) for d in self._dirs) + + @property + def is_empty(self) -> bool: + return self._empty + # Pure tree search mode for ablation experiments. # When enabled, search relies solely on tree index navigation, skipping rga keyword search. _PURE_TREE_SEARCH: bool = os.getenv("SIRCHMUNK_PURE_TREE_SEARCH", "false").lower() == "true" @@ -1556,7 +1604,8 @@ async def _search_deep( _llm_usage_start = len(self.llm_usages) # --- Adaptive compile artifact detection (shared with FAST) --- - artifacts = self._detect_compile_artifacts() + _scope = _PathScope(paths) + artifacts = self._detect_compile_artifacts(paths) # ============================================================== # Phase 0a: Direct document analysis (intent-gated short-circuit) @@ -1591,8 +1640,8 @@ async def _search_deep( self._probe_knowledge_cache(query), self._load_spec_context(paths, stale_hours=spec_stale_hours), self._probe_tree_index(query), - self._probe_compile_hints([query]), # query-level hints; keyword-level runs post-Phase 1 - self._probe_summary_index(query, artifacts), # GAP 2: zero-LLM BM25 + self._probe_compile_hints([query], scope=_scope), # query-level hints; keyword-level runs post-Phase 1 + self._probe_summary_index(query, artifacts, scope=_scope), # GAP 2: zero-LLM BM25 self._probe_catalog_for_deep(query, artifacts), # GAP 4: zero-LLM keyword overlap return_exceptions=True, ) @@ -2322,7 +2371,8 @@ async def _search_fast( self._tree_nav_cache = _TreeNavCache() # --- Adaptive compile artifact detection (one-shot, zero LLM) --- - artifacts = self._detect_compile_artifacts() + _scope = _PathScope(paths) + artifacts = self._detect_compile_artifacts(paths) if artifacts.catalog or artifacts.tree_available_paths: await self._logger.info( f"[FAST:Artifacts] catalog={'yes' if artifacts.catalog else 'no'} " @@ -2375,7 +2425,7 @@ async def _search_fast( messages=[{"role": "user", "content": prompt}], stream=False, ) - _compile_hints_task = self._probe_compile_hints([query]) + _compile_hints_task = self._probe_compile_hints([query], scope=_scope) _tree_probe_task = self._probe_tree_for_fast(query, artifacts) _parallel_results = await asyncio.gather( @@ -2494,7 +2544,7 @@ async def _search_fast( keyword_idfs.setdefault(p, 0.6) # P4: compile hints — pre-fetched (query-level) + keyword-level supplement - _kw_compile_hints = await self._probe_compile_hints(primary + fallback) + _kw_compile_hints = await self._probe_compile_hints(primary + fallback, scope=_scope) compile_hints = self._merge_compile_hints(_early_compile_hints, _kw_compile_hints) for kw in compile_hints.extra_keywords: if kw not in all_kw_set: @@ -2515,7 +2565,7 @@ async def _search_fast( seen_hint_paths.add(fp) compile_hint_files.append(fp) # Summary index BM25 files: proactive zero-LLM discovery (GAP 2) - _summary_hint_files = await self._probe_summary_index(query, artifacts) + _summary_hint_files = await self._probe_summary_index(query, artifacts, scope=_scope) for fp in _summary_hint_files: if fp not in seen_hint_paths: seen_hint_paths.add(fp) @@ -3551,7 +3601,10 @@ def _load_document_catalog(self) -> Optional[List[Dict[str, str]]]: pass return None - def _detect_compile_artifacts(self) -> CompileArtifacts: + def _detect_compile_artifacts( + self, + search_paths: Optional[List[str]] = None, + ) -> CompileArtifacts: """One-shot probe of all compile artifacts for adaptive FAST activation. Reads the document catalog and scans the tree cache directory to @@ -3559,12 +3612,19 @@ def _detect_compile_artifacts(self) -> CompileArtifacts: start of ``_search_fast()``; the result is passed to downstream helpers so they can enable enhanced logic only when artifacts exist. + When *search_paths* is provided, all returned artifacts are filtered + to only include entries whose file paths fall within the search scope. + This ensures downstream consumers (catalog routing, tree probing, + summary index) never see documents outside the requested scope. + Cost: one JSON read (catalog) + one directory listing (tree cache). Tree path results are cached in ``_tree_paths_cache`` so subsequent calls within the same instance avoid re-parsing every JSON file. Returns a ``CompileArtifacts`` with ``None``/empty fields when compile has not been run. """ + scope = _PathScope(search_paths) + catalog = self._load_document_catalog() catalog_map: Dict[str, Dict[str, str]] = {} if catalog: @@ -3623,6 +3683,14 @@ def _detect_compile_artifacts(self) -> CompileArtifacts: except Exception: pass + # --- Apply search-path scope filtering --- + if not scope.is_empty: + if catalog: + catalog = [e for e in catalog if scope.contains(e.get("path", ""))] + catalog_map = {p: e for p, e in catalog_map.items() if scope.contains(p)} + tree_paths = {p for p in tree_paths if scope.contains(p)} + manifest_map = {p: e for p, e in manifest_map.items() if scope.contains(p)} + print(f"SEARCH_WIKI_DEBUG [D1] manifest_map: {len(manifest_map)} entries, keys={list(manifest_map.keys())[:3]}", flush=True) print(f"SEARCH_WIKI_DEBUG [D2] tree_available_paths: {tree_paths}", flush=True) print(f"SEARCH_WIKI_DEBUG [D3] manifest_fallback_executed: {manifest_map and not tree_paths}", flush=True) @@ -5126,8 +5194,16 @@ def _prefilter_trees_by_query( if not tokens: return trees - year_tokens = {t for t in tokens if re.fullmatch(r"(?:19|20)\d{2}", t)} - entity_tokens = {t for t in tokens if len(t) >= 3 and t not in year_tokens} + # Extract years: bare "2018" and compound prefixed forms "fy2018", "cy2023" + year_tokens: Set[str] = set() + for t in tokens: + if re.fullmatch(r"(?:19|20)\d{2}", t): + year_tokens.add(t) + else: + m = re.search(r"((?:19|20)\d{2})", t) + if m: + year_tokens.add(m.group(1)) + entity_tokens = {t for t in tokens if len(t) >= 2 and t not in year_tokens} scored: List[Tuple[float, int]] = [] for idx, tree in enumerate(trees): @@ -5238,12 +5314,20 @@ async def _probe_tree_index(self, query: str) -> List[str]: except Exception: return [] - async def _probe_compile_hints(self, keywords: List[str]) -> CompileHints: + async def _probe_compile_hints( + self, + keywords: List[str], + *, + scope: Optional["_PathScope"] = None, + ) -> CompileHints: """Zero-LLM enrichment from compile manifest and tree cache. Scans the compile manifest for clusters whose patterns overlap with the query keywords, and scans cached tree root summaries for keyword matches. No LLM calls — only local JSON reads and in-memory DB lookups. + + When *scope* is provided, only file paths falling within the scope + are included in the returned hints. """ empty = CompileHints([], []) if not keywords: @@ -5255,6 +5339,11 @@ async def _probe_compile_hints(self, keywords: List[str]) -> CompileHints: seen_paths: set = set() seen_kw: set = set(kw_lower) + def _accept(fp: str) -> bool: + return bool(fp) and fp not in seen_paths and Path(fp).exists() and ( + scope is None or scope.contains(fp) + ) + # --- Cluster pattern matching via manifest --- manifest_path = self.work_path / ".cache" / "compile" / "manifest.json" if manifest_path.exists(): @@ -5280,7 +5369,7 @@ async def _probe_compile_hints(self, keywords: List[str]) -> CompileHints: if kw_lower & set(cluster_patterns): for ev in getattr(c, "evidences", []): fp = str(getattr(ev, "file_or_url", "")) - if fp and fp not in seen_paths and Path(fp).exists(): + if _accept(fp): seen_paths.add(fp) file_paths.append(fp) for p in cluster_patterns: @@ -5307,7 +5396,7 @@ async def _probe_compile_hints(self, keywords: List[str]) -> CompileHints: summary_lower = (tree.root.summary or "").lower() if any(kw in summary_lower for kw in kw_lower): fp = tree.file_path - if fp not in seen_paths and Path(fp).exists(): + if _accept(fp): seen_paths.add(fp) file_paths.append(fp) except Exception: @@ -5339,6 +5428,8 @@ async def _probe_summary_index( self, query: str, artifacts: Optional["CompileArtifacts"] = None, + *, + scope: Optional["_PathScope"] = None, ) -> List[str]: """Zero-LLM file discovery via compile-time summary index (BM25 only). @@ -5346,9 +5437,13 @@ async def _probe_summary_index( summaries are lexically similar to the query. No LLM or embedding calls — pure local computation. + When *scope* is provided, results are post-filtered to only include + file paths within the search scope. + Args: query: User query string. artifacts: Compile artifacts (uses summary_index field). + scope: Optional path scope for filtering results. Returns: File paths of top-k matching documents, or empty list. @@ -5374,6 +5469,7 @@ async def _probe_summary_index( file_paths = [ fp for fp, score in results if score > 0.0 and Path(fp).exists() + and (scope is None or scope.contains(fp)) ] if file_paths: @@ -5459,6 +5555,10 @@ async def _probe_tree_for_fast( try: trees = self._load_cached_trees() + # Scope-filter: only keep trees whose files are in artifacts + if artifacts and artifacts.tree_available_paths: + scoped = artifacts.tree_available_paths + trees = [t for t in trees if t.file_path in scoped] print(f"SEARCH_WIKI_DEBUG [D5] loaded_trees: {len(trees)} trees, paths={[t.file_path for t in trees][:3]}", flush=True) if not trees: return [] From 7439521a810db8fe897ccc8cf520eabc0700a83b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E7=8F=AD=E6=89=AC?= Date: Tue, 12 May 2026 16:04:10 +0800 Subject: [PATCH 54/56] tree index and rga fusion --- src/sirchmunk/search.py | 154 +++++++++++++++++++++++++++++++++++----- 1 file changed, 138 insertions(+), 16 deletions(-) diff --git a/src/sirchmunk/search.py b/src/sirchmunk/search.py index 64a702d..3507577 100644 --- a/src/sirchmunk/search.py +++ b/src/sirchmunk/search.py @@ -2225,7 +2225,7 @@ async def _llm_chunked_summarize(self, combined: str, query: str) -> str: ".css", ".bash", ".java", ".c", ".cpp", ".h", ".go", ".rs", } _FAST_CONTEXT_WINDOW = 30 # ± lines around each grep hit - _FAST_MAX_EVIDENCE_CHARS = 20_000 # Plan 5: expanded from 15K to accommodate richer table evidence + _FAST_MAX_EVIDENCE_CHARS = 40_000 _FAST_SMALL_FILE_THRESHOLD = 100_000 # 100K chars - read full file instead of grep sampling # --- Wiki-enhanced ranking constants --- @@ -2261,7 +2261,7 @@ async def _llm_chunked_summarize(self, combined: str, query: str) -> str: """Maximum files returned by catalog keyword-overlap probe in DEEP mode.""" # --- Tree-guided sampling constants --- - _TREE_SAMPLE_MAX_SECTIONS = 5 + _TREE_SAMPLE_MAX_SECTIONS = 8 """Max tree sections to include per file in tree-guided sampling.""" _TREE_SAMPLE_SECTION_MAX_CHARS = 3000 """Max chars per tree section.""" @@ -2280,6 +2280,10 @@ async def _llm_chunked_summarize(self, combined: str, query: str) -> str: _CHAR_RANGE_MAX_SPAN_RATIO: float = 0.8 """char_range spanning more than this ratio of the document is treated as invalid.""" + # --- Tree probe / RGA fusion --- + _TREE_PROBE_RANKING_BOOST: float = 3.0 + """Score boost (0-10 scale) for files selected by LLM tree probing.""" + # --- Hierarchical file selection for large tree pools --- _TREE_PREFILTER_THRESHOLD: int = 15 """Tree pool size above which rule-based pre-filtering is applied.""" @@ -2288,10 +2292,12 @@ async def _llm_chunked_summarize(self, combined: str, query: str) -> str: _TREE_PREFILTER_MIN_SCORE: float = 0.5 """Minimum relevance score for a tree to survive pre-filtering.""" - # --- Tree navigation retry (Plan 3) --- + # --- Tree navigation --- + _TREE_NAV_MAX_RESULTS: int = 8 + """Primary max_results for LLM-driven tree navigation.""" _NAV_RETRY_MIN_EVIDENCE_CHARS: int = 200 """Evidence below this length triggers a retry with expanded results.""" - _NAV_RETRY_EXPANDED_RESULTS: int = 8 + _NAV_RETRY_EXPANDED_RESULTS: int = 12 """Expanded max_results for retry navigation pass.""" _CHAR_RANGE_MIN_SPAN: int = 200 @@ -2305,12 +2311,12 @@ async def _llm_chunked_summarize(self, combined: str, query: str) -> str: _NAV_COMPLEMENT_MIN_COMPONENTS: int = 2 """Minimum query decomposition components to trigger complementary navigation.""" - # --- Table evidence budgets (Plan 5) --- - _TABLE_EVIDENCE_DEFAULT_CHARS: int = 10_000 - """Default max_chars for _format_table_evidence (was 6000).""" + # --- Table evidence budgets --- + _TABLE_EVIDENCE_DEFAULT_CHARS: int = 20_000 + """Default max_chars for _format_table_evidence.""" _TABLE_EVIDENCE_PER_RANGE_CHARS: int = 8_000 - """Max chars for per-page-range table supplement in tree nav (was 4000).""" - _TABLE_EVIDENCE_STANDALONE_CHARS: int = 12_000 + """Max chars for per-page-range table supplement in tree nav.""" + _TABLE_EVIDENCE_STANDALONE_CHARS: int = 20_000 """Max chars for standalone table digest fallback when tree nav evidence is thin.""" # --- Self-correction expanded sampling --- @@ -2445,6 +2451,7 @@ async def _search_fast( if isinstance(_tree_probed_files, Exception): await self._logger.warning(f"[FAST:Step1] Tree probe failed: {_tree_probed_files}") _tree_probed_files = [] + _tree_probed_set: frozenset[str] = frozenset(_tree_probed_files) self.llm_usages.append(resp.usage) if resp.usage and isinstance(resp.usage, dict): @@ -2671,10 +2678,26 @@ async def _search_fast( for p in catalog_routed_files[:top_k_files] ] + # Narrow-scope RGA: search within tree-probed files first + if not best_files and _tree_probed_set and primary: + best_files = await self._fast_find_best_file( + primary, paths=list(_tree_probed_set), + top_k=top_k_files, keyword_idfs=keyword_idfs, + query=query, artifacts=artifacts, + ) + if best_files: + used_level = "tree_rga" + await self._logger.info( + f"[FAST:Step2] Narrow-scope tree+rga hit → " + f"{[Path(f['path']).name for f in best_files]}" + ) + + # Full-scope RGA with tree probe boost if not best_files and primary: best_files = await self._fast_find_best_file( primary, top_k=top_k_files, keyword_idfs=keyword_idfs, query=query, artifacts=artifacts, + tree_probed_paths=_tree_probed_set or None, **rga_kwargs, ) @@ -2686,6 +2709,7 @@ async def _search_fast( best_files = await self._fast_find_best_file( fallback, top_k=top_k_files, keyword_idfs=keyword_idfs, query=query, artifacts=artifacts, + tree_probed_paths=_tree_probed_set or None, **rga_kwargs, ) @@ -2756,7 +2780,11 @@ async def _search_fast( print(f"SEARCH_WIKI_DEBUG [D11] MISMATCH! tree_available_paths={artifacts.tree_available_paths}", flush=True) if artifacts and tree_nav_target in artifacts.tree_available_paths: - tree_task = self._navigate_tree_for_evidence(tree_nav_target, query) + tree_task = self._navigate_tree_for_evidence( + tree_nav_target, query, + max_results=self._TREE_NAV_MAX_RESULTS, + match_objects=best_files[0].get("matches"), + ) tree_nav_done.add(tree_nav_target) else: tree_task = self._async_noop(None) @@ -2818,12 +2846,15 @@ async def _rga_evidence() -> str: if _all_tables: _table_ev = self._format_table_evidence( - _all_tables, query=query, + _all_tables, + max_chars=self._TABLE_EVIDENCE_DEFAULT_CHARS, + query=query, ) if _table_ev: ev = f"[{fn} - Table Evidence]\n{_table_ev}" - # 1. Tree-guided sampling FIRST for tree-indexed files + # 1. Tree-guided sampling for tree-indexed files + # (skipped when a parallel tree_task already covers this file) _tree_cond = artifacts and fp in artifacts.tree_available_paths and fp not in tree_nav_done print(f"SEARCH_WIKI_DEBUG [D14] tree_sample: cond={_tree_cond}, in_tree_paths={fp in (artifacts.tree_available_paths if artifacts else set())}, in_nav_done={fp in tree_nav_done}", flush=True) if ( @@ -2839,7 +2870,10 @@ async def _rga_evidence() -> str: artifacts=artifacts, ) if tree_ev_inner: - ev = tree_ev_inner + if ev: + ev = ev + "\n\n" + tree_ev_inner + else: + ev = tree_ev_inner await self._logger.info( f"[FAST:Step3] Tree-guided sample for {fn} " f"({len(tree_ev_inner)} chars)" @@ -2883,6 +2917,8 @@ async def _rga_evidence() -> str: rga_ev, tree_ev = await asyncio.gather(rga_task, tree_task) # Merge: tree evidence first (highest quality), then rga + if tree_ev and rga_ev: + rga_ev = self._deduplicate_table_sections(tree_ev, rga_ev) evidence_parts_final: List[str] = [] if tree_ev: evidence_parts_final.append(tree_ev) @@ -3257,11 +3293,16 @@ async def _fast_find_best_file( keyword_idfs: Optional[Dict[str, float]] = None, query: str = "", artifacts: Optional["CompileArtifacts"] = None, + tree_probed_paths: Optional[Set[str]] = None, ) -> Optional[List[Dict[str, Any]]]: """Search per keyword via rga and return the top-k best-matching files ranked by IDF-weighted log-TF scoring, optionally enhanced with wiki-derived relevance from compile artifacts. + When *tree_probed_paths* is provided, files that were selected by + LLM-driven tree probing receive a ranking boost, ensuring the tree + probe's high-quality signal influences the final file ordering. + Args: keywords: Search keywords from FAST Step 1. paths: Search paths. @@ -3272,6 +3313,7 @@ async def _fast_find_best_file( keyword_idfs: Pre-computed IDF values for keywords. query: Original user query (used for wiki relevance scoring). artifacts: Compile artifacts for adaptive wiki-enhanced ranking. + tree_probed_paths: File paths selected by tree probing (receive boost). Returns: List of merged file dicts (path, matches, lines, total_matches, weighted_score) or None. @@ -3427,6 +3469,11 @@ async def _fast_find_best_file( + (1 - self._WIKI_BLEND_ALPHA) * wiki_score ) + if tree_probed_paths: + for f in merged: + if f["path"] in tree_probed_paths: + f["weighted_score"] += self._TREE_PROBE_RANKING_BOOST + merged.sort(key=lambda f: f["weighted_score"], reverse=True) pruned = self._prune_by_score(merged, top_k=top_k) @@ -4385,10 +4432,44 @@ def _score_table_relevance( return hits / len(query_tokens) + @staticmethod + def _deduplicate_table_sections( + primary_ev: str, secondary_ev: str, + ) -> str: + """Remove table sections from *secondary_ev* whose pages already + appear in *primary_ev*. + + Matching is based on ``[Table from page N]`` and ``[Tables pp.X-Y]`` + headers. Non-table content in *secondary_ev* is preserved intact. + """ + if not primary_ev or not secondary_ev: + return secondary_ev + + covered: Set[int] = { + int(m.group(1)) + for m in re.finditer(r"\[Table from page (\d+)\]", primary_ev) + } + for m in re.finditer(r"\[Tables pp\.(\d+)-(\d+)\]", primary_ev): + covered.update(range(int(m.group(1)), int(m.group(2)) + 1)) + + if not covered: + return secondary_ev + + blocks = secondary_ev.split("\n\n") + kept: List[str] = [] + for block in blocks: + page_m = re.search(r"\[Table from page (\d+)\]", block) + if page_m and int(page_m.group(1)) in covered: + continue + kept.append(block) + + result = "\n\n".join(kept) + return result if result.strip() else "" + @staticmethod def _format_table_evidence( tables: List[Dict[str, Any]], - max_chars: int = 10_000, + max_chars: int = 20_000, query: str = "", ) -> str: """Format table digest entries as LLM-friendly evidence text. @@ -4410,7 +4491,7 @@ def _format_table_evidence( ordered = tables if query: query_tokens = frozenset( - tok for tok in query.lower().split() if len(tok) > 2 + tok for tok in query.lower().split() if len(tok) >= 2 ) if query_tokens: scored = [ @@ -4465,7 +4546,12 @@ def _append_evidence_part( parts.append(f"{header}\n{text}") async def _navigate_tree_for_evidence( - self, file_path: str, query: str, *, max_results: int = 5, + self, + file_path: str, + query: str, + *, + max_results: int = 8, + match_objects: Optional[List[Dict[str, Any]]] = None, ) -> Optional[str]: """LLM-driven tree navigation: select relevant sections and read leaf content. @@ -4473,6 +4559,10 @@ async def _navigate_tree_for_evidence( *file_path*, returning concatenated leaf content as evidence. Returns None when no tree cache is available. + When *match_objects* (RGA hit dicts) are provided, keyword-level + context windows are appended as supplementary evidence after tree + navigation, fusing structural and keyword signals. + Extraction priority (highest first): 1. char_range – compile-time ENHANCED content slice (preserves tables) 2. page_range – page-level extraction via DocumentExtractor (fallback) @@ -4796,6 +4886,38 @@ async def _navigate_tree_for_evidence( print(f"SEARCH_WIKI_DEBUG [N5] table_supplement: tables_loaded={len(_all_tables) if _all_tables else 0}", flush=True) + # --- RGA keyword supplement: fuse keyword hits into tree evidence --- + if match_objects: + _ev_len = sum(len(p) for p in parts) + _rga_budget = max(0, self._FAST_MAX_EVIDENCE_CHARS - _ev_len) + if _rga_budget > 200: + hit_lines: List[int] = [ + m.get("data", {}).get("line_number") + for m in match_objects + if isinstance(m.get("data", {}).get("line_number"), int) + ] + ext = Path(file_path).suffix.lower() + rga_ctx: Optional[str] = None + if ext in self._FAST_TEXT_EXTENSIONS and hit_lines: + rga_ctx = self._read_context_windows( + file_path, hit_lines, + window=self._FAST_CONTEXT_WINDOW, + max_chars=_rga_budget, + ) + else: + snippet_parts: List[str] = [] + snippet_total = 0 + for m in match_objects: + text = m.get("data", {}).get("lines", {}).get("text", "").rstrip() + if text and snippet_total + len(text) < _rga_budget: + snippet_parts.append(text) + snippet_total += len(text) + if snippet_parts: + rga_ctx = "\n".join(snippet_parts) + if rga_ctx: + parts.append(f"[{fname} \u2192 keyword hits]\n{rga_ctx}") + evidence = "\n\n".join(parts) + print(f"SEARCH_WIKI_DEBUG [N6] _navigate_tree_for_evidence result: len={len(evidence) if evidence else 0}", flush=True) await self._logger.info( f"[FAST:TreeNav] Extracted {len(parts)} sections, " From cec209d9f5c57ad6b0de933793786f10d0c1ac14 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E7=8F=AD=E6=89=AC?= Date: Tue, 12 May 2026 18:12:22 +0800 Subject: [PATCH 55/56] fallback hybrid tree indexing --- src/sirchmunk/llm/prompts.py | 4 ++ src/sirchmunk/search.py | 82 +++++++++++++++++++++++++++++++++--- 2 files changed, 80 insertions(+), 6 deletions(-) diff --git a/src/sirchmunk/llm/prompts.py b/src/sirchmunk/llm/prompts.py index 909402d..89ea9e8 100644 --- a/src/sirchmunk/llm/prompts.py +++ b/src/sirchmunk/llm/prompts.py @@ -424,6 +424,8 @@ def generate_keyword_extraction_prompt(num_levels: int = 3) -> str: 3. **Style**: Keep it professional, objective, and clear. Avoid fluff. 4. **Precision**: When the query asks for a specific value, ratio, number, percentage, or yes/no determination, you MUST compute it and state the precise result. Show key calculation steps when applicable. 5. **Verify before answering**: For numerical calculations, complete ALL computation steps in the SUMMARY section FIRST. Only write the PRECISE_ANSWER tag AFTER you have verified the final result. If you discover an error during computation, use the corrected value in PRECISE_ANSWER. +6. **Rounding**: Match the precision implied by the query. If the question specifies units (e.g. "in USD millions", "in billions", "as a percentage") or expects a rounded figure, round your final result accordingly rather than reporting raw calculated values with excessive decimal places. +7. **Best-effort answering**: Always attempt to answer based on available evidence. When evidence is partial or indirect, derive the best possible answer and note any assumptions. Only mark SHOULD_ANSWER as "false" when the evidence is entirely unrelated to the query. ### Input Data - **User Input**: {user_input} @@ -465,6 +467,8 @@ def generate_keyword_extraction_prompt(num_levels: int = 3) -> str: 3. **Style**: Keep it professional, objective, and clear. Avoid fluff. 4. **Precision**: When the query asks for a specific value, ratio, number, percentage, or yes/no determination, you MUST compute it and state the precise result. Show key calculation steps when applicable. 5. **Verify before answering**: For numerical calculations, complete ALL computation steps in the SUMMARY section FIRST. Only write the PRECISE_ANSWER tag AFTER you have verified the final result. If you discover an error during computation, use the corrected value in PRECISE_ANSWER. +6. **Rounding**: Match the precision implied by the query. If the question specifies units (e.g. "in USD millions", "in billions", "as a percentage") or expects a rounded figure, round your final result accordingly rather than reporting raw calculated values with excessive decimal places. +7. **Best-effort answering**: Always attempt to answer based on available evidence. When evidence is partial or indirect, derive the best possible answer and note any assumptions. Only mark SHOULD_ANSWER as "false" when the evidence is entirely unrelated to the query. ### Document Context {document_context} diff --git a/src/sirchmunk/search.py b/src/sirchmunk/search.py index 3507577..55d25c7 100644 --- a/src/sirchmunk/search.py +++ b/src/sirchmunk/search.py @@ -2311,6 +2311,12 @@ async def _llm_chunked_summarize(self, combined: str, query: str) -> str: _NAV_COMPLEMENT_MIN_COMPONENTS: int = 2 """Minimum query decomposition components to trigger complementary navigation.""" + _NAV_PAGE_MARGIN: int = 1 + """Extra pages to extract on each side of a leaf's page_range.""" + + _NAV_REF_PAGE_MAX: int = 5 + """Maximum referenced-but-uncovered pages to extract as gap-fill.""" + # --- Table evidence budgets --- _TABLE_EVIDENCE_DEFAULT_CHARS: int = 20_000 """Default max_chars for _format_table_evidence.""" @@ -2911,10 +2917,8 @@ async def _rga_evidence() -> str: print(f"SEARCH_WIKI_DEBUG [D15] ev_source={_ev_source}, ev_len={len(ev) if ev else 0}", flush=True) return "\n\n---\n\n".join(parts) - # Launch tree navigation for the primary file alongside rga - rga_task = _rga_evidence() - - rga_ev, tree_ev = await asyncio.gather(rga_task, tree_task) + # Launch tree navigation alongside rga evidence collection. + rga_ev, tree_ev = await asyncio.gather(_rga_evidence(), tree_task) # Merge: tree evidence first (highest quality), then rga if tree_ev and rga_ev: @@ -4326,6 +4330,29 @@ def _check_leaf_coverage( missing = [c for c in components if c not in leaf_text] return covered, missing + @staticmethod + def _extract_referenced_pages(text: str) -> Set[int]: + """Extract page numbers referenced in evidence text. + + Detects cross-references like 'page 60', 'pages 45-47', 'pp. 12-15' + that hint at data-bearing pages not yet included in evidence. + """ + pages: Set[int] = set() + for m in re.finditer( + r"\b(?:pages?|pp?\.)\s*(\d+)\s*[-\u2013]\s*(\d+)", + text, re.IGNORECASE, + ): + start, end = int(m.group(1)), int(m.group(2)) + if 0 < start <= end and end - start <= 10: + pages.update(range(start, end + 1)) + for m in re.finditer( + r"\b(?:pages?|pp?\.)\s*(\d+)\b", text, re.IGNORECASE, + ): + p = int(m.group(1)) + if 0 < p <= 500: + pages.add(p) + return pages + @staticmethod def _load_compile_content( work_path: Path, file_path: str, @@ -4602,7 +4629,10 @@ async def _navigate_tree_for_evidence( if page_leaves: all_pages: set = set() for _leaf, (sp, ep) in page_leaves: - all_pages.update(range(sp, ep + 1)) + all_pages.update(range( + max(1, sp - self._NAV_PAGE_MARGIN), + ep + self._NAV_PAGE_MARGIN + 1, + )) try: page_contents = DocumentExtractor.extract_pages( file_path, sorted(all_pages), @@ -4697,7 +4727,10 @@ async def _navigate_tree_for_evidence( if page_fallback_leaves: all_fb_pages: set = set() for _lf, (sp, ep) in page_fallback_leaves: - all_fb_pages.update(range(sp, ep + 1)) + all_fb_pages.update(range( + max(1, sp - self._NAV_PAGE_MARGIN), + ep + self._NAV_PAGE_MARGIN + 1, + )) try: fb_contents = DocumentExtractor.extract_pages( file_path, sorted(all_fb_pages), @@ -4886,6 +4919,43 @@ async def _navigate_tree_for_evidence( print(f"SEARCH_WIKI_DEBUG [N5] table_supplement: tables_loaded={len(_all_tables) if _all_tables else 0}", flush=True) + # ── Phase 6: Referenced-page gap-fill ── + # Scan evidence for page cross-references (e.g. TOC entries + # pointing to financial statements) and extract any that were + # not covered by the navigated leaves. + if parts: + _covered_pages: Set[int] = set() + for leaf in leaves: + pr = getattr(leaf, "page_range", None) + if pr and len(pr) == 2 and pr[0] is not None: + _covered_pages.update(range( + max(1, pr[0] - self._NAV_PAGE_MARGIN), + pr[1] + self._NAV_PAGE_MARGIN + 1, + )) + _referenced = self._extract_referenced_pages("\n\n".join(parts)) + _gap_pages = sorted(_referenced - _covered_pages)[ + : self._NAV_REF_PAGE_MAX + ] + if _gap_pages: + try: + _gap_contents = DocumentExtractor.extract_pages( + file_path, _gap_pages, + ) + for pc in _gap_contents: + if pc.content and pc.content.strip(): + parts.append( + f"[{fname} \u2192 referenced p.{pc.page_number}]" + f"\n{pc.content}" + ) + evidence = "\n\n".join(parts) + print( + f"SEARCH_WIKI_DEBUG [N5.2] ref_page_gap_fill: " + f"pages={_gap_pages}", + flush=True, + ) + except Exception: + pass + # --- RGA keyword supplement: fuse keyword hits into tree evidence --- if match_objects: _ev_len = sum(len(p) for p in parts) From 59beaeac766a99da0aa2753b1c8878e05f195255 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E7=8F=AD=E6=89=AC?= Date: Wed, 13 May 2026 01:50:10 +0800 Subject: [PATCH 56/56] improve search pipeline for hybrid --- src/sirchmunk/llm/prompts.py | 10 ++-- src/sirchmunk/search.py | 111 +++++++++++++++++++++++++++++++++-- 2 files changed, 113 insertions(+), 8 deletions(-) diff --git a/src/sirchmunk/llm/prompts.py b/src/sirchmunk/llm/prompts.py index 89ea9e8..8c8c049 100644 --- a/src/sirchmunk/llm/prompts.py +++ b/src/sirchmunk/llm/prompts.py @@ -424,8 +424,9 @@ def generate_keyword_extraction_prompt(num_levels: int = 3) -> str: 3. **Style**: Keep it professional, objective, and clear. Avoid fluff. 4. **Precision**: When the query asks for a specific value, ratio, number, percentage, or yes/no determination, you MUST compute it and state the precise result. Show key calculation steps when applicable. 5. **Verify before answering**: For numerical calculations, complete ALL computation steps in the SUMMARY section FIRST. Only write the PRECISE_ANSWER tag AFTER you have verified the final result. If you discover an error during computation, use the corrected value in PRECISE_ANSWER. -6. **Rounding**: Match the precision implied by the query. If the question specifies units (e.g. "in USD millions", "in billions", "as a percentage") or expects a rounded figure, round your final result accordingly rather than reporting raw calculated values with excessive decimal places. -7. **Best-effort answering**: Always attempt to answer based on available evidence. When evidence is partial or indirect, derive the best possible answer and note any assumptions. Only mark SHOULD_ANSWER as "false" when the evidence is entirely unrelated to the query. +6. **Rounding**: Match the precision implied by the query. When the question asks for a value in specific units (e.g. "in USD millions"), round the final result to match the expected granularity. For percentages, use at most one decimal place unless the query explicitly asks for more. For dollar amounts, round to the nearest whole number in the stated unit. Example: if the raw calculation yields $8.738 billion and the expected unit is "USD billions", report $8.7 billion or $8.74 billion, not $8.738 billion. +7. **Best-effort answering**: Always attempt to answer based on available evidence. When the query requests a specific metric, ratio, or calculation, compute it from whatever relevant data is available — even if the data is partial. Do not refuse to calculate a metric solely because you believe it is unconventional or less applicable for a given entity type. Only mark SHOULD_ANSWER as "false" when the evidence is entirely unrelated to the query. +8. **Binary/judgment questions**: For questions expecting a Yes/No or directional answer, briefly list evidence supporting each side before stating your conclusion. Base your answer on the quantitative evidence rather than subjective assessments. ### Input Data - **User Input**: {user_input} @@ -467,8 +468,9 @@ def generate_keyword_extraction_prompt(num_levels: int = 3) -> str: 3. **Style**: Keep it professional, objective, and clear. Avoid fluff. 4. **Precision**: When the query asks for a specific value, ratio, number, percentage, or yes/no determination, you MUST compute it and state the precise result. Show key calculation steps when applicable. 5. **Verify before answering**: For numerical calculations, complete ALL computation steps in the SUMMARY section FIRST. Only write the PRECISE_ANSWER tag AFTER you have verified the final result. If you discover an error during computation, use the corrected value in PRECISE_ANSWER. -6. **Rounding**: Match the precision implied by the query. If the question specifies units (e.g. "in USD millions", "in billions", "as a percentage") or expects a rounded figure, round your final result accordingly rather than reporting raw calculated values with excessive decimal places. -7. **Best-effort answering**: Always attempt to answer based on available evidence. When evidence is partial or indirect, derive the best possible answer and note any assumptions. Only mark SHOULD_ANSWER as "false" when the evidence is entirely unrelated to the query. +6. **Rounding**: Match the precision implied by the query. When the question asks for a value in specific units (e.g. "in USD millions"), round the final result to match the expected granularity. For percentages, use at most one decimal place unless the query explicitly asks for more. For dollar amounts, round to the nearest whole number in the stated unit. Example: if the raw calculation yields $8.738 billion and the expected unit is "USD billions", report $8.7 billion or $8.74 billion, not $8.738 billion. +7. **Best-effort answering**: Always attempt to answer based on available evidence. When the query requests a specific metric, ratio, or calculation, compute it from whatever relevant data is available — even if the data is partial. Do not refuse to calculate a metric solely because you believe it is unconventional or less applicable for a given entity type. Only mark SHOULD_ANSWER as "false" when the evidence is entirely unrelated to the query. +8. **Binary/judgment questions**: For questions expecting a Yes/No or directional answer, briefly list evidence supporting each side before stating your conclusion. Base your answer on the quantitative evidence rather than subjective assessments. ### Document Context {document_context} diff --git a/src/sirchmunk/search.py b/src/sirchmunk/search.py index 55d25c7..58aa0a7 100644 --- a/src/sirchmunk/search.py +++ b/src/sirchmunk/search.py @@ -2324,6 +2324,16 @@ async def _llm_chunked_summarize(self, combined: str, query: str) -> str: """Max chars for per-page-range table supplement in tree nav.""" _TABLE_EVIDENCE_STANDALONE_CHARS: int = 20_000 """Max chars for standalone table digest fallback when tree nav evidence is thin.""" + _TABLE_CROSS_SECTION_CHARS: int = 6_000 + """Max chars for cross-section table supplement drawn from pages outside + the navigated leaf ranges. Ensures data-dense tables in distant + document sections (e.g. financial statements when leaves are in + management discussion) are included.""" + _TABLE_EVIDENCE_NAV_OVERLAP_CHARS: int = 8_000 + """Reduced table evidence budget for files that are already receiving + parallel tree navigation. Since tree_ev will provide targeted evidence, + the RGA path uses a smaller budget to supply incremental tables, + leaving room for more diverse evidence.""" # --- Self-correction expanded sampling --- _SELF_CORRECT_EXPANDED_NAV_RESULTS: int = 10 @@ -2851,9 +2861,14 @@ async def _rga_evidence() -> str: print(f"SEARCH_WIKI_DEBUG [D13] table_digest: manifest_lookup={'found' if artifacts.manifest_map and artifacts.manifest_map.get(fp) else 'miss'}, has_table_digest={getattr(artifacts.manifest_map.get(fp), 'has_table_digest', False) if artifacts.manifest_map else 'N/A'}, hash_fallback={'tried' if not _all_tables else 'skipped'}, tables_count={len(_all_tables) if _all_tables else 0}", flush=True) if _all_tables: + _td_budget = ( + self._TABLE_EVIDENCE_NAV_OVERLAP_CHARS + if fp in tree_nav_done + else self._TABLE_EVIDENCE_DEFAULT_CHARS + ) _table_ev = self._format_table_evidence( _all_tables, - max_chars=self._TABLE_EVIDENCE_DEFAULT_CHARS, + max_chars=_td_budget, query=query, ) if _table_ev: @@ -4412,6 +4427,16 @@ def _filter_tables_by_page_range( ] _TABLE_RELEVANCE_MIN_PREFIX = 5 + _TABLE_STRUCTURE_BONUS: float = 0.25 + """Bonus score for tables exhibiting structured data characteristics + (high row count, numeric density). Applied additively to the keyword + relevance score so that data-rich tables are preferred when keyword + scores tie.""" + _TABLE_STRUCTURE_MIN_ROWS: int = 5 + """Minimum ``|``-delimited rows for a table to qualify for the + structure bonus.""" + _TABLE_STRUCTURE_MIN_NUMERIC_RATIO: float = 0.15 + """Minimum ratio of numeric tokens to total tokens for the bonus.""" @staticmethod def _score_table_relevance( @@ -4459,6 +4484,39 @@ def _score_table_relevance( return hits / len(query_tokens) + @staticmethod + def _score_table_structure(markdown: str) -> float: + """Score a table's structural richness (row count + numeric density). + + Data-dense tables (financial statements, balance sheets) score + higher than narrative paragraphs that happen to contain a small + embedded table. The score is in [0, 1] and is added as a bonus + to the keyword relevance score during table ranking. + """ + if not markdown: + return 0.0 + + rows = markdown.count("\n") + if rows < AgenticSearch._TABLE_STRUCTURE_MIN_ROWS: + return 0.0 + + tokens = markdown.split() + if not tokens: + return 0.0 + + numeric_count = sum( + 1 for t in tokens + if any(c.isdigit() for c in t) + ) + numeric_ratio = numeric_count / len(tokens) + + if numeric_ratio < AgenticSearch._TABLE_STRUCTURE_MIN_NUMERIC_RATIO: + return 0.0 + + row_score = min(rows / 30.0, 1.0) + num_score = min(numeric_ratio / 0.4, 1.0) + return (row_score * 0.5 + num_score * 0.5) + @staticmethod def _deduplicate_table_sections( primary_ev: str, secondary_ev: str, @@ -4521,10 +4579,18 @@ def _format_table_evidence( tok for tok in query.lower().split() if len(tok) >= 2 ) if query_tokens: + struct_bonus = AgenticSearch._TABLE_STRUCTURE_BONUS scored = [ - (AgenticSearch._score_table_relevance( - t.get("markdown", ""), query_tokens, - ), idx, t) + ( + AgenticSearch._score_table_relevance( + t.get("markdown", ""), query_tokens, + ) + + struct_bonus * AgenticSearch._score_table_structure( + t.get("markdown", ""), + ), + idx, + t, + ) for idx, t in enumerate(tables) ] scored.sort(key=lambda x: (-x[0], x[1])) @@ -4897,6 +4963,43 @@ async def _navigate_tree_for_evidence( except Exception: pass + # ── Phase 5.5: Cross-section table supplement ── + # The leaf-scoped supplement (above) only includes tables from + # pages matching selected leaves. When leaves cluster in one + # region (e.g. management discussion), data-dense tables from + # other sections (e.g. financial statements) are missed. + # Fix: include top-ranked tables from UNCOVERED pages. + if _all_tables and leaves: + _leaf_page_set: Set[int] = set() + for _lf in leaves: + _pr = getattr(_lf, "page_range", None) + if _pr and len(_pr) == 2 and _pr[0] is not None: + _leaf_page_set.update(range( + max(1, _pr[0] - self._NAV_PAGE_MARGIN), + _pr[1] + self._NAV_PAGE_MARGIN + 1, + )) + _cross_tables = [ + t for t in _all_tables + if t.get("page_number") is not None + and t["page_number"] not in _leaf_page_set + ] + if _cross_tables: + _cross_ev = self._format_table_evidence( + _cross_tables, + max_chars=self._TABLE_CROSS_SECTION_CHARS, + query=query, + ) + if _cross_ev: + parts.append( + f"[{fname} - Cross-section Tables]\n{_cross_ev}" + ) + print( + f"SEARCH_WIKI_DEBUG [N5.3] cross_section_tables: " + f"uncovered_tables={len(_cross_tables)}, " + f"ev_len={len(_cross_ev)}", + flush=True, + ) + # Plan 3: If evidence is still too thin, add full table digest as standalone evidence = "\n\n".join(parts) if (