diff --git a/src/rank_io.rs b/src/rank_io.rs index a297c5f4..8e3be947 100644 --- a/src/rank_io.rs +++ b/src/rank_io.rs @@ -183,15 +183,21 @@ fn read_le_vec( /// offset accounting. fn check_payload_matches_file( reader: &mut R, + label: &str, file_len: u64, payload_bytes: usize, ) -> io::Result<()> { let pos = reader.stream_position()?; let remaining = file_len.saturating_sub(pos); - if payload_bytes as u64 != remaining { + let payload_bytes = payload_bytes as u64; + if payload_bytes > remaining { return Err(invalid(format!( - "declared payload ({payload_bytes} B) does not match remaining \ - file size ({remaining} B): truncated, forged, or trailing bytes" + "{label} payload truncated: header declares {payload_bytes} B but file has {remaining} B remaining" + ))); + } + if payload_bytes < remaining { + return Err(invalid(format!( + "{label} payload has trailing bytes: header declares {payload_bytes} B but file has {remaining} B remaining" ))); } Ok(()) @@ -253,19 +259,76 @@ fn check_payload_bytes(payload_bytes: usize) -> io::Result<()> { Ok(()) } -fn read_u32_le(reader: &mut R) -> io::Result { - let mut buf = [0u8; 4]; - reader.read_exact(&mut buf)?; - Ok(u32::from_le_bytes(buf)) +fn truncated_field(label: &str, field: &str) -> io::Error { + io::Error::new( + io::ErrorKind::UnexpectedEof, + format!("{label} header truncated while reading {field}"), + ) +} + +fn read_exact_field( + reader: &mut R, + label: &str, + field: &str, +) -> io::Result<[u8; N]> { + let mut buf = [0u8; N]; + reader + .read_exact(&mut buf) + .map_err(|err| match err.kind() { + io::ErrorKind::UnexpectedEof => truncated_field(label, field), + _ => err, + })?; + Ok(buf) +} + +fn read_u8_field(reader: &mut R, label: &str, field: &str) -> io::Result { + Ok(read_exact_field::<_, 1>(reader, label, field)?[0]) +} + +fn read_u32_le(reader: &mut R, label: &str, field: &str) -> io::Result { + Ok(u32::from_le_bytes(read_exact_field::<_, 4>( + reader, label, field, + )?)) } fn read_version(reader: &mut R, label: &str) -> io::Result { - let mut ver = [0u8; 1]; - reader.read_exact(&mut ver)?; - if ver[0] != VERSION { - return Err(invalid(format!("unsupported {label} version: {}", ver[0]))); + let ver = read_u8_field(reader, label, "version")?; + if ver != VERSION { + return Err(invalid(format!("unsupported {label} version: {ver}"))); } - Ok(ver[0]) + Ok(ver) +} + +fn read_magic(reader: &mut R, label: &str) -> io::Result<[u8; 4]> { + read_exact_field(reader, label, "magic") +} + +fn rank_payload_bytes(dim: usize, vector_count: usize) -> io::Result { + vector_count + .checked_mul(dim) + .and_then(|x| x.checked_mul(2)) + .ok_or_else(|| invalid("TVR1 payload size overflows usize")) +} + +fn rankquant_bytes_per_vec(dim: usize, bits: u8) -> io::Result { + dim.checked_mul(bits as usize) + .map(|x| x / 8) + .ok_or_else(|| invalid("TVRQ bytes_per_vec overflows usize")) +} + +fn rankquant_payload_bytes(dim: usize, vector_count: usize, bits: u8) -> io::Result { + let bytes_per_vec = rankquant_bytes_per_vec(dim, bits)?; + vector_count + .checked_mul(bytes_per_vec) + .ok_or_else(|| invalid("TVRQ payload size overflows usize")) +} + +fn bitmap_payload_bytes(dim: usize, vector_count: usize, label: &str) -> io::Result { + let qpv = dim / 64; + vector_count + .checked_mul(qpv) + .and_then(|x| x.checked_mul(8)) + .ok_or_else(|| invalid(format!("{label} payload size overflows usize"))) } /// Probe an ordvec index file's fixed header and declared byte shape. @@ -280,8 +343,7 @@ pub fn probe_index_metadata(path: impl AsRef) -> io::Result let file = File::open(path)?; let file_size_bytes = file.metadata()?.len(); let mut f = BufReader::new(file); - let mut magic = [0u8; 4]; - f.read_exact(&mut magic)?; + let magic = read_magic(&mut f, "ordvec index")?; match &magic { TVR_MAGIC => probe_rank_metadata(&mut f, file_size_bytes), TVRQ_MAGIC => probe_rankquant_metadata(&mut f, file_size_bytes), @@ -296,18 +358,14 @@ fn probe_rank_metadata( file_size_bytes: u64, ) -> io::Result { let format_version = read_version(reader, "TVR1")?; - let dim = read_u32_le(reader)? as usize; + let dim = read_u32_le(reader, "TVR1", "dim")? as usize; check_dim(dim)?; - let vector_count = read_u32_le(reader)? as usize; + let vector_count = read_u32_le(reader, "TVR1", "n_vectors")? as usize; check_n_vectors(vector_count)?; - let bytes_per_vec = dim - .checked_mul(2) - .ok_or_else(|| invalid("bytes_per_vec overflows usize"))?; - let payload_bytes = vector_count - .checked_mul(bytes_per_vec) - .ok_or_else(|| invalid("payload size overflows usize"))?; + let bytes_per_vec = rank_payload_bytes(dim, 1)?; + let payload_bytes = rank_payload_bytes(dim, vector_count)?; check_payload_bytes(payload_bytes)?; - check_payload_matches_file(reader, file_size_bytes, payload_bytes)?; + check_payload_matches_file(reader, "TVR1", file_size_bytes, payload_bytes)?; Ok(IndexMetadata { kind: IndexKind::Rank, format_version, @@ -324,15 +382,13 @@ fn probe_rankquant_metadata( file_size_bytes: u64, ) -> io::Result { let format_version = read_version(reader, "TVRQ")?; - let mut bits_buf = [0u8; 1]; - reader.read_exact(&mut bits_buf)?; - let bits = bits_buf[0]; + let bits = read_u8_field(reader, "TVRQ", "bits")?; if !matches!(bits, 1 | 2 | 4) { return Err(invalid(format!( "unsupported TVRQ bits: {bits} (expected 1, 2, or 4)" ))); } - let dim = read_u32_le(reader)? as usize; + let dim = read_u32_le(reader, "TVRQ", "dim")? as usize; check_dim(dim)?; let n_buckets = 1usize << bits; if !dim.is_multiple_of(n_buckets) { @@ -347,19 +403,12 @@ fn probe_rankquant_metadata( "TVRQ dim {dim} is not a multiple of codes_per_byte = {codes_per_byte}" ))); } - let vector_count = read_u32_le(reader)? as usize; + let vector_count = read_u32_le(reader, "TVRQ", "n_vectors")? as usize; check_n_vectors(vector_count)?; - let payload_bytes = vector_count - .checked_mul(dim) - .and_then(|x| x.checked_mul(bits as usize)) - .map(|x| x / 8) - .ok_or_else(|| invalid("payload size overflows usize"))?; + let payload_bytes = rankquant_payload_bytes(dim, vector_count, bits)?; check_payload_bytes(payload_bytes)?; - check_payload_matches_file(reader, file_size_bytes, payload_bytes)?; - let bytes_per_vec = dim - .checked_mul(bits as usize) - .map(|x| x / 8) - .ok_or_else(|| invalid("bytes_per_vec overflows usize"))?; + check_payload_matches_file(reader, "TVRQ", file_size_bytes, payload_bytes)?; + let bytes_per_vec = rankquant_bytes_per_vec(dim, bits)?; Ok(IndexMetadata { kind: IndexKind::RankQuant, format_version, @@ -376,26 +425,22 @@ fn probe_bitmap_metadata( file_size_bytes: u64, ) -> io::Result { let format_version = read_version(reader, "TVBM")?; - let dim = read_u32_le(reader)? as usize; + let dim = read_u32_le(reader, "TVBM", "dim")? as usize; check_dim(dim)?; if !dim.is_multiple_of(64) { return Err(invalid(format!("TVBM dim {dim} is not a multiple of 64"))); } - let n_top = read_u32_le(reader)? as usize; + let n_top = read_u32_le(reader, "TVBM", "n_top")? as usize; if n_top == 0 || n_top >= dim { return Err(invalid(format!( "TVBM n_top {n_top} must satisfy 0 < n_top < dim ({dim})" ))); } - let vector_count = read_u32_le(reader)? as usize; + let vector_count = read_u32_le(reader, "TVBM", "n_vectors")? as usize; check_n_vectors(vector_count)?; - let qpv = dim / 64; - let payload_bytes = vector_count - .checked_mul(qpv) - .and_then(|x| x.checked_mul(8)) - .ok_or_else(|| invalid("payload size overflows usize"))?; + let payload_bytes = bitmap_payload_bytes(dim, vector_count, "TVBM")?; check_payload_bytes(payload_bytes)?; - check_payload_matches_file(reader, file_size_bytes, payload_bytes)?; + check_payload_matches_file(reader, "TVBM", file_size_bytes, payload_bytes)?; Ok(IndexMetadata { kind: IndexKind::Bitmap, format_version, @@ -412,17 +457,13 @@ fn probe_sign_bitmap_metadata( file_size_bytes: u64, ) -> io::Result { let format_version = read_version(reader, "TVSB")?; - let dim = read_u32_le(reader)? as usize; + let dim = read_u32_le(reader, "TVSB", "dim")? as usize; check_sign_bitmap_dim(dim)?; - let vector_count = read_u32_le(reader)? as usize; + let vector_count = read_u32_le(reader, "TVSB", "n_vectors")? as usize; check_n_vectors(vector_count)?; - let qpv = dim / 64; - let payload_bytes = vector_count - .checked_mul(qpv) - .and_then(|x| x.checked_mul(8)) - .ok_or_else(|| invalid("payload size overflows usize"))?; + let payload_bytes = bitmap_payload_bytes(dim, vector_count, "TVSB")?; check_payload_bytes(payload_bytes)?; - check_payload_matches_file(reader, file_size_bytes, payload_bytes)?; + check_payload_matches_file(reader, "TVSB", file_size_bytes, payload_bytes)?; Ok(IndexMetadata { kind: IndexKind::SignBitmap, format_version, @@ -449,10 +490,7 @@ pub(crate) fn write_rank( // Enforce the loaders' MAX_PAYLOAD cap *before* File::create so a rejected // oversized write never truncates an existing file. Defense-in-depth; the // round-trip guarantee is type-level (see module docs). Mirrors load_rank. - let payload_bytes = n_vectors - .checked_mul(dim) - .and_then(|x| x.checked_mul(2)) - .ok_or_else(|| invalid("payload size overflows usize"))?; + let payload_bytes = rank_payload_bytes(dim, n_vectors)?; check_payload_bytes(payload_bytes)?; assert_eq!(ranks.len(), payload_bytes / 2); let mut f = BufWriter::new(File::create(path)?); @@ -476,30 +514,18 @@ pub(crate) fn load_rank(path: impl AsRef) -> io::Result<(usize, usize, Vec // the trailing-byte check. Both are wrong on a metadata race (NFS/procfs). let file_len = file.metadata()?.len(); let mut f = BufReader::new(file); - let mut magic = [0u8; 4]; - f.read_exact(&mut magic)?; + let magic = read_magic(&mut f, "TVR1")?; if &magic != TVR_MAGIC { return Err(invalid("not a TVR1 file: wrong magic")); } - let mut ver = [0u8; 1]; - f.read_exact(&mut ver)?; - if ver[0] != VERSION { - return Err(invalid(format!("unsupported TVR1 version: {}", ver[0]))); - } - let mut dim_buf = [0u8; 4]; - f.read_exact(&mut dim_buf)?; - let dim = u32::from_le_bytes(dim_buf) as usize; + read_version(&mut f, "TVR1")?; + let dim = read_u32_le(&mut f, "TVR1", "dim")? as usize; check_dim(dim)?; - let mut n_buf = [0u8; 4]; - f.read_exact(&mut n_buf)?; - let n_vectors = u32::from_le_bytes(n_buf) as usize; + let n_vectors = read_u32_le(&mut f, "TVR1", "n_vectors")? as usize; check_n_vectors(n_vectors)?; - let payload_bytes = n_vectors - .checked_mul(dim) - .and_then(|x| x.checked_mul(2)) - .ok_or_else(|| invalid("payload size overflows usize"))?; + let payload_bytes = rank_payload_bytes(dim, n_vectors)?; check_payload_bytes(payload_bytes)?; - check_payload_matches_file(&mut f, file_len, payload_bytes)?; + check_payload_matches_file(&mut f, "TVR1", file_len, payload_bytes)?; // `payload_bytes == n_vectors * dim * 2`, so the u16 element count is // `payload_bytes / 2`. Read directly into a fallibly reserved Vec // instead of allocating a byte buffer and `.collect()`-ing it — the old @@ -557,11 +583,7 @@ pub(crate) fn write_rankquant( // Enforce the loaders' MAX_PAYLOAD cap *before* File::create (defense-in- // depth; a rejected write must not truncate an existing file). Mirrors // load_rankquant: checked multiply before the /8 divide. - let payload_bytes = n_vectors - .checked_mul(dim) - .and_then(|x| x.checked_mul(bits as usize)) - .map(|x| x / 8) - .ok_or_else(|| invalid("payload size overflows usize"))?; + let payload_bytes = rankquant_payload_bytes(dim, n_vectors, bits)?; check_payload_bytes(payload_bytes)?; assert_eq!(packed.len(), payload_bytes); let mut f = BufWriter::new(File::create(path)?); @@ -584,27 +606,18 @@ pub(crate) fn load_rankquant(path: impl AsRef) -> io::Result<(u8, usize, u // the trailing-byte check. Both are wrong on a metadata race (NFS/procfs). let file_len = file.metadata()?.len(); let mut f = BufReader::new(file); - let mut magic = [0u8; 4]; - f.read_exact(&mut magic)?; + let magic = read_magic(&mut f, "TVRQ")?; if &magic != TVRQ_MAGIC { return Err(invalid("not a TVRQ file: wrong magic")); } - let mut ver = [0u8; 1]; - f.read_exact(&mut ver)?; - if ver[0] != VERSION { - return Err(invalid(format!("unsupported TVRQ version: {}", ver[0]))); - } - let mut bits_buf = [0u8; 1]; - f.read_exact(&mut bits_buf)?; - let bits = bits_buf[0]; + read_version(&mut f, "TVRQ")?; + let bits = read_u8_field(&mut f, "TVRQ", "bits")?; if !matches!(bits, 1 | 2 | 4) { return Err(invalid(format!( "unsupported TVRQ bits: {bits} (expected 1, 2, or 4)" ))); } - let mut dim_buf = [0u8; 4]; - f.read_exact(&mut dim_buf)?; - let dim = u32::from_le_bytes(dim_buf) as usize; + let dim = read_u32_le(&mut f, "TVRQ", "dim")? as usize; check_dim(dim)?; // Constant-composition invariants (documented at module level and // enforced by `RankQuant::new`): `dim` must be a multiple of @@ -626,17 +639,11 @@ pub(crate) fn load_rankquant(path: impl AsRef) -> io::Result<(u8, usize, u "TVRQ dim {dim} is not a multiple of codes_per_byte = {codes_per_byte}" ))); } - let mut n_buf = [0u8; 4]; - f.read_exact(&mut n_buf)?; - let n_vectors = u32::from_le_bytes(n_buf) as usize; + let n_vectors = read_u32_le(&mut f, "TVRQ", "n_vectors")? as usize; check_n_vectors(n_vectors)?; - let payload_bytes = n_vectors - .checked_mul(dim) - .and_then(|x| x.checked_mul(bits as usize)) - .map(|x| x / 8) - .ok_or_else(|| invalid("payload size overflows usize"))?; + let payload_bytes = rankquant_payload_bytes(dim, n_vectors, bits)?; check_payload_bytes(payload_bytes)?; - check_payload_matches_file(&mut f, file_len, payload_bytes)?; + check_payload_matches_file(&mut f, "TVRQ", file_len, payload_bytes)?; let mut packed = try_alloc_zeroed(payload_bytes)?; f.read_exact(&mut packed)?; // Constant-composition invariant: every document must place exactly @@ -683,14 +690,10 @@ pub(crate) fn write_bitmap( n_vectors: usize, bitmaps: &[u64], ) -> io::Result<()> { - let qpv = dim / 64; // Enforce the loaders' MAX_PAYLOAD cap *before* File::create (defense-in- // depth; a rejected write must not truncate an existing file). Mirrors // load_bitmap. - let payload_bytes = n_vectors - .checked_mul(qpv) - .and_then(|x| x.checked_mul(8)) - .ok_or_else(|| invalid("payload size overflows usize"))?; + let payload_bytes = bitmap_payload_bytes(dim, n_vectors, "TVBM")?; check_payload_bytes(payload_bytes)?; assert_eq!(bitmaps.len(), payload_bytes / 8); let mut f = BufWriter::new(File::create(path)?); @@ -715,42 +718,28 @@ pub(crate) fn load_bitmap(path: impl AsRef) -> io::Result<(usize, usize, u // the trailing-byte check. Both are wrong on a metadata race (NFS/procfs). let file_len = file.metadata()?.len(); let mut f = BufReader::new(file); - let mut magic = [0u8; 4]; - f.read_exact(&mut magic)?; + let magic = read_magic(&mut f, "TVBM")?; if &magic != TVBM_MAGIC { return Err(invalid("not a TVBM file: wrong magic")); } - let mut ver = [0u8; 1]; - f.read_exact(&mut ver)?; - if ver[0] != VERSION { - return Err(invalid(format!("unsupported TVBM version: {}", ver[0]))); - } - let mut dim_buf = [0u8; 4]; - f.read_exact(&mut dim_buf)?; - let dim = u32::from_le_bytes(dim_buf) as usize; + read_version(&mut f, "TVBM")?; + let dim = read_u32_le(&mut f, "TVBM", "dim")? as usize; check_dim(dim)?; if !dim.is_multiple_of(64) { return Err(invalid(format!("TVBM dim {dim} is not a multiple of 64"))); } - let mut top_buf = [0u8; 4]; - f.read_exact(&mut top_buf)?; - let n_top = u32::from_le_bytes(top_buf) as usize; + let n_top = read_u32_le(&mut f, "TVBM", "n_top")? as usize; if n_top == 0 || n_top >= dim { return Err(invalid(format!( "TVBM n_top {n_top} must satisfy 0 < n_top < dim ({dim})" ))); } - let mut n_buf = [0u8; 4]; - f.read_exact(&mut n_buf)?; - let n_vectors = u32::from_le_bytes(n_buf) as usize; + let n_vectors = read_u32_le(&mut f, "TVBM", "n_vectors")? as usize; check_n_vectors(n_vectors)?; let qpv = dim / 64; - let payload_bytes = n_vectors - .checked_mul(qpv) - .and_then(|x| x.checked_mul(8)) - .ok_or_else(|| invalid("payload size overflows usize"))?; + let payload_bytes = bitmap_payload_bytes(dim, n_vectors, "TVBM")?; check_payload_bytes(payload_bytes)?; - check_payload_matches_file(&mut f, file_len, payload_bytes)?; + check_payload_matches_file(&mut f, "TVBM", file_len, payload_bytes)?; // `payload_bytes == n_vectors * qpv * 8`, so the u64 element count is // `payload_bytes / 8`. Read directly into a fallibly reserved Vec // rather than allocating a byte buffer and `.collect()`-ing it. @@ -792,14 +781,10 @@ pub(crate) fn write_sign_bitmap( n_vectors: usize, bitmaps: &[u64], ) -> io::Result<()> { - let qpv = dim / 64; // Enforce the loaders' MAX_PAYLOAD cap *before* File::create (defense-in- // depth; a rejected write must not truncate an existing file). Mirrors // load_sign_bitmap. - let payload_bytes = n_vectors - .checked_mul(qpv) - .and_then(|x| x.checked_mul(8)) - .ok_or_else(|| invalid("payload size overflows usize"))?; + let payload_bytes = bitmap_payload_bytes(dim, n_vectors, "TVSB")?; check_payload_bytes(payload_bytes)?; assert_eq!(bitmaps.len(), payload_bytes / 8); let mut f = BufWriter::new(File::create(path)?); @@ -820,8 +805,9 @@ pub(crate) fn write_sign_bitmap( /// `[64, MAX_SIGN_BITMAP_DIM]` and a multiple of 64), and `n_vectors` /// (≤ `MAX_VECTORS`). Payload size is computed with `checked_mul` and /// rejected if it overflows or exceeds the 128 GiB hard cap from -/// `check_payload_bytes`. Any malformed input returns -/// `io::Error::InvalidData`. +/// `check_payload_bytes`. Malformed input returns `io::Error`; structurally +/// invalid fields use `InvalidData`, while truncated headers surface +/// `UnexpectedEof` with field context. /// /// Dim validation deliberately does NOT use `check_dim`: that helper /// caps at `u16::MAX` to honour [`crate::Rank`]'s `u16` rank @@ -837,31 +823,18 @@ pub(crate) fn load_sign_bitmap(path: impl AsRef) -> io::Result<(usize, usi // the trailing-byte check. Both are wrong on a metadata race (NFS/procfs). let file_len = file.metadata()?.len(); let mut f = BufReader::new(file); - let mut magic = [0u8; 4]; - f.read_exact(&mut magic)?; + let magic = read_magic(&mut f, "TVSB")?; if &magic != TVSB_MAGIC { return Err(invalid("not a TVSB file: wrong magic")); } - let mut ver = [0u8; 1]; - f.read_exact(&mut ver)?; - if ver[0] != VERSION { - return Err(invalid(format!("unsupported TVSB version: {}", ver[0]))); - } - let mut dim_buf = [0u8; 4]; - f.read_exact(&mut dim_buf)?; - let dim = u32::from_le_bytes(dim_buf) as usize; + read_version(&mut f, "TVSB")?; + let dim = read_u32_le(&mut f, "TVSB", "dim")? as usize; check_sign_bitmap_dim(dim)?; - let mut n_buf = [0u8; 4]; - f.read_exact(&mut n_buf)?; - let n_vectors = u32::from_le_bytes(n_buf) as usize; + let n_vectors = read_u32_le(&mut f, "TVSB", "n_vectors")? as usize; check_n_vectors(n_vectors)?; - let qpv = dim / 64; - let payload_bytes = n_vectors - .checked_mul(qpv) - .and_then(|x| x.checked_mul(8)) - .ok_or_else(|| invalid("payload size overflows usize"))?; + let payload_bytes = bitmap_payload_bytes(dim, n_vectors, "TVSB")?; check_payload_bytes(payload_bytes)?; - check_payload_matches_file(&mut f, file_len, payload_bytes)?; + check_payload_matches_file(&mut f, "TVSB", file_len, payload_bytes)?; // `payload_bytes == n_vectors * qpv * 8`, so the u64 element count is // `payload_bytes / 8`. Read directly into a fallibly reserved Vec // rather than allocating a byte buffer and `.collect()`-ing it. @@ -916,6 +889,55 @@ mod tests { p } + fn assert_err_contains(result: std::io::Result, expected: &str) { + let Err(err) = result else { + panic!("expected error containing {expected:?}, got Ok(_)"); + }; + let text = err.to_string(); + assert!( + text.contains(expected), + "expected error containing {expected:?}, got {text:?}" + ); + } + + fn rank_header(dim: u32, n_vectors: u32) -> Vec { + let mut v = Vec::new(); + v.extend_from_slice(b"TVR1"); + v.push(VERSION); + v.extend_from_slice(&dim.to_le_bytes()); + v.extend_from_slice(&n_vectors.to_le_bytes()); + v + } + + fn rankquant_header(bits: u8, dim: u32, n_vectors: u32) -> Vec { + let mut v = Vec::new(); + v.extend_from_slice(b"TVRQ"); + v.push(VERSION); + v.push(bits); + v.extend_from_slice(&dim.to_le_bytes()); + v.extend_from_slice(&n_vectors.to_le_bytes()); + v + } + + fn bitmap_header(dim: u32, n_top: u32, n_vectors: u32) -> Vec { + let mut v = Vec::new(); + v.extend_from_slice(b"TVBM"); + v.push(VERSION); + v.extend_from_slice(&dim.to_le_bytes()); + v.extend_from_slice(&n_top.to_le_bytes()); + v.extend_from_slice(&n_vectors.to_le_bytes()); + v + } + + fn sign_bitmap_header(dim: u32, n_vectors: u32) -> Vec { + let mut v = Vec::new(); + v.extend_from_slice(b"TVSB"); + v.push(VERSION); + v.extend_from_slice(&dim.to_le_bytes()); + v.extend_from_slice(&n_vectors.to_le_bytes()); + v + } + #[test] fn probe_metadata_matches_full_loaders_on_generated_fixtures() { let mut paths = Vec::new(); @@ -1029,16 +1051,18 @@ mod tests { let truncated = forge("truncated_header", b"TVR1\x01"); let err = probe_index_metadata(&truncated).unwrap_err(); assert_eq!(err.kind(), std::io::ErrorKind::UnexpectedEof); + assert!( + err.to_string() + .contains("TVR1 header truncated while reading dim"), + "unexpected error: {err}" + ); std::fs::remove_file(&truncated).ok(); - let mut length_mismatch = Vec::new(); - length_mismatch.extend_from_slice(b"TVR1"); - length_mismatch.push(VERSION); - length_mismatch.extend_from_slice(&8u32.to_le_bytes()); - length_mismatch.extend_from_slice(&1u32.to_le_bytes()); - let length_mismatch = forge("length_mismatch", &length_mismatch); - let err = probe_index_metadata(&length_mismatch).unwrap_err(); - assert_eq!(err.kind(), std::io::ErrorKind::InvalidData); + let length_mismatch = forge("length_mismatch", &rank_header(8, 1)); + assert_err_contains( + probe_index_metadata(&length_mismatch), + "TVR1 payload truncated", + ); std::fs::remove_file(&length_mismatch).ok(); let mut huge_declared = Vec::new(); @@ -1056,6 +1080,96 @@ mod tests { std::fs::remove_file(&huge_declared).ok(); } + #[test] + fn probe_reports_header_field_context_for_truncated_headers() { + let cases: [(&str, Vec, &str); 5] = [ + ( + "short_magic", + b"TV".to_vec(), + "ordvec index header truncated while reading magic", + ), + ( + "rank_version", + b"TVR1".to_vec(), + "TVR1 header truncated while reading version", + ), + ( + "rankquant_bits", + b"TVRQ\x01".to_vec(), + "TVRQ header truncated while reading bits", + ), + ( + "bitmap_n_top", + { + let mut v = Vec::new(); + v.extend_from_slice(b"TVBM"); + v.push(VERSION); + v.extend_from_slice(&64u32.to_le_bytes()); + v + }, + "TVBM header truncated while reading n_top", + ), + ( + "sign_n_vectors", + { + let mut v = Vec::new(); + v.extend_from_slice(b"TVSB"); + v.push(VERSION); + v.extend_from_slice(&64u32.to_le_bytes()); + v + }, + "TVSB header truncated while reading n_vectors", + ), + ]; + for (suffix, bytes, expected) in cases { + let path = forge(suffix, &bytes); + assert_err_contains(probe_index_metadata(&path), expected); + std::fs::remove_file(&path).ok(); + } + } + + #[test] + fn probe_reports_distinct_payload_truncation_and_trailing_bytes_for_all_formats() { + let cases: [(&str, Vec, Vec, &str); 4] = [ + ("rank", rank_header(8, 1), rank_header(8, 0), "TVR1"), + ( + "rankquant", + rankquant_header(2, 8, 1), + rankquant_header(2, 8, 0), + "TVRQ", + ), + ( + "bitmap", + bitmap_header(64, 16, 1), + bitmap_header(64, 16, 0), + "TVBM", + ), + ( + "sign_bitmap", + sign_bitmap_header(64, 1), + sign_bitmap_header(64, 0), + "TVSB", + ), + ]; + + for (suffix, truncated_header, mut trailing_bytes, label) in cases { + let truncated = forge(&format!("{suffix}_truncated"), &truncated_header); + assert_err_contains( + probe_index_metadata(&truncated), + &format!("{label} payload truncated"), + ); + std::fs::remove_file(&truncated).ok(); + + trailing_bytes.push(0); + let trailing = forge(&format!("{suffix}_trailing"), &trailing_bytes); + assert_err_contains( + probe_index_metadata(&trailing), + &format!("{label} payload has trailing bytes"), + ); + std::fs::remove_file(&trailing).ok(); + } + } + #[test] fn probe_rejects_format_specific_header_errors() { let mut bad_bits = Vec::new(); diff --git a/tests/index/loader_validation.rs b/tests/index/loader_validation.rs index 918755cb..a454d5c5 100644 --- a/tests/index/loader_validation.rs +++ b/tests/index/loader_validation.rs @@ -29,6 +29,65 @@ fn tmp(name: &str) -> std::path::PathBuf { )) } +fn assert_load_err_contains(result: std::io::Result, expected: &str) { + let Err(err) = result else { + panic!("expected error containing {expected:?}, got Ok(_)"); + }; + let text = err.to_string(); + assert!( + text.contains(expected), + "expected error containing {expected:?}, got {text:?}" + ); +} + +fn set_u32_field(bytes: &mut [u8], offset: usize, value: u32) { + bytes[offset..offset + 4].copy_from_slice(&value.to_le_bytes()); +} + +fn rank_payload_cases(dim: usize) -> (Vec, Vec) { + let p = tmp("rank_empty_payload_case"); + Rank::new(dim).write(&p).unwrap(); + let trailing = read_bytes(&p); + std::fs::remove_file(&p).ok(); + assert_eq!(trailing.len(), 13, "empty Rank file is header-only"); + let mut truncated = trailing.clone(); + set_u32_field(&mut truncated, 9, 1); + (truncated, trailing) +} + +fn rankquant_payload_cases(bits: u8, dim: usize) -> (Vec, Vec) { + let p = tmp("rankquant_empty_payload_case"); + RankQuant::new(dim, bits).write(&p).unwrap(); + let trailing = read_bytes(&p); + std::fs::remove_file(&p).ok(); + assert_eq!(trailing.len(), 14, "empty RankQuant file is header-only"); + let mut truncated = trailing.clone(); + set_u32_field(&mut truncated, 10, 1); + (truncated, trailing) +} + +fn bitmap_payload_cases(dim: usize, n_top: usize) -> (Vec, Vec) { + let p = tmp("bitmap_empty_payload_case"); + Bitmap::new(dim, n_top).write(&p).unwrap(); + let trailing = read_bytes(&p); + std::fs::remove_file(&p).ok(); + assert_eq!(trailing.len(), 17, "empty Bitmap file is header-only"); + let mut truncated = trailing.clone(); + set_u32_field(&mut truncated, 13, 1); + (truncated, trailing) +} + +fn sign_bitmap_payload_cases(dim: usize) -> (Vec, Vec) { + let p = tmp("sign_bitmap_empty_payload_case"); + SignBitmap::new(dim).write(&p).unwrap(); + let trailing = read_bytes(&p); + std::fs::remove_file(&p).ok(); + assert_eq!(trailing.len(), 13, "empty SignBitmap file is header-only"); + let mut truncated = trailing.clone(); + set_u32_field(&mut truncated, 9, 1); + (truncated, trailing) +} + #[test] fn load_rank_rejects_non_permutation_row() { let corpus = make_corpus(1); @@ -142,6 +201,69 @@ fn load_sign_bitmap_accepts_any_bit_pattern() { ); } +#[test] +fn public_loaders_report_stable_malformed_payload_context() { + let rank = rank_payload_cases(4); + let rankquant = rankquant_payload_cases(2, 8); + let bitmap = bitmap_payload_cases(64, 16); + let sign_bitmap = sign_bitmap_payload_cases(64); + let cases: [(&str, Vec, Vec, &str); 4] = [ + ("rank", rank.0, rank.1, "TVR1"), + ("rankquant", rankquant.0, rankquant.1, "TVRQ"), + ("bitmap", bitmap.0, bitmap.1, "TVBM"), + ("sign_bitmap", sign_bitmap.0, sign_bitmap.1, "TVSB"), + ]; + + for (suffix, truncated_header, mut trailing_bytes, label) in cases { + let truncated = tmp(&format!("{suffix}_truncated_context")); + write_bytes(&truncated, &truncated_header); + match label { + "TVR1" => assert_load_err_contains( + Rank::load(&truncated), + &format!("{label} payload truncated"), + ), + "TVRQ" => assert_load_err_contains( + RankQuant::load(&truncated), + &format!("{label} payload truncated"), + ), + "TVBM" => assert_load_err_contains( + Bitmap::load(&truncated), + &format!("{label} payload truncated"), + ), + "TVSB" => assert_load_err_contains( + SignBitmap::load(&truncated), + &format!("{label} payload truncated"), + ), + _ => unreachable!(), + } + std::fs::remove_file(&truncated).ok(); + + trailing_bytes.push(0); + let trailing = tmp(&format!("{suffix}_trailing_context")); + write_bytes(&trailing, &trailing_bytes); + match label { + "TVR1" => assert_load_err_contains( + Rank::load(&trailing), + &format!("{label} payload has trailing bytes"), + ), + "TVRQ" => assert_load_err_contains( + RankQuant::load(&trailing), + &format!("{label} payload has trailing bytes"), + ), + "TVBM" => assert_load_err_contains( + Bitmap::load(&trailing), + &format!("{label} payload has trailing bytes"), + ), + "TVSB" => assert_load_err_contains( + SignBitmap::load(&trailing), + &format!("{label} payload has trailing bytes"), + ), + _ => unreachable!(), + } + std::fs::remove_file(&trailing).ok(); + } +} + #[test] fn loaders_reject_trailing_bytes() { // Every v1 format's payload is the file's final section, so the loader diff --git a/tests/index/main.rs b/tests/index/main.rs index 9f89d2f1..29649d2d 100644 --- a/tests/index/main.rs +++ b/tests/index/main.rs @@ -188,14 +188,15 @@ fn rank_io_loaders_reject_malformed_files_without_panicking() { v }), // TVR1 with truncated payload (header claims a payload bigger - // than what's on disk → read_exact returns UnexpectedEof, not - // a panic). + // than what's on disk → exact-length validation returns Err, not + // a panic or allocation attempt). ("tvr_truncated", { let mut v = Vec::new(); v.extend_from_slice(b"TVR1"); v.push(1); // Header claims 100 * 64 * 2 = 12800 payload bytes but only 100 - // are provided, so the loader hits UnexpectedEof, not a panic. + // are provided, so the loader rejects the exact-length mismatch, + // not a panic. v.extend_from_slice(&64u32.to_le_bytes()); // dim v.extend_from_slice(&100u32.to_le_bytes()); // n_vectors v.extend(std::iter::repeat_n(0u8, 100)); @@ -222,7 +223,7 @@ fn rank_io_loaders_reject_malformed_files_without_panicking() { }), // TVSB with truncated payload: header declares 8 docs * 128/64 = // 16 qwords = 128 payload bytes but the file ends right after the - // header, so read_exact yields UnexpectedEof rather than a panic. + // header, so exact-length validation fails before payload allocation. ("tvsb_truncated", { let mut v = Vec::new(); v.extend_from_slice(b"TVSB");