Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 33 additions & 5 deletions benchmarks/beir-bench/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,6 @@ use hnsw_rs::prelude::*;
// HNSW hyper-parameters (faithful to the prior "hnswlib M=32" comparison).
const HNSW_M: usize = 32;
const HNSW_EF_CONSTRUCTION: usize = 200;
const HNSW_EF_SEARCH: usize = 128;
const HNSW_MAX_LAYER: usize = 16;

// ---------------------------------------------------------------------------
Expand All @@ -56,6 +55,7 @@ struct Config {
out_dir: String,
threads: usize, // 0 = all cores
max_docs: Option<usize>, // None = full corpus
ef_search: usize, // HNSW query-time recall/latency knob (default 128)
}

fn parse_args() -> Config {
Expand All @@ -76,10 +76,18 @@ fn parse_args() -> Config {
let mut out_dir = String::from("results/beir");
let mut threads = 0usize;
let mut max_docs: Option<usize> = None;
let mut ef_search = 128usize;

let mut args = std::env::args().skip(1);
while let Some(a) = args.next() {
match a.as_str() {
"--ef-search" => {
ef_search = args
.next()
.expect("--ef-search requires a value")
.parse()
.expect("--ef-search must be an integer")
Comment thread
toadkicker marked this conversation as resolved.
}
"--cache-dir" => cache_dir = args.next().expect("--cache-dir requires a value"),
"--dataset" => dataset = args.next().expect("--dataset requires a value"),
"--split" => split = args.next().expect("--split requires a value"),
Expand Down Expand Up @@ -136,6 +144,19 @@ fn parse_args() -> Config {
assert!(batch >= 1, "--batch must be >= 1");
assert!(top_k >= 1, "--top-k must be >= 1");
assert!(candidates >= 1, "--candidates must be >= 1");
// hnsw_rs requires ef_search >= the requested neighbour count (it internally
// clamps ef = max(ef, knbn)). An --ef-search below --top-k would otherwise be
// silently bumped, flattening an ef sweep at the low end. Clamp explicitly +
// warn so the sweep stays meaningful and the recorded ef matches what ran.
let ef_search = if ef_search < top_k {
eprintln!(
"warning: --ef-search {ef_search} < --top-k {top_k}; clamping ef_search to {top_k} \
(hnsw_rs requires ef >= k)"
);
top_k
} else {
ef_search
};

Config {
cache_dir,
Expand All @@ -148,6 +169,7 @@ fn parse_args() -> Config {
out_dir,
threads,
max_docs,
ef_search,
}
}

Expand Down Expand Up @@ -1031,8 +1053,14 @@ fn run_hnsw(
write_topk: bool,
timing_writer: &mut dyn Write,
) {
let slug = "hnsw";
eprintln!(" building HNSW M={HNSW_M} ef_c={HNSW_EF_CONSTRUCTION} ({n_docs} docs) ...");
// ef in the slug so an ef-sweep does not overwrite topk/summary/timing rows
// (each operating point on the recall/latency frontier is recorded distinctly).
let slug = format!("hnsw_ef{}", cfg.ef_search);
Comment thread
Fieldnote-Echo marked this conversation as resolved.

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Badge Remove stale default HNSW artifacts

When rerunning the standard Makefile targets in a results directory that already contains the old default hnsw.topk.jsonl/hnsw.summary.json, this new default slug writes hnsw_ef128.* instead of overwriting those files. beir_eval.discover_runs() globs every *.topk.jsonl, and bench-beir-quality/bench-beir-scaling only remove timing.jsonl, so the regenerated report includes a stale hnsw row alongside the new hnsw_ef128 row. Please remove/migrate the legacy HNSW artifacts when emitting the ef-qualified slug, or keep the default output name and record ef separately.

Useful? React with 👍 / 👎.

let slug = slug.as_str();
eprintln!(
" building HNSW M={HNSW_M} ef_c={HNSW_EF_CONSTRUCTION} ef_s={} ({n_docs} docs) ...",
cfg.ef_search
);
// DistL2 (not DistDot): embeddings are unit-normalized, so min-L2 ≡ max-dot ≡
// max-cosine — identical neighbors — but DistL2 avoids anndists' DistDot
// `1-dot` distance assert, which panics on near-duplicate pairs whose float
Expand Down Expand Up @@ -1070,7 +1098,7 @@ fn run_hnsw(
// Single-thread: serial search per query.
(bs..be)
.map(|qi| {
hnsw.search(query_rows[qi], top_k, HNSW_EF_SEARCH)
hnsw.search(query_rows[qi], top_k, cfg.ef_search)
Comment thread
toadkicker marked this conversation as resolved.
.into_iter()
.map(|nb| (nb.d_id as i64, -nb.distance))
.collect()
Expand All @@ -1080,7 +1108,7 @@ fn run_hnsw(
// Threaded: batched parallel search (rayon, this pool).
let batch_slice: Vec<Vec<f32>> =
(bs..be).map(|qi| query_rows[qi].to_vec()).collect();
hnsw.parallel_search(&batch_slice, top_k, HNSW_EF_SEARCH)
hnsw.parallel_search(&batch_slice, top_k, cfg.ef_search)
.into_iter()
.map(|nbs| {
nbs.into_iter()
Expand Down
5 changes: 5 additions & 0 deletions benchmarks/beir/beir_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -577,6 +577,11 @@ def write_csv(

def method_stem(method_slug: str) -> str:
"""Strip ``-m<N>`` / ``-b<N>`` parameter suffixes from a method slug."""
hnsw_ef_prefix = "hnsw_ef"
if method_slug.startswith(hnsw_ef_prefix):
suffix = method_slug[len(hnsw_ef_prefix):]
if suffix.isdigit():
return "hnsw"
parts = method_slug.split("-")
kept = [
p
Expand Down
72 changes: 66 additions & 6 deletions benchmarks/beir/beir_plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,53 @@
ORDER = [s for s, _, _ in METHOD_STYLE]


def _hnsw_ef(slug: str) -> int | None:
prefix = "hnsw_ef"
if not slug.startswith(prefix):
return None
suffix = slug[len(prefix):]
return int(suffix) if suffix.isdigit() else None


def _method_family(slug: str) -> str:
if _hnsw_ef(slug) is not None:
return "hnsw"
return slug


def _method_label(slug: str) -> str:
ef = _hnsw_ef(slug)
if ef is not None:
return f"HNSW M=32 ef={ef} (4096 B + graph)"
return LABEL.get(slug, slug)


def _method_short_label(slug: str) -> str:
ef = _hnsw_ef(slug)
if ef is not None:
return f"HNSW ef={ef}"
return _method_label(slug).split(" (")[0]


def _method_color(slug: str) -> str:
return COLOR.get(_method_family(slug), "#777777")


def _method_order_key(slug: str) -> tuple[int, int, str]:
family = _method_family(slug)
try:
base_order = ORDER.index(family)
except ValueError:
base_order = len(ORDER)
ef = _hnsw_ef(slug)
return (base_order, ef if ef is not None else -1, slug)


def _ordered_methods(records: list[dict]) -> list[str]:
slugs = {r["method"] for r in records if "method" in r}
return sorted(slugs, key=_method_order_key)
Comment on lines +93 to +94

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Badge Filter probe methods out of default plots

Because this now returns every method slug present in an append-only timing.jsonl, plot_scaling and plot_bars no longer restrict the README figures to METHOD_STYLE/BENCH_METHODS. If the file contains a prior probe run such as sign-rq2-threaded (which the Makefile explicitly excludes from committed figures until reviewed), the plotter will render that stale/probe method with the gray fallback instead of only adding ef-qualified HNSW variants. Please filter to the supported/default families while normalizing hnsw_ef* to the HNSW family.

Useful? React with 👍 / 👎.



def _read_timing(path: pathlib.Path) -> list[dict]:
records: list[dict] = []
with path.open("r", encoding="utf-8") as fh:
Expand Down Expand Up @@ -82,7 +129,7 @@ def plot_scaling(records: list[dict], dataset: str, threads: int, batch: int,

mode = "single-query (batch=1)" if batch == 1 else f"batched (batch={batch})"
fig, ax = plt.subplots(figsize=(8.2, 5.0))
for slug in ORDER:
for slug in _ordered_methods(recs):
pts = sorted(
((r["n_docs"], r["query_latency_ms_p50"]) for r in recs if r["method"] == slug),
key=lambda t: t[0],
Expand All @@ -92,9 +139,22 @@ def plot_scaling(records: list[dict], dataset: str, threads: int, batch: int,
if len(xs) < 2:
continue
if slug == "flat":
ax.axhline(1.0, color=COLOR[slug], ls="--", lw=1.2, label=LABEL[slug])
ax.axhline(
1.0,
color=_method_color(slug),
ls="--",
lw=1.2,
label=_method_label(slug),
)
else:
ax.plot(xs, ys, marker="o", lw=2.0, color=COLOR[slug], label=LABEL[slug])
ax.plot(
xs,
ys,
marker="o",
lw=2.0,
color=_method_color(slug),
label=_method_label(slug),
)

ax.set_xscale("log")
ax.set_yscale("log")
Expand Down Expand Up @@ -124,15 +184,15 @@ def plot_bars(records: list[dict], dataset: str, threads: int, batch: int, n_doc
lambda r: r["method"],
)
by_method = {r["method"]: r for r in recs}
slugs = [s for s in ORDER if s in by_method]
slugs = [s for s in _ordered_methods(recs) if s in by_method]
if not slugs:
print(f"[plot] no records for {fname} (threads={threads}, n={n_docs})", file=sys.stderr)
return

p50 = [by_method[s]["query_latency_ms_p50"] for s in slugs]
qps = [by_method[s]["queries_per_second"] for s in slugs]
colors = [COLOR[s] for s in slugs]
labels = [LABEL[s].split(" (")[0] for s in slugs]
colors = [_method_color(s) for s in slugs]
labels = [_method_short_label(s) for s in slugs]

flat_p50 = by_method.get("flat", {}).get("query_latency_ms_p50")

Expand Down
Loading