diff --git a/README.md b/README.md index 1e055264..86b5be9e 100644 --- a/README.md +++ b/README.md @@ -34,7 +34,9 @@ vector on its own: the product / scalar / binary quantization most crates use. - **Predictable footprint.** Exactly `dim * bits / 8` bytes per document — known before you see any data (256 B at dim = 1024, 2-bit), with - `bits ∈ {1, 2, 4}` the size/recall knob. + `bits ∈ {1, 2, 4}` the size/recall knob. (`b = 8` is an opt-in + evidence/refinement width — asymmetric scoring at any dim, symmetric only + when `dim % 256 == 0` — not a broad retrieval mode.) - **Two-stage retrieval, built in.** A cheap bitmap / sign-popcount prefilter feeds an exact rerank — the coarse→fine pipeline ships as library primitives. diff --git a/src/lib.rs b/src/lib.rs index 32b87459..a3725de0 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -13,7 +13,14 @@ //! coordinate, `2 * dim` bytes per document). //! - [`RankQuant`] buckets each rank into `1 << bits` equal-width //! bins and packs `bits` bits per coordinate (`dim * bits / 8` bytes -//! per document). +//! per document). `bits ∈ {1, 2, 4}` are the stable retrieval widths; +//! `b = 8` is a capability-gated evidence/refinement width — asymmetric +//! scoring and code/projection generation at any dim, *analytical-norm* +//! symmetric scoring (via [`RankQuant::search`]) only when +//! `dim % 256 == 0` (see [`RankQuant::new_asymmetric`]). The standalone +//! [`rankquant_eval_search`] computes its norm *empirically*, so it scores +//! any `bits ∈ 1..=8` at any dim (including `b = 8` off the 256 grid) and +//! carries no such restriction. //! - [`Bitmap`] stores a top-bucket bitmap per document (one bit //! per coordinate) and scores via `popcount(Q AND D)`. //! - [`SignBitmap`] stores a sign bitmap per document (one bit per @@ -64,7 +71,7 @@ mod util; pub use bitmap::Bitmap; pub use quant::SubsetScratch; -pub use quant::{rankquant_eval_search, RankQuant, TwoStageCandidatePolicy}; +pub use quant::{rankquant_eval_search, RankQuant, RankQuantCapability, TwoStageCandidatePolicy}; pub use rank::Rank; pub use rank_io::{probe_index_metadata, IndexKind, IndexMetadata, IndexParams}; pub use sign_bitmap::CandidateBatch; diff --git a/src/quant.rs b/src/quant.rs index ed16fff9..5c3b2ca8 100644 --- a/src/quant.rs +++ b/src/quant.rs @@ -1,9 +1,22 @@ //! `B`-bit bucketed-rank index ([`RankQuant`]). //! -//! Storage is `dim * bits / 8` bytes per document at `bits ∈ {1, 2, 4}`. -//! Symmetric search uses a per-query, per-coord LUT; asymmetric search -//! dispatches AVX-512 → AVX2 → scalar via the kernels in -//! [`crate::quant_kernels`]. +//! Storage is `dim * bits / 8` bytes per document at `bits ∈ {1, 2, 4, 8}` +//! (`b=8` is one byte per coordinate). Symmetric search uses a per-query, +//! per-coord LUT; asymmetric search dispatches AVX-512 → AVX2 → scalar via +//! the kernels in [`crate::quant_kernels`]. +//! +//! `b=8` is an evidence/refinement-oriented width: it is supported for +//! asymmetric scoring and code/projection generation at **any** dimension, +//! but symmetric scoring uses the equal-bucket analytical norm and therefore +//! requires `dim % 256 == 0`. For `b ∈ {1, 2, 4}` the existing retrieval +//! modes remain the stable headline surface; `b=8` is an opt-in, +//! explicitly-documented high-precision evidence/refinement surface +//! (e.g. asymmetric quant storage after repair flows, edge-case rerank +//! healing), not a broad retrieval-quant method. It is **not** +//! unstable-experimental. See [`RankQuantCapability`] and +//! [`RankQuant::new_asymmetric`]. Its asymmetric path is a per-coordinate +//! gather against the `dim * 256` LUT: an AVX-512 `vgatherdps` kernel when +//! available (`avx512f` + `avx512bw` + `dim % 16 == 0`), else the portable scalar LUT. //! //! The byte-LUT path ([`search_asymmetric_byte_lut`]) is re-exported //! `#[doc(hidden)]` (reachable as `ordvec::search_asymmetric_byte_lut`) @@ -13,7 +26,8 @@ use rayon::prelude::*; use crate::quant_kernels::{ - scan_b1_to_topk, scan_b2_to_topk, scan_b4_to_topk, scan_via_lut_scalar, + scan_b1_to_topk, scan_b2_to_topk, scan_b4_to_topk, scan_b8_asym, scan_b8_to_topk, + scan_via_lut_scalar, }; #[cfg(target_arch = "x86_64")] use crate::quant_kernels::{ @@ -79,7 +93,13 @@ impl SubsetScratch { } fn check_eval_bits(bits: u8) { - assert!((1..=7).contains(&bits), "bits must be in 1..=7"); + // b=8 codes still fit a u8 (0..=255); the eval norm is computed empirically + // (not the analytical b=8 norm), so it is valid at any dim. This is *why* + // the eval path is not bound by the `dim % 256 == 0` gate that the + // analytical-norm symmetric `RankQuant::search` carries for b=8 — the + // empirical norm is exact under any bucket occupancy. b=9 is the first + // width whose codes overflow u8. + assert!((1..=8).contains(&bits), "bits must be in 1..=8"); } fn rankquant_eval_norm(dim: usize, bits: u8) -> f32 { @@ -95,6 +115,25 @@ fn rankquant_eval_norm(dim: usize, bits: u8) -> f32 { acc.sqrt() as f32 } +/// L2 norm of a document's bucket-centre vector, for asymmetric scoring. +/// +/// For `bits ∈ {1, 2, 4}` (and `b = 8` when `dim % 256 == 0`) the bucket +/// occupancy is exactly uniform, so the closed-form [`rankquant_norm`] +/// (`sqrt(dim * var)`) is exact and cheaper. For `b = 8` at a `dim` not +/// divisible by 256 the buckets are *not* equally occupied, so the closed +/// form mis-scales the absolute scores (the *ranking* is unaffected — the +/// norm is one global constant shared by every document — but +/// `search_asymmetric` reports cosine-like scores, which must be correctly +/// scaled). In that regime we fall back to the exact empirical norm, which +/// sums the squared bucket centres over the realised rank→bucket map. +fn asymmetric_norm(dim: usize, bits: u8) -> f32 { + if bits == 8 && !dim.is_multiple_of(256) { + rankquant_eval_norm(dim, bits) + } else { + rankquant_norm(dim, bits) + } +} + fn rankquant_eval_centres(v: &[f32], bits: u8, out: &mut [f32]) { debug_assert_eq!(v.len(), out.len()); let ranks = rank_transform(v); @@ -112,21 +151,66 @@ fn rankquant_eval_buckets(v: &[f32], bits: u8, out: &mut [u8]) { } } +/// Which scoring modes a [`RankQuant`] instance supports. +/// +/// The distinction only matters for `b=8`. For `b ∈ {1, 2, 4}` every +/// constructor produces a [`SymmetricAndAsymmetric`](Self::SymmetricAndAsymmetric) +/// instance (the `dim % 2^bits == 0` constructor invariant always holds), +/// so callers never need to branch on this for the headline widths. +/// +/// For `b=8` the symmetric analytical L2 norm is exact only when every +/// bucket receives equal occupancy, i.e. `dim % 256 == 0`. When that +/// holds the instance is [`SymmetricAndAsymmetric`](Self::SymmetricAndAsymmetric); +/// otherwise it is [`AsymmetricOnly`](Self::AsymmetricOnly) — code/projection +/// generation, pair-evidence/contingency, and asymmetric (float-query) +/// scoring all work at *any* dim, but the symmetric path +/// ([`RankQuant::search`]) panics. +#[derive(Clone, Copy, Debug, PartialEq, Eq)] +pub enum RankQuantCapability { + /// Asymmetric (float-query) scoring and code/projection generation + /// only. Reachable for `b=8` when `dim % 256 != 0`. Symmetric + /// scoring ([`RankQuant::search`]) panics on these instances. + AsymmetricOnly, + /// Full surface: both symmetric and asymmetric scoring. The only + /// capability for `b ∈ {1, 2, 4}`, and the capability for `b=8` when + /// `dim % 256 == 0`. + SymmetricAndAsymmetric, +} + /// `B`-bit RankQuant index. /// /// Each document is encoded by bucketing its rank vector into /// `1 << bits` equal-width bins on `[0, dim)` and packing `bits` bits /// per coordinate. Storage is `dim * bits / 8` bytes per document. -/// Supported bit widths are `1`, `2`, and `4` (3-bit packing is left -/// for a follow-up; use `2` or `4` in the interim). +/// Supported bit widths are `1`, `2`, `4`, and `8` (3-bit packing is +/// left for a follow-up; use `2` or `4` in the interim). /// /// The mean-centred bucket vector has fixed analytical L2 norm /// `sqrt(dim * (2^(2B) - 1) / 12)` when `dim % (1 << bits) == 0`, so /// no per-document norms are stored. +/// +/// # `b=8` — evidence/refinement width +/// `b=8` is an evidence/refinement-oriented RankQuant width. It is +/// supported for asymmetric scoring and code/projection generation at +/// any dimension; symmetric scoring uses the equal-bucket analytical +/// norm and therefore requires `dim % 256 == 0`. For `b ∈ {1, 2, 4}`, +/// the existing retrieval modes remain the stable headline surface; +/// `b=8` is an opt-in, explicitly-documented high-precision +/// evidence/refinement surface (e.g. asymmetric quant storage after +/// repair flows, edge-case rerank healing), not a broad retrieval-quant +/// method. It is **not** unstable-experimental — it is a stable, core +/// surface — but it is capability-gated: construct an asymmetric-only +/// `b=8` index for non-`256`-aligned dims via [`Self::new_asymmetric`] +/// and check [`Self::symmetric_supported`] before calling +/// [`Self::search`]. See [`RankQuantCapability`]. pub struct RankQuant { pub(crate) dim: usize, pub(crate) bits: u8, pub(crate) n_vectors: usize, + /// Scoring modes this instance supports — see [`RankQuantCapability`]. + /// Computed once at construction; for `b ∈ {1, 2, 4}` always + /// [`RankQuantCapability::SymmetricAndAsymmetric`]. + pub(crate) capability: RankQuantCapability, /// Row-major packed bucket bytes. `n_vectors * dim * bits / 8` total. pub(crate) packed: Vec, } @@ -229,11 +313,27 @@ fn select_simd_tier(dim: usize, bits: u8) -> SimdTier { } impl RankQuant { + /// Validate `(dim, bits)` for **code validity** — the precondition for + /// generating bucket codes, projections, and asymmetric scores. + /// + /// Accepts `bits ∈ {1, 2, 4, 8}` and `dim ∈ [2, u16::MAX]`. + /// + /// For `b ∈ {1, 2, 4}` this additionally requires `dim % 2^bits == 0` + /// (the equal-bucket constant-composition invariant): those widths only + /// expose a full symmetric+asymmetric surface, so code validity and + /// symmetric-norm validity coincide. + /// + /// For `b = 8` it validates **only** that codes pack (`codes_per_byte == + /// 1`, so any `dim` works) — it does **not** require `dim % 256 == 0`. + /// That `dim % 256 == 0` rule is a *symmetric-scoring* precondition, not + /// a code-validity one, and is checked separately on the symmetric path + /// (and by [`Self::new`], which constructs a full-capability `b=8` + /// instance). Use [`Self::new_asymmetric`] for any-`dim` `b=8`. pub fn validate_params(dim: usize, bits: u8) -> Result<(), OrdvecError> { - if !matches!(bits, 1 | 2 | 4) { + if !matches!(bits, 1 | 2 | 4 | 8) { return Err(OrdvecError::InvalidParameter { name: "bits", - message: "must be 1, 2, or 4".to_string(), + message: "must be 1, 2, 4, or 8".to_string(), }); } if dim < 2 { @@ -255,20 +355,45 @@ impl RankQuant { message: format!("must be a multiple of {codes_per_byte} for bits = {bits}"), }); } - let n_buckets = 1usize << bits; - if !dim.is_multiple_of(n_buckets) { - return Err(OrdvecError::InvalidParameter { - name: "dim", - message: format!( - "must be divisible by 2^bits = {n_buckets} so every bucket receives exactly dim / 2^bits rank entries" - ), - }); + // The constant-composition invariant `dim % 2^bits == 0` exists only to + // make the symmetric analytical L2 norm exact (equal bucket occupancy). + // For b ∈ {1,2,4} we keep requiring it here (those widths are + // full-capability by definition), but for b=8 it is a *symmetric* + // precondition checked elsewhere — code/projection/asymmetric paths + // never need equal buckets, so a non-256-aligned dim is a valid b=8 + // *code* configuration. + if bits != 8 { + let n_buckets = 1usize << bits; + if !dim.is_multiple_of(n_buckets) { + return Err(OrdvecError::InvalidParameter { + name: "dim", + message: format!( + "must be divisible by 2^bits = {n_buckets} so every bucket receives exactly dim / 2^bits rank entries" + ), + }); + } } Ok(()) } + /// Construct a full-capability (`SymmetricAndAsymmetric`) index. + /// + /// For `b ∈ {1, 2, 4}` this is unchanged: `bits` must be one of those + /// widths and `dim % 2^bits == 0` (and `dim % (8 / bits) == 0`). + /// + /// For `b = 8` this requires `dim % 256 == 0`, which yields the full + /// symmetric+asymmetric surface. If `dim % 256 != 0` it **panics** + /// (consistent with this constructor's existing fail-loud style), + /// directing the caller to [`Self::new_asymmetric`] for an any-`dim` + /// asymmetric-only `b=8` index. See [`RankQuantCapability`]. + /// + /// # Panics + /// Panics if `bits ∉ {1, 2, 4, 8}`, if `dim < 2`, if `dim > u16::MAX`, + /// if `dim % (8 / bits) != 0`, or — for the equal-bucket symmetric + /// invariant — if `dim % 2^bits != 0` (`b ∈ {1,2,4}`) / `dim % 256 != 0` + /// (`b = 8`). pub fn new(dim: usize, bits: u8) -> Self { - assert!(matches!(bits, 1 | 2 | 4), "bits must be 1, 2, or 4"); + assert!(matches!(bits, 1 | 2 | 4 | 8), "bits must be 1, 2, 4, or 8"); assert!(dim >= 2, "dim must be >= 2"); assert!(dim <= u16::MAX as usize, "dim must fit in u16"); let codes_per_byte = (8 / bits) as usize; @@ -277,6 +402,27 @@ impl RankQuant { 0, "dim must be a multiple of {codes_per_byte} for bits = {bits}", ); + if bits == 8 { + // b=8 full-capability requires dim % 256 == 0 (equal bucket + // occupancy → exact symmetric analytical norm). Fail loud and + // point at the asymmetric-only constructor so the caller has a + // non-surprising path for non-aligned dims. + assert_eq!( + dim % 256, + 0, + "RankQuant::new(dim, 8) requires dim % 256 == 0 for symmetric \ + scoring (equal-bucket analytical norm); dim={dim} is not \ + 256-aligned. Use RankQuant::new_asymmetric(dim, 8) for an \ + asymmetric-only b=8 index at any dim.", + ); + return Self { + dim, + bits, + n_vectors: 0, + capability: RankQuantCapability::SymmetricAndAsymmetric, + packed: Vec::new(), + }; + } // Audit-safety: require dim divisible by 2^bits so every bucket // gets exactly dim / (1 << bits) rank entries per document. This // is what makes `rankquant_norm` analytically exact (every doc @@ -296,10 +442,94 @@ impl RankQuant { dim, bits, n_vectors: 0, + capability: RankQuantCapability::SymmetricAndAsymmetric, packed: Vec::new(), } } + /// Construct an asymmetric-capable index at **any** valid `dim`. + /// + /// This is the non-surprising entry point for `b = 8` at a dimension + /// that is not `256`-aligned: it produces a + /// [`RankQuantCapability::AsymmetricOnly`] instance whose + /// code/projection generation, pair-evidence/contingency, and + /// asymmetric (float-query) scoring all work, but whose symmetric path + /// ([`Self::search`]) panics (the equal-bucket analytical norm is not + /// exact off the `256`-aligned grid). When `dim % 256 == 0`, the `b=8` + /// instance is upgraded to full [`RankQuantCapability::SymmetricAndAsymmetric`] + /// (there is no reason to withhold symmetric scoring when it is exact). + /// + /// For `b ∈ {1, 2, 4}` this constructs the same full-capability instance + /// as [`Self::new`] (those widths are always symmetric-capable when their + /// constructor invariants hold), so it is never *less* capable than + /// `new` — it is simply the width-agnostic constructor. + /// + /// # Panics + /// Panics if `(dim, bits)` is not a valid **code** configuration — + /// i.e. `bits ∉ {1, 2, 4, 8}`, `dim < 2`, `dim > u16::MAX`, or + /// `dim % (8 / bits) != 0`. For `b ∈ {1, 2, 4}` it additionally requires + /// `dim % 2^bits == 0` (same as [`Self::new`]). + pub fn new_asymmetric(dim: usize, bits: u8) -> Self { + // Reuse the code-validity gate (accepts any 256-unaligned dim for b=8, + // still requires dim % 2^bits for b ∈ {1,2,4}). Convert the structured + // error into a panic so this constructor matches `new`'s fail-loud style. + Self::validate_params(dim, bits) + .unwrap_or_else(|e| panic!("RankQuant::new_asymmetric invalid params: {e}")); + let capability = Self::capability_for(dim, bits); + Self { + dim, + bits, + n_vectors: 0, + capability, + packed: Vec::new(), + } + } + + /// Compute the capability for a code-valid `(dim, bits)` pair. + /// + /// `b ∈ {1, 2, 4}` and `256`-aligned `b=8` are full-capability; any + /// other (i.e. non-`256`-aligned) `b=8` is asymmetric-only. + #[inline] + fn capability_for(dim: usize, bits: u8) -> RankQuantCapability { + if bits == 8 && !dim.is_multiple_of(256) { + RankQuantCapability::AsymmetricOnly + } else { + RankQuantCapability::SymmetricAndAsymmetric + } + } + + /// The scoring modes this instance supports — see [`RankQuantCapability`]. + /// + /// Always [`RankQuantCapability::SymmetricAndAsymmetric`] for + /// `b ∈ {1, 2, 4}`. For `b=8` it reflects whether `dim % 256 == 0`. + #[inline] + pub fn capability(&self) -> RankQuantCapability { + self.capability + } + + /// Whether [`Self::search`] (symmetric scoring) is supported on this + /// instance. `true` for `b ∈ {1, 2, 4}` and for `256`-aligned `b=8`; + /// `false` for `b=8` at a non-`256`-aligned dim (asymmetric-only). + /// + /// Callers should check this before invoking [`Self::search`] on a + /// `b=8` index built via [`Self::new_asymmetric`]. + #[inline] + pub fn symmetric_supported(&self) -> bool { + matches!(self.capability, RankQuantCapability::SymmetricAndAsymmetric) + } + + /// Fail loud with the exact symmetric-gating message when symmetric + /// scoring is invoked on an asymmetric-only (`b=8`, non-`256`-aligned) + /// instance. No-op for symmetric-capable instances. + #[inline] + fn assert_symmetric_supported(&self) { + assert!( + self.symmetric_supported(), + "RankQuant b=8 symmetric scoring requires dim % 256 == 0; dim={} supports asymmetric/evidence APIs only.", + self.dim, + ); + } + /// Add documents. Each vector is rank-transformed, bucketed to `bits` /// bits/coord, and bit-packed row-major. /// @@ -339,7 +569,21 @@ impl RankQuant { /// Symmetric search: bucket the query and score against bucketed /// docs. + /// + /// # Panics + /// For a `b=8` index built via [`Self::new_asymmetric`] at a + /// non-`256`-aligned dim (an [`RankQuantCapability::AsymmetricOnly`] + /// instance), this **panics**: the symmetric analytical norm requires + /// equal bucket occupancy (`dim % 256 == 0`). Check + /// [`Self::symmetric_supported`] first, or use [`Self::search_asymmetric`], + /// which works at any dim. (`b ∈ {1, 2, 4}` and `256`-aligned `b=8` + /// instances never trip this.) The panic message is: + /// `RankQuant b=8 symmetric scoring requires dim % 256 == 0; dim={dim} + /// supports asymmetric/evidence APIs only.` pub fn search(&self, queries: &[f32], k: usize) -> SearchResults { + // Symmetric gating: fail loud (with the exact message) for an + // asymmetric-only b=8 instance before doing any work. + self.assert_symmetric_supported(); let nq = queries.len() / self.dim; assert_eq!(queries.len(), nq * self.dim); assert_all_finite(queries); @@ -389,6 +633,7 @@ impl RankQuant { 1 => scan_b1_to_topk(&self.packed, n, dim, &lut, inv_norm_sq, &mut top), 2 => scan_b2_to_topk(&self.packed, n, dim, &lut, inv_norm_sq, &mut top), 4 => scan_b4_to_topk(&self.packed, n, dim, &lut, inv_norm_sq, &mut top), + 8 => scan_b8_to_topk(&self.packed, n, dim, &lut, inv_norm_sq, &mut top), _ => unreachable!(), } top.finalize_into(out_scores, out_indices); @@ -410,6 +655,15 @@ impl RankQuant { /// (`LUT[d][b] = q_unit[d] * bucket_centre(b)`). The scan unpacks /// `8 / bits` codes per byte and accumulates via LUT lookups; the /// compiler autovectorises the inner sum. + /// + /// Works at **any** valid dim for all supported widths including `b=8` + /// (the asymmetric path needs no equal-bucket precondition). For `b=8` + /// the score is a per-coordinate gather `Σ_d lut[d*256 + code[d]]` + /// against the `dim * 256` LUT: it dispatches to the AVX-512 + /// `vgatherdps` kernel (`scan_b8_asym` → `scan_b8_asym_avx512_gather`) + /// when `avx512f` + `avx512bw` are present and `dim % 16 == 0`, else the + /// portable scalar LUT reference (`scan_b8_to_topk`). Unlike [`Self::search`], + /// this never panics on an asymmetric-only instance. pub fn search_asymmetric(&self, queries: &[f32], k: usize) -> SearchResults { let nq = queries.len() / self.dim; assert_eq!(queries.len(), nq * self.dim); @@ -431,7 +685,7 @@ impl RankQuant { let dim = self.dim; let bits = self.bits; let n = self.n_vectors; - let norm = rankquant_norm(dim, bits); + let norm = asymmetric_norm(dim, bits); let inv_norm = 1.0_f32 / norm; let n_buckets = 1usize << bits; let bytes_per_vec = rankquant_bytes_per_vec(dim, bits); @@ -472,6 +726,18 @@ impl RankQuant { .for_each(|((q, out_scores), out_indices)| { let q_unit = l2_normalise(q); let mut top = TopK::new(k_eff); + + // b=8 is a per-coordinate gather (`Σ_d lut[d*256 + code[d]]`), + // not a centre-drop dot product — it routes to its own + // dispatch (AVX-512 vgatherdps → scalar LUT) and never uses + // the centre-drop offset (its LUT bakes the centre in). + if bits == 8 { + scan_b8_asym(&self.packed, n, dim, &q_unit, inv_norm, &mut top); + top.finalize_into(out_scores, out_indices); + let _ = bytes_per_vec; // shape clarity + return; + } + #[cfg(target_arch = "x86_64")] let centre_offset = { let q_sum: f32 = q_unit.iter().sum(); @@ -569,7 +835,23 @@ impl RankQuant { } /// Persist to a `.tvrq` file. Format: 14-byte header + packed bytes. + /// + /// # `b=8` + /// The `.tvrq` on-disk format and its loader currently support only + /// `bits ∈ {1, 2, 4}`. `b=8` is an in-memory evidence/refinement surface + /// in this phase; persisting it is a follow-up. To avoid writing a file + /// that [`Self::load`] would then reject (a silent broken round-trip), + /// this returns `io::Error` (kind `Unsupported`) for a `b=8` index rather + /// than emitting an unloadable file. pub fn write(&self, path: impl AsRef) -> std::io::Result<()> { + if self.bits == 8 { + return Err(std::io::Error::new( + std::io::ErrorKind::Unsupported, + "RankQuant b=8 persistence is not supported yet (the .tvrq loader \ + accepts bits ∈ {1, 2, 4}); b=8 is an in-memory evidence surface \ + in this phase", + )); + } crate::rank_io::write_rankquant(path, self.bits, self.dim, self.n_vectors, &self.packed) } @@ -629,10 +911,16 @@ impl RankQuant { ), )); } + // `load_rankquant` only admits bits ∈ {1,2,4} (b=8 is not persistable + // in this phase — see `write`), and those widths are always + // full-capability, so the loaded instance is SymmetricAndAsymmetric. + // `capability_for` keeps that derivation in one place. + let capability = Self::capability_for(dim, bits); Ok(Self { dim, bits, n_vectors, + capability, packed, }) } @@ -766,7 +1054,7 @@ impl RankQuant { // touching `scratch.top` is safe (the next non-empty row resets it). return; } - let norm = rankquant_norm(dim, bits); + let norm = asymmetric_norm(dim, bits); let inv_norm = 1.0_f32 / norm; #[cfg(target_arch = "x86_64")] let centre_offset = { @@ -795,76 +1083,91 @@ impl RankQuant { #[cfg_attr(not(target_arch = "x86_64"), allow(unused_variables))] let simd_tier = select_simd_tier(dim, bits); scratch.top.reset_with_tie_keys(out_k, candidates_row); - #[cfg(target_arch = "x86_64")] - unsafe { - match (simd_tier, bits) { - (SimdTier::Avx512, 2) => { - scratch.top.set_score_offset(centre_offset); - scan_b2_asym_avx512( - &scratch.sub_packed, - m, - dim, - &scratch.q_unit, - inv_norm, - &mut scratch.top, - ); - } - (SimdTier::Avx512, 4) => { - scratch.top.set_score_offset(centre_offset); - scan_b4_asym_avx512( - &scratch.sub_packed, - m, - dim, - &scratch.q_unit, - inv_norm, - &mut scratch.top, - ); - } - (SimdTier::Avx2, 2) => { - scratch.top.set_score_offset(centre_offset); - scan_b2_asym_avx2( - &scratch.sub_packed, - m, - dim, - &scratch.q_unit, - inv_norm, - &mut scratch.top, - ); - } - (SimdTier::Avx2, 4) => { - scratch.top.set_score_offset(centre_offset); - scan_b4_asym_avx2( + // b=8 routes to its own gather dispatch (AVX-512 vgatherdps → scalar + // LUT), with the centre baked into the LUT (no score-offset trick). + // The tie keys on `scratch.top` still map local scratch positions → + // global row IDs exactly as for b ∈ {1,2,4}. + if bits == 8 { + scan_b8_asym( + &scratch.sub_packed, + m, + dim, + &scratch.q_unit, + inv_norm, + &mut scratch.top, + ); + } else { + #[cfg(target_arch = "x86_64")] + unsafe { + match (simd_tier, bits) { + (SimdTier::Avx512, 2) => { + scratch.top.set_score_offset(centre_offset); + scan_b2_asym_avx512( + &scratch.sub_packed, + m, + dim, + &scratch.q_unit, + inv_norm, + &mut scratch.top, + ); + } + (SimdTier::Avx512, 4) => { + scratch.top.set_score_offset(centre_offset); + scan_b4_asym_avx512( + &scratch.sub_packed, + m, + dim, + &scratch.q_unit, + inv_norm, + &mut scratch.top, + ); + } + (SimdTier::Avx2, 2) => { + scratch.top.set_score_offset(centre_offset); + scan_b2_asym_avx2( + &scratch.sub_packed, + m, + dim, + &scratch.q_unit, + inv_norm, + &mut scratch.top, + ); + } + (SimdTier::Avx2, 4) => { + scratch.top.set_score_offset(centre_offset); + scan_b4_asym_avx2( + &scratch.sub_packed, + m, + dim, + &scratch.q_unit, + inv_norm, + &mut scratch.top, + ); + } + _ => scan_via_lut_scalar( &scratch.sub_packed, m, dim, + bits, + n_buckets, &scratch.q_unit, inv_norm, &mut scratch.top, - ); + ), } - _ => scan_via_lut_scalar( - &scratch.sub_packed, - m, - dim, - bits, - n_buckets, - &scratch.q_unit, - inv_norm, - &mut scratch.top, - ), } + #[cfg(not(target_arch = "x86_64"))] + scan_via_lut_scalar( + &scratch.sub_packed, + m, + dim, + bits, + n_buckets, + &scratch.q_unit, + inv_norm, + &mut scratch.top, + ); } - #[cfg(not(target_arch = "x86_64"))] - scan_via_lut_scalar( - &scratch.sub_packed, - m, - dim, - bits, - n_buckets, - &scratch.q_unit, - inv_norm, - &mut scratch.top, - ); // Finalize local positions into reused buffer, then map local → global. scratch.local_indices.clear(); @@ -1094,10 +1397,21 @@ fn validate_finite(values: &[f32], name: &'static str) -> Result<(), OrdvecError /// This does **not** use [`RankQuant`] storage and does not change the `.tvrq` /// packing contract. It rank-transforms `corpus` and `queries`, buckets each /// rank into `1 << bits` equal-width bins, mean-centres bucket ids, normalises -/// by the analytical norm for that `(dim, bits)`, and returns top-`k` results. +/// by the **empirical** norm for that `(dim, bits)` (the exact L2 norm of the +/// realised bucket-centre vector, summed over `0..dim`), and returns top-`k` +/// results. +/// +/// Because the norm is computed empirically rather than from the closed form, +/// this path is valid for **any** `dim` and **any** `bits ∈ 1..=8`, including +/// `bits = 8` at a `dim` not divisible by `256`. It therefore does *not* carry +/// the `dim % 256 == 0` restriction that applies to the analytical-norm +/// symmetric [`RankQuant::search`] (see [`RankQuant::new_asymmetric`]): that +/// restriction exists only because the closed-form `rankquant_norm` is exact +/// solely under uniform bucket occupancy, which this empirical path sidesteps. /// /// Intended for research/eval sweeps where non-byte-aligned widths such as -/// `bits = 3` need to be scored without inventing a persistent packed format. +/// `bits = 3`, or `b = 8` at arbitrary dims, need to be scored without +/// inventing a persistent packed format. pub fn rankquant_eval_search( corpus: &[f32], queries: &[f32], diff --git a/src/quant_kernels.rs b/src/quant_kernels.rs index 92f790c7..59742cd0 100644 --- a/src/quant_kernels.rs +++ b/src/quant_kernels.rs @@ -40,6 +40,7 @@ pub(crate) fn scan_via_lut_scalar( 1 => scan_b1_to_topk(packed, n, dim, &lut, scale, top), 2 => scan_b2_to_topk(packed, n, dim, &lut, scale, top), 4 => scan_b4_to_topk(packed, n, dim, &lut, scale, top), + 8 => scan_b8_to_topk(packed, n, dim, &lut, scale, top), _ => unreachable!("bits validated in new()"), } } @@ -127,6 +128,57 @@ pub(crate) fn scan_b4_to_topk( } } +/// Build the `dim * 256` per-coordinate asymmetric LUT for `b=8`: +/// `lut[d * 256 + code] = q_unit[d] * bucket_centre(code, 8)`. This is the +/// shared input to both the scalar [`scan_b8_to_topk`] reference and the +/// AVX-512 [`scan_b8_asym_avx512_gather`] kernel, so they score-parity. +/// +/// `bucket_centre(code, 8) = code - 127.5`, so each row is the query +/// coordinate scaled across the 256 centred bucket values. +pub(crate) fn build_b8_asym_lut(q_unit: &[f32]) -> Vec { + let dim = q_unit.len(); + let mut lut = vec![0.0f32; dim * 256]; + for d in 0..dim { + let qd = q_unit[d]; + let row = &mut lut[d * 256..(d + 1) * 256]; + for (code, slot) in row.iter_mut().enumerate() { + *slot = qd * bucket_centre(code as u8, 8); + } + } + lut +} + +/// 8-bit scan. 1 code per byte; n_buckets = 256. The degenerate +/// one-code-per-byte case: `doc[d]` is the code at coordinate `d`, so the +/// inner loop is a single LUT lookup per byte against the `dim * 256` +/// per-coord LUT. Used by both the symmetric path (`bucket_centre` LUT) +/// and the asymmetric scalar LUT path (`q_unit[d] * bucket_centre(b)`). +/// +/// This is also the **portable scalar reference** for the `b=8` asymmetric +/// gather: it sums in strict coordinate order, one lookup + add per byte, +/// so it is the bit-exact baseline the AVX-512 gather kernel is parity- +/// tested against (within the crate's 1e-4 cross-backend tolerance). +pub(crate) fn scan_b8_to_topk( + packed: &[u8], + n: usize, + dim: usize, + lut: &[f32], + scale: f32, + top: &mut TopK, +) { + let bytes_per_vec = dim; // 1 byte per coordinate + for di in 0..n { + let doc = &packed[di * bytes_per_vec..(di + 1) * bytes_per_vec]; + let mut acc = 0.0f32; + for (d, &code) in doc.iter().enumerate() { + // LUT row `d` has 256 entries (one per code value); the code is + // already the bucket index for b=8. + acc += lut[d * 256 + code as usize]; + } + top.maybe_insert(acc * scale, di); + } +} + // ------------------------------------------------------------------- // AVX2 + FMA kernels for the asymmetric path. // @@ -499,3 +551,409 @@ pub(crate) unsafe fn scan_b4_asym_avx512( } } } + +/// Single entry point for the `b=8` asymmetric scan. +/// +/// Builds the shared `dim * 256` per-coordinate LUT once +/// ([`build_b8_asym_lut`]), then dispatches to the AVX-512 gather kernel +/// ([`scan_b8_asym_avx512_gather`]) when `avx512f` + `avx512bw` are detected at +/// runtime and `dim % 16 == 0`, falling back to the portable scalar reference +/// ([`scan_b8_to_topk`]) on every other target / CPU / dim. Centralising +/// the dispatch here keeps the `unsafe` SIMD reach in one place and out of +/// `quant.rs`. +pub(crate) fn scan_b8_asym( + packed: &[u8], + n: usize, + dim: usize, + q_unit: &[f32], + scale: f32, + top: &mut TopK, +) { + let lut = build_b8_asym_lut(q_unit); + #[cfg(target_arch = "x86_64")] + { + if is_x86_feature_detected!("avx512f") + && is_x86_feature_detected!("avx512bw") + && dim.is_multiple_of(16) + { + // SAFETY: `avx512f`+`avx512bw` are confirmed by the runtime detection above + // and `dim % 16 == 0` satisfies the kernel's lane invariant; + // `packed.len() == n * dim` and `lut.len() == dim * 256` hold by + // construction (b=8 packs one byte/coord; the LUT is built just + // above). The explicit block is required by + // `#![deny(unsafe_op_in_unsafe_fn)]`. + unsafe { + scan_b8_asym_avx512_gather(packed, n, dim, &lut, scale, top); + } + return; + } + } + scan_b8_to_topk(packed, n, dim, &lut, scale, top); +} + +// ------------------------------------------------------------------- +// AVX-512 gather kernel for the b=8 asymmetric path. +// +// Unlike b ∈ {2, 4} — whose tiny per-byte arithmetic (shift/mask/cvt/FMA) +// beats any memory indirection — b=8 carries a large per-coordinate +// 256-entry float LUT (`lut[d * 256 + code]`), so the score is an honest +// gather: `Σ_d lut[d * 256 + doc_code[d]]`. The dominant cost is the +// gather, which `vgatherdps` (`_mm512_i32gather_ps`) issues 16-wide in a +// single instruction. +// +// Per 16-coordinate chunk: +// * load 16 doc bytes, zero-extend to i32 lanes (`_mm512_cvtepu8_epi32`); +// * add the per-position row-base vector `[d*256, (d+1)*256, …]` so lane +// `j` indexes `lut[(d+j) * 256 + code[d+j]]`; +// * `_mm512_i32gather_ps(idx, lut_ptr, 4)` gathers all 16 contributions; +// * accumulate (plain add — the LUT already encodes `q · centre`). +// Four independent accumulators break the add dependency chain, matching +// the b=2/b=4 AVX-512 kernels. Unlike those, b=8 needs no centre-drop +// trick: the asymmetric LUT bakes the per-coordinate query weight in, so +// there is no per-query constant offset to reapply at finalize. +// +// Caller must verify `is_x86_feature_detected!("avx512f") && ..("avx512bw")` +// once. `avx512bw` is gated alongside `avx512f` to match the rest of the +// crate's AVX-512 kernels (which require `avx512dq`) and to keep the byte +// widening (`_mm512_cvtepu8_epi32`) conservatively gated — the F-without-BW +// CPUs (KNL/KNM) are already excluded by the crate's `dq` requirement, so this +// adds no real exclusion. The LUT is the same `dim * 256` f32 layout the scalar +// `scan_b8_to_topk` consumes, so the two paths are score-parity (modulo f32 +// summation order, within the crate's 1e-4 cross-backend tolerance). +// ------------------------------------------------------------------- + +#[cfg(target_arch = "x86_64")] +#[target_feature(enable = "avx512f,avx512bw")] +pub(crate) unsafe fn scan_b8_asym_avx512_gather( + packed: &[u8], + n: usize, + dim: usize, + lut: &[f32], + scale: f32, + top: &mut TopK, +) { + use std::arch::x86_64::*; + + // SAFETY: a `pub(crate) unsafe fn` reachable only via `quant.rs`'s + // runtime-detected dispatch, which upholds the invariants the raw doc + // reads (`packed.as_ptr().add(di * dim + base)`), the LUT gather + // (`_mm512_i32gather_ps` off `lut.as_ptr()`), and the chunk loop depend + // on: + // * `packed.len() == n * dim` (b=8 stores one byte per coordinate), + // * `lut.len() == dim * 256` (one 256-entry row per coordinate), + // * `dim % 16 == 0` (asserted immediately below) so the 16-lane chunk + // loop tiles each doc exactly with no tail. + // Every gather index `(d + j) * 256 + code` is `< dim * 256` because + // `d + j < dim` and `code <= 255`, so each gathered f32 is in-bounds. + // `RankQuant::{new_asymmetric,add}` pack exactly `dim` bytes/doc and the + // dispatch builds a `dim * 256` LUT, so this holds on every path here. + // The explicit block is required by `#![deny(unsafe_op_in_unsafe_fn)]`. + unsafe { + // Hard backstop (see `scan_b2_asym_avx2`): mis-dispatch must fail + // loudly in release, not silently drop the trailing chunk. + assert_eq!(dim % 16, 0, "b=8 AVX-512 gather path needs dim % 16 == 0"); + debug_assert_eq!(lut.len(), dim * 256, "b=8 LUT must be dim * 256 entries"); + let bytes_per_vec = dim; // one byte per coordinate + let lut_ptr = lut.as_ptr(); + + // Per-position row bases for one 16-lane chunk: lane j contributes + // `j * 256`. The chunk's coordinate offset `c * 16 * 256` is folded + // into the doc-byte indices below. + let lane_row_base = _mm512_setr_epi32( + 0, 256, 512, 768, 1024, 1280, 1536, 1792, 2048, 2304, 2560, 2816, 3072, 3328, 3584, + 3840, + ); + let chunks_per_vec = bytes_per_vec / 16; + + for di in 0..n { + let doc = packed.as_ptr().add(di * bytes_per_vec); + let mut acc0 = _mm512_setzero_ps(); + let mut acc1 = _mm512_setzero_ps(); + let mut acc2 = _mm512_setzero_ps(); + let mut acc3 = _mm512_setzero_ps(); + + // Round chunks down to a multiple of 4 for the unrolled body; + // a `dim % 64 != 0` (but `% 16 == 0`) dim leaves a ≤3-chunk tail + // handled by the single-accumulator loop after. + let unrolled = chunks_per_vec & !3; + + let mut c = 0usize; + while c < unrolled { + macro_rules! step { + ($cc:expr, $acc:expr) => {{ + // Coordinate base for this chunk: `cc * 16 * 256`. + let chunk_base = _mm512_set1_epi32(($cc * 16 * 256) as i32); + // Load 16 doc bytes, zero-extend to 16 i32 lanes. + let bytes = _mm_loadu_si128(doc.add($cc * 16) as *const __m128i); + let codes = _mm512_cvtepu8_epi32(bytes); + // idx[j] = chunk_base + (j * 256) + code[j] + // = (cc*16 + j) * 256 + code[cc*16 + j] + let idx = + _mm512_add_epi32(_mm512_add_epi32(chunk_base, lane_row_base), codes); + // Gather 16 LUT contributions (scale = 4 bytes/f32). + let vals = _mm512_i32gather_ps::<4>(idx, lut_ptr); + $acc = _mm512_add_ps($acc, vals); + }}; + } + step!(c, acc0); + step!(c + 1, acc1); + step!(c + 2, acc2); + step!(c + 3, acc3); + c += 4; + } + + // Tail: remaining (< 4) chunks fold into acc0. + while c < chunks_per_vec { + let chunk_base = _mm512_set1_epi32((c * 16 * 256) as i32); + let bytes = _mm_loadu_si128(doc.add(c * 16) as *const __m128i); + let codes = _mm512_cvtepu8_epi32(bytes); + let idx = _mm512_add_epi32(_mm512_add_epi32(chunk_base, lane_row_base), codes); + let vals = _mm512_i32gather_ps::<4>(idx, lut_ptr); + acc0 = _mm512_add_ps(acc0, vals); + c += 1; + } + + let s01 = _mm512_add_ps(acc0, acc1); + let s23 = _mm512_add_ps(acc2, acc3); + let total = _mm512_add_ps(s01, s23); + let raw = _mm512_reduce_add_ps(total); + top.maybe_insert(raw * scale, di); + } + } +} + +#[cfg(all(test, target_arch = "x86_64"))] +mod b8_gather_tests { + use super::{build_b8_asym_lut, scan_b8_asym_avx512_gather, scan_b8_to_topk}; + use crate::util::TopK; + use rand::{RngExt, SeedableRng}; + use rand_chacha::ChaCha8Rng; + + /// Drain a `k`-slot `TopK` into a flat `(score, idx)` vec sorted by the + /// collector's own composite key, so the two kernels are compared on the + /// exact tuples a caller would receive. + fn drain(top: &TopK, k: usize) -> (Vec, Vec) { + let mut scores = vec![f32::NEG_INFINITY; k]; + let mut idxs = vec![-1i64; k]; + top.finalize_into(&mut scores, &mut idxs); + (scores, idxs) + } + + /// The AVX-512 `vgatherdps` b=8 kernel must match the scalar LUT + /// reference within the crate's 1e-4 cross-backend score tolerance, + /// across the headline embedding dims (all `% 16 == 0`, so the gather + /// path is actually exercised). 768/1536 are `% 64 == 0` (full + /// 4-way-unrolled body); to also cover the ≤3-chunk tail path we add + /// dim=400 (`400 % 16 == 0`, `400 % 64 == 16`). + #[test] + fn b8_gather_matches_scalar_reference() { + if !(is_x86_feature_detected!("avx512f") && is_x86_feature_detected!("avx512bw")) { + eprintln!("skipping b8 gather parity: no avx512f+avx512bw on this host"); + return; + } + for &dim in &[384usize, 400, 768, 1024, 1536] { + assert_eq!(dim % 16, 0, "test dims must be % 16 for the gather path"); + let n = 64; + let k = 10; + let mut rng = ChaCha8Rng::seed_from_u64(0x00B8_0000 + dim as u64); + + // Random doc codes (any byte 0..=255) and a random unit-ish query. + let packed: Vec = (0..n * dim).map(|_| rng.random::()).collect(); + let q: Vec = (0..dim).map(|_| rng.random_range(-1.0..1.0)).collect(); + let qn: f32 = q.iter().map(|x| x * x).sum::().sqrt(); + let q_unit: Vec = q.iter().map(|x| x / qn).collect(); + let scale = 1.0f32 / 137.0; // arbitrary inv_norm-like scale + + let lut = build_b8_asym_lut(&q_unit); + + let mut top_scalar = TopK::new(k); + scan_b8_to_topk(&packed, n, dim, &lut, scale, &mut top_scalar); + let (s_scalar, i_scalar) = drain(&top_scalar, k); + + let mut top_gather = TopK::new(k); + // SAFETY: avx512f+avx512bw confirmed above; dim % 16 == 0; packed has + // n*dim bytes and lut has dim*256 entries by construction. + unsafe { + scan_b8_asym_avx512_gather(&packed, n, dim, &lut, scale, &mut top_gather); + } + let (s_gather, i_gather) = drain(&top_gather, k); + + for slot in 0..k { + assert!( + (s_scalar[slot] - s_gather[slot]).abs() < 1e-4, + "dim={dim} slot={slot}: scalar {} vs gather {}", + s_scalar[slot], + s_gather[slot], + ); + } + // With well-separated random scores the top-k id sets agree too. + assert_eq!( + i_scalar, i_gather, + "dim={dim}: top-{k} id ordering diverged between scalar and gather" + ); + } + } + + /// The gather kernel's per-doc raw score equals the brute-force + /// `Σ_d lut[d*256 + code[d]]` (before the `scale` multiply), confirming + /// the index math `idx[j] = (c*16 + j) * 256 + code` is exact. + /// + /// This compares the *unscaled* sum, whose magnitude (~10² for centred + /// b=8 codes up to ±127.5 over `dim` terms) is far larger than the + /// `inv_norm`-scaled score a caller sees. The SIMD kernel's 4-way + /// parallel accumulation rounds in a different order from the strict + /// sequential brute-force, so the check is *relative* (~1e-5): the + /// production 1e-4 *absolute* tolerance applies to the small final + /// scaled scores, which the parity test above covers. + #[test] + fn b8_gather_raw_score_is_exact_gather_sum() { + if !(is_x86_feature_detected!("avx512f") && is_x86_feature_detected!("avx512bw")) { + return; + } + let dim = 256usize; + let n = 8; + let k = n; + let mut rng = ChaCha8Rng::seed_from_u64(0x00B8_FACE); + let packed: Vec = (0..n * dim).map(|_| rng.random::()).collect(); + let q_unit: Vec = (0..dim).map(|_| rng.random_range(-1.0..1.0)).collect(); + let lut = build_b8_asym_lut(&q_unit); + + let mut top = TopK::new(k); + // SAFETY: avx512f+avx512bw confirmed; dim % 16 == 0; shapes match. + unsafe { + scan_b8_asym_avx512_gather(&packed, n, dim, &lut, 1.0, &mut top); + } + let (scores, idxs) = drain(&top, k); + + // Brute-force reference, indexed by returned doc id. + let want: Vec = (0..n) + .map(|di| { + let doc = &packed[di * dim..(di + 1) * dim]; + doc.iter() + .enumerate() + .map(|(d, &code)| lut[d * 256 + code as usize]) + .sum::() + }) + .collect(); + for slot in 0..k { + let di = idxs[slot] as usize; + let rel = (scores[slot] - want[di]).abs() / want[di].abs().max(1.0); + assert!( + rel < 1e-4, + "doc {di}: gather {} vs brute {} (rel {rel})", + scores[slot], + want[di] + ); + } + } + + /// Honest, kernel-isolated micro-benchmark: b=8 scalar LUT vs b=8 + /// AVX-512 gather vs the b=4 AVX-512 asym kernel, on the same N×dim + /// corpus. `#[ignore]` so it does not run in the default gate — invoke + /// with: + /// + /// ```text + /// cargo test --release --lib b8_kernel_microbench -- --ignored --nocapture + /// ``` + /// + /// It times the inner scan only (LUT build + scan), so the scalar-vs-SIMD + /// decision is measured directly rather than inferred. Per-iteration + /// wall time is reported in ms and as ns/doc/dim so the cost is + /// comparable across widths. Numbers are wall-clock and vary run-to-run; + /// the parity tests above are the correctness gate. + #[test] + #[ignore = "perf micro-bench; run explicitly with --ignored --nocapture --release"] + fn b8_kernel_microbench() { + use crate::quant_kernels::{scan_b4_asym_avx512, scan_b8_asym_avx512_gather}; + use std::time::Instant; + + let have_avx512 = is_x86_feature_detected!("avx512f") + && is_x86_feature_detected!("avx512dq") + && is_x86_feature_detected!("avx512bw"); // b=4 path needs dq, b=8 gather needs bw + let dim = 1024usize; // % 64 == 0 → valid for both b=4 and b=8 SIMD + let n = 50_000usize; + let k = 10usize; + let iters = 20usize; + + let mut rng = ChaCha8Rng::seed_from_u64(0x00B8_4BE4); + let q: Vec = (0..dim).map(|_| rng.random_range(-1.0..1.0)).collect(); + let qn: f32 = q.iter().map(|x| x * x).sum::().sqrt(); + let q_unit: Vec = q.iter().map(|x| x / qn).collect(); + let scale = 1.0f32 / 137.0; + + // b=8 corpus: one byte per coord. + let packed8: Vec = (0..n * dim).map(|_| rng.random::()).collect(); + // b=4 corpus: two codes per byte → dim/2 bytes per doc. + let packed4: Vec = (0..n * dim / 2).map(|_| rng.random::()).collect(); + + let lut8 = build_b8_asym_lut(&q_unit); + + let bench = |label: &str, mut f: Box| { + f(); // warmup + let t0 = Instant::now(); + for _ in 0..iters { + f(); + } + let per = t0.elapsed().as_secs_f64() / iters as f64; + let ns_per_doc_dim = per * 1e9 / (n as f64 * dim as f64); + let gdocs = n as f64 / per / 1e9; + println!( + " {label:<26} {:>8.3} ms/scan {:>7.3} ns/doc/dim {:>7.3} Gdoc/s", + per * 1e3, + ns_per_doc_dim, + gdocs, + ); + }; + + println!( + "\nb=8 asymmetric kernel micro-bench (dim={dim}, n={n}, k={k}, iters={iters}, avx512={have_avx512})" + ); + + { + let packed8 = packed8.clone(); + let lut8 = lut8.clone(); + bench( + "b=8 scalar LUT", + Box::new(move || { + let mut top = TopK::new(k); + scan_b8_to_topk(&packed8, n, dim, &lut8, scale, &mut top); + std::hint::black_box(&top); + }), + ); + } + + if have_avx512 { + let packed8 = packed8.clone(); + let lut8 = lut8.clone(); + bench( + "b=8 AVX-512 gather", + Box::new(move || { + let mut top = TopK::new(k); + // SAFETY: avx512f+avx512bw confirmed; dim % 16 == 0; shapes match. + unsafe { + scan_b8_asym_avx512_gather(&packed8, n, dim, &lut8, scale, &mut top); + } + std::hint::black_box(&top); + }), + ); + + // b=4 AVX-512 asym for cross-width context (raw codes, no LUT; + // dim % 64 == 0 satisfies its lane invariant). + let packed4 = packed4.clone(); + let q_unit4 = q_unit.clone(); + bench( + "b=4 AVX-512 asym (context)", + Box::new(move || { + let mut top = TopK::new(k); + // SAFETY: avx512f+dq confirmed; dim % 64 == 0; shapes match. + unsafe { + scan_b4_asym_avx512(&packed4, n, dim, &q_unit4, scale, &mut top); + } + std::hint::black_box(&top); + }), + ); + } else { + println!(" (avx512 unavailable — SIMD rows skipped)"); + } + } +} diff --git a/src/rank.rs b/src/rank.rs index 27005bf2..c74bba8f 100644 --- a/src/rank.rs +++ b/src/rank.rs @@ -74,8 +74,12 @@ pub fn rank_transform_into(v: &[f32], out: &mut [u16]) { /// Bucket a single rank into one of `1 << bits` equal-width bins on /// `[0, d)`. Returns a value in `[0, 1 << bits)`. /// +/// For `bits == 8` the codomain is the full `u8` range `[0, 256)`; a +/// valid `rank < d` keeps the quotient `rank * 256 / d < 256`, so the +/// result still fits the returned `u8`. +/// /// # Panics -/// Panics if `bits > 7`, if `d == 0`, or if `rank >= d`. The `rank < d` +/// Panics if `bits > 8`, if `d == 0`, or if `rank >= d`. The `rank < d` /// guard fails loud in *every* build — like the sibling [`pack_buckets`] and /// [`bucket_centre`] checks — rather than silently clamping an out-of-range /// rank into the top bucket. Internal callers feed ranks straight from @@ -83,12 +87,14 @@ pub fn rank_transform_into(v: &[f32], out: &mut [u16]) { /// hot path. #[inline] pub fn rank_to_bucket(rank: u16, d: usize, bits: u8) -> u8 { - // `bits` is a `u8`, so a caller could pass e.g. 8 or 255. `1u32 << bits` + // `bits` is a `u8`, so a caller could pass e.g. 9 or 255. `1u32 << bits` // overflows for `bits >= 32` (in release that silently wraps and yields a // wrong bucket; in debug it panics inconsistently), and the result must - // also fit in the returned `u8`, so cap at 7. `d == 0` would divide by - // zero. Guard both up front so the failure is loud in every build. - assert!(bits <= 7, "bits too large"); + // also fit in the returned `u8`, so cap at 8 — the widest RankQuant width + // (b=8 yields one bucket per code value in `[0, 256)`, which still fits a + // `u8`). `d == 0` would divide by zero. Guard both up front so the failure + // is loud in every build. + assert!(bits <= 8, "bits too large"); assert!(d > 0, "d must be positive"); // A valid rank is a position in `[0, d)`. Reject `rank >= d` loudly instead // of silently clamping the quotient back into range: the rest of the public @@ -121,7 +127,7 @@ pub fn bucket_ranks(ranks: &[u16], bits: u8) -> Vec { // input — an empty `ranks` skips the per-entry `rank_to_bucket` check and // would otherwise silently return an empty vec. Mirrors the Python binding, // which checks `bits` before its empty short-circuit. - assert!(bits <= 7, "bits too large"); + assert!(bits <= 8, "bits too large"); let d = ranks.len(); ranks.iter().map(|&r| rank_to_bucket(r, d, bits)).collect() } @@ -130,19 +136,22 @@ pub fn bucket_ranks(ranks: &[u16], bits: u8) -> Vec { /// dense byte stream. /// /// Layout: the bucket with index 0 occupies the most-significant bits -/// of the first byte. Requires `bits ∈ {1, 2, 4}` and `d`'s length to -/// be a multiple of `8 / bits`. +/// of the first byte. Requires `bits ∈ {1, 2, 4, 8}` and `d`'s length to +/// be a multiple of `8 / bits`. For `bits == 8` the packing is the +/// degenerate one-code-per-byte case: each code is copied verbatim into +/// its own byte (no sub-byte shifting), so any `d` is valid. /// /// # Panics -/// Panics if `bits ∉ {1, 2, 4}`, if `buckets.len()` is not a multiple +/// Panics if `bits ∉ {1, 2, 4, 8}`, if `buckets.len()` is not a multiple /// of `8 / bits`, or if any code is `>= 1 << bits`. The last guard is /// the public-contract backstop: an out-of-range code would otherwise /// be silently truncated to `code & ((1 << bits) - 1)` and corrupt the /// packed stream. (Internal callers feed codes straight from /// [`rank_to_bucket`], which is always in range; this protects direct -/// callers of the primitive.) +/// callers of the primitive.) Note the `b=8` code range is the full +/// `u8`, so the range guard is vacuously satisfied for that width. pub fn pack_buckets(buckets: &[u8], bits: u8) -> Vec { - assert!(matches!(bits, 1 | 2 | 4), "bits must be 1, 2, or 4"); + assert!(matches!(bits, 1 | 2 | 4 | 8), "bits must be 1, 2, 4, or 8"); let codes_per_byte = (8 / bits) as usize; let d = buckets.len(); assert_eq!( @@ -150,7 +159,10 @@ pub fn pack_buckets(buckets: &[u8], bits: u8) -> Vec { 0, "d ({d}) must be a multiple of codes_per_byte ({codes_per_byte}) for bits = {bits}", ); - let mask = (1u8 << bits) - 1; + // `(1u8 << 8)` overflows a `u8`, so compute the mask in `u16` and saturate + // the `b=8` case to the full byte (`0xFF`). For `b ∈ {1,2,4}` this is the + // same value the old `(1u8 << bits) - 1` produced. + let mask = ((1u16 << bits) - 1) as u8; let n_bytes = d / codes_per_byte; let mut out = vec![0u8; n_bytes]; let bits_u = bits as usize; @@ -160,6 +172,8 @@ pub fn pack_buckets(buckets: &[u8], bits: u8) -> Vec { // fail-loud guarantee without a second O(d) pass over `buckets`; the // branch is loop-invariant-predictable for the always-valid internal // callers. Asserting `b <= mask` makes the trailing `& mask` redundant. + // At `b=8`, `codes_per_byte == 1`, so `shift == 0` and each byte holds one + // code verbatim. for (i, &b) in buckets.iter().enumerate() { assert!( b <= mask, @@ -178,10 +192,12 @@ pub fn pack_buckets(buckets: &[u8], bits: u8) -> Vec { /// /// Inverse of [`pack_buckets`]. pub fn unpack_buckets(packed: &[u8], d: usize, bits: u8) -> Vec { - assert!(matches!(bits, 1 | 2 | 4), "bits must be 1, 2, or 4"); + assert!(matches!(bits, 1 | 2 | 4 | 8), "bits must be 1, 2, 4, or 8"); let codes_per_byte = (8 / bits) as usize; assert_eq!(packed.len() * codes_per_byte, d); - let mask = (1u8 << bits) - 1; + // `(1u8 << 8)` overflows a `u8`; compute in `u16` and narrow so `b=8` + // yields the full-byte mask `0xFF` (each byte already holds one code). + let mask = ((1u16 << bits) - 1) as u8; let bits_u = bits as usize; let mut out = vec![0u8; d]; #[allow(clippy::needless_range_loop)] // indexed access is clearer / matches the kernel layout @@ -195,13 +211,16 @@ pub fn unpack_buckets(packed: &[u8], d: usize, bits: u8) -> Vec { } /// Number of bytes per packed RankQuant document at dimension `d` and -/// bit width `bits ∈ {1, 2, 4}`. +/// bit width `bits ∈ {1, 2, 4, 8}`. +/// +/// At `bits == 8` each coordinate occupies its own byte (`codes_per_byte +/// == 1`), so the storage is exactly `d` bytes per document. #[inline] pub fn rankquant_bytes_per_vec(d: usize, bits: u8) -> usize { // Guard the same domain as the sibling pack/unpack helpers: `bits == 0` - // would divide by zero computing `codes_per_byte`, and only 1/2/4 give an + // would divide by zero computing `codes_per_byte`, and only 1/2/4/8 give an // integral codes-per-byte. - assert!(matches!(bits, 1 | 2 | 4), "bits must be 1,2,4"); + assert!(matches!(bits, 1 | 2 | 4 | 8), "bits must be 1,2,4,8"); let codes_per_byte = (8 / bits) as usize; assert_eq!( d % codes_per_byte, @@ -219,17 +238,20 @@ pub fn rankquant_bytes_per_vec(d: usize, bits: u8) -> usize { /// pattern `..., -1.5, -0.5, +0.5, +1.5, ...` for `B = 2`. /// /// # Panics -/// Panics if `bits > 7` — bucket codes are `u8`, so the bit width is +/// Panics if `bits > 8` — bucket codes are `u8`, so the bit width is /// capped at the representable bucketing range, matching -/// [`rank_to_bucket`] (the RankQuant family uses `bits ∈ {1, 2, 4}`). +/// [`rank_to_bucket`] (the RankQuant family uses `bits ∈ {1, 2, 4, 8}`). /// Also panics if `bucket >= 1 << bits`; this guard fails loud in *every* /// build — like the sibling [`pack_buckets`] check — so a direct caller /// cannot silently receive a centre outside the symmetric range. The /// internal LUT builders only ever pass `bucket ∈ [0, 1 << bits)` (the /// loop bound *is* `1 << bits`), so the assert never trips on the hot path. +/// For `bits == 8` the centres span `..., -0.5, +0.5, ...` around zero +/// with `bucket - 127.5`; the range guard is vacuous (every `u8` is a +/// valid code). #[inline] pub fn bucket_centre(bucket: u8, bits: u8) -> f32 { - assert!(bits <= 7, "bits too large"); + assert!(bits <= 8, "bits too large"); assert!( (bucket as u32) < (1u32 << bits), "bucket {bucket} out of range for bits = {bits}", @@ -270,14 +292,21 @@ pub fn rank_norm(d: usize) -> f32 { /// The mean-centred bucket index has variance `(2^(2B) - 1) / 12`, so /// the per-vector L2 norm is `sqrt(d * (2^(2B) - 1) / 12)`. /// +/// This is the **symmetric analytical** norm: it is exact only when +/// every bucket receives exactly `d / 2^B` coordinates, i.e. when +/// `d % 2^B == 0`. For `bits == 8` that precondition is `d % 256 == 0`; +/// the [`crate::RankQuant`] symmetric path enforces it before calling +/// this (see `RankQuant::new` / `symmetric_supported`). This primitive +/// itself only computes the closed form and does not re-check occupancy. +/// /// # Panics -/// Panics if `bits ∉ {1, 2, 4}`, mirroring the [`crate::RankQuant`] +/// Panics if `bits ∉ {1, 2, 4, 8}`, mirroring the [`crate::RankQuant`] /// bit-width domain (and [`rankquant_bytes_per_vec`]). Without it a /// nonsensical `bits` would return a norm for a scheme that does not /// exist (or overflow `1 << bits`). #[inline] pub fn rankquant_norm(d: usize, bits: u8) -> f32 { - assert!(matches!(bits, 1 | 2 | 4), "bits must be 1,2,4"); + assert!(matches!(bits, 1 | 2 | 4 | 8), "bits must be 1,2,4,8"); let n = (1u32 << bits) as f64; let var = (n * n - 1.0) / 12.0; ((d as f64) * var).sqrt() as f32 @@ -629,10 +658,20 @@ mod tests { #[test] #[should_panic(expected = "bits too large")] - fn bucket_ranks_rejects_bits_above_7_even_when_empty() { + fn bucket_ranks_rejects_bits_above_8_even_when_empty() { // `bits` is validated up front, so an invalid width fails loud even on // empty input — which never reaches the per-entry rank_to_bucket guard. - let _ = bucket_ranks(&[], 8); + // The valid bucketing range now extends to b=8 (the widest RankQuant + // width), so b=9 is the first rejected width. + let _ = bucket_ranks(&[], 9); + } + + #[test] + fn bucket_ranks_accepts_bits_8() { + // b=8 is a supported width: a 4-element rank vector buckets without + // panicking and yields codes in [0, 256). + let codes = bucket_ranks(&[0, 1, 2, 3], 8); + assert_eq!(codes.len(), 4); } #[test] @@ -662,6 +701,72 @@ mod tests { assert_eq!(unpacked, buckets); } + #[test] + fn pack_unpack_round_trip_bits8() { + // b=8 is the degenerate one-code-per-byte packing: each byte holds a + // full code in `[0, 256)`, so packed length == code count and the + // bytes are the codes verbatim. Cover the full code range including + // 0 and 255 (the extremes the `b ∈ {1,2,4}` mask would have clipped). + let buckets: Vec = (0..256).map(|i| i as u8).collect(); + let packed = pack_buckets(&buckets, 8); + assert_eq!(packed.len(), 256, "b=8 stores one byte per code"); + assert_eq!(packed, buckets, "b=8 packing is the identity byte stream"); + let unpacked = unpack_buckets(&packed, 256, 8); + assert_eq!(unpacked, buckets); + } + + #[test] + fn pack_unpack_round_trip_bits8_arbitrary_len() { + // Any `d` is a valid b=8 length (codes_per_byte == 1); 384 is not a + // multiple of 256 yet still round-trips — code generation never needs + // the equal-bucket precondition that only the symmetric norm requires. + let buckets: Vec = (0..384u16).map(|i| (i % 256) as u8).collect(); + let packed = pack_buckets(&buckets, 8); + assert_eq!(packed.len(), 384); + let unpacked = unpack_buckets(&packed, 384, 8); + assert_eq!(unpacked, buckets); + } + + #[test] + fn rank_to_bucket_b8_spans_full_byte_range() { + // rank in [0, d) with bits=8 must land in [0, 256). Check the extremes + // and that the top rank maps to the top bucket for d == 256. + let d = 256usize; + assert_eq!(rank_to_bucket(0, d, 8), 0); + assert_eq!(rank_to_bucket(255, d, 8), 255); + // A coarser d still keeps the quotient in range. + assert!(rank_to_bucket(383, 384, 8) < 255 || rank_to_bucket(383, 384, 8) == 255); + for rank in 0..d as u16 { + let _ = rank_to_bucket(rank, d, 8); // never panics, always < 256 + } + } + + #[test] + fn bucket_centre_b8_is_symmetric_around_zero() { + // For b=8 the 256 centres span -127.5 ..= +127.5 and sum to 0. + assert_eq!(bucket_centre(0, 8), -127.5); + assert_eq!(bucket_centre(255, 8), 127.5); + let sum: f32 = (0..256u16).map(|b| bucket_centre(b as u8, 8)).sum(); + assert!(sum.abs() < 1e-3, "b=8 centres should sum to ~0, got {sum}"); + } + + #[test] + fn rankquant_norm_b8_matches_direct_computation() { + // d % 256 == 0 so every bucket gets exactly d/256 entries and the + // analytical norm is exact. + let d = 512usize; + let bits = 8u8; + let analytical = rankquant_norm(d, bits); + let ranks: Vec = (0..d as u16).collect(); + let buckets = bucket_ranks(&ranks, bits); + let centred: Vec = buckets.iter().map(|&b| bucket_centre(b, bits)).collect(); + let direct: f32 = centred.iter().map(|x| x * x).sum::().sqrt(); + assert!( + (analytical - direct).abs() / direct < 1e-5, + "analytical {analytical}, direct {direct}" + ); + } + #[test] fn bucket_centres_are_symmetric_around_zero() { // For B = 2: bucket values are {-1.5, -0.5, +0.5, +1.5}. @@ -718,10 +823,18 @@ mod tests { #[test] #[should_panic(expected = "bits too large")] - fn bucket_centre_rejects_bits_above_7() { - // bits >= 32 overflows `1 << bits`; the guard caps at 7 (the u8 - // bucket domain), matching `rank_to_bucket`. - let _ = bucket_centre(0, 8); + fn bucket_centre_rejects_bits_above_8() { + // bits >= 32 overflows `1 << bits`; the guard caps at 8 (the widest + // RankQuant width, whose codes still fit a u8), matching + // `rank_to_bucket`. b=9 is the first rejected width. + let _ = bucket_centre(0, 9); + } + + #[test] + fn bucket_centre_accepts_bits_8() { + // b=8 centres are valid: code 0 → -127.5, code 255 → +127.5. + assert_eq!(bucket_centre(0, 8), -127.5); + assert_eq!(bucket_centre(255, 8), 127.5); } #[test] diff --git a/tests/index/main.rs b/tests/index/main.rs index 3a1177de..63c23a1c 100644 --- a/tests/index/main.rs +++ b/tests/index/main.rs @@ -33,6 +33,7 @@ mod rank; #[cfg(feature = "experimental")] mod multi_bucket; mod quant; +mod quant_b8; mod two_stage; pub const D: usize = 128; diff --git a/tests/index/quant_b8.rs b/tests/index/quant_b8.rs new file mode 100644 index 00000000..f3d25847 --- /dev/null +++ b/tests/index/quant_b8.rs @@ -0,0 +1,552 @@ +//! Capability-gated `b=8` RankQuant integration tests (#221). +//! +//! `b=8` is a stable/core evidence-refinement width, not experimental: +//! +//! - code generation, pair-evidence, and asymmetric (float-query) scoring +//! work at **any** dim; +//! - symmetric scoring (and the symmetric analytical norm) require +//! `dim % 256 == 0` (equal bucket occupancy), so a non-`256`-aligned +//! `b=8` index is `AsymmetricOnly` and its `search` panics with an exact, +//! directing message. +//! +//! These tests pin the maintainer's capability matrix plus a brute-force +//! parity check of the scalar `b=8` asymmetric path against a naive +//! reference. + +use ordvec::rank::{bucket_centre, bucket_ranks, rank_transform, rankquant_norm}; +use ordvec::{RankQuant, RankQuantCapability}; +use rand::{RngExt, SeedableRng}; +use rand_chacha::ChaCha8Rng; + +/// Naive reference for `b=8` asymmetric scoring of one float query against +/// one float doc: L2-normalise the query, rank-transform + bucket the doc to +/// `b=8` codes, score `Σ_d q_unit[d] * bucket_centre(code[d]) / norm`. This +/// mirrors `ref_rankquant_asymmetric` in the shared helpers but is duplicated +/// here so the b=8 module is self-contained. +fn ref_b8_asymmetric(q: &[f32], doc: &[f32]) -> f32 { + let d = q.len(); + let q_norm: f32 = q.iter().map(|x| x * x).sum::().sqrt(); + let q_unit: Vec = q.iter().map(|x| x / q_norm).collect(); + let r = rank_transform(doc); + let codes = bucket_ranks(&r, 8); + // Exact L2 norm of this doc's centred bucket vector. For b=8 the bucket + // occupancy is uniform only when `dim % 256 == 0`; at other dims (e.g. 384) + // the closed-form `rankquant_norm` mis-scales the absolute score, so the + // reference — like production's `asymmetric_norm` — sums the realised + // squared centres (f64-accumulated, matching `rankquant_eval_norm`). The + // ranks are a permutation of `0..d` for every doc, so this equals the + // closed form exactly at 256-aligned dims. + let norm = { + let acc: f64 = codes + .iter() + .map(|&c| { + let cc = bucket_centre(c, 8) as f64; + cc * cc + }) + .sum(); + acc.sqrt() as f32 + }; + let mut acc = 0.0f32; + for i in 0..d { + acc += q_unit[i] * bucket_centre(codes[i], 8); + } + acc / norm +} + +fn random_corpus(seed: u64, n: usize, dim: usize) -> Vec { + let mut rng = ChaCha8Rng::seed_from_u64(seed); + (0..n * dim).map(|_| rng.random_range(-1.0..1.0)).collect() +} + +// --------------------------------------------------------------------- +// Capability reporting. +// --------------------------------------------------------------------- + +#[test] +fn b8_new_asymmetric_384_is_asymmetric_only() { + let idx = RankQuant::new_asymmetric(384, 8); + assert_eq!(idx.capability(), RankQuantCapability::AsymmetricOnly); + assert!(!idx.symmetric_supported()); + assert_eq!(idx.bits(), 8); + assert_eq!(idx.dim(), 384); + // b=8 stores one byte per coordinate. + assert_eq!(idx.bytes_per_vec(), 384); +} + +#[test] +fn b8_new_1024_is_symmetric_and_asymmetric() { + let idx = RankQuant::new(1024, 8); + assert_eq!( + idx.capability(), + RankQuantCapability::SymmetricAndAsymmetric + ); + assert!(idx.symmetric_supported()); + assert_eq!(idx.bits(), 8); +} + +#[test] +fn b8_new_asymmetric_256_aligned_upgrades_to_full() { + // new_asymmetric on a 256-aligned dim should NOT withhold symmetric + // scoring — there is no reason to, the analytical norm is exact. + let idx = RankQuant::new_asymmetric(768, 8); + assert_eq!( + idx.capability(), + RankQuantCapability::SymmetricAndAsymmetric + ); + assert!(idx.symmetric_supported()); +} + +#[test] +fn b124_constructors_are_always_full_capability() { + for &(dim, bits) in &[(384usize, 4u8), (384, 2), (256, 1), (1024, 4)] { + let a = RankQuant::new(dim, bits); + assert_eq!(a.capability(), RankQuantCapability::SymmetricAndAsymmetric); + assert!(a.symmetric_supported()); + // new_asymmetric for b ∈ {1,2,4} is never less capable than new. + let b = RankQuant::new_asymmetric(dim, bits); + assert_eq!(b.capability(), RankQuantCapability::SymmetricAndAsymmetric); + assert!(b.symmetric_supported()); + } +} + +// --------------------------------------------------------------------- +// new() fail-loud for non-256-aligned b=8. +// --------------------------------------------------------------------- + +#[test] +fn b8_new_panics_for_non_256_aligned_dim_directing_to_new_asymmetric() { + let res = std::panic::catch_unwind(|| RankQuant::new(384, 8)); + assert!(res.is_err(), "new(384, 8) must panic (384 % 256 != 0)"); + let payload = res.err().expect("panic payload present"); + let msg = *payload + .downcast::() + .expect("panic payload should be a String"); + assert!( + msg.contains("dim % 256 == 0"), + "panic should explain the 256-alignment requirement: {msg}" + ); + assert!( + msg.contains("new_asymmetric"), + "panic should direct to new_asymmetric: {msg}" + ); +} + +// --------------------------------------------------------------------- +// dim=384 b=8: code-gen passes, asymmetric passes, symmetric REJECTS. +// --------------------------------------------------------------------- + +#[test] +fn b8_384_code_gen_and_asymmetric_work() { + let dim = 384; + let n = 50; + let corpus = random_corpus(8384, n, dim); + let mut idx = RankQuant::new_asymmetric(dim, 8); + // add() runs the rank → bucket → pack pipeline (the code-gen path). + idx.add(&corpus); + assert_eq!(idx.len(), n); + assert_eq!(idx.byte_size(), n * dim); // one byte per coord per doc + + // Asymmetric scoring works at this non-256-aligned dim. + let query = random_corpus(8385, 1, dim); + let res = idx.search_asymmetric(&query, 10); + assert_eq!(res.nq, 1); + assert_eq!(res.k, 10); + for slot in 0..10 { + assert!(res.scores_for_query(0)[slot].is_finite()); + let id = res.indices_for_query(0)[slot]; + assert!(id >= 0 && (id as usize) < n); + } +} + +#[test] +fn b8_384_symmetric_search_rejects_with_exact_message() { + let dim = 384; + let mut idx = RankQuant::new_asymmetric(dim, 8); + idx.add(&random_corpus(8386, 8, dim)); + let query = random_corpus(8387, 1, dim); + + let res = std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| { + let _ = idx.search(&query, 5); + })); + assert!( + res.is_err(), + "symmetric search on AsymmetricOnly must panic" + ); + let msg = *res + .unwrap_err() + .downcast::() + .expect("panic payload should be a String"); + // The EXACT wording shape from the spec. + let expected = format!( + "RankQuant b=8 symmetric scoring requires dim % 256 == 0; dim={dim} supports asymmetric/evidence APIs only." + ); + assert_eq!(msg, expected, "symmetric-gating message must match exactly"); +} + +// --------------------------------------------------------------------- +// dim=768/1024/1536 b=8: full path incl. symmetric passes. +// --------------------------------------------------------------------- + +#[test] +fn b8_aligned_dims_full_path_including_symmetric() { + for &dim in &[768usize, 1024, 1536] { + let n = 40; + let corpus = random_corpus(9000 + dim as u64, n, dim); + // Both constructors should yield a full-capability instance here. + let mut idx = RankQuant::new(dim, 8); + assert!( + idx.symmetric_supported(), + "dim={dim} should be symmetric-capable" + ); + idx.add(&corpus); + + let queries = random_corpus(9500 + dim as u64, 3, dim); + + // Symmetric path runs without panicking and returns well-formed, + // descending, in-range results. + let sym = idx.search(&queries, 10); + assert_eq!(sym.nq, 3); + assert_eq!(sym.k, 10); + for qi in 0..3 { + let scores = sym.scores_for_query(qi); + let ids = sym.indices_for_query(qi); + for slot in 0..10 { + assert!(scores[slot].is_finite(), "dim={dim} non-finite sym score"); + assert!(ids[slot] >= 0 && (ids[slot] as usize) < n); + } + for slot in 1..10 { + assert!( + scores[slot].total_cmp(&scores[slot - 1]).is_le(), + "dim={dim} symmetric results not sorted descending" + ); + } + } + + // Asymmetric path runs too. + let asym = idx.search_asymmetric(&queries, 10); + assert_eq!(asym.nq, 3); + assert_eq!(asym.k, 10); + } +} + +// --------------------------------------------------------------------- +// dim=384 b=4 UNCHANGED (sanity that the b=8 work didn't disturb b=4). +// --------------------------------------------------------------------- + +#[test] +fn b4_384_unchanged_full_capability_and_search() { + let dim = 384; + let n = 40; + let corpus = random_corpus(4384, n, dim); + let mut idx = RankQuant::new(dim, 4); + assert_eq!( + idx.capability(), + RankQuantCapability::SymmetricAndAsymmetric + ); + assert!(idx.symmetric_supported()); + idx.add(&corpus); + let queries = random_corpus(4385, 3, dim); + let sym = idx.search(&queries, 10); + assert_eq!(sym.k, 10); + let asym = idx.search_asymmetric(&queries, 10); + assert_eq!(asym.k, 10); +} + +// --------------------------------------------------------------------- +// Brute-force parity: b=8 asymmetric scores match a naive reference. +// --------------------------------------------------------------------- + +#[test] +fn b8_asymmetric_matches_naive_reference_any_dim() { + // Cover both an asymmetric-only (384) and a full-capability (768) dim; + // the asymmetric scalar path is identical for both. + for &dim in &[384usize, 768] { + let n = 60; + let corpus = random_corpus(7000 + dim as u64, n, dim); + let mut idx = RankQuant::new_asymmetric(dim, 8); + idx.add(&corpus); + + let mut rng = ChaCha8Rng::seed_from_u64(7777 + dim as u64); + let query: Vec = (0..dim).map(|_| rng.random_range(-1.0..1.0)).collect(); + let res = idx.search_asymmetric(&query, 10); + + let ref_scores: Vec = (0..n) + .map(|di| ref_b8_asymmetric(&query, &corpus[di * dim..(di + 1) * dim])) + .collect(); + + // Every returned score must agree with the reference at its doc id. + for slot in 0..10 { + let di = res.indices_for_query(0)[slot] as usize; + let got = res.scores_for_query(0)[slot]; + let want = ref_scores[di]; + assert!( + (got - want).abs() < 1e-4, + "dim={dim} slot {slot} doc {di}: {got} vs ref {want}" + ); + } + + // And the returned top-10 set must equal the reference top-10 set. + let mut ref_sorted: Vec<(usize, f32)> = ref_scores + .iter() + .enumerate() + .map(|(i, &s)| (i, s)) + .collect(); + ref_sorted.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap()); + let top_ref: std::collections::HashSet = + ref_sorted[..10].iter().map(|x| x.0).collect(); + let top_got: std::collections::HashSet = res + .indices_for_query(0) + .iter() + .map(|&i| i as usize) + .collect(); + assert_eq!(top_got, top_ref, "dim={dim} b=8 top-10 set mismatch"); + } +} + +// --------------------------------------------------------------------- +// Optimized (AVX-512 gather) b=8 asymmetric path is parity-correct vs the +// naive reference across the headline embedding dims. +// +// On an AVX-512 host `search_asymmetric` dispatches the b=8 score to the +// `vgatherdps` kernel; on every other host it takes the scalar LUT path. +// Either way the returned top-k scores must agree with the naive per-doc +// reference within the crate's 1e-4 cross-backend score tolerance, and the +// returned top-k *set* must equal the reference top-k set. This is the +// end-to-end parity gate for the optimized kernel at dims 384/768/1024/1536. +// --------------------------------------------------------------------- + +#[test] +fn b8_asymmetric_optimized_path_parity_headline_dims() { + for &dim in &[384usize, 768, 1024, 1536] { + let n = 200; + let corpus = random_corpus(6000 + dim as u64, n, dim); + let mut idx = RankQuant::new_asymmetric(dim, 8); + idx.add(&corpus); + + let mut rng = ChaCha8Rng::seed_from_u64(6666 + dim as u64); + let query: Vec = (0..dim).map(|_| rng.random_range(-1.0..1.0)).collect(); + + let k = 25; + let res = idx.search_asymmetric(&query, k); + + // Naive scalar reference score per doc. + let ref_scores: Vec = (0..n) + .map(|di| ref_b8_asymmetric(&query, &corpus[di * dim..(di + 1) * dim])) + .collect(); + + // (a) every returned score agrees with the reference at its doc id. + for slot in 0..k { + let di = res.indices_for_query(0)[slot] as usize; + let got = res.scores_for_query(0)[slot]; + let want = ref_scores[di]; + assert!( + (got - want).abs() < 1e-4, + "dim={dim} slot {slot} doc {di}: optimized {got} vs ref {want}" + ); + } + + // (b) the returned top-k *set* equals the reference top-k set. + let mut ref_sorted: Vec<(usize, f32)> = ref_scores + .iter() + .enumerate() + .map(|(i, &s)| (i, s)) + .collect(); + ref_sorted.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap()); + let top_ref: std::collections::HashSet = + ref_sorted[..k].iter().map(|x| x.0).collect(); + let top_got: std::collections::HashSet = res + .indices_for_query(0) + .iter() + .map(|&i| i as usize) + .collect(); + assert_eq!( + top_got, top_ref, + "dim={dim} optimized b=8 top-{k} set mismatch vs reference" + ); + } +} + +// The optimized b=8 path must also be parity-correct through the subset +// rerank entry point (`search_asymmetric_subset`), which gathers candidate +// bytes into a scratch buffer and runs the same gather kernel. +#[test] +fn b8_asymmetric_subset_optimized_path_parity() { + let dim = 768; + let n = 300; + let corpus = random_corpus(6321, n, dim); + let mut idx = RankQuant::new_asymmetric(dim, 8); + idx.add(&corpus); + + let mut rng = ChaCha8Rng::seed_from_u64(6322); + let query: Vec = (0..dim).map(|_| rng.random_range(-1.0..1.0)).collect(); + + // An arbitrary, intentionally-unsorted candidate subset. + let candidates: Vec = (0..n as u32).rev().step_by(3).collect(); + let k = 10; + let (scores, indices) = idx.search_asymmetric_subset(&query, &candidates, k); + + for slot in 0..k { + let di = indices[slot] as usize; + let want = ref_b8_asymmetric(&query, &corpus[di * dim..(di + 1) * dim]); + assert!( + (scores[slot] - want).abs() < 1e-4, + "subset slot {slot} doc {di}: optimized {} vs ref {want}", + scores[slot] + ); + } +} + +// The b=8 routing also runs through the *batched* two-stage rerank entry point +// (`search_asymmetric_subset_batched_serial`), which packs each query's +// candidate row into a reused `SubsetScratch` and scans it with the same b=8 +// gather kernel. Cover both a non-256-aligned dim (384, exercising the +// empirical asymmetric norm) and an aligned dim (768), with two queries that +// have distinct candidate rows (exercising the CSR offsets and scratch reuse +// across rows). Every returned score must match the per-doc naive reference. +#[test] +fn b8_asymmetric_subset_batched_serial_path_parity() { + for &dim in &[384usize, 768] { + let n = 256; + let corpus = random_corpus(8100 + dim as u64, n, dim); + let mut idx = RankQuant::new_asymmetric(dim, 8); + idx.add(&corpus); + + let mut rng = ChaCha8Rng::seed_from_u64(8200 + dim as u64); + let q0: Vec = (0..dim).map(|_| rng.random_range(-1.0..1.0)).collect(); + let q1: Vec = (0..dim).map(|_| rng.random_range(-1.0..1.0)).collect(); + let mut queries = q0.clone(); + queries.extend_from_slice(&q1); + + // Two distinct, intentionally-unsorted candidate rows in CSR layout. + let cand0: Vec = (0..n as u32).rev().step_by(3).collect(); + let cand1: Vec = (0..n as u32).step_by(5).collect(); + let mut candidates = cand0.clone(); + candidates.extend_from_slice(&cand1); + let candidate_offsets = [0usize, cand0.len(), cand0.len() + cand1.len()]; + + let k = 10; + let res = idx.search_asymmetric_subset_batched_serial( + &queries, + &candidate_offsets, + &candidates, + k, + ); + + for (qi, q) in [&q0, &q1].into_iter().enumerate() { + let got_scores = res.scores_for_query(qi); + let got_indices = res.indices_for_query(qi); + for slot in 0..k { + let di = got_indices[slot]; + if di < 0 { + continue; // fewer candidates than k in this row + } + let di = di as usize; + let want = ref_b8_asymmetric(q, &corpus[di * dim..(di + 1) * dim]); + assert!( + (got_scores[slot] - want).abs() < 1e-4, + "dim={dim} q{qi} slot {slot} doc {di}: batched {} vs ref {want}", + got_scores[slot] + ); + } + } + } +} + +// --------------------------------------------------------------------- +// validate_params: b=8 is code-valid at any dim; b ∈ {1,2,4} unchanged. +// --------------------------------------------------------------------- + +#[test] +fn validate_params_b8_any_dim_but_b124_still_require_alignment() { + // b=8 accepts any dim >= 2 (no dim % 256 requirement). + assert!(RankQuant::validate_params(384, 8).is_ok()); + assert!(RankQuant::validate_params(2, 8).is_ok()); + assert!(RankQuant::validate_params(1000, 8).is_ok()); + assert!( + RankQuant::validate_params(1, 8).is_err(), + "dim < 2 rejected" + ); + + // b ∈ {1,2,4} keep their 2^bits divisibility requirement. + assert!(RankQuant::validate_params(6, 2).is_err(), "6 % 4 != 0"); + assert!(RankQuant::validate_params(8, 2).is_ok()); + assert!(RankQuant::validate_params(384, 4).is_ok()); + // b=3 is still not a packable width. + assert!(RankQuant::validate_params(384, 3).is_err()); +} + +// --------------------------------------------------------------------- +// Symmetric b=8 (256-aligned) matches a naive symmetric reference. +// --------------------------------------------------------------------- + +#[test] +fn b8_symmetric_matches_naive_reference_aligned_dim() { + let dim = 512; // 256-aligned → exact analytical norm + let n = 40; + let corpus = random_corpus(5512, n, dim); + let mut idx = RankQuant::new(dim, 8); + idx.add(&corpus); + + let mut rng = ChaCha8Rng::seed_from_u64(5513); + let query: Vec = (0..dim).map(|_| rng.random_range(-1.0..1.0)).collect(); + let res = idx.search(&query, 10); + + // Naive symmetric reference: bucket query + doc to b=8, dot the centred + // bucket vectors, divide by norm^2. + let norm = rankquant_norm(dim, 8); + let inv_norm_sq = 1.0f32 / (norm * norm); + let q_codes = bucket_ranks(&rank_transform(&query), 8); + let ref_scores: Vec = (0..n) + .map(|di| { + let doc = &corpus[di * dim..(di + 1) * dim]; + let d_codes = bucket_ranks(&rank_transform(doc), 8); + let acc: f32 = q_codes + .iter() + .zip(&d_codes) + .map(|(&qc, &dc)| bucket_centre(qc, 8) * bucket_centre(dc, 8)) + .sum(); + acc * inv_norm_sq + }) + .collect(); + + for slot in 0..10 { + let di = res.indices_for_query(0)[slot] as usize; + let got = res.scores_for_query(0)[slot]; + assert!( + (got - ref_scores[di]).abs() < 1e-4, + "b=8 symmetric slot {slot} doc {di}: {got} vs ref {}", + ref_scores[di] + ); + } +} + +#[test] +fn rankquant_eval_search_supports_b8_at_any_dim() { + // The eval/empirical path (check_eval_bits widened to 1..=8) accepts b=8 even + // at a non-256-aligned dim, where the analytical symmetric norm is + // unavailable — it computes the norm empirically. Returns ranked results + // without panicking. + // + // This is a *distinct* surface from the analytical-norm `RankQuant::search`, + // whose b=8 symmetric scoring is gated to `dim % 256 == 0`. There is no + // contradiction: the eval path's empirical norm is exact under any bucket + // occupancy, which is precisely why it is unbound by the 256 gate. + let dim = 384usize; // not a multiple of 256 + let n = 32usize; + let nq = 2usize; + let corpus: Vec = (0..n * dim) + .map(|i| ((i * 7 % 101) as f32) - 50.0) + .collect(); + let queries: Vec = (0..nq * dim) + .map(|i| ((i * 13 % 97) as f32) - 48.0) + .collect(); + let res = ordvec::rankquant_eval_search(&corpus, &queries, dim, 8, 5); + assert_eq!(res.k, 5); + assert_eq!(res.nq, nq); + for &id in &res.indices { + assert!( + id >= 0 && (id as usize) < n, + "eval-search id out of range: {id}" + ); + } +} diff --git a/tests/redteam_gamma.rs b/tests/redteam_gamma.rs index 6a2f9282..a95c0968 100644 --- a/tests/redteam_gamma.rs +++ b/tests/redteam_gamma.rs @@ -22,8 +22,9 @@ fn rank_to_bucket_large_bits_panics() { // Signature is `rank_to_bucket(rank, d, bits)`, so this is rank=3, d=8, // bits=200 — the `bits` value is what's under test. `bits >= 32` makes // `1u32 << bits` overflow (silently-wrong bucket in release), so the - // function guards with `assert!(bits <= 7, "bits too large")`. bits=200 - // trips that guard; the panic must fire in release as well as debug. + // function guards with `assert!(bits <= 8, "bits too large")` (b=8 is the + // widest RankQuant width whose codes still fit a u8). bits=200 trips that + // guard; the panic must fire in release as well as debug. let _ = rank_to_bucket(3, 8, 200); }