Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
185 changes: 44 additions & 141 deletions packages/leann-core/src/leann/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
import time
import warnings
from abc import ABC, abstractmethod
from collections import Counter, defaultdict
from dataclasses import dataclass, field
from pathlib import Path
from typing import Any, Literal, Optional, Union
Expand Down Expand Up @@ -279,12 +278,7 @@ def __len__(self) -> int:


class BM25Index(ABC):
"""Minimal contract for a BM25-style sparse index over LEANN passages.

Concrete implementations today: `BM25Scorer` (in-memory, fit-on-search).
Planned: an FTS5-backed implementation that builds at index-build time
and queries memory-bounded. See the design issue for the broader plan.
"""
"""Minimal contract for a BM25-style sparse index over LEANN passages."""

@abstractmethod
def fit(self, documents: list[dict[str, Any]]) -> None:
Expand All @@ -304,88 +298,11 @@ def search(self, query: str, top_k: int = 5) -> list["SearchResult"]:
"""


class BM25Scorer(BM25Index):
def __init__(self, k1: float = 1.2, b: float = 0.75):
self.k1 = k1
self.b = b
self.doc_freqs = None # How many docs contain each term (DF)
self.doc_lengths = {} # How long each doc is (in words)
self.word_counts = {} # How many times each word appears in each doc (TF)
self.avg_doc_length = None
self.corpus_size = None
self.idlist = set() # List of all document IDs for easier searching

def _tokenize(self, text: str) -> list[str]:
return re.sub(r"[^\w\s]", "", text).lower().split()

def fit(self, documents: list[dict[str, Any]]):
"""
Build BM25 statistics from a document corpus.
Must be called before scoring.
"""
self.corpus_size = len(documents)
self.doc_lengths = {}
self.word_counts = {}
self.idlist = set()
doc_freqs = defaultdict(int)

for doc_data in documents:
doc_id = doc_data["id"]
words = self._tokenize(doc_data["text"])
doc_length = len(words)
self.doc_lengths[doc_id] = doc_length

unique_words = set(words)
for word in unique_words:
doc_freqs[word] += 1
self.word_counts[doc_id] = dict(Counter(words))
self.idlist.add(doc_id)

self.doc_freqs = dict(doc_freqs)
self.avg_doc_length = sum(self.doc_lengths.values()) / len(self.doc_lengths)

def score(self, query_words: list[str], document_id: str) -> float:
if (
self.doc_freqs is None
or self.doc_lengths == {}
or self.word_counts == {}
or self.avg_doc_length is None
or self.corpus_size is None
):
raise ValueError("BM25 model not fitted. Call fit() before scoring.")

passage_words = self.word_counts[document_id]
passage_length = sum(passage_words.values())
score = 0.0
for word in query_words:
if word not in self.doc_freqs:
continue
word_freq = passage_words[word] if word in passage_words else 0
idf = np.log(
(self.corpus_size - self.doc_freqs[word] + 0.5) / (self.doc_freqs[word] + 0.5) + 1
)
tf = (word_freq * (self.k1 + 1)) / (
word_freq + self.k1 * (1 - self.b + self.b * (passage_length / self.avg_doc_length))
)
score += idf * tf
return score

def search(self, query: str, top_k: int = 5) -> list[SearchResult]:
query_words = self._tokenize(query)
scores = {doc_id: self.score(query_words, doc_id) for doc_id in self.idlist}
sorted_scores = sorted(scores.items(), key=lambda x: x[1], reverse=True)
return [
SearchResult(id=doc_id, score=score, text="", metadata={})
for doc_id, score in sorted_scores[:top_k]
]


class Fts5BM25Index(BM25Index):
"""BM25 over a SQLite FTS5 virtual table, persisted on disk.

