Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions parquet-variant-compute/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ parquet-variant-json = { workspace = true }
chrono = { workspace = true }
uuid = { version = "1.18.0", features = ["v4"] }
serde_json = "1.0"
num-traits = { version = "0.2", default-features = false }

# uuid requires the `js` feature to run on wasm
[target.'cfg(target_arch = "wasm32")'.dependencies]
Expand Down
222 changes: 218 additions & 4 deletions parquet-variant-compute/src/shred_variant.rs
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,7 @@ pub(crate) fn shred_variant_with_options(
cast_options,
array.len(),
NullValue::TopLevelVariant,
true,
)?;
for i in 0..array.len() {
if array.is_null(i) {
Expand Down Expand Up @@ -145,6 +146,7 @@ pub(crate) fn make_variant_to_shredded_variant_arrow_row_builder<'a>(
cast_options: &'a CastOptions,
capacity: usize,
null_value: NullValue,
shred: bool,
) -> Result<VariantToShreddedVariantRowBuilder<'a>> {
let builder = match data_type {
DataType::Struct(fields) => {
Expand All @@ -153,6 +155,7 @@ pub(crate) fn make_variant_to_shredded_variant_arrow_row_builder<'a>(
cast_options,
capacity,
null_value,
shred,
)?;
VariantToShreddedVariantRowBuilder::Object(typed_value_builder)
}
Expand Down Expand Up @@ -193,7 +196,7 @@ pub(crate) fn make_variant_to_shredded_variant_arrow_row_builder<'a>(
| DataType::FixedSizeBinary(16) // UUID
=> {
let builder =
make_primitive_variant_to_arrow_row_builder(data_type, cast_options, capacity)?;
make_primitive_variant_to_arrow_row_builder(data_type, cast_options, capacity, shred)?;
let typed_value_builder =
VariantToShreddedPrimitiveVariantRowBuilder::new(builder, capacity, null_value);
VariantToShreddedVariantRowBuilder::Primitive(typed_value_builder)
Expand Down Expand Up @@ -369,13 +372,15 @@ impl<'a> VariantToShreddedObjectVariantRowBuilder<'a> {
cast_options: &'a CastOptions,
capacity: usize,
null_value: NullValue,
shred: bool,
) -> Result<Self> {
let typed_value_builders = fields.iter().map(|field| {
let builder = make_variant_to_shredded_variant_arrow_row_builder(
field.data_type(),
cast_options,
capacity,
NullValue::ObjectField,
shred,
)?;
Ok((field.name().as_str(), builder))
});
Expand Down Expand Up @@ -710,9 +715,12 @@ mod tests {
use arrow::datatypes::{
ArrowPrimitiveType, DataType, Field, Fields, Int64Type, TimeUnit, UnionFields, UnionMode,
};
use arrow_schema::IntervalUnit;
use chrono::{DateTime, NaiveDate, NaiveTime};
use parquet_variant::{
BuilderSpecificState, EMPTY_VARIANT_METADATA_BYTES, ObjectBuilder, ReadOnlyMetadataBuilder,
Variant, VariantBuilder, VariantPath, VariantPathElement,
ShortString, Variant, VariantBuilder, VariantDecimal4, VariantDecimal8, VariantDecimal16,
VariantPath, VariantPathElement,
};
use std::sync::Arc;
use uuid::Uuid;
Expand Down Expand Up @@ -1046,6 +1054,7 @@ mod tests {
&cast_options,
1,
mode,
true,
)
.unwrap();
primitive_builder.append_null().unwrap();
Expand Down Expand Up @@ -1076,6 +1085,7 @@ mod tests {
&cast_options,
1,
mode,
true,
)
.unwrap();
array_builder.append_null().unwrap();
Expand Down Expand Up @@ -1104,6 +1114,7 @@ mod tests {
&cast_options,
1,
mode,
true,
)
.unwrap();
object_builder.append_null().unwrap();
Expand Down Expand Up @@ -1310,7 +1321,7 @@ mod tests {
.downcast_ref::<arrow::array::Int32Array>()
.unwrap();
assert_eq!(typed_value_int32.value(0), 42);
assert_eq!(typed_value_int32.value(1), 3);
assert!(typed_value_int32.is_null(1)); // float doesn't shred to int32
assert!(typed_value_int32.is_null(2)); // string doesn't convert to int32

// Test Float64 target
Expand All @@ -1321,7 +1332,7 @@ mod tests {
.as_any()
.downcast_ref::<Float64Array>()
.unwrap();
assert_eq!(typed_value_float64.value(0), 42.0); // int converts to float
assert!(typed_value_float64.is_null(0)); // int doesn't shred to float
assert_eq!(typed_value_float64.value(1), 3.15);
assert!(typed_value_float64.is_null(2)); // string doesn't convert
}
Expand Down Expand Up @@ -2807,4 +2818,207 @@ mod tests {
let shredding_type = ShreddedSchemaBuilder::default().build();
assert_eq!(shredding_type, DataType::Null);
}

// This test wants to cover that the variant can/can't be shredded to the given data type.
#[test]
fn test_variant_type_shredded_correctly() {
// array contains all variant types
let mut array_builder = VariantArrayBuilder::new(30);
array_builder.append_value(Variant::Null);
array_builder.append_value(Variant::Int8(1));
array_builder.append_value(Variant::Int16(2));
array_builder.append_value(Variant::Int32(3));
array_builder.append_value(Variant::Int64(4));
array_builder.append_value(Variant::Date(NaiveDate::from_epoch_days(12345).unwrap()));
array_builder.append_value(Variant::TimestampMicros(
DateTime::from_timestamp_micros(123456789).unwrap(),
));
array_builder.append_value(Variant::TimestampNtzMicros(
DateTime::from_timestamp_micros(123456789)
.unwrap()
.naive_utc(),
));
array_builder.append_value(Variant::TimestampNanos(DateTime::from_timestamp_nanos(
1234567890123,
)));
array_builder.append_value(Variant::TimestampNtzNanos(
DateTime::from_timestamp_nanos(1234567890123).naive_utc(),
));
array_builder.append_value(VariantDecimal4::try_new(123, 2).unwrap());
array_builder.append_value(VariantDecimal8::try_new(123, 3).unwrap());
array_builder.append_value(VariantDecimal16::try_new(123, 4).unwrap());
array_builder.append_value(Variant::Float(5.2));
array_builder.append_value(Variant::Double(6.4));
array_builder.append_value(Variant::BooleanTrue);
array_builder.append_value(Variant::BooleanFalse);
array_builder.append_value(Variant::Binary("helow".as_bytes()));
array_builder.append_value(Variant::String("hello"));
array_builder.append_value(Variant::ShortString(
ShortString::try_from("world").unwrap(),
));
array_builder.append_value(Variant::Time(
NaiveTime::from_num_seconds_from_midnight_opt(12345, 123).unwrap(),
));

let array = array_builder.build();

fn can_shred_to(v: &Variant, dt: &DataType) -> bool {
matches!(
(v, dt),
(Variant::Int8(_), DataType::Int8)
| (Variant::Int8(_), DataType::Int16)
| (Variant::Int8(_), DataType::Int32)
| (Variant::Int8(_), DataType::Int64)
| (Variant::Int16(_), DataType::Int8)
| (Variant::Int16(_), DataType::Int16)
| (Variant::Int16(_), DataType::Int32)
| (Variant::Int16(_), DataType::Int64)
| (Variant::Int32(_), DataType::Int8)
| (Variant::Int32(_), DataType::Int16)
| (Variant::Int32(_), DataType::Int32)
| (Variant::Int32(_), DataType::Int64)
| (Variant::Int64(_), DataType::Int8)
| (Variant::Int64(_), DataType::Int16)
| (Variant::Int64(_), DataType::Int32)
| (Variant::Int64(_), DataType::Int64)
| (Variant::Date(_), DataType::Date32)
| (
Variant::TimestampMicros(_),
DataType::Timestamp(TimeUnit::Microsecond, Some(_)),
)
| (
Variant::TimestampMicros(_),
DataType::Timestamp(TimeUnit::Nanosecond, Some(_))
)
| (
Variant::TimestampNtzMicros(_),
DataType::Timestamp(TimeUnit::Microsecond, None),
)
| (
Variant::TimestampNtzMicros(_),
DataType::Timestamp(TimeUnit::Nanosecond, None)
)
| (
Variant::TimestampNanos(_),
DataType::Timestamp(TimeUnit::Nanosecond, Some(_)),
)
| (
Variant::TimestampNtzNanos(_),
DataType::Timestamp(TimeUnit::Nanosecond, None),
)
| (Variant::Decimal4(_), DataType::Decimal32(_, _))
| (Variant::Decimal4(_), DataType::Decimal64(_, _))
| (Variant::Decimal4(_), DataType::Decimal128(_, _))
| (Variant::Decimal8(_), DataType::Decimal32(_, _))
| (Variant::Decimal8(_), DataType::Decimal64(_, _))
| (Variant::Decimal8(_), DataType::Decimal128(_, _))
| (Variant::Decimal16(_), DataType::Decimal32(_, _))
| (Variant::Decimal16(_), DataType::Decimal64(_, _))
| (Variant::Decimal16(_), DataType::Decimal128(_, _))
| (Variant::Float(_), DataType::Float32)
| (Variant::Float(_), DataType::Float64)
| (Variant::Double(_), DataType::Float32)
| (Variant::Double(_), DataType::Float64)
| (Variant::BooleanFalse, DataType::Boolean)
| (Variant::BooleanTrue, DataType::Boolean)
| (Variant::Binary(_), DataType::Binary)
| (Variant::Binary(_), DataType::BinaryView)
| (Variant::Binary(_), DataType::LargeBinary)
| (Variant::ShortString(_), DataType::Utf8)
| (Variant::ShortString(_), DataType::Utf8View)
| (Variant::ShortString(_), DataType::LargeUtf8)
| (Variant::String(_), DataType::Utf8)
| (Variant::String(_), DataType::Utf8View)
| (Variant::String(_), DataType::LargeUtf8)
| (Variant::Time(_), DataType::Time64(_))
)
}

macro_rules! assert_shred_type {
($shred_type:expr, $expected_value_valid_bits:expr) => {
let shredded_array_result = shred_variant(&array, &$shred_type);
match shredded_array_result {
Ok(shredded_array) => {
let value_column = shredded_array.inner().column_by_name("value").unwrap();
for (idx, valid) in $expected_value_valid_bits.iter().enumerate() {
match valid {
true => assert!(
value_column.is_null(idx),
"{:?} should be shredded to {}",
array.value(idx),
$shred_type
),
false => assert!(
value_column.is_valid(idx),
"{:?} should not be shredded to {}",
array.value(idx),
$shred_type
),
}
}
}
Err(e) => {
let error_msg = format!("is not a valid variant shredding type");
assert!(
e.to_string().contains(error_msg.as_str()),
"{} => {}",
$shred_type,
e.to_string()
);
}
}
};
}

let types = [
DataType::Null,
DataType::Boolean,
DataType::Int8,
DataType::Int16,
DataType::Int32,
DataType::Int64,
DataType::UInt8,
DataType::UInt16,
DataType::UInt32,
DataType::UInt64,
DataType::Float32,
DataType::Float64,
DataType::Timestamp(TimeUnit::Second, Some("+00:00".into())),
DataType::Timestamp(TimeUnit::Second, None),
DataType::Timestamp(TimeUnit::Millisecond, Some("-00:00".into())),
DataType::Timestamp(TimeUnit::Millisecond, None),
DataType::Timestamp(TimeUnit::Microsecond, Some("-00:00".into())),
DataType::Timestamp(TimeUnit::Microsecond, None),
DataType::Timestamp(TimeUnit::Nanosecond, Some("+00:00".into())),
DataType::Timestamp(TimeUnit::Nanosecond, None),
DataType::Date32,
DataType::Date64,
DataType::Time32(TimeUnit::Second),
DataType::Time32(TimeUnit::Millisecond),
DataType::Time64(TimeUnit::Microsecond),
DataType::Time64(TimeUnit::Nanosecond),
DataType::Duration(TimeUnit::Nanosecond),
DataType::Interval(IntervalUnit::DayTime),
DataType::Binary,
DataType::FixedSizeBinary(16), // uuid
DataType::FixedSizeBinary(32),
DataType::LargeBinary,
DataType::BinaryView,
DataType::Utf8,
DataType::LargeUtf8,
DataType::Utf8View,
DataType::Decimal32(4, 2),
DataType::Decimal64(10, 4),
DataType::Decimal128(20, 10),
DataType::Decimal256(30, 10),
];

for data_type in types {
let expected_bits = array
.iter()
.map(|v| can_shred_to(&v.unwrap(), &data_type))
.collect::<Vec<bool>>();
assert_shred_type!(data_type, expected_bits);
}
}
}
Loading
Loading