-
-
Notifications
You must be signed in to change notification settings - Fork 4
feat(beir-bench): --ef-search flag for HNSW (was hardcoded const) #265
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
fc4c26f
cc7753a
c06270e
d00313b
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -38,7 +38,6 @@ use hnsw_rs::prelude::*; | |
| // HNSW hyper-parameters (faithful to the prior "hnswlib M=32" comparison). | ||
| const HNSW_M: usize = 32; | ||
| const HNSW_EF_CONSTRUCTION: usize = 200; | ||
| const HNSW_EF_SEARCH: usize = 128; | ||
| const HNSW_MAX_LAYER: usize = 16; | ||
|
|
||
| // --------------------------------------------------------------------------- | ||
|
|
@@ -56,6 +55,7 @@ struct Config { | |
| out_dir: String, | ||
| threads: usize, // 0 = all cores | ||
| max_docs: Option<usize>, // None = full corpus | ||
| ef_search: usize, // HNSW query-time recall/latency knob (default 128) | ||
| } | ||
|
|
||
| fn parse_args() -> Config { | ||
|
|
@@ -76,10 +76,18 @@ fn parse_args() -> Config { | |
| let mut out_dir = String::from("results/beir"); | ||
| let mut threads = 0usize; | ||
| let mut max_docs: Option<usize> = None; | ||
| let mut ef_search = 128usize; | ||
|
|
||
| let mut args = std::env::args().skip(1); | ||
| while let Some(a) = args.next() { | ||
| match a.as_str() { | ||
| "--ef-search" => { | ||
| ef_search = args | ||
| .next() | ||
| .expect("--ef-search requires a value") | ||
| .parse() | ||
| .expect("--ef-search must be an integer") | ||
| } | ||
| "--cache-dir" => cache_dir = args.next().expect("--cache-dir requires a value"), | ||
| "--dataset" => dataset = args.next().expect("--dataset requires a value"), | ||
| "--split" => split = args.next().expect("--split requires a value"), | ||
|
|
@@ -136,6 +144,19 @@ fn parse_args() -> Config { | |
| assert!(batch >= 1, "--batch must be >= 1"); | ||
| assert!(top_k >= 1, "--top-k must be >= 1"); | ||
| assert!(candidates >= 1, "--candidates must be >= 1"); | ||
| // hnsw_rs requires ef_search >= the requested neighbour count (it internally | ||
| // clamps ef = max(ef, knbn)). An --ef-search below --top-k would otherwise be | ||
| // silently bumped, flattening an ef sweep at the low end. Clamp explicitly + | ||
| // warn so the sweep stays meaningful and the recorded ef matches what ran. | ||
| let ef_search = if ef_search < top_k { | ||
| eprintln!( | ||
| "warning: --ef-search {ef_search} < --top-k {top_k}; clamping ef_search to {top_k} \ | ||
| (hnsw_rs requires ef >= k)" | ||
| ); | ||
| top_k | ||
| } else { | ||
| ef_search | ||
| }; | ||
|
|
||
| Config { | ||
| cache_dir, | ||
|
|
@@ -148,6 +169,7 @@ fn parse_args() -> Config { | |
| out_dir, | ||
| threads, | ||
| max_docs, | ||
| ef_search, | ||
| } | ||
| } | ||
|
|
||
|
|
@@ -1031,8 +1053,14 @@ fn run_hnsw( | |
| write_topk: bool, | ||
| timing_writer: &mut dyn Write, | ||
| ) { | ||
| let slug = "hnsw"; | ||
| eprintln!(" building HNSW M={HNSW_M} ef_c={HNSW_EF_CONSTRUCTION} ({n_docs} docs) ..."); | ||
| // ef in the slug so an ef-sweep does not overwrite topk/summary/timing rows | ||
| // (each operating point on the recall/latency frontier is recorded distinctly). | ||
| let slug = format!("hnsw_ef{}", cfg.ef_search); | ||
|
Fieldnote-Echo marked this conversation as resolved.
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
When rerunning the standard Makefile targets in a results directory that already contains the old default Useful? React with 👍 / 👎. |
||
| let slug = slug.as_str(); | ||
| eprintln!( | ||
| " building HNSW M={HNSW_M} ef_c={HNSW_EF_CONSTRUCTION} ef_s={} ({n_docs} docs) ...", | ||
| cfg.ef_search | ||
| ); | ||
| // DistL2 (not DistDot): embeddings are unit-normalized, so min-L2 ≡ max-dot ≡ | ||
| // max-cosine — identical neighbors — but DistL2 avoids anndists' DistDot | ||
| // `1-dot` distance assert, which panics on near-duplicate pairs whose float | ||
|
|
@@ -1070,7 +1098,7 @@ fn run_hnsw( | |
| // Single-thread: serial search per query. | ||
| (bs..be) | ||
| .map(|qi| { | ||
| hnsw.search(query_rows[qi], top_k, HNSW_EF_SEARCH) | ||
| hnsw.search(query_rows[qi], top_k, cfg.ef_search) | ||
|
toadkicker marked this conversation as resolved.
|
||
| .into_iter() | ||
| .map(|nb| (nb.d_id as i64, -nb.distance)) | ||
| .collect() | ||
|
|
@@ -1080,7 +1108,7 @@ fn run_hnsw( | |
| // Threaded: batched parallel search (rayon, this pool). | ||
| let batch_slice: Vec<Vec<f32>> = | ||
| (bs..be).map(|qi| query_rows[qi].to_vec()).collect(); | ||
| hnsw.parallel_search(&batch_slice, top_k, HNSW_EF_SEARCH) | ||
| hnsw.parallel_search(&batch_slice, top_k, cfg.ef_search) | ||
| .into_iter() | ||
| .map(|nbs| { | ||
| nbs.into_iter() | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -47,6 +47,53 @@ | |
| ORDER = [s for s, _, _ in METHOD_STYLE] | ||
|
|
||
|
|
||
| def _hnsw_ef(slug: str) -> int | None: | ||
| prefix = "hnsw_ef" | ||
| if not slug.startswith(prefix): | ||
| return None | ||
| suffix = slug[len(prefix):] | ||
| return int(suffix) if suffix.isdigit() else None | ||
|
|
||
|
|
||
| def _method_family(slug: str) -> str: | ||
| if _hnsw_ef(slug) is not None: | ||
| return "hnsw" | ||
| return slug | ||
|
|
||
|
|
||
| def _method_label(slug: str) -> str: | ||
| ef = _hnsw_ef(slug) | ||
| if ef is not None: | ||
| return f"HNSW M=32 ef={ef} (4096 B + graph)" | ||
| return LABEL.get(slug, slug) | ||
|
|
||
|
|
||
| def _method_short_label(slug: str) -> str: | ||
| ef = _hnsw_ef(slug) | ||
| if ef is not None: | ||
| return f"HNSW ef={ef}" | ||
| return _method_label(slug).split(" (")[0] | ||
|
|
||
|
|
||
| def _method_color(slug: str) -> str: | ||
| return COLOR.get(_method_family(slug), "#777777") | ||
|
|
||
|
|
||
| def _method_order_key(slug: str) -> tuple[int, int, str]: | ||
| family = _method_family(slug) | ||
| try: | ||
| base_order = ORDER.index(family) | ||
| except ValueError: | ||
| base_order = len(ORDER) | ||
| ef = _hnsw_ef(slug) | ||
| return (base_order, ef if ef is not None else -1, slug) | ||
|
|
||
|
|
||
| def _ordered_methods(records: list[dict]) -> list[str]: | ||
| slugs = {r["method"] for r in records if "method" in r} | ||
| return sorted(slugs, key=_method_order_key) | ||
|
Comment on lines
+93
to
+94
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Because this now returns every method slug present in an append-only Useful? React with 👍 / 👎. |
||
|
|
||
|
|
||
| def _read_timing(path: pathlib.Path) -> list[dict]: | ||
| records: list[dict] = [] | ||
| with path.open("r", encoding="utf-8") as fh: | ||
|
|
@@ -82,7 +129,7 @@ def plot_scaling(records: list[dict], dataset: str, threads: int, batch: int, | |
|
|
||
| mode = "single-query (batch=1)" if batch == 1 else f"batched (batch={batch})" | ||
| fig, ax = plt.subplots(figsize=(8.2, 5.0)) | ||
| for slug in ORDER: | ||
| for slug in _ordered_methods(recs): | ||
| pts = sorted( | ||
| ((r["n_docs"], r["query_latency_ms_p50"]) for r in recs if r["method"] == slug), | ||
| key=lambda t: t[0], | ||
|
|
@@ -92,9 +139,22 @@ def plot_scaling(records: list[dict], dataset: str, threads: int, batch: int, | |
| if len(xs) < 2: | ||
| continue | ||
| if slug == "flat": | ||
| ax.axhline(1.0, color=COLOR[slug], ls="--", lw=1.2, label=LABEL[slug]) | ||
| ax.axhline( | ||
| 1.0, | ||
| color=_method_color(slug), | ||
| ls="--", | ||
| lw=1.2, | ||
| label=_method_label(slug), | ||
| ) | ||
| else: | ||
| ax.plot(xs, ys, marker="o", lw=2.0, color=COLOR[slug], label=LABEL[slug]) | ||
| ax.plot( | ||
| xs, | ||
| ys, | ||
| marker="o", | ||
| lw=2.0, | ||
| color=_method_color(slug), | ||
| label=_method_label(slug), | ||
| ) | ||
|
|
||
| ax.set_xscale("log") | ||
| ax.set_yscale("log") | ||
|
|
@@ -124,15 +184,15 @@ def plot_bars(records: list[dict], dataset: str, threads: int, batch: int, n_doc | |
| lambda r: r["method"], | ||
| ) | ||
| by_method = {r["method"]: r for r in recs} | ||
| slugs = [s for s in ORDER if s in by_method] | ||
| slugs = [s for s in _ordered_methods(recs) if s in by_method] | ||
| if not slugs: | ||
| print(f"[plot] no records for {fname} (threads={threads}, n={n_docs})", file=sys.stderr) | ||
| return | ||
|
|
||
| p50 = [by_method[s]["query_latency_ms_p50"] for s in slugs] | ||
| qps = [by_method[s]["queries_per_second"] for s in slugs] | ||
| colors = [COLOR[s] for s in slugs] | ||
| labels = [LABEL[s].split(" (")[0] for s in slugs] | ||
| colors = [_method_color(s) for s in slugs] | ||
| labels = [_method_short_label(s) for s in slugs] | ||
|
|
||
| flat_p50 = by_method.get("flat", {}).get("query_latency_ms_p50") | ||
|
|
||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.