diff --git a/parquet/src/encodings/rle.rs b/parquet/src/encodings/rle.rs index 47e52d00b094..5de9dc850a34 100644 --- a/parquet/src/encodings/rle.rs +++ b/parquet/src/encodings/rle.rs @@ -400,13 +400,10 @@ impl RleDecoder { } let value = if self.rle_left > 0 { - let rle_value = T::try_from_le_slice( - &self - .current_value - .as_mut() - .ok_or_else(|| general_err!("current_value should be Some"))? - .to_ne_bytes(), - )?; + let current_value = self + .current_value + .ok_or_else(|| general_err!("current_value should be Some"))?; + let rle_value = T::from_u64(current_value); self.rle_left -= 1; rle_value } else { @@ -433,8 +430,7 @@ impl RleDecoder { while values_read < buffer.len() { if self.rle_left > 0 { let num_values = cmp::min(buffer.len() - values_read, self.rle_left as usize); - let repeated_value = - T::try_from_le_slice(&self.current_value.as_mut().unwrap().to_ne_bytes())?; + let repeated_value = T::from_u64(self.current_value.unwrap()); buffer[values_read..values_read + num_values].fill(repeated_value); self.rle_left -= num_values as u32; values_read += num_values; diff --git a/parquet/src/util/bit_util.rs b/parquet/src/util/bit_util.rs index d0b13072dc4e..22c126261b53 100644 --- a/parquet/src/util/bit_util.rs +++ b/parquet/src/util/bit_util.rs @@ -36,11 +36,9 @@ fn array_from_slice(bs: &[u8]) -> Result<[u8; N]> { } } -/// # Safety -/// All bit patterns 00000xxxx, where there are `BIT_CAPACITY` `x`s, -/// must be valid, unless BIT_CAPACITY is 0. -pub unsafe trait FromBytes: Sized { - const BIT_CAPACITY: usize; +/// Types that can be decoded from plain representations. This includes non-primitive types like +/// `FixedLenByteArray` and also variable length types like `ByteArray`. +pub trait FromBytes: Sized { type Buffer: AsMut<[u8]> + Default; fn try_from_le_slice(b: &[u8]) -> Result; fn from_le_bytes(bs: Self::Buffer) -> Self; @@ -52,17 +50,28 @@ pub unsafe trait FromBytes: Sized { /// directly converted from a u64 value. Types like Int96, ByteArray, /// and FixedLenByteArray that cannot be represented in 64 bits do not /// implement this trait. -pub trait FromBitpacked: FromBytes { +pub trait FromBitpacked { + /// The maximum number of bits that are allowed to be converted to this type. + /// This is at most the size of the type in bits, but could be less, for example + /// for the boolean type. + const BIT_CAPACITY: usize; + /// How many values are converted by one call to `unpack_batch`. + const BATCH_SIZE: usize; /// Convert directly from a u64 value by truncation, avoiding byte slice copies. fn from_u64(v: u64) -> Self; + + /// Converts multiple bitpacked values from `input` to `output`. + /// The `output` slice needs to have space for at least `BATCH_SIZE` elements, + /// otherwise this method will panic. + fn unpack_batch(input: &[u8], output: &mut [Self], num_bits: usize) + where + Self: Sized; } macro_rules! from_le_bytes { ($($ty: ty),*) => { $( - // SAFETY: this macro is used for types for which all bit patterns are valid. - unsafe impl FromBytes for $ty { - const BIT_CAPACITY: usize = std::mem::size_of::<$ty>() * 8; + impl FromBytes for $ty { type Buffer = [u8; size_of::()]; fn try_from_le_slice(b: &[u8]) -> Result { Ok(Self::from_le_bytes(array_from_slice(b)?)) @@ -71,59 +80,111 @@ macro_rules! from_le_bytes { <$ty>::from_le_bytes(bs) } } - impl FromBitpacked for $ty { - #[inline] - fn from_u64(v: u64) -> Self { - v as Self - } - } )* }; } -from_le_bytes! { u8, u16, u32, u64, i8, i16, i32, i64 } +macro_rules! from_bitpacked { + ($($ty: ty => $unpack: path),*) => { + $( + impl FromBitpacked for $ty { + const BIT_CAPACITY: usize = std::mem::size_of::<$ty>() * 8; + // this has to match the signature of the unpack* functions + const BATCH_SIZE: usize = std::mem::size_of::<$ty>() * 8; + + #[inline] + fn from_u64(v: u64) -> Self { + v as _ + } -// SAFETY: all bit patterns are valid for f32 and f64. -unsafe impl FromBytes for f32 { - const BIT_CAPACITY: usize = 32; - type Buffer = [u8; 4]; - fn try_from_le_slice(b: &[u8]) -> Result { - Ok(Self::from_le_bytes(array_from_slice(b)?)) + #[inline] + fn unpack_batch(input: &[u8], output: &mut [Self], num_bits: usize) { + $unpack(input, (&mut output[..Self::BATCH_SIZE]).try_into().unwrap(), num_bits) + } + } + )* } - fn from_le_bytes(bs: Self::Buffer) -> Self { - f32::from_le_bytes(bs) +} + +macro_rules! from_bitpacked_delegate { + ($($ty: ty => $delegate: ty),*) => { + $( + impl FromBitpacked for $ty { + const BIT_CAPACITY: usize = <$delegate as FromBitpacked>::BIT_CAPACITY; + const BATCH_SIZE: usize = <$delegate as FromBitpacked>::BATCH_SIZE; + + #[inline] + fn from_u64(v: u64) -> Self { + v as _ + } + + #[inline] + fn unpack_batch(input: &[u8], output: &mut [Self], num_bits: usize) { + // Guard against misusages of this macro, due to the const block this will fail + // already at compile-time if the types are not compatible. + const { + assert!( + std::mem::size_of::<$ty>() == std::mem::size_of::<$delegate>() + && std::mem::align_of::<$ty>() == std::mem::align_of::<$delegate>(), + "types need to have the same size and alignment" + ); + } + // Safety: ty and delegate have the same size and alignment, and this macro is only used for types that have transmutable bit patterns. + let output: &mut [$delegate] = unsafe { std::slice::from_raw_parts_mut(output.as_mut_ptr().cast::<$delegate>(), output.len()) }; + <$delegate>::unpack_batch(input, output, num_bits); + } + } + )* } } -impl FromBitpacked for f32 { +from_le_bytes! { u8, u16, u32, u64, i8, i16, i32, i64 } +from_bitpacked!(u8 => unpack8, u16 => unpack16, u32 => unpack32, u64 => unpack64); +from_bitpacked_delegate!(i8 => u8, i16 => u16, i32 => u32, i64 => u64); + +impl FromBitpacked for bool { + const BIT_CAPACITY: usize = 1; + const BATCH_SIZE: usize = ::BATCH_SIZE; + #[inline] fn from_u64(v: u64) -> Self { - f32::from_bits(v as u32) + v != 0 + } + + #[inline] + fn unpack_batch(input: &[u8], output: &mut [Self], num_bits: usize) { + assert!(num_bits == 1); + // Safety: + // we asserted that we will only decode with a bitwidth of 1, + // so the u8 can only be 0 or 1, which are the valid representations of a bool. + let output: &mut [u8] = unsafe { + std::slice::from_raw_parts_mut(output.as_mut_ptr().cast::(), output.len()) + }; + u8::unpack_batch(input, output, num_bits); } } -// SAFETY: all bit patterns are valid for f64. -unsafe impl FromBytes for f64 { - const BIT_CAPACITY: usize = 64; - type Buffer = [u8; 8]; +impl FromBytes for f32 { + type Buffer = [u8; 4]; fn try_from_le_slice(b: &[u8]) -> Result { Ok(Self::from_le_bytes(array_from_slice(b)?)) } fn from_le_bytes(bs: Self::Buffer) -> Self { - f64::from_le_bytes(bs) + f32::from_le_bytes(bs) } } -impl FromBitpacked for f64 { - #[inline] - fn from_u64(v: u64) -> Self { - f64::from_bits(v) +impl FromBytes for f64 { + type Buffer = [u8; 8]; + fn try_from_le_slice(b: &[u8]) -> Result { + Ok(Self::from_le_bytes(array_from_slice(b)?)) + } + fn from_le_bytes(bs: Self::Buffer) -> Self { + f64::from_le_bytes(bs) } } -// SAFETY: the 0000000x bit pattern is always valid for `bool`. -unsafe impl FromBytes for bool { - const BIT_CAPACITY: usize = 1; +impl FromBytes for bool { type Buffer = [u8; 1]; fn try_from_le_slice(b: &[u8]) -> Result { @@ -134,16 +195,7 @@ unsafe impl FromBytes for bool { } } -impl FromBitpacked for bool { - #[inline] - fn from_u64(v: u64) -> Self { - v != 0 - } -} - -// SAFETY: BIT_CAPACITY is 0. -unsafe impl FromBytes for Int96 { - const BIT_CAPACITY: usize = 0; +impl FromBytes for Int96 { type Buffer = [u8; 12]; fn try_from_le_slice(b: &[u8]) -> Result { @@ -168,9 +220,7 @@ unsafe impl FromBytes for Int96 { } } -// SAFETY: BIT_CAPACITY is 0. -unsafe impl FromBytes for ByteArray { - const BIT_CAPACITY: usize = 0; +impl FromBytes for ByteArray { type Buffer = Vec; fn try_from_le_slice(b: &[u8]) -> Result { @@ -181,9 +231,7 @@ unsafe impl FromBytes for ByteArray { } } -// SAFETY: BIT_CAPACITY is 0. -unsafe impl FromBytes for FixedLenByteArray { - const BIT_CAPACITY: usize = 0; +impl FromBytes for FixedLenByteArray { type Buffer = Vec; fn try_from_le_slice(b: &[u8]) -> Result { @@ -671,64 +719,10 @@ impl BitReader { assert!(num_bits <= T::BIT_CAPACITY); // Read directly into output buffer - match size_of::() { - 1 => { - let ptr = batch.as_mut_ptr() as *mut u8; - // SAFETY: batch is properly aligned and sized. Caller guarantees that all bit patterns - // in which only the lowest T::BIT_CAPACITY bits of T are set are valid, - // unpack{8,16,32,64} only set to non0 the lowest num_bits bits, and we - // checked that num_bits <= T::BIT_CAPACITY. - let out = unsafe { std::slice::from_raw_parts_mut(ptr, batch.len()) }; - while values_to_read - i >= 8 { - let out_slice = (&mut out[i..i + 8]).try_into().unwrap(); - unpack8(&self.buffer[self.byte_offset..], out_slice, num_bits); - self.byte_offset += num_bits; - i += 8; - } - } - 2 => { - let ptr = batch.as_mut_ptr() as *mut u16; - // SAFETY: batch is properly aligned and sized. Caller guarantees that all bit patterns - // in which only the lowest T::BIT_CAPACITY bits of T are set are valid, - // unpack{8,16,32,64} only set to non0 the lowest num_bits bits, and we - // checked that num_bits <= T::BIT_CAPACITY. - let out = unsafe { std::slice::from_raw_parts_mut(ptr, batch.len()) }; - while values_to_read - i >= 16 { - let out_slice = (&mut out[i..i + 16]).try_into().unwrap(); - unpack16(&self.buffer[self.byte_offset..], out_slice, num_bits); - self.byte_offset += 2 * num_bits; - i += 16; - } - } - 4 => { - let ptr = batch.as_mut_ptr() as *mut u32; - // SAFETY: batch is properly aligned and sized. Caller guarantees that all bit patterns - // in which only the lowest T::BIT_CAPACITY bits of T are set are valid, - // unpack{8,16,32,64} only set to non0 the lowest num_bits bits, and we - // checked that num_bits <= T::BIT_CAPACITY. - let out = unsafe { std::slice::from_raw_parts_mut(ptr, batch.len()) }; - while values_to_read - i >= 32 { - let out_slice = (&mut out[i..i + 32]).try_into().unwrap(); - unpack32(&self.buffer[self.byte_offset..], out_slice, num_bits); - self.byte_offset += 4 * num_bits; - i += 32; - } - } - 8 => { - let ptr = batch.as_mut_ptr() as *mut u64; - // SAFETY: batch is properly aligned and sized. Caller guarantees that all bit patterns - // in which only the lowest T::BIT_CAPACITY bits of T are set are valid, - // unpack{8,16,32,64} only set to non0 the lowest num_bits bits, and we - // checked that num_bits <= T::BIT_CAPACITY. - let out = unsafe { std::slice::from_raw_parts_mut(ptr, batch.len()) }; - while values_to_read - i >= 64 { - let out_slice = (&mut out[i..i + 64]).try_into().unwrap(); - unpack64(&self.buffer[self.byte_offset..], out_slice, num_bits); - self.byte_offset += 8 * num_bits; - i += 64; - } - } - _ => unreachable!(), + while values_to_read - i >= T::BATCH_SIZE { + T::unpack_batch(&self.buffer[self.byte_offset..], &mut batch[i..], num_bits); + self.byte_offset += num_bits * T::BATCH_SIZE / 8; + i += T::BATCH_SIZE; } // Try to read smaller batches if possible @@ -738,10 +732,7 @@ impl BitReader { self.byte_offset += 4 * num_bits; for out in out_buf { - // Zero-allocate buffer - let mut out_bytes = T::Buffer::default(); - out_bytes.as_mut()[..4].copy_from_slice(&out.to_le_bytes()); - batch[i] = T::from_le_bytes(out_bytes); + batch[i] = T::from_u64(out as u64); i += 1; } } @@ -752,10 +743,7 @@ impl BitReader { self.byte_offset += 2 * num_bits; for out in out_buf { - // Zero-allocate buffer - let mut out_bytes = T::Buffer::default(); - out_bytes.as_mut()[..2].copy_from_slice(&out.to_le_bytes()); - batch[i] = T::from_le_bytes(out_bytes); + batch[i] = T::from_u64(out as u64); i += 1; } } @@ -766,10 +754,7 @@ impl BitReader { self.byte_offset += num_bits; for out in out_buf { - // Zero-allocate buffer - let mut out_bytes = T::Buffer::default(); - out_bytes.as_mut()[..1].copy_from_slice(&out.to_le_bytes()); - batch[i] = T::from_le_bytes(out_bytes); + batch[i] = T::from_u64(out as u64); i += 1; } } @@ -1252,10 +1237,7 @@ mod tests { .collect(); // Generic values used to check against actual values read from `get_batch`. - let expected_values: Vec = values - .iter() - .map(|v| T::try_from_le_slice(v.as_bytes()).unwrap()) - .collect(); + let expected_values: Vec = values.iter().map(|v| T::from_u64(*v)).collect(); (0..total).for_each(|i| writer.put_value(values[i], num_bits));