diff --git a/packages/leann-core/src/leann/api.py b/packages/leann-core/src/leann/api.py index 2c8f243f..e22633ad 100644 --- a/packages/leann-core/src/leann/api.py +++ b/packages/leann-core/src/leann/api.py @@ -12,7 +12,6 @@ import time import warnings from abc import ABC, abstractmethod -from collections import Counter, defaultdict from dataclasses import dataclass, field from pathlib import Path from typing import Any, Literal, Optional, Union @@ -279,12 +278,7 @@ def __len__(self) -> int: class BM25Index(ABC): - """Minimal contract for a BM25-style sparse index over LEANN passages. - - Concrete implementations today: `BM25Scorer` (in-memory, fit-on-search). - Planned: an FTS5-backed implementation that builds at index-build time - and queries memory-bounded. See the design issue for the broader plan. - """ + """Minimal contract for a BM25-style sparse index over LEANN passages.""" @abstractmethod def fit(self, documents: list[dict[str, Any]]) -> None: @@ -304,88 +298,11 @@ def search(self, query: str, top_k: int = 5) -> list["SearchResult"]: """ -class BM25Scorer(BM25Index): - def __init__(self, k1: float = 1.2, b: float = 0.75): - self.k1 = k1 - self.b = b - self.doc_freqs = None # How many docs contain each term (DF) - self.doc_lengths = {} # How long each doc is (in words) - self.word_counts = {} # How many times each word appears in each doc (TF) - self.avg_doc_length = None - self.corpus_size = None - self.idlist = set() # List of all document IDs for easier searching - - def _tokenize(self, text: str) -> list[str]: - return re.sub(r"[^\w\s]", "", text).lower().split() - - def fit(self, documents: list[dict[str, Any]]): - """ - Build BM25 statistics from a document corpus. - Must be called before scoring. - """ - self.corpus_size = len(documents) - self.doc_lengths = {} - self.word_counts = {} - self.idlist = set() - doc_freqs = defaultdict(int) - - for doc_data in documents: - doc_id = doc_data["id"] - words = self._tokenize(doc_data["text"]) - doc_length = len(words) - self.doc_lengths[doc_id] = doc_length - - unique_words = set(words) - for word in unique_words: - doc_freqs[word] += 1 - self.word_counts[doc_id] = dict(Counter(words)) - self.idlist.add(doc_id) - - self.doc_freqs = dict(doc_freqs) - self.avg_doc_length = sum(self.doc_lengths.values()) / len(self.doc_lengths) - - def score(self, query_words: list[str], document_id: str) -> float: - if ( - self.doc_freqs is None - or self.doc_lengths == {} - or self.word_counts == {} - or self.avg_doc_length is None - or self.corpus_size is None - ): - raise ValueError("BM25 model not fitted. Call fit() before scoring.") - - passage_words = self.word_counts[document_id] - passage_length = sum(passage_words.values()) - score = 0.0 - for word in query_words: - if word not in self.doc_freqs: - continue - word_freq = passage_words[word] if word in passage_words else 0 - idf = np.log( - (self.corpus_size - self.doc_freqs[word] + 0.5) / (self.doc_freqs[word] + 0.5) + 1 - ) - tf = (word_freq * (self.k1 + 1)) / ( - word_freq + self.k1 * (1 - self.b + self.b * (passage_length / self.avg_doc_length)) - ) - score += idf * tf - return score - - def search(self, query: str, top_k: int = 5) -> list[SearchResult]: - query_words = self._tokenize(query) - scores = {doc_id: self.score(query_words, doc_id) for doc_id in self.idlist} - sorted_scores = sorted(scores.items(), key=lambda x: x[1], reverse=True) - return [ - SearchResult(id=doc_id, score=score, text="", metadata={}) - for doc_id, score in sorted_scores[:top_k] - ] - - class Fts5BM25Index(BM25Index): """BM25 over a SQLite FTS5 virtual table, persisted on disk. Built once at `leann build` time, queried memory-bounded at search time. - Avoids the in-memory term-frequency table that BM25Scorer keeps in RAM and - that gets re-fit on every cold start. See #327 for the broader plan. + SQLite owns the on-disk term/posting data; queries hit `bm25()` directly. """ # SQLite's FTS5 bm25() returns lower-is-better. We negate so the rest of @@ -425,9 +342,8 @@ def fit(self, documents: list[dict[str, Any]]) -> None: conn.close() def search(self, query: str, top_k: int = 5) -> list["SearchResult"]: - # Match the BM25Scorer tokenization for query consistency: strip - # punctuation, lowercase, OR the terms together. Avoids FTS5 query - # syntax surprises (`:`, `*`, etc.). + # Strip punctuation, lowercase, OR the terms together. Avoids FTS5 + # query syntax surprises (`:`, `*`, etc.) for natural-language queries. terms = re.sub(r"[^\w\s]", "", query).lower().split() if not terms: return [] @@ -459,16 +375,13 @@ def __init__( embedding_mode: str = "sentence-transformers", embedding_options: Optional[dict[str, Any]] = None, prebuild_bm25: bool = False, - bm25_backend: str = "memory", + bm25_backend: str = "fts5", **backend_kwargs, ): - if bm25_backend not in ("memory", "fts5"): - raise ValueError( - f"Unknown bm25_backend: {bm25_backend!r}. Expected 'memory' or 'fts5'." - ) + if bm25_backend != "fts5": + logger.warning(f"bm25_backend={bm25_backend!r} is deprecated; using 'fts5'.") + bm25_backend = "fts5" self.bm25_backend = bm25_backend - # If user picked fts5 explicitly, treat that as opting into prebuild — - # FTS5 only makes sense as a build-time artifact. self.prebuild_bm25 = prebuild_bm25 or bm25_backend == "fts5" self.backend_name = backend_name # Normalize incompatible combinations early (for consistent metadata) @@ -687,14 +600,9 @@ def build_index(self, index_path: str): meta_data["is_pruned"] = bool(is_recompute) if self.prebuild_bm25: - if self.bm25_backend == "fts5": - self._build_bm25_fts5(index_dir, index_name) - meta_data["bm25_backend"] = "fts5" - meta_data["bm25_db"] = f"{index_name}.bm25.sqlite" - else: - self._build_bm25_snapshot(index_dir, index_name) - meta_data["bm25_snapshot"] = f"{index_name}.bm25.pkl" - meta_data["bm25_backend"] = "memory" + self._build_bm25_fts5(index_dir, index_name) + meta_data["bm25_backend"] = "fts5" + meta_data["bm25_db"] = f"{index_name}.bm25.sqlite" with open(leann_meta_path, "w", encoding="utf-8") as f: json.dump(meta_data, f, indent=2) @@ -712,21 +620,6 @@ def _build_bm25_fts5(self, index_dir: Path, index_name: str) -> None: index.close() logger.info(f"Wrote BM25 FTS5 index to {db_path}") - def _build_bm25_snapshot(self, index_dir: Path, index_name: str) -> None: - """Fit BM25Scorer on self.chunks and pickle alongside the index. - - Lets LeannSearcher._init_bm25 load the fitted scorer on first BM25 query - instead of re-fitting against the passage JSONL — which scans every - passage and builds the full TF table in RAM, dominating first-search - latency on larger corpora. - """ - bm25_path = index_dir / f"{index_name}.bm25.pkl" - scorer = BM25Scorer() - scorer.fit(self.chunks) - with open(bm25_path, "wb") as f: - pickle.dump(scorer, f, protocol=pickle.HIGHEST_PROTOCOL) - logger.info(f"Wrote BM25 snapshot to {bm25_path}") - def build_index_from_arrays(self, index_path: str, ids: list, embeddings: np.ndarray): """Build an index from pre-computed embedding arrays. @@ -1559,31 +1452,41 @@ def _init_bm25(self) -> None: f"falling back to fit-on-search." ) - snapshot_name = self.meta_data.get("bm25_snapshot") - if snapshot_name: - snapshot_path = meta_dir / snapshot_name - if snapshot_path.exists(): - try: - with open(snapshot_path, "rb") as f: - self.bm25_scorer = pickle.load(f) - logger.info(f"Loaded BM25 snapshot from {snapshot_path}") - return - except Exception as exc: - logger.warning( - f"Failed to load BM25 snapshot at {snapshot_path}, " - f"falling back to fit-on-search: {exc}" - ) - - # No artifact (older indexes) or load failed: fit on the fly. - self.bm25_scorer = BM25Scorer() + # No FTS5 artifact: build one on the fly from passages. + db_path = meta_dir / (Path(self.meta_path_str).stem.replace(".meta", "") + ".bm25.sqlite") + index = Fts5BM25Index(str(db_path)) passages = [] for passage_file in self.passage_manager.passage_files.values(): - with open(passage_file, encoding="utf-8") as f: - for line in f: - if line.strip(): - data = json.loads(line) - passages.append(data) - self.bm25_scorer.fit(passages) + try: + with open(passage_file, encoding="utf-8") as f: + for line in f: + if line.strip(): + try: + passages.append(json.loads(line)) + except json.JSONDecodeError as exc: + logger.warning(f"Skipping malformed JSONL in {passage_file}: {exc}") + except FileNotFoundError: + logger.warning(f"Passage file missing: {passage_file}") + + if not passages: + logger.error( + "No passages found for on-demand BM25 index. " + "BM25/hybrid search will return empty results. " + "Re-run 'leann build' to regenerate passage files." + ) + return + + try: + index.fit(passages) + except (PermissionError, OSError) as exc: + logger.error( + f"Cannot write BM25 index to {db_path}: {exc}. " + f"Ensure the index directory is writable, or rebuild with prebuild_bm25=True." + ) + return + + self.bm25_scorer = index + logger.info(f"Built FTS5 BM25 index on-demand at {db_path}") def _bm25_search(self, query: str, top_k: int = 5) -> list[SearchResult]: """Perform BM25 search on raw passages""" diff --git a/packages/leann-core/src/leann/embedding_compute.py b/packages/leann-core/src/leann/embedding_compute.py index 00f7c843..34c02e5e 100644 --- a/packages/leann-core/src/leann/embedding_compute.py +++ b/packages/leann-core/src/leann/embedding_compute.py @@ -481,11 +481,10 @@ def compute_embeddings_sentence_transformers( torch.backends.cudnn.deterministic = False torch.cuda.set_per_process_memory_fraction(0.9) elif device == "mps": - try: - if hasattr(torch.mps, "set_per_process_memory_fraction"): - torch.mps.set_per_process_memory_fraction(0.9) - except AttributeError: - logger.warning("Some MPS optimizations not available in this PyTorch version") + # No device-level init for MPS. set_per_process_memory_fraction causes + # greedy allocation; torch.compile causes graph buffer bloat. Cache + # clearing is handled per-batch in the compute loop below. + pass elif device == "cpu": # TODO: Haven't tested this yet torch.set_num_threads(min(8, os.cpu_count() or 4)) @@ -586,7 +585,7 @@ def compute_embeddings_sentence_transformers( logger.warning(f"FP16 optimization failed: {e}") # Apply torch.compile optimization - if device in ["cuda", "mps"]: + if device == "cuda": try: model = torch.compile(model, mode="reduce-overhead", dynamic=True) logger.info(f"Applied torch.compile optimization: {model_name}") @@ -659,7 +658,7 @@ def compute_embeddings_sentence_transformers( hf_model.eval() # Optional compile on supported devices - if device in ["cuda", "mps"]: + if device == "cuda": try: hf_model = torch.compile(hf_model, mode="reduce-overhead", dynamic=True) logger.info( @@ -714,6 +713,11 @@ def compute_embeddings_sentence_transformers( pooled = masked.sum(dim=1) / lengths batch_embeddings = pooled.detach().to("cpu").float().numpy() all_embeddings.append(batch_embeddings) + if device == "mps": + try: + torch.mps.empty_cache() + except (RuntimeError, AttributeError): + pass embeddings = np.vstack(all_embeddings).astype(np.float32, copy=False) try: