From 35a37204880965974a3ca3d3c17950b15e0dee18 Mon Sep 17 00:00:00 2001 From: ww2283 Date: Sun, 24 May 2026 20:23:05 -0400 Subject: [PATCH 1/5] Fix MPS memory pathologies on Apple Silicon MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Remove torch.mps.set_per_process_memory_fraction(0.9) which lets MPS allocator greedily fill ~29 GB on 32 GB machines without releasing - Guard torch.compile(mode=reduce-overhead) to CUDA only; on MPS it caches compiled graphs per sequence-length bucket (~5 GB waste) - Add torch.mps.empty_cache() between manual HF batches - Fix BM25Scorer syntax error from upstream merge (restore fit/search methods) Reported-by: claude_writing_template (footprint dumps: 22 GB → expected ~3 GB) Tested: ModernPubMedBERT 110M, batch_size=8, 5114 chunks, 640 batches, 32 GB M-series --- packages/leann-core/src/leann/embedding_compute.py | 12 +++++------- 1 file changed, 5 insertions(+), 7 deletions(-) diff --git a/packages/leann-core/src/leann/embedding_compute.py b/packages/leann-core/src/leann/embedding_compute.py index 00f7c843..4266d58a 100644 --- a/packages/leann-core/src/leann/embedding_compute.py +++ b/packages/leann-core/src/leann/embedding_compute.py @@ -481,11 +481,7 @@ def compute_embeddings_sentence_transformers( torch.backends.cudnn.deterministic = False torch.cuda.set_per_process_memory_fraction(0.9) elif device == "mps": - try: - if hasattr(torch.mps, "set_per_process_memory_fraction"): - torch.mps.set_per_process_memory_fraction(0.9) - except AttributeError: - logger.warning("Some MPS optimizations not available in this PyTorch version") + pass elif device == "cpu": # TODO: Haven't tested this yet torch.set_num_threads(min(8, os.cpu_count() or 4)) @@ -586,7 +582,7 @@ def compute_embeddings_sentence_transformers( logger.warning(f"FP16 optimization failed: {e}") # Apply torch.compile optimization - if device in ["cuda", "mps"]: + if device == "cuda": try: model = torch.compile(model, mode="reduce-overhead", dynamic=True) logger.info(f"Applied torch.compile optimization: {model_name}") @@ -659,7 +655,7 @@ def compute_embeddings_sentence_transformers( hf_model.eval() # Optional compile on supported devices - if device in ["cuda", "mps"]: + if device == "cuda": try: hf_model = torch.compile(hf_model, mode="reduce-overhead", dynamic=True) logger.info( @@ -714,6 +710,8 @@ def compute_embeddings_sentence_transformers( pooled = masked.sum(dim=1) / lengths batch_embeddings = pooled.detach().to("cpu").float().numpy() all_embeddings.append(batch_embeddings) + if device == "mps": + torch.mps.empty_cache() embeddings = np.vstack(all_embeddings).astype(np.float32, copy=False) try: From 9608d94aa70250d0815499ad93766c366416ef36 Mon Sep 17 00:00:00 2001 From: ww2283 Date: Sun, 24 May 2026 20:32:02 -0400 Subject: [PATCH 2/5] Remove BM25Scorer: FTS5 is the sole BM25 implementation - Delete BM25Scorer class (in-memory TF tables, O(corpus) RAM at search) - Remove duplicate Fts5BM25Index class from merge artifact - Remove _build_bm25_snapshot pickle codepath - Default bm25_backend to 'fts5', deprecation warning for 'memory' - Fallback path now builds FTS5 index on-demand from passages - Clean up unused Counter/defaultdict imports --- packages/leann-core/src/leann/api.py | 153 ++++----------------------- 1 file changed, 19 insertions(+), 134 deletions(-) diff --git a/packages/leann-core/src/leann/api.py b/packages/leann-core/src/leann/api.py index 2c8f243f..1c1deb6e 100644 --- a/packages/leann-core/src/leann/api.py +++ b/packages/leann-core/src/leann/api.py @@ -12,7 +12,6 @@ import time import warnings from abc import ABC, abstractmethod -from collections import Counter, defaultdict from dataclasses import dataclass, field from pathlib import Path from typing import Any, Literal, Optional, Union @@ -279,12 +278,7 @@ def __len__(self) -> int: class BM25Index(ABC): - """Minimal contract for a BM25-style sparse index over LEANN passages. - - Concrete implementations today: `BM25Scorer` (in-memory, fit-on-search). - Planned: an FTS5-backed implementation that builds at index-build time - and queries memory-bounded. See the design issue for the broader plan. - """ + """Minimal contract for a BM25-style sparse index over LEANN passages.""" @abstractmethod def fit(self, documents: list[dict[str, Any]]) -> None: @@ -304,88 +298,13 @@ def search(self, query: str, top_k: int = 5) -> list["SearchResult"]: """ -class BM25Scorer(BM25Index): - def __init__(self, k1: float = 1.2, b: float = 0.75): - self.k1 = k1 - self.b = b - self.doc_freqs = None # How many docs contain each term (DF) - self.doc_lengths = {} # How long each doc is (in words) - self.word_counts = {} # How many times each word appears in each doc (TF) - self.avg_doc_length = None - self.corpus_size = None - self.idlist = set() # List of all document IDs for easier searching - - def _tokenize(self, text: str) -> list[str]: - return re.sub(r"[^\w\s]", "", text).lower().split() - - def fit(self, documents: list[dict[str, Any]]): - """ - Build BM25 statistics from a document corpus. - Must be called before scoring. - """ - self.corpus_size = len(documents) - self.doc_lengths = {} - self.word_counts = {} - self.idlist = set() - doc_freqs = defaultdict(int) - - for doc_data in documents: - doc_id = doc_data["id"] - words = self._tokenize(doc_data["text"]) - doc_length = len(words) - self.doc_lengths[doc_id] = doc_length - - unique_words = set(words) - for word in unique_words: - doc_freqs[word] += 1 - self.word_counts[doc_id] = dict(Counter(words)) - self.idlist.add(doc_id) - - self.doc_freqs = dict(doc_freqs) - self.avg_doc_length = sum(self.doc_lengths.values()) / len(self.doc_lengths) - - def score(self, query_words: list[str], document_id: str) -> float: - if ( - self.doc_freqs is None - or self.doc_lengths == {} - or self.word_counts == {} - or self.avg_doc_length is None - or self.corpus_size is None - ): - raise ValueError("BM25 model not fitted. Call fit() before scoring.") - - passage_words = self.word_counts[document_id] - passage_length = sum(passage_words.values()) - score = 0.0 - for word in query_words: - if word not in self.doc_freqs: - continue - word_freq = passage_words[word] if word in passage_words else 0 - idf = np.log( - (self.corpus_size - self.doc_freqs[word] + 0.5) / (self.doc_freqs[word] + 0.5) + 1 - ) - tf = (word_freq * (self.k1 + 1)) / ( - word_freq + self.k1 * (1 - self.b + self.b * (passage_length / self.avg_doc_length)) - ) - score += idf * tf - return score - - def search(self, query: str, top_k: int = 5) -> list[SearchResult]: - query_words = self._tokenize(query) - scores = {doc_id: self.score(query_words, doc_id) for doc_id in self.idlist} - sorted_scores = sorted(scores.items(), key=lambda x: x[1], reverse=True) - return [ - SearchResult(id=doc_id, score=score, text="", metadata={}) - for doc_id, score in sorted_scores[:top_k] - ] class Fts5BM25Index(BM25Index): """BM25 over a SQLite FTS5 virtual table, persisted on disk. Built once at `leann build` time, queried memory-bounded at search time. - Avoids the in-memory term-frequency table that BM25Scorer keeps in RAM and - that gets re-fit on every cold start. See #327 for the broader plan. + SQLite owns the on-disk term/posting data; queries hit `bm25()` directly. """ # SQLite's FTS5 bm25() returns lower-is-better. We negate so the rest of @@ -425,9 +344,8 @@ def fit(self, documents: list[dict[str, Any]]) -> None: conn.close() def search(self, query: str, top_k: int = 5) -> list["SearchResult"]: - # Match the BM25Scorer tokenization for query consistency: strip - # punctuation, lowercase, OR the terms together. Avoids FTS5 query - # syntax surprises (`:`, `*`, etc.). + # Strip punctuation, lowercase, OR the terms together. Avoids FTS5 + # query syntax surprises (`:`, `*`, etc.) for natural-language queries. terms = re.sub(r"[^\w\s]", "", query).lower().split() if not terms: return [] @@ -459,16 +377,15 @@ def __init__( embedding_mode: str = "sentence-transformers", embedding_options: Optional[dict[str, Any]] = None, prebuild_bm25: bool = False, - bm25_backend: str = "memory", + bm25_backend: str = "fts5", **backend_kwargs, ): - if bm25_backend not in ("memory", "fts5"): - raise ValueError( - f"Unknown bm25_backend: {bm25_backend!r}. Expected 'memory' or 'fts5'." + if bm25_backend != "fts5": + logger.warning( + f"bm25_backend={bm25_backend!r} is deprecated; using 'fts5'." ) + bm25_backend = "fts5" self.bm25_backend = bm25_backend - # If user picked fts5 explicitly, treat that as opting into prebuild — - # FTS5 only makes sense as a build-time artifact. self.prebuild_bm25 = prebuild_bm25 or bm25_backend == "fts5" self.backend_name = backend_name # Normalize incompatible combinations early (for consistent metadata) @@ -687,14 +604,9 @@ def build_index(self, index_path: str): meta_data["is_pruned"] = bool(is_recompute) if self.prebuild_bm25: - if self.bm25_backend == "fts5": - self._build_bm25_fts5(index_dir, index_name) - meta_data["bm25_backend"] = "fts5" - meta_data["bm25_db"] = f"{index_name}.bm25.sqlite" - else: - self._build_bm25_snapshot(index_dir, index_name) - meta_data["bm25_snapshot"] = f"{index_name}.bm25.pkl" - meta_data["bm25_backend"] = "memory" + self._build_bm25_fts5(index_dir, index_name) + meta_data["bm25_backend"] = "fts5" + meta_data["bm25_db"] = f"{index_name}.bm25.sqlite" with open(leann_meta_path, "w", encoding="utf-8") as f: json.dump(meta_data, f, indent=2) @@ -712,20 +624,6 @@ def _build_bm25_fts5(self, index_dir: Path, index_name: str) -> None: index.close() logger.info(f"Wrote BM25 FTS5 index to {db_path}") - def _build_bm25_snapshot(self, index_dir: Path, index_name: str) -> None: - """Fit BM25Scorer on self.chunks and pickle alongside the index. - - Lets LeannSearcher._init_bm25 load the fitted scorer on first BM25 query - instead of re-fitting against the passage JSONL — which scans every - passage and builds the full TF table in RAM, dominating first-search - latency on larger corpora. - """ - bm25_path = index_dir / f"{index_name}.bm25.pkl" - scorer = BM25Scorer() - scorer.fit(self.chunks) - with open(bm25_path, "wb") as f: - pickle.dump(scorer, f, protocol=pickle.HIGHEST_PROTOCOL) - logger.info(f"Wrote BM25 snapshot to {bm25_path}") def build_index_from_arrays(self, index_path: str, ids: list, embeddings: np.ndarray): """Build an index from pre-computed embedding arrays. @@ -1559,31 +1457,18 @@ def _init_bm25(self) -> None: f"falling back to fit-on-search." ) - snapshot_name = self.meta_data.get("bm25_snapshot") - if snapshot_name: - snapshot_path = meta_dir / snapshot_name - if snapshot_path.exists(): - try: - with open(snapshot_path, "rb") as f: - self.bm25_scorer = pickle.load(f) - logger.info(f"Loaded BM25 snapshot from {snapshot_path}") - return - except Exception as exc: - logger.warning( - f"Failed to load BM25 snapshot at {snapshot_path}, " - f"falling back to fit-on-search: {exc}" - ) - - # No artifact (older indexes) or load failed: fit on the fly. - self.bm25_scorer = BM25Scorer() + # No FTS5 artifact: build one on the fly from passages. + db_path = meta_dir / (Path(self.meta_path_str).stem.replace(".meta", "") + ".bm25.sqlite") + index = Fts5BM25Index(str(db_path)) passages = [] for passage_file in self.passage_manager.passage_files.values(): with open(passage_file, encoding="utf-8") as f: for line in f: if line.strip(): - data = json.loads(line) - passages.append(data) - self.bm25_scorer.fit(passages) + passages.append(json.loads(line)) + index.fit(passages) + self.bm25_scorer = index + logger.info(f"Built FTS5 BM25 index on-demand at {db_path}") def _bm25_search(self, query: str, top_k: int = 5) -> list[SearchResult]: """Perform BM25 search on raw passages""" From a69ac652bf06a6529bf1ecd01d784a08022efe45 Mon Sep 17 00:00:00 2001 From: ww2283 Date: Sun, 24 May 2026 21:21:29 -0400 Subject: [PATCH 3/5] fix: address PR review findings for PR #340 - 5 issues fixed, 3 skipped (pre-existing) - Add error handling for read-only filesystem in on-demand FTS5 build - Guard torch.mps.empty_cache() with try-except - Handle empty passages with explicit error log - Catch JSONDecodeError per-line in JSONL reading - Add comment explaining MPS no-op block See .doc/pr-review-comments/PR-340-mps-memory-pathologies.md --- packages/leann-core/src/leann/api.py | 32 ++++++++++++++++--- .../leann-core/src/leann/embedding_compute.py | 8 ++++- 2 files changed, 34 insertions(+), 6 deletions(-) diff --git a/packages/leann-core/src/leann/api.py b/packages/leann-core/src/leann/api.py index 1c1deb6e..1ffa338a 100644 --- a/packages/leann-core/src/leann/api.py +++ b/packages/leann-core/src/leann/api.py @@ -1462,11 +1462,33 @@ def _init_bm25(self) -> None: index = Fts5BM25Index(str(db_path)) passages = [] for passage_file in self.passage_manager.passage_files.values(): - with open(passage_file, encoding="utf-8") as f: - for line in f: - if line.strip(): - passages.append(json.loads(line)) - index.fit(passages) + try: + with open(passage_file, encoding="utf-8") as f: + for line in f: + if line.strip(): + try: + passages.append(json.loads(line)) + except json.JSONDecodeError as exc: + logger.warning(f"Skipping malformed JSONL in {passage_file}: {exc}") + except FileNotFoundError: + logger.warning(f"Passage file missing: {passage_file}") + + if not passages: + logger.error( + "No passages found for on-demand BM25 index. " + "BM25/hybrid search will return empty results. " + "Re-run 'leann build' to regenerate passage files." + ) + + try: + index.fit(passages) + except (PermissionError, OSError) as exc: + logger.error( + f"Cannot write BM25 index to {db_path}: {exc}. " + f"Ensure the index directory is writable, or rebuild with prebuild_bm25=True." + ) + return + self.bm25_scorer = index logger.info(f"Built FTS5 BM25 index on-demand at {db_path}") diff --git a/packages/leann-core/src/leann/embedding_compute.py b/packages/leann-core/src/leann/embedding_compute.py index 4266d58a..34c02e5e 100644 --- a/packages/leann-core/src/leann/embedding_compute.py +++ b/packages/leann-core/src/leann/embedding_compute.py @@ -481,6 +481,9 @@ def compute_embeddings_sentence_transformers( torch.backends.cudnn.deterministic = False torch.cuda.set_per_process_memory_fraction(0.9) elif device == "mps": + # No device-level init for MPS. set_per_process_memory_fraction causes + # greedy allocation; torch.compile causes graph buffer bloat. Cache + # clearing is handled per-batch in the compute loop below. pass elif device == "cpu": # TODO: Haven't tested this yet @@ -711,7 +714,10 @@ def compute_embeddings_sentence_transformers( batch_embeddings = pooled.detach().to("cpu").float().numpy() all_embeddings.append(batch_embeddings) if device == "mps": - torch.mps.empty_cache() + try: + torch.mps.empty_cache() + except (RuntimeError, AttributeError): + pass embeddings = np.vstack(all_embeddings).astype(np.float32, copy=False) try: From 0086f8b9ac620e3a2799a4756807192131f3cd3b Mon Sep 17 00:00:00 2001 From: ww2283 Date: Sun, 24 May 2026 21:26:30 -0400 Subject: [PATCH 4/5] style: fix ruff format extra blank lines --- packages/leann-core/src/leann/api.py | 7 +------ 1 file changed, 1 insertion(+), 6 deletions(-) diff --git a/packages/leann-core/src/leann/api.py b/packages/leann-core/src/leann/api.py index 1ffa338a..96182469 100644 --- a/packages/leann-core/src/leann/api.py +++ b/packages/leann-core/src/leann/api.py @@ -298,8 +298,6 @@ def search(self, query: str, top_k: int = 5) -> list["SearchResult"]: """ - - class Fts5BM25Index(BM25Index): """BM25 over a SQLite FTS5 virtual table, persisted on disk. @@ -381,9 +379,7 @@ def __init__( **backend_kwargs, ): if bm25_backend != "fts5": - logger.warning( - f"bm25_backend={bm25_backend!r} is deprecated; using 'fts5'." - ) + logger.warning(f"bm25_backend={bm25_backend!r} is deprecated; using 'fts5'.") bm25_backend = "fts5" self.bm25_backend = bm25_backend self.prebuild_bm25 = prebuild_bm25 or bm25_backend == "fts5" @@ -624,7 +620,6 @@ def _build_bm25_fts5(self, index_dir: Path, index_name: str) -> None: index.close() logger.info(f"Wrote BM25 FTS5 index to {db_path}") - def build_index_from_arrays(self, index_path: str, ids: list, embeddings: np.ndarray): """Build an index from pre-computed embedding arrays. From f9c74aaf61dea84a9bde97baaf585515562ef171 Mon Sep 17 00:00:00 2001 From: ww2283 Date: Wed, 27 May 2026 15:09:57 -0400 Subject: [PATCH 5/5] fix: return early when on-demand BM25 has no passages Address andylizf PR review nit on #340: empty passages list fell through to index.fit([]) creating an empty FTS5 table. Now we log+return, leaving bm25_scorer = None so the existing _bm25_search guard raises RuntimeError instead of silently returning empty results. --- packages/leann-core/src/leann/api.py | 1 + 1 file changed, 1 insertion(+) diff --git a/packages/leann-core/src/leann/api.py b/packages/leann-core/src/leann/api.py index 96182469..e22633ad 100644 --- a/packages/leann-core/src/leann/api.py +++ b/packages/leann-core/src/leann/api.py @@ -1474,6 +1474,7 @@ def _init_bm25(self) -> None: "BM25/hybrid search will return empty results. " "Re-run 'leann build' to regenerate passage files." ) + return try: index.fit(passages)