Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,9 @@ vector on its own:
the product / scalar / binary quantization most crates use.
- **Predictable footprint.** Exactly `dim * bits / 8` bytes per document —
known before you see any data (256 B at dim = 1024, 2-bit), with
`bits ∈ {1, 2, 4}` the size/recall knob.
`bits ∈ {1, 2, 4}` the size/recall knob. (`b = 8` is an opt-in
evidence/refinement width — asymmetric scoring at any dim, symmetric only
when `dim % 256 == 0` — not a broad retrieval mode.)
- **Two-stage retrieval, built in.** A cheap bitmap / sign-popcount
prefilter feeds an exact rerank — the coarse→fine pipeline ships as
library primitives.
Expand Down
11 changes: 9 additions & 2 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,14 @@
//! coordinate, `2 * dim` bytes per document).
//! - [`RankQuant`] buckets each rank into `1 << bits` equal-width
//! bins and packs `bits` bits per coordinate (`dim * bits / 8` bytes
//! per document).
//! per document). `bits ∈ {1, 2, 4}` are the stable retrieval widths;
//! `b = 8` is a capability-gated evidence/refinement width — asymmetric
//! scoring and code/projection generation at any dim, *analytical-norm*
//! symmetric scoring (via [`RankQuant::search`]) only when
//! `dim % 256 == 0` (see [`RankQuant::new_asymmetric`]). The standalone
//! [`rankquant_eval_search`] computes its norm *empirically*, so it scores
//! any `bits ∈ 1..=8` at any dim (including `b = 8` off the 256 grid) and
//! carries no such restriction.
//! - [`Bitmap`] stores a top-bucket bitmap per document (one bit
//! per coordinate) and scores via `popcount(Q AND D)`.
//! - [`SignBitmap`] stores a sign bitmap per document (one bit per
Expand Down Expand Up @@ -64,7 +71,7 @@ mod util;

pub use bitmap::Bitmap;
pub use quant::SubsetScratch;
pub use quant::{rankquant_eval_search, RankQuant, TwoStageCandidatePolicy};
pub use quant::{rankquant_eval_search, RankQuant, RankQuantCapability, TwoStageCandidatePolicy};
pub use rank::Rank;
pub use rank_io::{probe_index_metadata, IndexKind, IndexMetadata, IndexParams};
pub use sign_bitmap::CandidateBatch;
Expand Down
482 changes: 398 additions & 84 deletions src/quant.rs

Large diffs are not rendered by default.

458 changes: 458 additions & 0 deletions src/quant_kernels.rs

Large diffs are not rendered by default.