Built once at `leann build` time, queried memory-bounded at search time.
Avoids the in-memory term-frequency table that BM25Scorer keeps in RAM and
that gets re-fit on every cold start. See #327 for the broader plan.
SQLite owns the on-disk term/posting data; queries hit `bm25()` directly.
"""

# SQLite's FTS5 bm25() returns lower-is-better. We negate so the rest of
Expand Down Expand Up @@ -425,9 +342,8 @@ def fit(self, documents: list[dict[str, Any]]) -> None:
conn.close()

def search(self, query: str, top_k: int = 5) -> list["SearchResult"]:
# Match the BM25Scorer tokenization for query consistency: strip
# punctuation, lowercase, OR the terms together. Avoids FTS5 query
# syntax surprises (`:`, `*`, etc.).
# Strip punctuation, lowercase, OR the terms together. Avoids FTS5
# query syntax surprises (`:`, `*`, etc.) for natural-language queries.
terms = re.sub(r"[^\w\s]", "", query).lower().split()
if not terms:
return []
Expand Down Expand Up @@ -459,16 +375,13 @@ def __init__(
embedding_mode: str = "sentence-transformers",
embedding_options: Optional[dict[str, Any]] = None,
prebuild_bm25: bool = False,
bm25_backend: str = "memory",
bm25_backend: str = "fts5",
**backend_kwargs,
):
if bm25_backend not in ("memory", "fts5"):
raise ValueError(
f"Unknown bm25_backend: {bm25_backend!r}. Expected 'memory' or 'fts5'."
)
if bm25_backend != "fts5":
logger.warning(f"bm25_backend={bm25_backend!r} is deprecated; using 'fts5'.")
bm25_backend = "fts5"
self.bm25_backend = bm25_backend
# If user picked fts5 explicitly, treat that as opting into prebuild —
# FTS5 only makes sense as a build-time artifact.
self.prebuild_bm25 = prebuild_bm25 or bm25_backend == "fts5"
self.backend_name = backend_name
# Normalize incompatible combinations early (for consistent metadata)
Expand Down Expand Up @@ -687,14 +600,9 @@ def build_index(self, index_path: str):
meta_data["is_pruned"] = bool(is_recompute)

if self.prebuild_bm25:
if self.bm25_backend == "fts5":
self._build_bm25_fts5(index_dir, index_name)
meta_data["bm25_backend"] = "fts5"
meta_data["bm25_db"] = f"{index_name}.bm25.sqlite"
else:
self._build_bm25_snapshot(index_dir, index_name)
meta_data["bm25_snapshot"] = f"{index_name}.bm25.pkl"
meta_data["bm25_backend"] = "memory"
self._build_bm25_fts5(index_dir, index_name)
meta_data["bm25_backend"] = "fts5"
meta_data["bm25_db"] = f"{index_name}.bm25.sqlite"

with open(leann_meta_path, "w", encoding="utf-8") as f:
json.dump(meta_data, f, indent=2)
Expand All @@ -712,21 +620,6 @@ def _build_bm25_fts5(self, index_dir: Path, index_name: str) -> None:
index.close()
logger.info(f"Wrote BM25 FTS5 index to {db_path}")

def _build_bm25_snapshot(self, index_dir: Path, index_name: str) -> None:
"""Fit BM25Scorer on self.chunks and pickle alongside the index.

Lets LeannSearcher._init_bm25 load the fitted scorer on first BM25 query
instead of re-fitting against the passage JSONL — which scans every
passage and builds the full TF table in RAM, dominating first-search
latency on larger corpora.
"""
bm25_path = index_dir / f"{index_name}.bm25.pkl"
scorer = BM25Scorer()
scorer.fit(self.chunks)
with open(bm25_path, "wb") as f:
pickle.dump(scorer, f, protocol=pickle.HIGHEST_PROTOCOL)
logger.info(f"Wrote BM25 snapshot to {bm25_path}")

def build_index_from_arrays(self, index_path: str, ids: list, embeddings: np.ndarray):
"""Build an index from pre-computed embedding arrays.

Expand Down Expand Up @@ -1559,31 +1452,41 @@ def _init_bm25(self) -> None:
f"falling back to fit-on-search."
)

snapshot_name = self.meta_data.get("bm25_snapshot")
if snapshot_name:
snapshot_path = meta_dir / snapshot_name
if snapshot_path.exists():
try:
with open(snapshot_path, "rb") as f:
self.bm25_scorer = pickle.load(f)
logger.info(f"Loaded BM25 snapshot from {snapshot_path}")
return
except Exception as exc:
logger.warning(
f"Failed to load BM25 snapshot at {snapshot_path}, "
f"falling back to fit-on-search: {exc}"
)

