diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 3d7ecd49..07cb825c 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -433,6 +433,9 @@ jobs: if: steps.sde.outputs.sde-available == 'true' env: CARGO_TARGET_X86_64_UNKNOWN_LINUX_GNU_RUNNER: ${{ steps.sde.outputs.sde-path }} -spr -- + # Cause any AVX-512 test that would silently skip on a non-AVX-512 host + # to panic loudly here, so this job genuinely enforces the kernels. + ORDVEC_REQUIRE_AVX512: "1" run: | set -euo pipefail cargo test diff --git a/CHANGELOG.md b/CHANGELOG.md index d5fd4138..f330e02f 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -7,6 +7,28 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ## Unreleased +### Performance + +- **AVX-512 VPOPCNTDQ scan kernels now cover every `dim` (a multiple of 64), not + just multiples of 512 bits.** Previously the `SignBitmap` and `Bitmap` scan + kernels took the AVX-512 path only when the per-vector 64-bit word count was a + multiple of 8 (`dim` a multiple of 512), silently falling back to the scalar + loop otherwise — so common embedding widths like **768 (BGE) and 384 + (bge-small / MiniLM)** ran the entire stage-1 candidate scan scalar. The + kernels now process the trailing `(dim / 64) % 8` words with a masked load + (`_mm512_maskz_loadu_epi64`), so any supported `dim` uses VPOPCNTDQ. Measured + **~4× faster** stage-1 scan at dim=768 on a Zen5 / AVX-512 host (609 → 153 + µs/query, n=100k; see `examples/bge_kernel_bench`); 1024/1536 unchanged. + Results are byte-identical to the scalar path — parity tests cover qpv tail + residues 0..7 plus 384/512/768/1024/1536 for all six SignBitmap/Bitmap scan + kernels. This is stage-1 scan-kernel throughput, not a whole-pipeline figure. + +### Added + +- `avx512vpop_supported()` (`#[doc(hidden)]`) — reports whether the AVX-512 + VPOPCNTDQ scan kernels are active on the current CPU. The scan dispatch reads + only this predicate (no per-dimension gate). + ### Fixed - **`ordvec-manifest` crate and wheel now ship license text.** Both declared diff --git a/examples/bge_kernel_bench.rs b/examples/bge_kernel_bench.rs new file mode 100644 index 00000000..a0201e1d --- /dev/null +++ b/examples/bge_kernel_bench.rs @@ -0,0 +1,56 @@ +// Stage-1 SignBitmap scan-kernel A/B for BGE-style dims. +// Times `score_all_batched_flat` (the per-query dense Hamming scan, the +// stage-1 candidate-gen kernel) at a given dim. On origin/main, dim=768 +// (qpv=12) takes the SCALAR fallback; on the avx512-tail branch it takes +// AVX-512 VPOPCNTDQ with a masked tail. Same public call, same inputs. +// +// cargo run --release --example bge_kernel_bench -- +use ordvec::SignBitmap; +use rand::{RngExt, SeedableRng}; +use rand_chacha::ChaCha8Rng; +use std::time::Instant; + +fn median(mut v: Vec) -> f64 { + v.sort_by(|a, b| a.partial_cmp(b).unwrap()); + v[v.len() / 2] +} + +fn main() { + let a: Vec = std::env::args().collect(); + let dim: usize = a.get(1).and_then(|s| s.parse().ok()).unwrap_or(768); + let n: usize = a.get(2).and_then(|s| s.parse().ok()).unwrap_or(100_000); + let batch: usize = a.get(3).and_then(|s| s.parse().ok()).unwrap_or(256); + let reps: usize = a.get(4).and_then(|s| s.parse().ok()).unwrap_or(40); + + let mut rng = ChaCha8Rng::seed_from_u64(42); + let corpus: Vec = (0..n * dim).map(|_| rng.random_range(-1.0..1.0)).collect(); + let mut idx = SignBitmap::new(dim); + idx.add(&corpus); + let queries: Vec = (0..batch * dim) + .map(|_| rng.random_range(-1.0..1.0)) + .collect(); + + // Warmup. + for _ in 0..3 { + std::hint::black_box(idx.score_all_batched_flat(&queries)); + } + + let mut samples = Vec::with_capacity(reps); + for _ in 0..reps { + let t = Instant::now(); + let s = idx.score_all_batched_flat(&queries); + let us = t.elapsed().as_secs_f64() * 1e6 / batch as f64; + std::hint::black_box(&s); + samples.push(us); + } + let med = median(samples.clone()); + let p10 = { + let mut v = samples.clone(); + v.sort_by(|a, b| a.partial_cmp(b).unwrap()); + v[v.len() / 10] + }; + println!( + "dim={dim} n={n} batch={batch} reps={reps} qpv={} -> scan median {med:.2} us/query (p10 {p10:.2})", + dim / 64, + ); +} diff --git a/src/bitmap.rs b/src/bitmap.rs index 35ca9626..35087fb3 100644 --- a/src/bitmap.rs +++ b/src/bitmap.rs @@ -23,6 +23,16 @@ //! //! Intended primary use: candidate generator for two-stage retrieval //! (bitmap probe → top-M candidates → exact RankQuant rerank). +//! +//! # Dimensions and the AVX-512 kernel +//! +//! `dim` must be a multiple of 64. On a host with AVX-512 VPOPCNTDQ **every** +//! such `dim` runs the vectorized AND-popcount scan: whole 512-bit (8 × u64) +//! groups, then any trailing `(dim / 64) % 8` words via a single masked load +//! (`_mm512_maskz_loadu_epi64`). Dimensions whose word count is a multiple of 8 +//! (512, 1024, 1536, …) have no tail; others (e.g. **384, 768**) pay **one +//! extra masked chunk** — a few percent, so 768 ≈ 1024 — rather than dropping +//! to scalar. See [`crate::avx512vpop_supported`]. use rayon::prelude::*; @@ -424,12 +434,7 @@ impl Bitmap { // performance preference, so a debug_assert here would wrongly // panic on valid-but-unsorted input. - #[cfg(target_arch = "x86_64")] - let use_avx512vpop = is_x86_feature_detected!("avx512f") - && is_x86_feature_detected!("avx512vpopcntdq") - && qpv.is_multiple_of(8); - #[cfg(not(target_arch = "x86_64"))] - let use_avx512vpop = false; + let use_avx512vpop = crate::avx512vpop_supported(); if use_avx512vpop { #[cfg(target_arch = "x86_64")] @@ -560,12 +565,7 @@ impl Bitmap { fn bitmap_scan(bitmaps: &[u64], n: usize, qpv: usize, q: &[u64], top: &mut TopK) { debug_assert_eq!(q.len(), qpv); - #[cfg(target_arch = "x86_64")] - let use_avx512vpop = is_x86_feature_detected!("avx512f") - && is_x86_feature_detected!("avx512vpopcntdq") - && qpv.is_multiple_of(8); - #[cfg(not(target_arch = "x86_64"))] - let use_avx512vpop = false; + let use_avx512vpop = crate::avx512vpop_supported(); if use_avx512vpop { #[cfg(target_arch = "x86_64")] @@ -588,27 +588,33 @@ fn bitmap_scan_scalar(bitmaps: &[u64], n: usize, qpv: usize, q: &[u64], top: &mu #[target_feature(enable = "avx512f,avx512vpopcntdq")] unsafe fn bitmap_scan_avx512vpop(bitmaps: &[u64], n: usize, qpv: usize, q: &[u64], top: &mut TopK) { use std::arch::x86_64::*; - // SAFETY: every raw 512-bit load is in-bounds under the caller's contract - // (`bitmap_scan`): `qpv % 8 == 0` (gated by the `qpv.is_multiple_of(8)` - // dispatch check, so `lanes = qpv / 8` tiles `q` and each doc row exactly), - // `q.len() == qpv` (one full query row), and `bitmaps.len() == n * qpv` (the - // index stores `n` contiguous `qpv`-word rows). Thus `q.as_ptr().add(l*8)` - // (`l < qpv/8`) and `doc_ptr.add(l)` at `doc_ptr = bitmaps + di*qpv` - // (`di < n`) each stay within their slice. AVX-512 F/VPOPCNTDQ are confirmed - // by the `#[target_feature]` gate plus the caller's runtime - // `is_x86_feature_detected!`. + // SAFETY: the caller (`bitmap_scan`) guarantees `q.len() == qpv` (one full + // query row) and `bitmaps.len() == n * qpv` (n contiguous qpv-word rows). + // Full 8-word groups use `loadu`; the trailing `rem = qpv % 8` words use + // `maskz_loadu`, which only accesses the `rem` valid low lanes + // (fault-suppressed), so loads never over-read the query row or the doc + // buffer. AVX-512 F/VPOPCNTDQ are confirmed by the `#[target_feature]` gate + // plus the caller's runtime `is_x86_feature_detected!`. // The explicit block is required by `#![deny(unsafe_op_in_unsafe_fn)]`. unsafe { - debug_assert_eq!(qpv % 8, 0, "AVX-512 bitmap scan needs qpv % 8 == 0"); + debug_assert!(qpv > 0); let lanes = qpv / 8; + let rem = qpv % 8; + let tail_mask: __mmask8 = if rem != 0 { (1u8 << rem) - 1 } else { 0 }; let mut q_zmms: Vec<__m512i> = Vec::with_capacity(lanes); #[allow(clippy::needless_range_loop)] // indexed access is clearer / matches the kernel layout for l in 0..lanes { q_zmms.push(_mm512_loadu_si512(q.as_ptr().add(l * 8) as *const __m512i)); } + let q_tail = if rem != 0 { + _mm512_maskz_loadu_epi64(tail_mask, q.as_ptr().add(lanes * 8) as *const i64) + } else { + _mm512_setzero_si512() + }; for di in 0..n { - let doc_ptr = bitmaps.as_ptr().add(di * qpv) as *const __m512i; + let doc_base = bitmaps.as_ptr().add(di * qpv); + let doc_ptr = doc_base as *const __m512i; let mut acc_zmm = _mm512_setzero_si512(); #[allow(clippy::needless_range_loop)] // indexed access is clearer / matches the kernel layout @@ -618,6 +624,12 @@ unsafe fn bitmap_scan_avx512vpop(bitmaps: &[u64], n: usize, qpv: usize, q: &[u64 let pop_zmm = _mm512_popcnt_epi64(and_zmm); acc_zmm = _mm512_add_epi64(acc_zmm, pop_zmm); } + if rem != 0 { + let d_tail = + _mm512_maskz_loadu_epi64(tail_mask, doc_base.add(lanes * 8) as *const i64); + let and_zmm = _mm512_and_si512(d_tail, q_tail); + acc_zmm = _mm512_add_epi64(acc_zmm, _mm512_popcnt_epi64(and_zmm)); + } let acc_sum: i64 = _mm512_reduce_add_epi64(acc_zmm); top.maybe_insert(acc_sum as f32, di); } @@ -632,12 +644,7 @@ fn bitmap_scan_collect(bitmaps: &[u64], n: usize, qpv: usize, q: &[u64], scores: debug_assert_eq!(scores.len(), n); debug_assert_eq!(q.len(), qpv); - #[cfg(target_arch = "x86_64")] - let use_avx512vpop = is_x86_feature_detected!("avx512f") - && is_x86_feature_detected!("avx512vpopcntdq") - && qpv.is_multiple_of(8); - #[cfg(not(target_arch = "x86_64"))] - let use_avx512vpop = false; + let use_avx512vpop = crate::avx512vpop_supported(); if use_avx512vpop { #[cfg(target_arch = "x86_64")] @@ -664,23 +671,33 @@ unsafe fn bitmap_scan_collect_avx512vpop( ) { use std::arch::x86_64::*; // SAFETY: same contract as the sibling `bitmap_scan_avx512vpop` — the caller - // (`bitmap_scan_collect`) gates dispatch on `qpv.is_multiple_of(8)`, - // `q.len() == qpv`, and `bitmaps.len() == n * qpv`, bounding all raw loads. + // (`bitmap_scan_collect`) guarantees `q.len() == qpv` and + // `bitmaps.len() == n * qpv`. Full 8-word groups use `loadu`; the trailing + // `rem = qpv % 8` words use `maskz_loadu` (only the `rem` valid low lanes are + // accessed, fault-suppressed), so loads never over-read. // AVX-512 F/VPOPCNTDQ confirmed by `#[target_feature]` + runtime detection. // The explicit block is required by `#![deny(unsafe_op_in_unsafe_fn)]`. unsafe { - debug_assert_eq!(qpv % 8, 0); + debug_assert!(qpv > 0); let lanes = qpv / 8; + let rem = qpv % 8; + let tail_mask: __mmask8 = if rem != 0 { (1u8 << rem) - 1 } else { 0 }; let mut q_zmms: Vec<__m512i> = Vec::with_capacity(lanes); #[allow(clippy::needless_range_loop)] // indexed access is clearer / matches the kernel layout for l in 0..lanes { q_zmms.push(_mm512_loadu_si512(q.as_ptr().add(l * 8) as *const __m512i)); } + let q_tail = if rem != 0 { + _mm512_maskz_loadu_epi64(tail_mask, q.as_ptr().add(lanes * 8) as *const i64) + } else { + _mm512_setzero_si512() + }; #[allow(clippy::needless_range_loop)] // indexed access is clearer / matches the kernel layout for di in 0..n { - let doc_ptr = bitmaps.as_ptr().add(di * qpv) as *const __m512i; + let doc_base = bitmaps.as_ptr().add(di * qpv); + let doc_ptr = doc_base as *const __m512i; let mut acc_zmm = _mm512_setzero_si512(); for l in 0..lanes { let d_zmm = _mm512_loadu_si512(doc_ptr.add(l)); @@ -688,6 +705,12 @@ unsafe fn bitmap_scan_collect_avx512vpop( let pop_zmm = _mm512_popcnt_epi64(and_zmm); acc_zmm = _mm512_add_epi64(acc_zmm, pop_zmm); } + if rem != 0 { + let d_tail = + _mm512_maskz_loadu_epi64(tail_mask, doc_base.add(lanes * 8) as *const i64); + let and_zmm = _mm512_and_si512(d_tail, q_tail); + acc_zmm = _mm512_add_epi64(acc_zmm, _mm512_popcnt_epi64(and_zmm)); + } let acc_sum: i64 = _mm512_reduce_add_epi64(acc_zmm); scores[di] = acc_sum as u32; } @@ -710,8 +733,8 @@ unsafe fn bitmap_scan_collect_avx512vpop( // is paid once. // ------------------------------------------------------------------- -/// Scalar fallback for the batched scan. Used when AVX-512 VPOPCNTDQ -/// is unavailable or when `qpv % 8 != 0`. +/// Scalar fallback for the batched scan. Used only when AVX-512 VPOPCNTDQ is +/// unavailable (the kernel handles any `qpv` via a masked tail). fn bitmap_scan_collect_batched_scalar( bitmaps: &[u64], n: usize, @@ -755,16 +778,21 @@ unsafe fn bitmap_scan_collect_batched_avx512vpop( ) { use std::arch::x86_64::*; // SAFETY: same contract as the sibling `bitmap_scan_avx512vpop` — the caller - // (`bitmap_scan_collect_batched`) gates dispatch on `qpv.is_multiple_of(8)`, - // `q_batch.len() == batch * qpv`, `bitmaps.len() == n * qpv`, and - // `scores.len() == batch * n`, bounding all raw loads and `scores[…]` writes. - // AVX-512 F/VPOPCNTDQ confirmed by `#[target_feature]` + runtime detection. + // (`bitmap_scan_collect_batched`) guarantees `q_batch.len() == batch * qpv`, + // `bitmaps.len() == n * qpv`, and `scores.len() == batch * n`. Full 8-word + // groups use `loadu`; the trailing `rem = qpv % 8` words use `maskz_loadu` + // (only the `rem` valid low lanes accessed, fault-suppressed), so loads never + // over-read a per-vector slice or the buffer end; `scores[…]` writes are + // bounded as before. AVX-512 F/VPOPCNTDQ confirmed by `#[target_feature]` + + // runtime detection. // The explicit block is required by `#![deny(unsafe_op_in_unsafe_fn)]`. unsafe { - debug_assert_eq!(qpv % 8, 0); + debug_assert!(qpv > 0); debug_assert_eq!(q_batch.len(), batch * qpv); debug_assert_eq!(scores.len(), batch * n); let lanes = qpv / 8; + let rem = qpv % 8; + let tail_mask: __mmask8 = if rem != 0 { (1u8 << rem) - 1 } else { 0 }; const CHUNK: usize = BATCHED_AVX512_CHUNK; // Pre-load all batch * lanes query ZMMs once. For typical @@ -779,6 +807,16 @@ unsafe fn bitmap_scan_collect_batched_avx512vpop( )); } } + // Per-query masked tail (trailing `rem` words); empty when qpv % 8 == 0. + let mut q_tails: Vec<__m512i> = Vec::with_capacity(if rem != 0 { batch } else { 0 }); + if rem != 0 { + for bi in 0..batch { + q_tails.push(_mm512_maskz_loadu_epi64( + tail_mask, + q_batch.as_ptr().add(bi * qpv + lanes * 8) as *const i64, + )); + } + } // Hot path: process whole CHUNK-sized groups. The inner `for bi // in 0..CHUNK` is bounded by a *const*, so LLVM unrolls it and @@ -792,7 +830,8 @@ unsafe fn bitmap_scan_collect_batched_avx512vpop( while chunk_start + CHUNK <= batch { for di in 0..n { let mut accs: [__m512i; CHUNK] = [_mm512_setzero_si512(); CHUNK]; - let doc_ptr = bitmaps.as_ptr().add(di * qpv) as *const __m512i; + let doc_base = bitmaps.as_ptr().add(di * qpv); + let doc_ptr = doc_base as *const __m512i; for l in 0..lanes { let d_zmm = _mm512_loadu_si512(doc_ptr.add(l)); for bi in 0..CHUNK { @@ -802,6 +841,14 @@ unsafe fn bitmap_scan_collect_batched_avx512vpop( accs[bi] = _mm512_add_epi64(accs[bi], pop_zmm); } } + if rem != 0 { + let d_tail = + _mm512_maskz_loadu_epi64(tail_mask, doc_base.add(lanes * 8) as *const i64); + for bi in 0..CHUNK { + let and_zmm = _mm512_and_si512(d_tail, q_tails[chunk_start + bi]); + accs[bi] = _mm512_add_epi64(accs[bi], _mm512_popcnt_epi64(and_zmm)); + } + } for bi in 0..CHUNK { let acc_sum: i64 = _mm512_reduce_add_epi64(accs[bi]); scores[(chunk_start + bi) * n + di] = acc_sum as u32; @@ -818,7 +865,8 @@ unsafe fn bitmap_scan_collect_batched_avx512vpop( if tail > 0 { for di in 0..n { let mut accs: [__m512i; CHUNK] = [_mm512_setzero_si512(); CHUNK]; - let doc_ptr = bitmaps.as_ptr().add(di * qpv) as *const __m512i; + let doc_base = bitmaps.as_ptr().add(di * qpv); + let doc_ptr = doc_base as *const __m512i; for l in 0..lanes { let d_zmm = _mm512_loadu_si512(doc_ptr.add(l)); for bi in 0..tail { @@ -828,6 +876,14 @@ unsafe fn bitmap_scan_collect_batched_avx512vpop( accs[bi] = _mm512_add_epi64(accs[bi], pop_zmm); } } + if rem != 0 { + let d_tail = + _mm512_maskz_loadu_epi64(tail_mask, doc_base.add(lanes * 8) as *const i64); + for bi in 0..tail { + let and_zmm = _mm512_and_si512(d_tail, q_tails[chunk_start + bi]); + accs[bi] = _mm512_add_epi64(accs[bi], _mm512_popcnt_epi64(and_zmm)); + } + } for bi in 0..tail { let acc_sum: i64 = _mm512_reduce_add_epi64(accs[bi]); scores[(chunk_start + bi) * n + di] = acc_sum as u32; @@ -839,8 +895,8 @@ unsafe fn bitmap_scan_collect_batched_avx512vpop( /// Batched bitmap scan: writes `scores[bi * n + di]` = popcount overlap /// for query `bi` against doc `di`, for all `bi ∈ [0, batch)` and -/// `di ∈ [0, n)`. Dispatches to the AVX-512 VPOPCNTDQ kernel when -/// available (qpv % 8 == 0), else falls back to scalar. +/// `di ∈ [0, n)`. Dispatches to the AVX-512 VPOPCNTDQ kernel when available +/// (any `qpv`; non-multiples of 8 are handled by a masked tail), else scalar. fn bitmap_scan_collect_batched( bitmaps: &[u64], n: usize, @@ -849,12 +905,7 @@ fn bitmap_scan_collect_batched( batch: usize, scores: &mut [u32], ) { - #[cfg(target_arch = "x86_64")] - let use_avx512vpop = is_x86_feature_detected!("avx512f") - && is_x86_feature_detected!("avx512vpopcntdq") - && qpv.is_multiple_of(8); - #[cfg(not(target_arch = "x86_64"))] - let use_avx512vpop = false; + let use_avx512vpop = crate::avx512vpop_supported(); if use_avx512vpop { #[cfg(target_arch = "x86_64")] @@ -877,18 +928,20 @@ unsafe fn body_overlap_scores_subset_avx512vpop( ) { use std::arch::x86_64::*; // SAFETY: in-bounds under the public `body_overlap_scores_subset` - // pre-dispatch asserts: `q_bitmap.len() == qpv` and `qpv % 8 == 0` (the - // latter also gated by `qpv.is_multiple_of(8)` in the dispatch), so the - // `lanes = qpv/8` loads `q_bitmap.as_ptr().add(l*8)` tile `q_bitmap` - // exactly; every `di ∈ doc_ids` is hard-asserted `< n_vectors` *before* - // dispatch, so `bitmaps + di*qpv` plus the `lanes` loads stay within the + // pre-dispatch asserts: `q_bitmap.len() == qpv`, so full 8-word groups + // (`loadu`) plus the trailing `rem = qpv % 8` words (`maskz_loadu`, only the + // `rem` valid low lanes accessed, fault-suppressed) tile `q_bitmap` without + // over-read; every `di ∈ doc_ids` is hard-asserted `< n_vectors` *before* + // dispatch, so `bitmaps + di*qpv` plus the loads stay within the // `n_vectors*qpv`-word buffer; and `out.len() == doc_ids.len()` bounds the // `out[i]` writes. AVX-512 F/VPOPCNTDQ confirmed by `#[target_feature]` + // runtime detection. // The explicit block is required by `#![deny(unsafe_op_in_unsafe_fn)]`. unsafe { - debug_assert_eq!(qpv % 8, 0); + debug_assert!(qpv > 0); let lanes = qpv / 8; + let rem = qpv % 8; + let tail_mask: __mmask8 = if rem != 0 { (1u8 << rem) - 1 } else { 0 }; let mut q_zmms: Vec<__m512i> = Vec::with_capacity(lanes); #[allow(clippy::needless_range_loop)] // indexed access is clearer / matches the kernel layout @@ -897,8 +950,14 @@ unsafe fn body_overlap_scores_subset_avx512vpop( q_bitmap.as_ptr().add(l * 8) as *const __m512i )); } + let q_tail = if rem != 0 { + _mm512_maskz_loadu_epi64(tail_mask, q_bitmap.as_ptr().add(lanes * 8) as *const i64) + } else { + _mm512_setzero_si512() + }; for (i, &di) in doc_ids.iter().enumerate() { - let doc_ptr = bitmaps.as_ptr().add((di as usize) * qpv) as *const __m512i; + let doc_base = bitmaps.as_ptr().add((di as usize) * qpv); + let doc_ptr = doc_base as *const __m512i; let mut acc_zmm = _mm512_setzero_si512(); #[allow(clippy::needless_range_loop)] // indexed access is clearer / matches the kernel layout @@ -908,8 +967,196 @@ unsafe fn body_overlap_scores_subset_avx512vpop( let pop_zmm = _mm512_popcnt_epi64(and_zmm); acc_zmm = _mm512_add_epi64(acc_zmm, pop_zmm); } + if rem != 0 { + let d_tail = + _mm512_maskz_loadu_epi64(tail_mask, doc_base.add(lanes * 8) as *const i64); + let and_zmm = _mm512_and_si512(d_tail, q_tail); + acc_zmm = _mm512_add_epi64(acc_zmm, _mm512_popcnt_epi64(and_zmm)); + } let acc_sum: i64 = _mm512_reduce_add_epi64(acc_zmm); out[i] = acc_sum as u32; } } } + +#[cfg(test)] +mod tests { + use super::*; + use rand::{RngExt, SeedableRng}; + use rand_chacha::ChaCha8Rng; + + fn scalar_overlap(doc: &[u64], q: &[u64]) -> u32 { + doc.iter().zip(q).map(|(d, qq)| (d & qq).count_ones()).sum() + } + + /// Returns `true` if the host supports AVX-512 VPOPCNTDQ and the test should + /// proceed. When AVX-512 is absent: + /// + /// - If `ORDVEC_REQUIRE_AVX512` is set to `"1"` or `"true"` (used by the + /// Intel SDE CI job), this panics so the job fails loudly instead of + /// silently treating a skipped test as green coverage. + /// - Otherwise it emits a skip notice to stderr and returns `false`; the + /// caller should return immediately. + fn require_avx512_or_skip(test_name: &str) -> bool { + if crate::avx512vpop_supported() { + return true; + } + let required = std::env::var("ORDVEC_REQUIRE_AVX512") + .map(|v| v == "1" || v == "true") + .unwrap_or(false); + if required { + panic!( + "SKIP {test_name}: host lacks AVX-512 VPOPCNTDQ but \ + ORDVEC_REQUIRE_AVX512 is set — AVX-512 kernels are not enforced" + ); + } + eprintln!( + "SKIP {test_name}: host lacks AVX-512 VPOPCNTDQ; \ + set ORDVEC_REQUIRE_AVX512=1 to enforce" + ); + false + } + + // Dims covering every qwords-per-vec tail residue (qpv % 8 ∈ 0..=7), the + // lanes==0 all-tail cases (qpv < 8: 64/384/448), and the common embedding + // dims 384/512/768/1024/1536. qpv = dim / 64. + const PARITY_DIMS: [usize; 13] = [ + 64, 384, 448, 512, 576, 640, 704, 768, 832, 896, 960, 1024, 1536, + ]; + + #[test] + fn avx512_path_matches_scalar_across_residues_and_common_dims() { + if !require_avx512_or_skip("avx512_path_matches_scalar_across_residues_and_common_dims") { + return; + } + for &dim in &PARITY_DIMS { + let n = 300usize; + let n_top = (dim / 4).max(1); + let m = 32usize; + let nq = 4usize; + let mut rng = ChaCha8Rng::seed_from_u64(9000 + dim as u64); + let corpus: Vec = (0..n * dim).map(|_| rng.random_range(-1.0..1.0)).collect(); + let mut idx = Bitmap::new(dim, n_top); + idx.add(&corpus); + let qpv = idx.qwords_per_vec; + let queries: Vec = (0..nq * dim).map(|_| rng.random_range(-1.0..1.0)).collect(); + + let batched = idx.top_m_candidates_batched(&queries, m); + for qi in 0..nq { + let q = &queries[qi * dim..(qi + 1) * dim]; + let qbm = idx.build_query_bitmap_fp32(q); + + // (1) body_overlap_scores_subset kernel: exact overlap for ALL + // ids vs an independent scalar over the stored bitmaps. + let all_ids: Vec = (0..n as u32).collect(); + let mut out = vec![0u32; n]; + idx.body_overlap_scores_subset(&qbm, &all_ids, &mut out); + let mut ref_pairs: Vec<(u32, u32)> = Vec::with_capacity(n); + #[allow(clippy::needless_range_loop)] + for di in 0..n { + let off = di * qpv; + let ov = scalar_overlap(&idx.bitmaps[off..off + qpv], &qbm); + assert_eq!(out[di], ov, "body_overlap dim={dim} qi={qi} di={di}"); + ref_pairs.push((ov, di as u32)); + } + // Reference top-m under the library's (overlap desc, id asc) key. + ref_pairs.sort_by(|a, b| b.0.cmp(&a.0).then_with(|| a.1.cmp(&b.1))); + let reference: Vec = ref_pairs.iter().take(m).map(|&(_, d)| d).collect(); + + // (2) bitmap_scan_collect kernel. + assert_eq!( + idx.top_m_candidates(q, m), + reference, + "top_m dim={dim} qi={qi}" + ); + // (3) bitmap_scan_collect_batched kernel. + assert_eq!(batched[qi], reference, "batched dim={dim} qi={qi}"); + + // (4) bitmap_scan (TopK) kernel via search: the returned m docs + // must be a valid top-m by overlap (tie-policy-independent; + // a wrong scan score would admit an out-of-top-m doc). + let res = idx.search(q, m); + let got = res.indices_for_query(0); + assert_eq!(got.len(), m, "search len dim={dim} qi={qi}"); + let got_set: std::collections::HashSet = got.iter().copied().collect(); + let ov_of = + |di: usize| scalar_overlap(&idx.bitmaps[di * qpv..(di + 1) * qpv], &qbm); + let min_in = got.iter().map(|&id| ov_of(id as usize)).min().unwrap(); + let max_out = (0..n) + .filter(|di| !got_set.contains(&(*di as i64))) + .map(ov_of) + .max() + .unwrap_or(0); + assert!( + min_in >= max_out, + "search not a valid top-m: dim={dim} qi={qi} min_in={min_in} max_out={max_out}" + ); + } + } + } + + #[test] + fn unchanged_at_512bit_multiple_dims() { + // 1024 (qpv=16) and 1536 (qpv=24) were always on the AVX-512 path + // (qpv % 8 == 0). They must stay byte-identical to scalar — this pins + // "no behavior change" for the previously-fast dims. + for &dim in &[1024usize, 1536] { + let n = 200usize; + let n_top = dim / 4; + let mut rng = ChaCha8Rng::seed_from_u64(123 + dim as u64); + let corpus: Vec = (0..n * dim).map(|_| rng.random_range(-1.0..1.0)).collect(); + let mut idx = Bitmap::new(dim, n_top); + idx.add(&corpus); + let qpv = idx.qwords_per_vec; + let qbm = idx.build_query_bitmap_fp32(&corpus[..dim]); + let all_ids: Vec = (0..n as u32).collect(); + let mut out = vec![0u32; n]; + idx.body_overlap_scores_subset(&qbm, &all_ids, &mut out); + #[allow(clippy::needless_range_loop)] + for di in 0..n { + let off = di * qpv; + let ov = scalar_overlap(&idx.bitmaps[off..off + qpv], &qbm); + assert_eq!( + out[di], ov, + "512-bit-multiple dim={dim} regressed at di={di}" + ); + } + } + } + + #[test] + fn masked_tail_kernel_matches_scalar_when_avx512_present() { + // Directly exercise the masked-tail AVX-512 kernel. This is ONLY + // meaningful on a host (or Intel SDE) with AVX-512 VPOPCNTDQ: there the + // qpv % 8 != 0 dims below force the masked tail and the dispatch routes + // to the AVX-512 path, so this asserts the tail kernel itself. On a + // non-AVX-512 host it SKIPS with a notice (rather than silently passing + // on the scalar path), so a green run on such a host is not mistaken for + // tail-kernel coverage. (Replaces a tautological predicate self-equality + // check; the cross-platform scalar parity lives in the test above.) + if !require_avx512_or_skip("masked_tail_kernel_matches_scalar_when_avx512_present") { + return; + } + // qpv % 8 != 0 (384->6, 768->12, 832->13) -> the masked tail runs. + for &dim in &[384usize, 768, 832] { + let n = 200usize; + let n_top = (dim / 4).max(1); + let mut rng = ChaCha8Rng::seed_from_u64(424_242 + dim as u64); + let corpus: Vec = (0..n * dim).map(|_| rng.random_range(-1.0..1.0)).collect(); + let mut idx = Bitmap::new(dim, n_top); + idx.add(&corpus); + let qpv = idx.qwords_per_vec; + assert_ne!(qpv % 8, 0, "dim={dim} must force the masked tail"); + let qbm = idx.build_query_bitmap_fp32(&corpus[..dim]); + let all_ids: Vec = (0..n as u32).collect(); + let mut out = vec![0u32; n]; + idx.body_overlap_scores_subset(&qbm, &all_ids, &mut out); + #[allow(clippy::needless_range_loop)] + for di in 0..n { + let off = di * qpv; + let ov = scalar_overlap(&idx.bitmaps[off..off + qpv], &qbm); + assert_eq!(out[di], ov, "masked-tail AVX-512 dim={dim} di={di}"); + } + } + } +} diff --git a/src/lib.rs b/src/lib.rs index 5e1e0110..32b87459 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -94,6 +94,28 @@ pub use multi_bucket::MultiBucketBitmap; #[doc(hidden)] pub use fastscan::RankQuantFastscan; +/// Whether the AVX-512 VPOPCNTDQ bitmap/sign scan kernels are active on this +/// CPU. `#[doc(hidden)]` — a diagnostic for tests and downstream probes, not a +/// stability surface. +/// +/// The scan dispatch ([`SignBitmap`] and [`Bitmap`]) consults this and +/// **nothing else** — it takes no dimension. So once VPOPCNTDQ is present, +/// *every* `dim` (a multiple of 64) runs the kernel, including dims whose +/// 64-bit word count is not a multiple of 8 (e.g. 384, 768): those are handled +/// by a masked tail, not by falling back to the scalar path. +#[doc(hidden)] +#[must_use] +pub fn avx512vpop_supported() -> bool { + #[cfg(target_arch = "x86_64")] + { + is_x86_feature_detected!("avx512f") && is_x86_feature_detected!("avx512vpopcntdq") + } + #[cfg(not(target_arch = "x86_64"))] + { + false + } +} + // Pre-0.2 names (the `Index` suffix was dropped in the OrdVec ontology // rebrand). Retained as deprecated type aliases for back-compat; remove // in a future release. `pub type` (rather than `pub use … as`) causes diff --git a/src/sign_bitmap.rs b/src/sign_bitmap.rs index 2cca6cc5..04081ab3 100644 --- a/src/sign_bitmap.rs +++ b/src/sign_bitmap.rs @@ -26,6 +26,17 @@ //! only material difference is `_mm512_xor_si512` in place of //! `_mm512_and_si512` and an ascending tie-broken composite-key //! selection on Hamming distance. +//! +//! # Dimensions and the AVX-512 kernel +//! +//! `dim` must be a multiple of 64. On a host with AVX-512 VPOPCNTDQ **every** +//! such `dim` runs the vectorized scan: the kernel processes whole 512-bit +//! (8 × u64) groups, then handles any trailing `(dim / 64) % 8` words with a +//! single masked load (`_mm512_maskz_loadu_epi64`). Dimensions whose 64-bit +//! word count is a multiple of 8 — 512, 1024, 1536, … — have no tail; others +//! (e.g. **384, 768**, the common BGE/MiniLM widths) pay **one extra masked +//! chunk** — a few percent, so 768 ≈ 1024 — instead of falling back to the +//! scalar path. See [`crate::avx512vpop_supported`]. use rayon::prelude::*; @@ -485,12 +496,7 @@ fn sign_scan_collect(bitmaps: &[u64], n: usize, qpv: usize, q: &[u64], scores: & debug_assert_eq!(scores.len(), n); debug_assert_eq!(q.len(), qpv); - #[cfg(target_arch = "x86_64")] - let use_avx512vpop = is_x86_feature_detected!("avx512f") - && is_x86_feature_detected!("avx512vpopcntdq") - && qpv.is_multiple_of(8); - #[cfg(not(target_arch = "x86_64"))] - let use_avx512vpop = false; + let use_avx512vpop = crate::avx512vpop_supported(); if use_avx512vpop { #[cfg(target_arch = "x86_64")] @@ -516,24 +522,36 @@ unsafe fn sign_scan_collect_avx512vpop( scores: &mut [u32], ) { use std::arch::x86_64::*; - // SAFETY: mirrors `bitmap_scan_collect_avx512vpop` — the caller - // (`sign_scan_collect`) gates dispatch on `qpv.is_multiple_of(8)`, - // `q.len() == qpv`, and `bitmaps.len() == n * qpv`, bounding all raw loads. - // AVX-512 F/VPOPCNTDQ confirmed by `#[target_feature]` + runtime detection. - // The explicit block is required by `#![deny(unsafe_op_in_unsafe_fn)]`. + // SAFETY: mirrors `bitmap_scan_collect_avx512vpop`. The caller + // (`sign_scan_collect`) guarantees `q.len() == qpv` and + // `bitmaps.len() == n * qpv`. Full 8-word groups use `loadu`; the trailing + // `rem = qpv % 8` words use `maskz_loadu`, which only accesses the `rem` + // valid low lanes (fault-suppressed), so loads never over-read the qpv + // slice or the buffer end. AVX-512 F/VPOPCNTDQ confirmed by + // `#[target_feature]` + runtime detection. The explicit block is required + // by `#![deny(unsafe_op_in_unsafe_fn)]`. unsafe { - debug_assert_eq!(qpv % 8, 0); + debug_assert!(qpv > 0); let lanes = qpv / 8; + let rem = qpv % 8; + let tail_mask: __mmask8 = if rem != 0 { (1u8 << rem) - 1 } else { 0 }; let mut q_zmms: Vec<__m512i> = Vec::with_capacity(lanes); #[allow(clippy::needless_range_loop)] // indexed access is clearer / matches the kernel layout for l in 0..lanes { q_zmms.push(_mm512_loadu_si512(q.as_ptr().add(l * 8) as *const __m512i)); } + // Trailing `rem` query words, masked (high lanes read as 0). + let q_tail = if rem != 0 { + _mm512_maskz_loadu_epi64(tail_mask, q.as_ptr().add(lanes * 8) as *const i64) + } else { + _mm512_setzero_si512() + }; #[allow(clippy::needless_range_loop)] // indexed access is clearer / matches the kernel layout for di in 0..n { - let doc_ptr = bitmaps.as_ptr().add(di * qpv) as *const __m512i; + let doc_base = bitmaps.as_ptr().add(di * qpv); + let doc_ptr = doc_base as *const __m512i; let mut acc_zmm = _mm512_setzero_si512(); for l in 0..lanes { let d_zmm = _mm512_loadu_si512(doc_ptr.add(l)); @@ -541,6 +559,14 @@ unsafe fn sign_scan_collect_avx512vpop( let pop_zmm = _mm512_popcnt_epi64(xor_zmm); acc_zmm = _mm512_add_epi64(acc_zmm, pop_zmm); } + if rem != 0 { + // Masked tail: masked-off lanes are not loaded and XOR/popcnt to + // 0, so they leave the Hamming sum unchanged. + let d_tail = + _mm512_maskz_loadu_epi64(tail_mask, doc_base.add(lanes * 8) as *const i64); + let xor_zmm = _mm512_xor_si512(d_tail, q_tail); + acc_zmm = _mm512_add_epi64(acc_zmm, _mm512_popcnt_epi64(xor_zmm)); + } let acc_sum: i64 = _mm512_reduce_add_epi64(acc_zmm); scores[di] = acc_sum as u32; } @@ -563,12 +589,7 @@ fn sign_scan_collect_batched( batch: usize, scores: &mut [u32], ) { - #[cfg(target_arch = "x86_64")] - let use_avx512vpop = is_x86_feature_detected!("avx512f") - && is_x86_feature_detected!("avx512vpopcntdq") - && qpv.is_multiple_of(8); - #[cfg(not(target_arch = "x86_64"))] - let use_avx512vpop = false; + let use_avx512vpop = crate::avx512vpop_supported(); if use_avx512vpop { #[cfg(target_arch = "x86_64")] @@ -598,17 +619,21 @@ unsafe fn sign_scan_collect_batched_avx512vpop( scores: &mut [u32], ) { use std::arch::x86_64::*; - // SAFETY: mirrors `bitmap_scan_collect_batched_avx512vpop` — the caller - // (`sign_scan_collect_batched`) gates dispatch on `qpv.is_multiple_of(8)`, - // `q_batch.len() == batch * qpv`, `bitmaps.len() == n * qpv`, and - // `scores.len() == batch * n`, bounding all raw loads and `scores[…]` writes. - // AVX-512 F/VPOPCNTDQ confirmed by `#[target_feature]` + runtime detection. - // The explicit block is required by `#![deny(unsafe_op_in_unsafe_fn)]`. + // SAFETY: mirrors `bitmap_scan_collect_batched_avx512vpop`. The caller + // (`sign_scan_collect_batched`) guarantees `q_batch.len() == batch * qpv`, + // `bitmaps.len() == n * qpv`, and `scores.len() == batch * n`. Full 8-word + // groups use `loadu`; the trailing `rem = qpv % 8` words use `maskz_loadu`, + // which only accesses the `rem` valid low lanes (fault-suppressed), so loads + // never over-read a per-vector slice or the buffer end. AVX-512 F/VPOPCNTDQ + // confirmed by `#[target_feature]` + runtime detection. The explicit block + // is required by `#![deny(unsafe_op_in_unsafe_fn)]`. unsafe { - debug_assert_eq!(qpv % 8, 0); + debug_assert!(qpv > 0); debug_assert_eq!(q_batch.len(), batch * qpv); debug_assert_eq!(scores.len(), batch * n); let lanes = qpv / 8; + let rem = qpv % 8; + let tail_mask: __mmask8 = if rem != 0 { (1u8 << rem) - 1 } else { 0 }; const CHUNK: usize = BATCHED_AVX512_CHUNK; let mut q_zmms: Vec<__m512i> = Vec::with_capacity(batch * lanes); @@ -619,6 +644,16 @@ unsafe fn sign_scan_collect_batched_avx512vpop( )); } } + // Per-query masked tail (trailing `rem` words); empty when qpv % 8 == 0. + let mut q_tails: Vec<__m512i> = Vec::with_capacity(if rem != 0 { batch } else { 0 }); + if rem != 0 { + for bi in 0..batch { + q_tails.push(_mm512_maskz_loadu_epi64( + tail_mask, + q_batch.as_ptr().add(bi * qpv + lanes * 8) as *const i64, + )); + } + } // Hot path: CHUNK-sized groups; const-bounded inner bi loop so // LLVM unrolls and promotes the accs array to ZMM registers. @@ -626,7 +661,8 @@ unsafe fn sign_scan_collect_batched_avx512vpop( while chunk_start + CHUNK <= batch { for di in 0..n { let mut accs: [__m512i; CHUNK] = [_mm512_setzero_si512(); CHUNK]; - let doc_ptr = bitmaps.as_ptr().add(di * qpv) as *const __m512i; + let doc_base = bitmaps.as_ptr().add(di * qpv); + let doc_ptr = doc_base as *const __m512i; for l in 0..lanes { let d_zmm = _mm512_loadu_si512(doc_ptr.add(l)); for bi in 0..CHUNK { @@ -636,6 +672,14 @@ unsafe fn sign_scan_collect_batched_avx512vpop( accs[bi] = _mm512_add_epi64(accs[bi], pop_zmm); } } + if rem != 0 { + let d_tail = + _mm512_maskz_loadu_epi64(tail_mask, doc_base.add(lanes * 8) as *const i64); + for bi in 0..CHUNK { + let xor_zmm = _mm512_xor_si512(d_tail, q_tails[chunk_start + bi]); + accs[bi] = _mm512_add_epi64(accs[bi], _mm512_popcnt_epi64(xor_zmm)); + } + } for bi in 0..CHUNK { let acc_sum: i64 = _mm512_reduce_add_epi64(accs[bi]); scores[(chunk_start + bi) * n + di] = acc_sum as u32; @@ -643,12 +687,13 @@ unsafe fn sign_scan_collect_batched_avx512vpop( } chunk_start += CHUNK; } - // Tail. + // Tail over the query batch. let tail = batch - chunk_start; if tail > 0 { for di in 0..n { let mut accs: [__m512i; CHUNK] = [_mm512_setzero_si512(); CHUNK]; - let doc_ptr = bitmaps.as_ptr().add(di * qpv) as *const __m512i; + let doc_base = bitmaps.as_ptr().add(di * qpv); + let doc_ptr = doc_base as *const __m512i; for l in 0..lanes { let d_zmm = _mm512_loadu_si512(doc_ptr.add(l)); for bi in 0..tail { @@ -658,6 +703,14 @@ unsafe fn sign_scan_collect_batched_avx512vpop( accs[bi] = _mm512_add_epi64(accs[bi], pop_zmm); } } + if rem != 0 { + let d_tail = + _mm512_maskz_loadu_epi64(tail_mask, doc_base.add(lanes * 8) as *const i64); + for bi in 0..tail { + let xor_zmm = _mm512_xor_si512(d_tail, q_tails[chunk_start + bi]); + accs[bi] = _mm512_add_epi64(accs[bi], _mm512_popcnt_epi64(xor_zmm)); + } + } for bi in 0..tail { let acc_sum: i64 = _mm512_reduce_add_epi64(accs[bi]); scores[(chunk_start + bi) * n + di] = acc_sum as u32; @@ -687,6 +740,34 @@ mod tests { .sum() } + /// Returns `true` if the host supports AVX-512 VPOPCNTDQ and the test should + /// proceed. When AVX-512 is absent: + /// + /// - If `ORDVEC_REQUIRE_AVX512` is set to `"1"` or `"true"` (used by the + /// Intel SDE CI job), this panics so the job fails loudly instead of + /// silently treating a skipped test as green coverage. + /// - Otherwise it emits a skip notice to stderr and returns `false`; the + /// caller should return immediately. + fn require_avx512_or_skip(test_name: &str) -> bool { + if crate::avx512vpop_supported() { + return true; + } + let required = std::env::var("ORDVEC_REQUIRE_AVX512") + .map(|v| v == "1" || v == "true") + .unwrap_or(false); + if required { + panic!( + "SKIP {test_name}: host lacks AVX-512 VPOPCNTDQ but \ + ORDVEC_REQUIRE_AVX512 is set — AVX-512 kernels are not enforced" + ); + } + eprintln!( + "SKIP {test_name}: host lacks AVX-512 VPOPCNTDQ; \ + set ORDVEC_REQUIRE_AVX512=1 to enforce" + ); + false + } + #[test] fn candidate_batch_helpers() { use super::CandidateBatch; @@ -965,6 +1046,9 @@ mod tests { #[test] fn avx512_path_matches_scalar_at_production_dim() { + if !require_avx512_or_skip("avx512_path_matches_scalar_at_production_dim") { + return; + } const PROD_D: usize = 1024; let n = 256; let mut rng = ChaCha8Rng::seed_from_u64(31); @@ -996,4 +1080,61 @@ mod tests { ); } } + + #[test] + fn avx512_path_matches_scalar_across_residues_and_common_dims() { + if !require_avx512_or_skip("avx512_path_matches_scalar_across_residues_and_common_dims") { + return; + } + // Covers every qpv tail residue (qpv % 8 ∈ 0..=7), the lanes==0 all-tail + // cases (qpv < 8: 64/384/448), and the common embedding dims + // 384/512/768/1024/1536. The AVX-512 path (masked tail for non-multiples + // of 8) must stay byte-identical to a scalar Hamming reference. + for &dim in &[ + 64usize, 384, 448, 512, 576, 640, 704, 768, 832, 896, 960, 1024, 1536, + ] { + let n = 200usize; + let m = 32usize; + let nq = 5usize; + let mut rng = ChaCha8Rng::seed_from_u64(7000 + dim as u64); + let corpus: Vec = (0..n * dim).map(|_| rng.random_range(-1.0..1.0)).collect(); + let mut idx = SignBitmap::new(dim); + idx.add(&corpus); + let queries: Vec = (0..nq * dim).map(|_| rng.random_range(-1.0..1.0)).collect(); + + let qpv = idx.qwords_per_vec; + let dimu = dim as u32; + let batched = idx.top_m_candidates_batched(&queries, m); + let scores_flat = idx.score_all_batched_flat(&queries); + for qi in 0..nq { + let q = &queries[qi * dim..(qi + 1) * dim]; + let qbm = idx.build_query_bitmap(q); + let single_scores = idx.score_all(q); + let mut ref_pairs: Vec<(u32, u32)> = Vec::with_capacity(n); + for di in 0..n { + let off = di * qpv; + let ham = scalar_hamming(&qbm, &idx.bitmaps[off..off + qpv]); + let agree = dimu - ham; + assert_eq!( + single_scores[di], agree, + "score_all dim={dim} qi={qi} di={di}" + ); + assert_eq!( + scores_flat[qi * n + di], + agree, + "score_all_batched_flat dim={dim} qi={qi} di={di}" + ); + ref_pairs.push((ham, di as u32)); + } + ref_pairs.sort_by_key(|&(h, did)| (h, did)); + let reference: Vec = ref_pairs.iter().take(m).map(|&(_, did)| did).collect(); + assert_eq!( + idx.top_m_candidates(q, m), + reference, + "single dim={dim} qi={qi}" + ); + assert_eq!(batched[qi], reference, "batched dim={dim} qi={qi}"); + } + } + } }