diff --git a/benchmarks/beir-bench/src/main.rs b/benchmarks/beir-bench/src/main.rs index b931c6a..a12acc7 100644 --- a/benchmarks/beir-bench/src/main.rs +++ b/benchmarks/beir-bench/src/main.rs @@ -38,7 +38,6 @@ use hnsw_rs::prelude::*; // HNSW hyper-parameters (faithful to the prior "hnswlib M=32" comparison). const HNSW_M: usize = 32; const HNSW_EF_CONSTRUCTION: usize = 200; -const HNSW_EF_SEARCH: usize = 128; const HNSW_MAX_LAYER: usize = 16; // --------------------------------------------------------------------------- @@ -56,6 +55,7 @@ struct Config { out_dir: String, threads: usize, // 0 = all cores max_docs: Option, // None = full corpus + ef_search: usize, // HNSW query-time recall/latency knob (default 128) } fn parse_args() -> Config { @@ -76,10 +76,18 @@ fn parse_args() -> Config { let mut out_dir = String::from("results/beir"); let mut threads = 0usize; let mut max_docs: Option = None; + let mut ef_search = 128usize; let mut args = std::env::args().skip(1); while let Some(a) = args.next() { match a.as_str() { + "--ef-search" => { + ef_search = args + .next() + .expect("--ef-search requires a value") + .parse() + .expect("--ef-search must be an integer") + } "--cache-dir" => cache_dir = args.next().expect("--cache-dir requires a value"), "--dataset" => dataset = args.next().expect("--dataset requires a value"), "--split" => split = args.next().expect("--split requires a value"), @@ -136,6 +144,19 @@ fn parse_args() -> Config { assert!(batch >= 1, "--batch must be >= 1"); assert!(top_k >= 1, "--top-k must be >= 1"); assert!(candidates >= 1, "--candidates must be >= 1"); + // hnsw_rs requires ef_search >= the requested neighbour count (it internally + // clamps ef = max(ef, knbn)). An --ef-search below --top-k would otherwise be + // silently bumped, flattening an ef sweep at the low end. Clamp explicitly + + // warn so the sweep stays meaningful and the recorded ef matches what ran. + let ef_search = if ef_search < top_k { + eprintln!( + "warning: --ef-search {ef_search} < --top-k {top_k}; clamping ef_search to {top_k} \ + (hnsw_rs requires ef >= k)" + ); + top_k + } else { + ef_search + }; Config { cache_dir, @@ -148,6 +169,7 @@ fn parse_args() -> Config { out_dir, threads, max_docs, + ef_search, } } @@ -1031,8 +1053,14 @@ fn run_hnsw( write_topk: bool, timing_writer: &mut dyn Write, ) { - let slug = "hnsw"; - eprintln!(" building HNSW M={HNSW_M} ef_c={HNSW_EF_CONSTRUCTION} ({n_docs} docs) ..."); + // ef in the slug so an ef-sweep does not overwrite topk/summary/timing rows + // (each operating point on the recall/latency frontier is recorded distinctly). + let slug = format!("hnsw_ef{}", cfg.ef_search); + let slug = slug.as_str(); + eprintln!( + " building HNSW M={HNSW_M} ef_c={HNSW_EF_CONSTRUCTION} ef_s={} ({n_docs} docs) ...", + cfg.ef_search + ); // DistL2 (not DistDot): embeddings are unit-normalized, so min-L2 ≡ max-dot ≡ // max-cosine — identical neighbors — but DistL2 avoids anndists' DistDot // `1-dot` distance assert, which panics on near-duplicate pairs whose float @@ -1070,7 +1098,7 @@ fn run_hnsw( // Single-thread: serial search per query. (bs..be) .map(|qi| { - hnsw.search(query_rows[qi], top_k, HNSW_EF_SEARCH) + hnsw.search(query_rows[qi], top_k, cfg.ef_search) .into_iter() .map(|nb| (nb.d_id as i64, -nb.distance)) .collect() @@ -1080,7 +1108,7 @@ fn run_hnsw( // Threaded: batched parallel search (rayon, this pool). let batch_slice: Vec> = (bs..be).map(|qi| query_rows[qi].to_vec()).collect(); - hnsw.parallel_search(&batch_slice, top_k, HNSW_EF_SEARCH) + hnsw.parallel_search(&batch_slice, top_k, cfg.ef_search) .into_iter() .map(|nbs| { nbs.into_iter() diff --git a/benchmarks/beir/beir_eval.py b/benchmarks/beir/beir_eval.py index f8ef8b5..22b3580 100644 --- a/benchmarks/beir/beir_eval.py +++ b/benchmarks/beir/beir_eval.py @@ -577,6 +577,11 @@ def write_csv( def method_stem(method_slug: str) -> str: """Strip ``-m`` / ``-b`` parameter suffixes from a method slug.""" + hnsw_ef_prefix = "hnsw_ef" + if method_slug.startswith(hnsw_ef_prefix): + suffix = method_slug[len(hnsw_ef_prefix):] + if suffix.isdigit(): + return "hnsw" parts = method_slug.split("-") kept = [ p diff --git a/benchmarks/beir/beir_plot.py b/benchmarks/beir/beir_plot.py index 74ddcf1..f83f801 100644 --- a/benchmarks/beir/beir_plot.py +++ b/benchmarks/beir/beir_plot.py @@ -47,6 +47,53 @@ ORDER = [s for s, _, _ in METHOD_STYLE] +def _hnsw_ef(slug: str) -> int | None: + prefix = "hnsw_ef" + if not slug.startswith(prefix): + return None + suffix = slug[len(prefix):] + return int(suffix) if suffix.isdigit() else None + + +def _method_family(slug: str) -> str: + if _hnsw_ef(slug) is not None: + return "hnsw" + return slug + + +def _method_label(slug: str) -> str: + ef = _hnsw_ef(slug) + if ef is not None: + return f"HNSW M=32 ef={ef} (4096 B + graph)" + return LABEL.get(slug, slug) + + +def _method_short_label(slug: str) -> str: + ef = _hnsw_ef(slug) + if ef is not None: + return f"HNSW ef={ef}" + return _method_label(slug).split(" (")[0] + + +def _method_color(slug: str) -> str: + return COLOR.get(_method_family(slug), "#777777") + + +def _method_order_key(slug: str) -> tuple[int, int, str]: + family = _method_family(slug) + try: + base_order = ORDER.index(family) + except ValueError: + base_order = len(ORDER) + ef = _hnsw_ef(slug) + return (base_order, ef if ef is not None else -1, slug) + + +def _ordered_methods(records: list[dict]) -> list[str]: + slugs = {r["method"] for r in records if "method" in r} + return sorted(slugs, key=_method_order_key) + + def _read_timing(path: pathlib.Path) -> list[dict]: records: list[dict] = [] with path.open("r", encoding="utf-8") as fh: @@ -82,7 +129,7 @@ def plot_scaling(records: list[dict], dataset: str, threads: int, batch: int, mode = "single-query (batch=1)" if batch == 1 else f"batched (batch={batch})" fig, ax = plt.subplots(figsize=(8.2, 5.0)) - for slug in ORDER: + for slug in _ordered_methods(recs): pts = sorted( ((r["n_docs"], r["query_latency_ms_p50"]) for r in recs if r["method"] == slug), key=lambda t: t[0], @@ -92,9 +139,22 @@ def plot_scaling(records: list[dict], dataset: str, threads: int, batch: int, if len(xs) < 2: continue if slug == "flat": - ax.axhline(1.0, color=COLOR[slug], ls="--", lw=1.2, label=LABEL[slug]) + ax.axhline( + 1.0, + color=_method_color(slug), + ls="--", + lw=1.2, + label=_method_label(slug), + ) else: - ax.plot(xs, ys, marker="o", lw=2.0, color=COLOR[slug], label=LABEL[slug]) + ax.plot( + xs, + ys, + marker="o", + lw=2.0, + color=_method_color(slug), + label=_method_label(slug), + ) ax.set_xscale("log") ax.set_yscale("log") @@ -124,15 +184,15 @@ def plot_bars(records: list[dict], dataset: str, threads: int, batch: int, n_doc lambda r: r["method"], ) by_method = {r["method"]: r for r in recs} - slugs = [s for s in ORDER if s in by_method] + slugs = [s for s in _ordered_methods(recs) if s in by_method] if not slugs: print(f"[plot] no records for {fname} (threads={threads}, n={n_docs})", file=sys.stderr) return p50 = [by_method[s]["query_latency_ms_p50"] for s in slugs] qps = [by_method[s]["queries_per_second"] for s in slugs] - colors = [COLOR[s] for s in slugs] - labels = [LABEL[s].split(" (")[0] for s in slugs] + colors = [_method_color(s) for s in slugs] + labels = [_method_short_label(s) for s in slugs] flat_p50 = by_method.get("flat", {}).get("query_latency_ms_p50")