# No artifact (older indexes) or load failed: fit on the fly.
self.bm25_scorer = BM25Scorer()
# No FTS5 artifact: build one on the fly from passages.
db_path = meta_dir / (Path(self.meta_path_str).stem.replace(".meta", "") + ".bm25.sqlite")
index = Fts5BM25Index(str(db_path))
passages = []
for passage_file in self.passage_manager.passage_files.values():
with open(passage_file, encoding="utf-8") as f:
for line in f:
if line.strip():
data = json.loads(line)
passages.append(data)
self.bm25_scorer.fit(passages)
try:
with open(passage_file, encoding="utf-8") as f:
for line in f:
if line.strip():
try:
passages.append(json.loads(line))
except json.JSONDecodeError as exc:
logger.warning(f"Skipping malformed JSONL in {passage_file}: {exc}")
except FileNotFoundError:
logger.warning(f"Passage file missing: {passage_file}")

if not passages:
logger.error(
"No passages found for on-demand BM25 index. "
"BM25/hybrid search will return empty results. "
"Re-run 'leann build' to regenerate passage files."
)
return

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: After logging the error, this falls through to index.fit([]) which creates an empty FTS5 table. Consider adding a return here to bail out early — an empty index silently returns no results on search, which is harder to debug than skipping BM25 init entirely.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good catch — fixed in f9c74aa. The empty-passages path now returns early, leaving self.bm25_scorer = None. _bm25_search (line 1496-1497) checks for that and raises RuntimeError("BM25 scorer failed to initialize"), so it surfaces loudly instead of returning empty hits.

try:
index.fit(passages)
except (PermissionError, OSError) as exc:
logger.error(
f"Cannot write BM25 index to {db_path}: {exc}. "
f"Ensure the index directory is writable, or rebuild with prebuild_bm25=True."
)
return

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: When fit() raises PermissionError/OSError, self.bm25_scorer is never assigned. Worth confirming the caller guards against bm25_scorer being unset — otherwise _bm25_search will AttributeError on the next hybrid query.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Confirmed safe — _bm25_search at lines 1496-1497 already does if scorer is None: raise RuntimeError("BM25 scorer failed to initialize"). The PermissionError/OSError path returns without assigning self.bm25_scorer, so it stays None and the next call surfaces as RuntimeError, not AttributeError. With f9c74aa, the empty-passages path now behaves the same way.

self.bm25_scorer = index
logger.info(f"Built FTS5 BM25 index on-demand at {db_path}")

def _bm25_search(self, query: str, top_k: int = 5) -> list[SearchResult]:
"""Perform BM25 search on raw passages"""
Expand Down
18 changes: 11 additions & 7 deletions packages/leann-core/src/leann/embedding_compute.py
Original file line number Diff line number Diff line change
Expand Up @@ -481,11 +481,10 @@ def compute_embeddings_sentence_transformers(
torch.backends.cudnn.deterministic = False
torch.cuda.set_per_process_memory_fraction(0.9)
elif device == "mps":
try:
if hasattr(torch.mps, "set_per_process_memory_fraction"):
torch.mps.set_per_process_memory_fraction(0.9)
except AttributeError:
logger.warning("Some MPS optimizations not available in this PyTorch version")
# No device-level init for MPS. set_per_process_memory_fraction causes
# greedy allocation; torch.compile causes graph buffer bloat. Cache
# clearing is handled per-batch in the compute loop below.
pass
elif device == "cpu":
# TODO: Haven't tested this yet
torch.set_num_threads(min(8, os.cpu_count() or 4))
Expand Down Expand Up @@ -586,7 +585,7 @@ def compute_embeddings_sentence_transformers(
logger.warning(f"FP16 optimization failed: {e}")

# Apply torch.compile optimization
if device in ["cuda", "mps"]:
if device == "cuda":
try:
model = torch.compile(model, mode="reduce-overhead", dynamic=True)
logger.info(f"Applied torch.compile optimization: {model_name}")
Expand Down Expand Up @@ -659,7 +658,7 @@ def compute_embeddings_sentence_transformers(

hf_model.eval()
# Optional compile on supported devices
if device in ["cuda", "mps"]:
if device == "cuda":
try:
hf_model = torch.compile(hf_model, mode="reduce-overhead", dynamic=True)
logger.info(
Expand Down Expand Up @@ -714,6 +713,11 @@ def compute_embeddings_sentence_transformers(
pooled = masked.sum(dim=1) / lengths
batch_embeddings = pooled.detach().to("cpu").float().numpy()
all_embeddings.append(batch_embeddings)
if device == "mps":
try:
torch.mps.empty_cache()
except (RuntimeError, AttributeError):
pass

embeddings = np.vstack(all_embeddings).astype(np.float32, copy=False)
try:
Expand Down
Loading