-
Notifications
You must be signed in to change notification settings - Fork 1.1k
Fix MPS memory pathologies on Apple Silicon #340
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
35a3720
9608d94
a69ac65
0086f8b
f9c74aa
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -12,7 +12,6 @@ | |
| import time | ||
| import warnings | ||
| from abc import ABC, abstractmethod | ||
| from collections import Counter, defaultdict | ||
| from dataclasses import dataclass, field | ||
| from pathlib import Path | ||
| from typing import Any, Literal, Optional, Union | ||
|
|
@@ -279,12 +278,7 @@ def __len__(self) -> int: | |
|
|
||
|
|
||
| class BM25Index(ABC): | ||
| """Minimal contract for a BM25-style sparse index over LEANN passages. | ||
|
|
||
| Concrete implementations today: `BM25Scorer` (in-memory, fit-on-search). | ||
| Planned: an FTS5-backed implementation that builds at index-build time | ||
| and queries memory-bounded. See the design issue for the broader plan. | ||
| """ | ||
| """Minimal contract for a BM25-style sparse index over LEANN passages.""" | ||
|
|
||
| @abstractmethod | ||
| def fit(self, documents: list[dict[str, Any]]) -> None: | ||
|
|
@@ -304,88 +298,11 @@ def search(self, query: str, top_k: int = 5) -> list["SearchResult"]: | |
| """ | ||
|
|
||
|
|
||
| class BM25Scorer(BM25Index): | ||
| def __init__(self, k1: float = 1.2, b: float = 0.75): | ||
| self.k1 = k1 | ||
| self.b = b | ||
| self.doc_freqs = None # How many docs contain each term (DF) | ||
| self.doc_lengths = {} # How long each doc is (in words) | ||
| self.word_counts = {} # How many times each word appears in each doc (TF) | ||
| self.avg_doc_length = None | ||
| self.corpus_size = None | ||
| self.idlist = set() # List of all document IDs for easier searching | ||
|
|
||
| def _tokenize(self, text: str) -> list[str]: | ||
| return re.sub(r"[^\w\s]", "", text).lower().split() | ||
|
|
||
| def fit(self, documents: list[dict[str, Any]]): | ||
| """ | ||
| Build BM25 statistics from a document corpus. | ||
| Must be called before scoring. | ||
| """ | ||
| self.corpus_size = len(documents) | ||
| self.doc_lengths = {} | ||
| self.word_counts = {} | ||
| self.idlist = set() | ||
| doc_freqs = defaultdict(int) | ||
|
|
||
| for doc_data in documents: | ||
| doc_id = doc_data["id"] | ||
| words = self._tokenize(doc_data["text"]) | ||
| doc_length = len(words) | ||
| self.doc_lengths[doc_id] = doc_length | ||
|
|
||
| unique_words = set(words) | ||
| for word in unique_words: | ||
| doc_freqs[word] += 1 | ||
| self.word_counts[doc_id] = dict(Counter(words)) | ||
| self.idlist.add(doc_id) | ||
|
|
||
| self.doc_freqs = dict(doc_freqs) | ||
| self.avg_doc_length = sum(self.doc_lengths.values()) / len(self.doc_lengths) | ||
|
|
||
| def score(self, query_words: list[str], document_id: str) -> float: | ||
| if ( | ||
| self.doc_freqs is None | ||
| or self.doc_lengths == {} | ||
| or self.word_counts == {} | ||
| or self.avg_doc_length is None | ||
| or self.corpus_size is None | ||
| ): | ||
| raise ValueError("BM25 model not fitted. Call fit() before scoring.") | ||
|
|
||
| passage_words = self.word_counts[document_id] | ||
| passage_length = sum(passage_words.values()) | ||
| score = 0.0 | ||
| for word in query_words: | ||
| if word not in self.doc_freqs: | ||
| continue | ||
| word_freq = passage_words[word] if word in passage_words else 0 | ||
| idf = np.log( | ||
| (self.corpus_size - self.doc_freqs[word] + 0.5) / (self.doc_freqs[word] + 0.5) + 1 | ||
| ) | ||
| tf = (word_freq * (self.k1 + 1)) / ( | ||
| word_freq + self.k1 * (1 - self.b + self.b * (passage_length / self.avg_doc_length)) | ||
| ) | ||
| score += idf * tf | ||
| return score | ||
|
|
||
| def search(self, query: str, top_k: int = 5) -> list[SearchResult]: | ||
| query_words = self._tokenize(query) | ||
| scores = {doc_id: self.score(query_words, doc_id) for doc_id in self.idlist} | ||
| sorted_scores = sorted(scores.items(), key=lambda x: x[1], reverse=True) | ||
| return [ | ||
| SearchResult(id=doc_id, score=score, text="", metadata={}) | ||
| for doc_id, score in sorted_scores[:top_k] | ||
| ] | ||
|
|
||
|
|
||
| class Fts5BM25Index(BM25Index): | ||
| """BM25 over a SQLite FTS5 virtual table, persisted on disk. | ||
|
|
||
| Built once at `leann build` time, queried memory-bounded at search time. | ||
| Avoids the in-memory term-frequency table that BM25Scorer keeps in RAM and | ||
| that gets re-fit on every cold start. See #327 for the broader plan. | ||
| SQLite owns the on-disk term/posting data; queries hit `bm25()` directly. | ||
| """ | ||
|
|
||
| # SQLite's FTS5 bm25() returns lower-is-better. We negate so the rest of | ||
|
|
@@ -425,9 +342,8 @@ def fit(self, documents: list[dict[str, Any]]) -> None: | |
| conn.close() | ||
|
|
||
| def search(self, query: str, top_k: int = 5) -> list["SearchResult"]: | ||
| # Match the BM25Scorer tokenization for query consistency: strip | ||
| # punctuation, lowercase, OR the terms together. Avoids FTS5 query | ||
| # syntax surprises (`:`, `*`, etc.). | ||
| # Strip punctuation, lowercase, OR the terms together. Avoids FTS5 | ||
| # query syntax surprises (`:`, `*`, etc.) for natural-language queries. | ||
| terms = re.sub(r"[^\w\s]", "", query).lower().split() | ||
| if not terms: | ||
| return [] | ||
|
|
@@ -459,16 +375,13 @@ def __init__( | |
| embedding_mode: str = "sentence-transformers", | ||
| embedding_options: Optional[dict[str, Any]] = None, | ||
| prebuild_bm25: bool = False, | ||
| bm25_backend: str = "memory", | ||
| bm25_backend: str = "fts5", | ||
| **backend_kwargs, | ||
| ): | ||
| if bm25_backend not in ("memory", "fts5"): | ||
| raise ValueError( | ||
| f"Unknown bm25_backend: {bm25_backend!r}. Expected 'memory' or 'fts5'." | ||
| ) | ||
| if bm25_backend != "fts5": | ||
| logger.warning(f"bm25_backend={bm25_backend!r} is deprecated; using 'fts5'.") | ||
| bm25_backend = "fts5" | ||
| self.bm25_backend = bm25_backend | ||
| # If user picked fts5 explicitly, treat that as opting into prebuild — | ||
| # FTS5 only makes sense as a build-time artifact. | ||
| self.prebuild_bm25 = prebuild_bm25 or bm25_backend == "fts5" | ||
| self.backend_name = backend_name | ||
| # Normalize incompatible combinations early (for consistent metadata) | ||
|
|
@@ -687,14 +600,9 @@ def build_index(self, index_path: str): | |
| meta_data["is_pruned"] = bool(is_recompute) | ||
|
|
||
| if self.prebuild_bm25: | ||
| if self.bm25_backend == "fts5": | ||
| self._build_bm25_fts5(index_dir, index_name) | ||
| meta_data["bm25_backend"] = "fts5" | ||
| meta_data["bm25_db"] = f"{index_name}.bm25.sqlite" | ||
| else: | ||
| self._build_bm25_snapshot(index_dir, index_name) | ||
| meta_data["bm25_snapshot"] = f"{index_name}.bm25.pkl" | ||
| meta_data["bm25_backend"] = "memory" | ||
| self._build_bm25_fts5(index_dir, index_name) | ||
| meta_data["bm25_backend"] = "fts5" | ||
| meta_data["bm25_db"] = f"{index_name}.bm25.sqlite" | ||
|
|
||
| with open(leann_meta_path, "w", encoding="utf-8") as f: | ||
| json.dump(meta_data, f, indent=2) | ||
|
|
@@ -712,21 +620,6 @@ def _build_bm25_fts5(self, index_dir: Path, index_name: str) -> None: | |
| index.close() | ||
| logger.info(f"Wrote BM25 FTS5 index to {db_path}") | ||
|
|
||
| def _build_bm25_snapshot(self, index_dir: Path, index_name: str) -> None: | ||
| """Fit BM25Scorer on self.chunks and pickle alongside the index. | ||
|
|
||
| Lets LeannSearcher._init_bm25 load the fitted scorer on first BM25 query | ||
| instead of re-fitting against the passage JSONL — which scans every | ||
| passage and builds the full TF table in RAM, dominating first-search | ||
| latency on larger corpora. | ||
| """ | ||
| bm25_path = index_dir / f"{index_name}.bm25.pkl" | ||
| scorer = BM25Scorer() | ||
| scorer.fit(self.chunks) | ||
| with open(bm25_path, "wb") as f: | ||
| pickle.dump(scorer, f, protocol=pickle.HIGHEST_PROTOCOL) | ||
| logger.info(f"Wrote BM25 snapshot to {bm25_path}") | ||
|
|
||
| def build_index_from_arrays(self, index_path: str, ids: list, embeddings: np.ndarray): | ||
| """Build an index from pre-computed embedding arrays. | ||
|
|
||
|
|
@@ -1559,31 +1452,41 @@ def _init_bm25(self) -> None: | |
| f"falling back to fit-on-search." | ||
| ) | ||
|
|
||
| snapshot_name = self.meta_data.get("bm25_snapshot") | ||
| if snapshot_name: | ||
| snapshot_path = meta_dir / snapshot_name | ||
| if snapshot_path.exists(): | ||
| try: | ||
| with open(snapshot_path, "rb") as f: | ||
| self.bm25_scorer = pickle.load(f) | ||
| logger.info(f"Loaded BM25 snapshot from {snapshot_path}") | ||
| return | ||
| except Exception as exc: | ||
| logger.warning( | ||
| f"Failed to load BM25 snapshot at {snapshot_path}, " | ||
| f"falling back to fit-on-search: {exc}" | ||
| ) | ||
|
|
||
| # No artifact (older indexes) or load failed: fit on the fly. | ||
| self.bm25_scorer = BM25Scorer() | ||
| # No FTS5 artifact: build one on the fly from passages. | ||
| db_path = meta_dir / (Path(self.meta_path_str).stem.replace(".meta", "") + ".bm25.sqlite") | ||
| index = Fts5BM25Index(str(db_path)) | ||
| passages = [] | ||
| for passage_file in self.passage_manager.passage_files.values(): | ||
| with open(passage_file, encoding="utf-8") as f: | ||
| for line in f: | ||
| if line.strip(): | ||
| data = json.loads(line) | ||
| passages.append(data) | ||
| self.bm25_scorer.fit(passages) | ||
| try: | ||
| with open(passage_file, encoding="utf-8") as f: | ||
| for line in f: | ||
| if line.strip(): | ||
| try: | ||
| passages.append(json.loads(line)) | ||
| except json.JSONDecodeError as exc: | ||
| logger.warning(f"Skipping malformed JSONL in {passage_file}: {exc}") | ||
| except FileNotFoundError: | ||
| logger.warning(f"Passage file missing: {passage_file}") | ||
|
|
||
| if not passages: | ||
| logger.error( | ||
| "No passages found for on-demand BM25 index. " | ||
| "BM25/hybrid search will return empty results. " | ||
| "Re-run 'leann build' to regenerate passage files." | ||
| ) | ||
| return | ||
|
|
||
| try: | ||
| index.fit(passages) | ||
| except (PermissionError, OSError) as exc: | ||
| logger.error( | ||
| f"Cannot write BM25 index to {db_path}: {exc}. " | ||
| f"Ensure the index directory is writable, or rebuild with prebuild_bm25=True." | ||
| ) | ||
| return | ||
|
|
||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nit: When
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Confirmed safe — |
||
| self.bm25_scorer = index | ||
| logger.info(f"Built FTS5 BM25 index on-demand at {db_path}") | ||
|
|
||
| def _bm25_search(self, query: str, top_k: int = 5) -> list[SearchResult]: | ||
| """Perform BM25 search on raw passages""" | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: After logging the error, this falls through to
index.fit([])which creates an empty FTS5 table. Consider adding areturnhere to bail out early — an empty index silently returns no results on search, which is harder to debug than skipping BM25 init entirely.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good catch — fixed in f9c74aa. The empty-passages path now
returns early, leavingself.bm25_scorer = None._bm25_search(line 1496-1497) checks for that and raisesRuntimeError("BM25 scorer failed to initialize"), so it surfaces loudly instead of returning empty hits.