169 changes: 141 additions & 28 deletions src/rank.rs
Original file line number Diff line number Diff line change
Expand Up @@ -74,21 +74,27 @@ pub fn rank_transform_into(v: &[f32], out: &mut [u16]) {
/// Bucket a single rank into one of `1 << bits` equal-width bins on
/// `[0, d)`. Returns a value in `[0, 1 << bits)`.
///
/// For `bits == 8` the codomain is the full `u8` range `[0, 256)`; a
/// valid `rank < d` keeps the quotient `rank * 256 / d < 256`, so the
/// result still fits the returned `u8`.
///
/// # Panics
/// Panics if `bits > 7`, if `d == 0`, or if `rank >= d`. The `rank < d`
/// Panics if `bits > 8`, if `d == 0`, or if `rank >= d`. The `rank < d`
/// guard fails loud in *every* build — like the sibling [`pack_buckets`] and
/// [`bucket_centre`] checks — rather than silently clamping an out-of-range
/// rank into the top bucket. Internal callers feed ranks straight from
/// [`rank_transform`] (a permutation of `[0, d)`), so it never trips on the
/// hot path.
#[inline]
pub fn rank_to_bucket(rank: u16, d: usize, bits: u8) -> u8 {
// `bits` is a `u8`, so a caller could pass e.g. 8 or 255. `1u32 << bits`
// `bits` is a `u8`, so a caller could pass e.g. 9 or 255. `1u32 << bits`
// overflows for `bits >= 32` (in release that silently wraps and yields a
// wrong bucket; in debug it panics inconsistently), and the result must
// also fit in the returned `u8`, so cap at 7. `d == 0` would divide by
// zero. Guard both up front so the failure is loud in every build.
assert!(bits <= 7, "bits too large");
// also fit in the returned `u8`, so cap at 8 — the widest RankQuant width
// (b=8 yields one bucket per code value in `[0, 256)`, which still fits a
// `u8`). `d == 0` would divide by zero. Guard both up front so the failure
// is loud in every build.
assert!(bits <= 8, "bits too large");
assert!(d > 0, "d must be positive");
// A valid rank is a position in `[0, d)`. Reject `rank >= d` loudly instead
// of silently clamping the quotient back into range: the rest of the public
Expand Down Expand Up @@ -121,7 +127,7 @@ pub fn bucket_ranks(ranks: &[u16], bits: u8) -> Vec<u8> {
// input — an empty `ranks` skips the per-entry `rank_to_bucket` check and
// would otherwise silently return an empty vec. Mirrors the Python binding,
// which checks `bits` before its empty short-circuit.
assert!(bits <= 7, "bits too large");
assert!(bits <= 8, "bits too large");
let d = ranks.len();
ranks.iter().map(|&r| rank_to_bucket(r, d, bits)).collect()
}
Expand All @@ -130,27 +136,33 @@ pub fn bucket_ranks(ranks: &[u16], bits: u8) -> Vec<u8> {
/// dense byte stream.
///
/// Layout: the bucket with index 0 occupies the most-significant bits
/// of the first byte. Requires `bits ∈ {1, 2, 4}` and `d`'s length to
/// be a multiple of `8 / bits`.
/// of the first byte. Requires `bits ∈ {1, 2, 4, 8}` and `d`'s length to
/// be a multiple of `8 / bits`. For `bits == 8` the packing is the
/// degenerate one-code-per-byte case: each code is copied verbatim into
/// its own byte (no sub-byte shifting), so any `d` is valid.
///
/// # Panics
/// Panics if `bits ∉ {1, 2, 4}`, if `buckets.len()` is not a multiple
/// Panics if `bits ∉ {1, 2, 4, 8}`, if `buckets.len()` is not a multiple
/// of `8 / bits`, or if any code is `>= 1 << bits`. The last guard is
/// the public-contract backstop: an out-of-range code would otherwise
/// be silently truncated to `code & ((1 << bits) - 1)` and corrupt the
/// packed stream. (Internal callers feed codes straight from
/// [`rank_to_bucket`], which is always in range; this protects direct
/// callers of the primitive.)
/// callers of the primitive.) Note the `b=8` code range is the full
/// `u8`, so the range guard is vacuously satisfied for that width.
pub fn pack_buckets(buckets: &[u8], bits: u8) -> Vec<u8> {
assert!(matches!(bits, 1 | 2 | 4), "bits must be 1, 2, or 4");
assert!(matches!(bits, 1 | 2 | 4 | 8), "bits must be 1, 2, 4, or 8");
let codes_per_byte = (8 / bits) as usize;
let d = buckets.len();
assert_eq!(
d % codes_per_byte,
0,
"d ({d}) must be a multiple of codes_per_byte ({codes_per_byte}) for bits = {bits}",
);
let mask = (1u8 << bits) - 1;
// `(1u8 << 8)` overflows a `u8`, so compute the mask in `u16` and saturate
// the `b=8` case to the full byte (`0xFF`). For `b ∈ {1,2,4}` this is the
// same value the old `(1u8 << bits) - 1` produced.
let mask = ((1u16 << bits) - 1) as u8;
let n_bytes = d / codes_per_byte;
let mut out = vec![0u8; n_bytes];
let bits_u = bits as usize;
Expand All @@ -160,6 +172,8 @@ pub fn pack_buckets(buckets: &[u8], bits: u8) -> Vec<u8> {
// fail-loud guarantee without a second O(d) pass over `buckets`; the
// branch is loop-invariant-predictable for the always-valid internal
// callers. Asserting `b <= mask` makes the trailing `& mask` redundant.
// At `b=8`, `codes_per_byte == 1`, so `shift == 0` and each byte holds one
// code verbatim.
for (i, &b) in buckets.iter().enumerate() {
assert!(
b <= mask,
Expand All @@ -178,10 +192,12 @@ pub fn pack_buckets(buckets: &[u8], bits: u8) -> Vec<u8> {
///
/// Inverse of [`pack_buckets`].
pub fn unpack_buckets(packed: &[u8], d: usize, bits: u8) -> Vec<u8> {
assert!(matches!(bits, 1 | 2 | 4), "bits must be 1, 2, or 4");
assert!(matches!(bits, 1 | 2 | 4 | 8), "bits must be 1, 2, 4, or 8");
let codes_per_byte = (8 / bits) as usize;
assert_eq!(packed.len() * codes_per_byte, d);
let mask = (1u8 << bits) - 1;
// `(1u8 << 8)` overflows a `u8`; compute in `u16` and narrow so `b=8`
// yields the full-byte mask `0xFF` (each byte already holds one code).
let mask = ((1u16 << bits) - 1) as u8;
let bits_u = bits as usize;
let mut out = vec![0u8; d];
#[allow(clippy::needless_range_loop)] // indexed access is clearer / matches the kernel layout
Expand All @@ -195,13 +211,16 @@ pub fn unpack_buckets(packed: &[u8], d: usize, bits: u8) -> Vec<u8> {
}

/// Number of bytes per packed RankQuant document at dimension `d` and
/// bit width `bits ∈ {1, 2, 4}`.
/// bit width `bits ∈ {1, 2, 4, 8}`.
///
/// At `bits == 8` each coordinate occupies its own byte (`codes_per_byte
/// == 1`), so the storage is exactly `d` bytes per document.
#[inline]
pub fn rankquant_bytes_per_vec(d: usize, bits: u8) -> usize {
// Guard the same domain as the sibling pack/unpack helpers: `bits == 0`
// would divide by zero computing `codes_per_byte`, and only 1/2/4 give an
// would divide by zero computing `codes_per_byte`, and only 1/2/4/8 give an
// integral codes-per-byte.
assert!(matches!(bits, 1 | 2 | 4), "bits must be 1,2,4");
assert!(matches!(bits, 1 | 2 | 4 | 8), "bits must be 1,2,4,8");
let codes_per_byte = (8 / bits) as usize;
assert_eq!(
d % codes_per_byte,
Expand All @@ -219,17 +238,20 @@ pub fn rankquant_bytes_per_vec(d: usize, bits: u8) -> usize {
/// pattern `..., -1.5, -0.5, +0.5, +1.5, ...` for `B = 2`.
///
/// # Panics
/// Panics if `bits > 7` — bucket codes are `u8`, so the bit width is
/// Panics if `bits > 8` — bucket codes are `u8`, so the bit width is
/// capped at the representable bucketing range, matching
/// [`rank_to_bucket`] (the RankQuant family uses `bits ∈ {1, 2, 4}`).
/// [`rank_to_bucket`] (the RankQuant family uses `bits ∈ {1, 2, 4, 8}`).
/// Also panics if `bucket >= 1 << bits`; this guard fails loud in *every*
/// build — like the sibling [`pack_buckets`] check — so a direct caller
/// cannot silently receive a centre outside the symmetric range. The
/// internal LUT builders only ever pass `bucket ∈ [0, 1 << bits)` (the
/// loop bound *is* `1 << bits`), so the assert never trips on the hot path.
/// For `bits == 8` the centres span `..., -0.5, +0.5, ...` around zero
/// with `bucket - 127.5`; the range guard is vacuous (every `u8` is a
/// valid code).
#[inline]
pub fn bucket_centre(bucket: u8, bits: u8) -> f32 {
assert!(bits <= 7, "bits too large");
assert!(bits <= 8, "bits too large");
assert!(
(bucket as u32) < (1u32 << bits),
"bucket {bucket} out of range for bits = {bits}",
Expand Down Expand Up @@ -270,14 +292,21 @@ pub fn rank_norm(d: usize) -> f32 {
/// The mean-centred bucket index has variance `(2^(2B) - 1) / 12`, so
/// the per-vector L2 norm is `sqrt(d * (2^(2B) - 1) / 12)`.
///
/// This is the **symmetric analytical** norm: it is exact only when
/// every bucket receives exactly `d / 2^B` coordinates, i.e. when
/// `d % 2^B == 0`. For `bits == 8` that precondition is `d % 256 == 0`;
/// the [`crate::RankQuant`] symmetric path enforces it before calling
/// this (see `RankQuant::new` / `symmetric_supported`). This primitive
/// itself only computes the closed form and does not re-check occupancy.
///
/// # Panics
/// Panics if `bits ∉ {1, 2, 4}`, mirroring the [`crate::RankQuant`]
/// Panics if `bits ∉ {1, 2, 4, 8}`, mirroring the [`crate::RankQuant`]
/// bit-width domain (and [`rankquant_bytes_per_vec`]). Without it a
/// nonsensical `bits` would return a norm for a scheme that does not
/// exist (or overflow `1 << bits`).
#[inline]
pub fn rankquant_norm(d: usize, bits: u8) -> f32 {
assert!(matches!(bits, 1 | 2 | 4), "bits must be 1,2,4");
assert!(matches!(bits, 1 | 2 | 4 | 8), "bits must be 1,2,4,8");
let n = (1u32 << bits) as f64;
let var = (n * n - 1.0) / 12.0;
((d as f64) * var).sqrt() as f32
Expand Down Expand Up @@ -629,10 +658,20 @@ mod tests {

#[test]
#[should_panic(expected = "bits too large")]
fn bucket_ranks_rejects_bits_above_7_even_when_empty() {
fn bucket_ranks_rejects_bits_above_8_even_when_empty() {
// `bits` is validated up front, so an invalid width fails loud even on
// empty input — which never reaches the per-entry rank_to_bucket guard.
let _ = bucket_ranks(&[], 8);
// The valid bucketing range now extends to b=8 (the widest RankQuant
// width), so b=9 is the first rejected width.
let _ = bucket_ranks(&[], 9);
}

#[test]
fn bucket_ranks_accepts_bits_8() {
// b=8 is a supported width: a 4-element rank vector buckets without
// panicking and yields codes in [0, 256).
let codes = bucket_ranks(&[0, 1, 2, 3], 8);
assert_eq!(codes.len(), 4);
}

#[test]
Expand Down Expand Up @@ -662,6 +701,72 @@ mod tests {
assert_eq!(unpacked, buckets);
}

#[test]
fn pack_unpack_round_trip_bits8() {
// b=8 is the degenerate one-code-per-byte packing: each byte holds a
// full code in `[0, 256)`, so packed length == code count and the
// bytes are the codes verbatim. Cover the full code range including
// 0 and 255 (the extremes the `b ∈ {1,2,4}` mask would have clipped).
let buckets: Vec<u8> = (0..256).map(|i| i as u8).collect();
let packed = pack_buckets(&buckets, 8);
assert_eq!(packed.len(), 256, "b=8 stores one byte per code");
assert_eq!(packed, buckets, "b=8 packing is the identity byte stream");
let unpacked = unpack_buckets(&packed, 256, 8);
assert_eq!(unpacked, buckets);
}

#[test]
fn pack_unpack_round_trip_bits8_arbitrary_len() {
// Any `d` is a valid b=8 length (codes_per_byte == 1); 384 is not a
// multiple of 256 yet still round-trips — code generation never needs
// the equal-bucket precondition that only the symmetric norm requires.
let buckets: Vec<u8> = (0..384u16).map(|i| (i % 256) as u8).collect();
let packed = pack_buckets(&buckets, 8);
assert_eq!(packed.len(), 384);
let unpacked = unpack_buckets(&packed, 384, 8);
assert_eq!(unpacked, buckets);
}

#[test]
fn rank_to_bucket_b8_spans_full_byte_range() {
// rank in [0, d) with bits=8 must land in [0, 256). Check the extremes
// and that the top rank maps to the top bucket for d == 256.
let d = 256usize;
assert_eq!(rank_to_bucket(0, d, 8), 0);
assert_eq!(rank_to_bucket(255, d, 8), 255);
// A coarser d still keeps the quotient in range.
assert!(rank_to_bucket(383, 384, 8) < 255 || rank_to_bucket(383, 384, 8) == 255);
for rank in 0..d as u16 {
let _ = rank_to_bucket(rank, d, 8); // never panics, always < 256
}
}

#[test]
fn bucket_centre_b8_is_symmetric_around_zero() {
// For b=8 the 256 centres span -127.5 ..= +127.5 and sum to 0.
assert_eq!(bucket_centre(0, 8), -127.5);
assert_eq!(bucket_centre(255, 8), 127.5);
let sum: f32 = (0..256u16).map(|b| bucket_centre(b as u8, 8)).sum();
assert!(sum.abs() < 1e-3, "b=8 centres should sum to ~0, got {sum}");
}

#[test]
fn rankquant_norm_b8_matches_direct_computation() {
// d % 256 == 0 so every bucket gets exactly d/256 entries and the
// analytical norm is exact.
let d = 512usize;
let bits = 8u8;
let analytical = rankquant_norm(d, bits);
let ranks: Vec<u16> = (0..d as u16).collect();
let buckets = bucket_ranks(&ranks, bits);
let centred: Vec<f32> = buckets.iter().map(|&b| bucket_centre(b, bits)).collect();
let direct: f32 = centred.iter().map(|x| x * x).sum::<f32>().sqrt();
assert!(
(analytical - direct).abs() / direct < 1e-5,
"analytical {analytical}, direct {direct}"
);
}

#[test]
fn bucket_centres_are_symmetric_around_zero() {
// For B = 2: bucket values are {-1.5, -0.5, +0.5, +1.5}.
Expand Down Expand Up @@ -718,10 +823,18 @@ mod tests {

#[test]
#[should_panic(expected = "bits too large")]
fn bucket_centre_rejects_bits_above_7() {
// bits >= 32 overflows `1 << bits`; the guard caps at 7 (the u8
// bucket domain), matching `rank_to_bucket`.
let _ = bucket_centre(0, 8);
fn bucket_centre_rejects_bits_above_8() {
// bits >= 32 overflows `1 << bits`; the guard caps at 8 (the widest
// RankQuant width, whose codes still fit a u8), matching
// `rank_to_bucket`. b=9 is the first rejected width.
let _ = bucket_centre(0, 9);
}

#[test]
fn bucket_centre_accepts_bits_8() {
// b=8 centres are valid: code 0 → -127.5, code 255 → +127.5.
assert_eq!(bucket_centre(0, 8), -127.5);
assert_eq!(bucket_centre(255, 8), 127.5);
}

#[test]
Expand Down
1 change: 1 addition & 0 deletions tests/index/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ mod rank;
#[cfg(feature = "experimental")]
mod multi_bucket;
mod quant;
mod quant_b8;
mod two_stage;

pub const D: usize = 128;
Expand Down
Loading
Loading