-
Notifications
You must be signed in to change notification settings - Fork 1.2k
Split traits for plain and bitpacked decoding and fix soundness issue in BitReader::get_batch #10172
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
+115
−137
Merged
Split traits for plain and bitpacked decoding and fix soundness issue in BitReader::get_batch #10172
Changes from all commits
Commits
Show all changes
4 commits
Select commit
Hold shift + click to select a range
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -36,11 +36,9 @@ fn array_from_slice<const N: usize>(bs: &[u8]) -> Result<[u8; N]> { | |
| } | ||
| } | ||
|
|
||
| /// # Safety | ||
| /// All bit patterns 00000xxxx, where there are `BIT_CAPACITY` `x`s, | ||
| /// must be valid, unless BIT_CAPACITY is 0. | ||
| pub unsafe trait FromBytes: Sized { | ||
| const BIT_CAPACITY: usize; | ||
| /// Types that can be decoded from plain representations. This includes non-primitive types like | ||
| /// `FixedLenByteArray` and also variable length types like `ByteArray`. | ||
| pub trait FromBytes: Sized { | ||
| type Buffer: AsMut<[u8]> + Default; | ||
| fn try_from_le_slice(b: &[u8]) -> Result<Self>; | ||
| fn from_le_bytes(bs: Self::Buffer) -> Self; | ||
|
|
@@ -52,17 +50,28 @@ pub unsafe trait FromBytes: Sized { | |
| /// directly converted from a u64 value. Types like Int96, ByteArray, | ||
| /// and FixedLenByteArray that cannot be represented in 64 bits do not | ||
| /// implement this trait. | ||
| pub trait FromBitpacked: FromBytes { | ||
| pub trait FromBitpacked { | ||
| /// The maximum number of bits that are allowed to be converted to this type. | ||
| /// This is at most the size of the type in bits, but could be less, for example | ||
| /// for the boolean type. | ||
| const BIT_CAPACITY: usize; | ||
| /// How many values are converted by one call to `unpack_batch`. | ||
| const BATCH_SIZE: usize; | ||
| /// Convert directly from a u64 value by truncation, avoiding byte slice copies. | ||
| fn from_u64(v: u64) -> Self; | ||
|
|
||
| /// Converts multiple bitpacked values from `input` to `output`. | ||
| /// The `output` slice needs to have space for at least `BATCH_SIZE` elements, | ||
| /// otherwise this method will panic. | ||
| fn unpack_batch(input: &[u8], output: &mut [Self], num_bits: usize) | ||
| where | ||
| Self: Sized; | ||
| } | ||
|
|
||
| macro_rules! from_le_bytes { | ||
| ($($ty: ty),*) => { | ||
| $( | ||
| // SAFETY: this macro is used for types for which all bit patterns are valid. | ||
| unsafe impl FromBytes for $ty { | ||
| const BIT_CAPACITY: usize = std::mem::size_of::<$ty>() * 8; | ||
| impl FromBytes for $ty { | ||
| type Buffer = [u8; size_of::<Self>()]; | ||
| fn try_from_le_slice(b: &[u8]) -> Result<Self> { | ||
| Ok(Self::from_le_bytes(array_from_slice(b)?)) | ||
|
|
@@ -71,59 +80,111 @@ macro_rules! from_le_bytes { | |
| <$ty>::from_le_bytes(bs) | ||
| } | ||
| } | ||
| impl FromBitpacked for $ty { | ||
| #[inline] | ||
| fn from_u64(v: u64) -> Self { | ||
| v as Self | ||
| } | ||
| } | ||
| )* | ||
| }; | ||
| } | ||
|
|
||
| from_le_bytes! { u8, u16, u32, u64, i8, i16, i32, i64 } | ||
| macro_rules! from_bitpacked { | ||
| ($($ty: ty => $unpack: path),*) => { | ||
| $( | ||
| impl FromBitpacked for $ty { | ||
| const BIT_CAPACITY: usize = std::mem::size_of::<$ty>() * 8; | ||
| // this has to match the signature of the unpack* functions | ||
| const BATCH_SIZE: usize = std::mem::size_of::<$ty>() * 8; | ||
|
|
||
| #[inline] | ||
| fn from_u64(v: u64) -> Self { | ||
| v as _ | ||
| } | ||
|
|
||
| // SAFETY: all bit patterns are valid for f32 and f64. | ||
| unsafe impl FromBytes for f32 { | ||
| const BIT_CAPACITY: usize = 32; | ||
| type Buffer = [u8; 4]; | ||
| fn try_from_le_slice(b: &[u8]) -> Result<Self> { | ||
| Ok(Self::from_le_bytes(array_from_slice(b)?)) | ||
| #[inline] | ||
| fn unpack_batch(input: &[u8], output: &mut [Self], num_bits: usize) { | ||
| $unpack(input, (&mut output[..Self::BATCH_SIZE]).try_into().unwrap(), num_bits) | ||
| } | ||
| } | ||
| )* | ||
| } | ||
| fn from_le_bytes(bs: Self::Buffer) -> Self { | ||
| f32::from_le_bytes(bs) | ||
| } | ||
|
|
||
| macro_rules! from_bitpacked_delegate { | ||
| ($($ty: ty => $delegate: ty),*) => { | ||
| $( | ||
| impl FromBitpacked for $ty { | ||
| const BIT_CAPACITY: usize = <$delegate as FromBitpacked>::BIT_CAPACITY; | ||
| const BATCH_SIZE: usize = <$delegate as FromBitpacked>::BATCH_SIZE; | ||
|
|
||
| #[inline] | ||
| fn from_u64(v: u64) -> Self { | ||
| v as _ | ||
| } | ||
|
|
||
| #[inline] | ||
| fn unpack_batch(input: &[u8], output: &mut [Self], num_bits: usize) { | ||
| // Guard against misusages of this macro, due to the const block this will fail | ||
| // already at compile-time if the types are not compatible. | ||
| const { | ||
| assert!( | ||
| std::mem::size_of::<$ty>() == std::mem::size_of::<$delegate>() | ||
| && std::mem::align_of::<$ty>() == std::mem::align_of::<$delegate>(), | ||
| "types need to have the same size and alignment" | ||
| ); | ||
| } | ||
| // Safety: ty and delegate have the same size and alignment, and this macro is only used for types that have transmutable bit patterns. | ||
| let output: &mut [$delegate] = unsafe { std::slice::from_raw_parts_mut(output.as_mut_ptr().cast::<$delegate>(), output.len()) }; | ||
| <$delegate>::unpack_batch(input, output, num_bits); | ||
| } | ||
| } | ||
| )* | ||
| } | ||
| } | ||
|
|
||
| impl FromBitpacked for f32 { | ||
| from_le_bytes! { u8, u16, u32, u64, i8, i16, i32, i64 } | ||
| from_bitpacked!(u8 => unpack8, u16 => unpack16, u32 => unpack32, u64 => unpack64); | ||
| from_bitpacked_delegate!(i8 => u8, i16 => u16, i32 => u32, i64 => u64); | ||
|
|
||
| impl FromBitpacked for bool { | ||
| const BIT_CAPACITY: usize = 1; | ||
| const BATCH_SIZE: usize = <u8 as FromBitpacked>::BATCH_SIZE; | ||
|
|
||
| #[inline] | ||
| fn from_u64(v: u64) -> Self { | ||
| f32::from_bits(v as u32) | ||
| v != 0 | ||
| } | ||
|
|
||
| #[inline] | ||
| fn unpack_batch(input: &[u8], output: &mut [Self], num_bits: usize) { | ||
| assert!(num_bits == 1); | ||
|
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The compiler should hoist this outside of any loop, or even realize that for boolean it will only get called with a bit width of 1. I just searched for benchmarks for the boolean data type and did not find any, so it might make sense to add a benchmark first. |
||
| // Safety: | ||
| // we asserted that we will only decode with a bitwidth of 1, | ||
| // so the u8 can only be 0 or 1, which are the valid representations of a bool. | ||
| let output: &mut [u8] = unsafe { | ||
| std::slice::from_raw_parts_mut(output.as_mut_ptr().cast::<u8>(), output.len()) | ||
| }; | ||
| u8::unpack_batch(input, output, num_bits); | ||
| } | ||
| } | ||
|
|
||
| // SAFETY: all bit patterns are valid for f64. | ||
| unsafe impl FromBytes for f64 { | ||
| const BIT_CAPACITY: usize = 64; | ||
| type Buffer = [u8; 8]; | ||
| impl FromBytes for f32 { | ||
| type Buffer = [u8; 4]; | ||
| fn try_from_le_slice(b: &[u8]) -> Result<Self> { | ||
| Ok(Self::from_le_bytes(array_from_slice(b)?)) | ||
| } | ||
| fn from_le_bytes(bs: Self::Buffer) -> Self { | ||
| f64::from_le_bytes(bs) | ||
| f32::from_le_bytes(bs) | ||
| } | ||
| } | ||
|
|
||
| impl FromBitpacked for f64 { | ||
| #[inline] | ||
| fn from_u64(v: u64) -> Self { | ||
| f64::from_bits(v) | ||
| impl FromBytes for f64 { | ||
| type Buffer = [u8; 8]; | ||
| fn try_from_le_slice(b: &[u8]) -> Result<Self> { | ||
| Ok(Self::from_le_bytes(array_from_slice(b)?)) | ||
| } | ||
| fn from_le_bytes(bs: Self::Buffer) -> Self { | ||
| f64::from_le_bytes(bs) | ||
| } | ||
| } | ||
|
|
||
| // SAFETY: the 0000000x bit pattern is always valid for `bool`. | ||
| unsafe impl FromBytes for bool { | ||
| const BIT_CAPACITY: usize = 1; | ||
| impl FromBytes for bool { | ||
| type Buffer = [u8; 1]; | ||
|
|
||
| fn try_from_le_slice(b: &[u8]) -> Result<Self> { | ||
|
|
@@ -134,16 +195,7 @@ unsafe impl FromBytes for bool { | |
| } | ||
| } | ||
|
|
||
| impl FromBitpacked for bool { | ||
| #[inline] | ||
| fn from_u64(v: u64) -> Self { | ||
| v != 0 | ||
| } | ||
| } | ||
|
|
||
| // SAFETY: BIT_CAPACITY is 0. | ||
| unsafe impl FromBytes for Int96 { | ||
| const BIT_CAPACITY: usize = 0; | ||
| impl FromBytes for Int96 { | ||
| type Buffer = [u8; 12]; | ||
|
|
||
| fn try_from_le_slice(b: &[u8]) -> Result<Self> { | ||
|
|
@@ -168,9 +220,7 @@ unsafe impl FromBytes for Int96 { | |
| } | ||
| } | ||
|
|
||
| // SAFETY: BIT_CAPACITY is 0. | ||
| unsafe impl FromBytes for ByteArray { | ||
| const BIT_CAPACITY: usize = 0; | ||
| impl FromBytes for ByteArray { | ||
| type Buffer = Vec<u8>; | ||
|
|
||
| fn try_from_le_slice(b: &[u8]) -> Result<Self> { | ||
|
|
@@ -181,9 +231,7 @@ unsafe impl FromBytes for ByteArray { | |
| } | ||
| } | ||
|
|
||
| // SAFETY: BIT_CAPACITY is 0. | ||
| unsafe impl FromBytes for FixedLenByteArray { | ||
| const BIT_CAPACITY: usize = 0; | ||
| impl FromBytes for FixedLenByteArray { | ||
| type Buffer = Vec<u8>; | ||
|
|
||
| fn try_from_le_slice(b: &[u8]) -> Result<Self> { | ||
|
|
@@ -671,64 +719,10 @@ impl BitReader { | |
| assert!(num_bits <= T::BIT_CAPACITY); | ||
|
|
||
| // Read directly into output buffer | ||
| match size_of::<T>() { | ||
| 1 => { | ||
| let ptr = batch.as_mut_ptr() as *mut u8; | ||
| // SAFETY: batch is properly aligned and sized. Caller guarantees that all bit patterns | ||
| // in which only the lowest T::BIT_CAPACITY bits of T are set are valid, | ||
| // unpack{8,16,32,64} only set to non0 the lowest num_bits bits, and we | ||
| // checked that num_bits <= T::BIT_CAPACITY. | ||
| let out = unsafe { std::slice::from_raw_parts_mut(ptr, batch.len()) }; | ||
| while values_to_read - i >= 8 { | ||
| let out_slice = (&mut out[i..i + 8]).try_into().unwrap(); | ||
| unpack8(&self.buffer[self.byte_offset..], out_slice, num_bits); | ||
| self.byte_offset += num_bits; | ||
| i += 8; | ||
| } | ||
| } | ||
| 2 => { | ||
| let ptr = batch.as_mut_ptr() as *mut u16; | ||
| // SAFETY: batch is properly aligned and sized. Caller guarantees that all bit patterns | ||
| // in which only the lowest T::BIT_CAPACITY bits of T are set are valid, | ||
| // unpack{8,16,32,64} only set to non0 the lowest num_bits bits, and we | ||
| // checked that num_bits <= T::BIT_CAPACITY. | ||
| let out = unsafe { std::slice::from_raw_parts_mut(ptr, batch.len()) }; | ||
| while values_to_read - i >= 16 { | ||
| let out_slice = (&mut out[i..i + 16]).try_into().unwrap(); | ||
| unpack16(&self.buffer[self.byte_offset..], out_slice, num_bits); | ||
| self.byte_offset += 2 * num_bits; | ||
| i += 16; | ||
| } | ||
| } | ||
| 4 => { | ||
| let ptr = batch.as_mut_ptr() as *mut u32; | ||
| // SAFETY: batch is properly aligned and sized. Caller guarantees that all bit patterns | ||
| // in which only the lowest T::BIT_CAPACITY bits of T are set are valid, | ||
| // unpack{8,16,32,64} only set to non0 the lowest num_bits bits, and we | ||
| // checked that num_bits <= T::BIT_CAPACITY. | ||
| let out = unsafe { std::slice::from_raw_parts_mut(ptr, batch.len()) }; | ||
| while values_to_read - i >= 32 { | ||
| let out_slice = (&mut out[i..i + 32]).try_into().unwrap(); | ||
| unpack32(&self.buffer[self.byte_offset..], out_slice, num_bits); | ||
| self.byte_offset += 4 * num_bits; | ||
| i += 32; | ||
| } | ||
| } | ||
| 8 => { | ||
| let ptr = batch.as_mut_ptr() as *mut u64; | ||
| // SAFETY: batch is properly aligned and sized. Caller guarantees that all bit patterns | ||
| // in which only the lowest T::BIT_CAPACITY bits of T are set are valid, | ||
| // unpack{8,16,32,64} only set to non0 the lowest num_bits bits, and we | ||
| // checked that num_bits <= T::BIT_CAPACITY. | ||
| let out = unsafe { std::slice::from_raw_parts_mut(ptr, batch.len()) }; | ||
| while values_to_read - i >= 64 { | ||
| let out_slice = (&mut out[i..i + 64]).try_into().unwrap(); | ||
| unpack64(&self.buffer[self.byte_offset..], out_slice, num_bits); | ||
| self.byte_offset += 8 * num_bits; | ||
| i += 64; | ||
| } | ||
| } | ||
| _ => unreachable!(), | ||
| while values_to_read - i >= T::BATCH_SIZE { | ||
| T::unpack_batch(&self.buffer[self.byte_offset..], &mut batch[i..], num_bits); | ||
| self.byte_offset += num_bits * T::BATCH_SIZE / 8; | ||
| i += T::BATCH_SIZE; | ||
| } | ||
|
|
||
| // Try to read smaller batches if possible | ||
|
|
@@ -738,10 +732,7 @@ impl BitReader { | |
| self.byte_offset += 4 * num_bits; | ||
|
|
||
| for out in out_buf { | ||
| // Zero-allocate buffer | ||
| let mut out_bytes = T::Buffer::default(); | ||
| out_bytes.as_mut()[..4].copy_from_slice(&out.to_le_bytes()); | ||
| batch[i] = T::from_le_bytes(out_bytes); | ||
| batch[i] = T::from_u64(out as u64); | ||
| i += 1; | ||
| } | ||
| } | ||
|
|
@@ -752,10 +743,7 @@ impl BitReader { | |
| self.byte_offset += 2 * num_bits; | ||
|
|
||
| for out in out_buf { | ||
| // Zero-allocate buffer | ||
| let mut out_bytes = T::Buffer::default(); | ||
| out_bytes.as_mut()[..2].copy_from_slice(&out.to_le_bytes()); | ||
| batch[i] = T::from_le_bytes(out_bytes); | ||
| batch[i] = T::from_u64(out as u64); | ||
| i += 1; | ||
| } | ||
| } | ||
|
|
@@ -766,10 +754,7 @@ impl BitReader { | |
| self.byte_offset += num_bits; | ||
|
|
||
| for out in out_buf { | ||
| // Zero-allocate buffer | ||
| let mut out_bytes = T::Buffer::default(); | ||
| out_bytes.as_mut()[..1].copy_from_slice(&out.to_le_bytes()); | ||
| batch[i] = T::from_le_bytes(out_bytes); | ||
| batch[i] = T::from_u64(out as u64); | ||
| i += 1; | ||
| } | ||
| } | ||
|
|
@@ -1252,10 +1237,7 @@ mod tests { | |
| .collect(); | ||
|
|
||
| // Generic values used to check against actual values read from `get_batch`. | ||
| let expected_values: Vec<T> = values | ||
| .iter() | ||
| .map(|v| T::try_from_le_slice(v.as_bytes()).unwrap()) | ||
| .collect(); | ||
| let expected_values: Vec<T> = values.iter().map(|v| T::from_u64(*v)).collect(); | ||
|
|
||
| (0..total).for_each(|i| writer.put_value(values[i], num_bits)); | ||
|
|
||
|
|
||
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.