diff --git a/common/src/main/scala/org/apache/spark/sql/comet/util/Utils.scala b/common/src/main/scala/org/apache/spark/sql/comet/util/Utils.scala index 036c5c9aaf..c7dc82c1b9 100644 --- a/common/src/main/scala/org/apache/spark/sql/comet/util/Utils.scala +++ b/common/src/main/scala/org/apache/spark/sql/comet/util/Utils.scala @@ -26,7 +26,7 @@ import java.nio.channels.Channels import scala.jdk.CollectionConverters._ import org.apache.arrow.c.CDataDictionaryProvider -import org.apache.arrow.vector.{BigIntVector, BitVector, DateDayVector, DecimalVector, FieldVector, FixedSizeBinaryVector, Float4Vector, Float8Vector, IntVector, SmallIntVector, TimeStampMicroTZVector, TimeStampMicroVector, TinyIntVector, ValueVector, VarBinaryVector, VarCharVector, VectorSchemaRoot} +import org.apache.arrow.vector.{BigIntVector, BitVector, DateDayVector, DecimalVector, FieldVector, FixedSizeBinaryVector, Float4Vector, Float8Vector, IntVector, NullVector, SmallIntVector, TimeStampMicroTZVector, TimeStampMicroVector, TinyIntVector, ValueVector, VarBinaryVector, VarCharVector, VectorSchemaRoot} import org.apache.arrow.vector.complex.{ListVector, MapVector, StructVector} import org.apache.arrow.vector.dictionary.DictionaryProvider import org.apache.arrow.vector.ipc.ArrowStreamWriter @@ -282,7 +282,7 @@ object Utils extends CometTypeShim { _: BigIntVector | _: Float4Vector | _: Float8Vector | _: VarCharVector | _: DecimalVector | _: DateDayVector | _: TimeStampMicroTZVector | _: VarBinaryVector | _: FixedSizeBinaryVector | _: TimeStampMicroVector | _: StructVector | _: ListVector | - _: MapVector) => + _: MapVector | _: NullVector) => v.asInstanceOf[FieldVector] case _ => throw new SparkException(s"Unsupported Arrow Vector for $reason: ${valueVector.getClass}") diff --git a/native/core/src/execution/columnar_to_row.rs b/native/core/src/execution/columnar_to_row.rs index 78ab7637e8..51f58c7c95 100644 --- a/native/core/src/execution/columnar_to_row.rs +++ b/native/core/src/execution/columnar_to_row.rs @@ -41,16 +41,45 @@ use arrow::array::types::{ UInt64Type, UInt8Type, }; use arrow::array::*; +use arrow::compute::{cast_with_options, CastOptions}; use arrow::datatypes::{ArrowNativeType, DataType, TimeUnit}; use std::sync::Arc; /// Maximum digits for decimal that can fit in a long (8 bytes). const MAX_LONG_DIGITS: u8 = 18; +/// Helper macro for downcasting arrays with consistent error messages. +macro_rules! downcast_array { + ($array:expr, $array_type:ty) => { + $array + .as_any() + .downcast_ref::<$array_type>() + .ok_or_else(|| { + CometError::Internal(format!( + "Failed to downcast to {}, actual type: {:?}", + stringify!($array_type), + $array.data_type() + )) + }) + }; +} + +/// Writes bytes to buffer with 8-byte alignment padding. +/// Returns the unpadded length. +#[inline] +fn write_bytes_padded(buffer: &mut Vec, bytes: &[u8]) -> usize { + let len = bytes.len(); + buffer.extend_from_slice(bytes); + let padding = round_up_to_8(len) - len; + buffer.extend(std::iter::repeat_n(0u8, padding)); + len +} + /// Pre-downcast array reference to avoid type dispatch in inner loops. /// This enum holds references to concrete array types, allowing direct access /// without repeated downcast_ref calls. enum TypedArray<'a> { + Null, Boolean(&'a BooleanArray), Int8(&'a Int8Array), Int16(&'a Int16Array), @@ -65,6 +94,7 @@ enum TypedArray<'a> { LargeString(&'a LargeStringArray), Binary(&'a BinaryArray), LargeBinary(&'a LargeBinaryArray), + FixedSizeBinary(&'a FixedSizeBinaryArray), Struct( &'a StructArray, arrow::datatypes::Fields, @@ -78,119 +108,46 @@ enum TypedArray<'a> { impl<'a> TypedArray<'a> { /// Pre-downcast an ArrayRef to a TypedArray. - fn from_array(array: &'a ArrayRef, schema_type: &DataType) -> CometResult { + fn from_array(array: &'a ArrayRef) -> CometResult { let actual_type = array.data_type(); match actual_type { - DataType::Boolean => Ok(TypedArray::Boolean( - array - .as_any() - .downcast_ref::() - .ok_or_else(|| { - CometError::Internal("Failed to downcast to BooleanArray".to_string()) - })?, - )), - DataType::Int8 => Ok(TypedArray::Int8( - array.as_any().downcast_ref::().ok_or_else(|| { - CometError::Internal("Failed to downcast to Int8Array".to_string()) - })?, - )), - DataType::Int16 => Ok(TypedArray::Int16( - array.as_any().downcast_ref::().ok_or_else(|| { - CometError::Internal("Failed to downcast to Int16Array".to_string()) - })?, - )), - DataType::Int32 => Ok(TypedArray::Int32( - array.as_any().downcast_ref::().ok_or_else(|| { - CometError::Internal("Failed to downcast to Int32Array".to_string()) - })?, - )), - DataType::Int64 => Ok(TypedArray::Int64( - array.as_any().downcast_ref::().ok_or_else(|| { - CometError::Internal("Failed to downcast to Int64Array".to_string()) - })?, - )), - DataType::Float32 => Ok(TypedArray::Float32( - array - .as_any() - .downcast_ref::() - .ok_or_else(|| { - CometError::Internal("Failed to downcast to Float32Array".to_string()) - })?, - )), - DataType::Float64 => Ok(TypedArray::Float64( - array - .as_any() - .downcast_ref::() - .ok_or_else(|| { - CometError::Internal("Failed to downcast to Float64Array".to_string()) - })?, - )), - DataType::Date32 => Ok(TypedArray::Date32( - array - .as_any() - .downcast_ref::() - .ok_or_else(|| { - CometError::Internal("Failed to downcast to Date32Array".to_string()) - })?, - )), + DataType::Null => { + // Verify the array is actually a NullArray, but we don't need to store the reference + // since all values are null by definition + downcast_array!(array, NullArray)?; + Ok(TypedArray::Null) + } + DataType::Boolean => Ok(TypedArray::Boolean(downcast_array!(array, BooleanArray)?)), + DataType::Int8 => Ok(TypedArray::Int8(downcast_array!(array, Int8Array)?)), + DataType::Int16 => Ok(TypedArray::Int16(downcast_array!(array, Int16Array)?)), + DataType::Int32 => Ok(TypedArray::Int32(downcast_array!(array, Int32Array)?)), + DataType::Int64 => Ok(TypedArray::Int64(downcast_array!(array, Int64Array)?)), + DataType::Float32 => Ok(TypedArray::Float32(downcast_array!(array, Float32Array)?)), + DataType::Float64 => Ok(TypedArray::Float64(downcast_array!(array, Float64Array)?)), + DataType::Date32 => Ok(TypedArray::Date32(downcast_array!(array, Date32Array)?)), DataType::Timestamp(TimeUnit::Microsecond, _) => Ok(TypedArray::TimestampMicro( - array - .as_any() - .downcast_ref::() - .ok_or_else(|| { - CometError::Internal( - "Failed to downcast to TimestampMicrosecondArray".to_string(), - ) - })?, + downcast_array!(array, TimestampMicrosecondArray)?, )), DataType::Decimal128(p, _) => Ok(TypedArray::Decimal128( - array - .as_any() - .downcast_ref::() - .ok_or_else(|| { - CometError::Internal("Failed to downcast to Decimal128Array".to_string()) - })?, + downcast_array!(array, Decimal128Array)?, *p, )), - DataType::Utf8 => Ok(TypedArray::String( - array - .as_any() - .downcast_ref::() - .ok_or_else(|| { - CometError::Internal("Failed to downcast to StringArray".to_string()) - })?, - )), - DataType::LargeUtf8 => Ok(TypedArray::LargeString( - array - .as_any() - .downcast_ref::() - .ok_or_else(|| { - CometError::Internal("Failed to downcast to LargeStringArray".to_string()) - })?, - )), - DataType::Binary => Ok(TypedArray::Binary( - array - .as_any() - .downcast_ref::() - .ok_or_else(|| { - CometError::Internal("Failed to downcast to BinaryArray".to_string()) - })?, - )), - DataType::LargeBinary => Ok(TypedArray::LargeBinary( - array - .as_any() - .downcast_ref::() - .ok_or_else(|| { - CometError::Internal("Failed to downcast to LargeBinaryArray".to_string()) - })?, - )), + DataType::Utf8 => Ok(TypedArray::String(downcast_array!(array, StringArray)?)), + DataType::LargeUtf8 => Ok(TypedArray::LargeString(downcast_array!( + array, + LargeStringArray + )?)), + DataType::Binary => Ok(TypedArray::Binary(downcast_array!(array, BinaryArray)?)), + DataType::LargeBinary => Ok(TypedArray::LargeBinary(downcast_array!( + array, + LargeBinaryArray + )?)), + DataType::FixedSizeBinary(_) => Ok(TypedArray::FixedSizeBinary(downcast_array!( + array, + FixedSizeBinaryArray + )?)), DataType::Struct(fields) => { - let struct_arr = array - .as_any() - .downcast_ref::() - .ok_or_else(|| { - CometError::Internal("Failed to downcast to StructArray".to_string()) - })?; + let struct_arr = downcast_array!(array, StructArray)?; // Pre-downcast all struct fields once let typed_fields: Vec = fields .iter() @@ -202,27 +159,18 @@ impl<'a> TypedArray<'a> { Ok(TypedArray::Struct(struct_arr, fields.clone(), typed_fields)) } DataType::List(field) => Ok(TypedArray::List( - array.as_any().downcast_ref::().ok_or_else(|| { - CometError::Internal("Failed to downcast to ListArray".to_string()) - })?, + downcast_array!(array, ListArray)?, Arc::clone(field), )), DataType::LargeList(field) => Ok(TypedArray::LargeList( - array - .as_any() - .downcast_ref::() - .ok_or_else(|| { - CometError::Internal("Failed to downcast to LargeListArray".to_string()) - })?, + downcast_array!(array, LargeListArray)?, Arc::clone(field), )), DataType::Map(field, _) => Ok(TypedArray::Map( - array.as_any().downcast_ref::().ok_or_else(|| { - CometError::Internal("Failed to downcast to MapArray".to_string()) - })?, + downcast_array!(array, MapArray)?, Arc::clone(field), )), - DataType::Dictionary(_, _) => Ok(TypedArray::Dictionary(array, schema_type.clone())), + DataType::Dictionary(_, _) => Ok(TypedArray::Dictionary(array, actual_type.clone())), _ => Err(CometError::Internal(format!( "Unsupported data type for pre-downcast: {:?}", actual_type @@ -234,6 +182,7 @@ impl<'a> TypedArray<'a> { #[inline] fn is_null(&self, row_idx: usize) -> bool { match self { + TypedArray::Null => true, // Null type is always null TypedArray::Boolean(arr) => arr.is_null(row_idx), TypedArray::Int8(arr) => arr.is_null(row_idx), TypedArray::Int16(arr) => arr.is_null(row_idx), @@ -248,6 +197,7 @@ impl<'a> TypedArray<'a> { TypedArray::LargeString(arr) => arr.is_null(row_idx), TypedArray::Binary(arr) => arr.is_null(row_idx), TypedArray::LargeBinary(arr) => arr.is_null(row_idx), + TypedArray::FixedSizeBinary(arr) => arr.is_null(row_idx), TypedArray::Struct(arr, _, _) => arr.is_null(row_idx), TypedArray::List(arr, _) => arr.is_null(row_idx), TypedArray::LargeList(arr, _) => arr.is_null(row_idx), @@ -291,7 +241,8 @@ impl<'a> TypedArray<'a> { #[inline] fn is_variable_length(&self) -> bool { match self { - TypedArray::Boolean(_) + TypedArray::Null + | TypedArray::Boolean(_) | TypedArray::Int8(_) | TypedArray::Int16(_) | TypedArray::Int32(_) @@ -309,44 +260,17 @@ impl<'a> TypedArray<'a> { fn write_variable_to_buffer(&self, buffer: &mut Vec, row_idx: usize) -> CometResult { match self { TypedArray::String(arr) => { - let bytes = arr.value(row_idx).as_bytes(); - let len = bytes.len(); - buffer.extend_from_slice(bytes); - let padding = round_up_to_8(len) - len; - buffer.extend(std::iter::repeat_n(0u8, padding)); - Ok(len) + Ok(write_bytes_padded(buffer, arr.value(row_idx).as_bytes())) } TypedArray::LargeString(arr) => { - let bytes = arr.value(row_idx).as_bytes(); - let len = bytes.len(); - buffer.extend_from_slice(bytes); - let padding = round_up_to_8(len) - len; - buffer.extend(std::iter::repeat_n(0u8, padding)); - Ok(len) - } - TypedArray::Binary(arr) => { - let bytes = arr.value(row_idx); - let len = bytes.len(); - buffer.extend_from_slice(bytes); - let padding = round_up_to_8(len) - len; - buffer.extend(std::iter::repeat_n(0u8, padding)); - Ok(len) - } - TypedArray::LargeBinary(arr) => { - let bytes = arr.value(row_idx); - let len = bytes.len(); - buffer.extend_from_slice(bytes); - let padding = round_up_to_8(len) - len; - buffer.extend(std::iter::repeat_n(0u8, padding)); - Ok(len) + Ok(write_bytes_padded(buffer, arr.value(row_idx).as_bytes())) } + TypedArray::Binary(arr) => Ok(write_bytes_padded(buffer, arr.value(row_idx))), + TypedArray::LargeBinary(arr) => Ok(write_bytes_padded(buffer, arr.value(row_idx))), + TypedArray::FixedSizeBinary(arr) => Ok(write_bytes_padded(buffer, arr.value(row_idx))), TypedArray::Decimal128(arr, precision) if *precision > MAX_LONG_DIGITS => { let bytes = i128_to_spark_decimal_bytes(arr.value(row_idx)); - let len = bytes.len(); - buffer.extend_from_slice(&bytes); - let padding = round_up_to_8(len) - len; - buffer.extend(std::iter::repeat_n(0u8, padding)); - Ok(len) + Ok(write_bytes_padded(buffer, &bytes)) } TypedArray::Struct(arr, fields, typed_fields) => { write_struct_to_buffer_typed(buffer, arr, row_idx, fields, typed_fields) @@ -394,6 +318,7 @@ enum TypedElements<'a> { LargeString(&'a LargeStringArray), Binary(&'a BinaryArray), LargeBinary(&'a LargeBinaryArray), + FixedSizeBinary(&'a FixedSizeBinaryArray), // For nested types, fall back to ArrayRef Other(&'a ArrayRef, DataType), } @@ -472,6 +397,11 @@ impl<'a> TypedElements<'a> { return TypedElements::LargeBinary(arr); } } + DataType::FixedSizeBinary(_) => { + if let Some(arr) = array.as_any().downcast_ref::() { + return TypedElements::FixedSizeBinary(arr); + } + } _ => {} } TypedElements::Other(array, element_type.clone()) @@ -525,6 +455,7 @@ impl<'a> TypedElements<'a> { TypedElements::LargeString(arr) => arr.is_null(idx), TypedElements::Binary(arr) => arr.is_null(idx), TypedElements::LargeBinary(arr) => arr.is_null(idx), + TypedElements::FixedSizeBinary(arr) => arr.is_null(idx), TypedElements::Other(arr, _) => arr.is_null(idx), } } @@ -572,55 +503,21 @@ impl<'a> TypedElements<'a> { } /// Write variable-length data to buffer. Returns length written (0 for fixed-width). - fn write_variable_value( - &self, - buffer: &mut Vec, - idx: usize, - base_offset: usize, - ) -> CometResult { + fn write_variable_value(&self, buffer: &mut Vec, idx: usize) -> CometResult { match self { - TypedElements::String(arr) => { - let bytes = arr.value(idx).as_bytes(); - let len = bytes.len(); - buffer.extend_from_slice(bytes); - let padding = round_up_to_8(len) - len; - buffer.extend(std::iter::repeat_n(0u8, padding)); - Ok(len) - } + TypedElements::String(arr) => Ok(write_bytes_padded(buffer, arr.value(idx).as_bytes())), TypedElements::LargeString(arr) => { - let bytes = arr.value(idx).as_bytes(); - let len = bytes.len(); - buffer.extend_from_slice(bytes); - let padding = round_up_to_8(len) - len; - buffer.extend(std::iter::repeat_n(0u8, padding)); - Ok(len) - } - TypedElements::Binary(arr) => { - let bytes = arr.value(idx); - let len = bytes.len(); - buffer.extend_from_slice(bytes); - let padding = round_up_to_8(len) - len; - buffer.extend(std::iter::repeat_n(0u8, padding)); - Ok(len) - } - TypedElements::LargeBinary(arr) => { - let bytes = arr.value(idx); - let len = bytes.len(); - buffer.extend_from_slice(bytes); - let padding = round_up_to_8(len) - len; - buffer.extend(std::iter::repeat_n(0u8, padding)); - Ok(len) + Ok(write_bytes_padded(buffer, arr.value(idx).as_bytes())) } + TypedElements::Binary(arr) => Ok(write_bytes_padded(buffer, arr.value(idx))), + TypedElements::LargeBinary(arr) => Ok(write_bytes_padded(buffer, arr.value(idx))), + TypedElements::FixedSizeBinary(arr) => Ok(write_bytes_padded(buffer, arr.value(idx))), TypedElements::Decimal128(arr, precision) if *precision > MAX_LONG_DIGITS => { let bytes = i128_to_spark_decimal_bytes(arr.value(idx)); - let len = bytes.len(); - buffer.extend_from_slice(&bytes); - let padding = round_up_to_8(len) - len; - buffer.extend(std::iter::repeat_n(0u8, padding)); - Ok(len) + Ok(write_bytes_padded(buffer, &bytes)) } TypedElements::Other(arr, element_type) => { - write_nested_variable_to_buffer(buffer, element_type, arr, idx, base_offset) + write_nested_variable_to_buffer(buffer, element_type, arr, idx) } _ => Ok(0), // Fixed-width types } @@ -771,11 +668,7 @@ impl<'a> TypedElements<'a> { set_null_bit(buffer, null_bitset_start, i); } else { let bytes = i128_to_spark_decimal_bytes(arr.value(src_idx)); - let len = bytes.len(); - buffer.extend_from_slice(&bytes); - let padding = round_up_to_8(len) - len; - buffer.extend(std::iter::repeat_n(0u8, padding)); - + let len = write_bytes_padded(buffer, &bytes); let data_offset = buffer.len() - round_up_to_8(len) - array_start; let offset_and_len = ((data_offset as i64) << 32) | (len as i64); let slot_offset = elements_start + i * 8; @@ -790,12 +683,7 @@ impl<'a> TypedElements<'a> { if arr.is_null(src_idx) { set_null_bit(buffer, null_bitset_start, i); } else { - let bytes = arr.value(src_idx).as_bytes(); - let len = bytes.len(); - buffer.extend_from_slice(bytes); - let padding = round_up_to_8(len) - len; - buffer.extend(std::iter::repeat_n(0u8, padding)); - + let len = write_bytes_padded(buffer, arr.value(src_idx).as_bytes()); let data_offset = buffer.len() - round_up_to_8(len) - array_start; let offset_and_len = ((data_offset as i64) << 32) | (len as i64); let slot_offset = elements_start + i * 8; @@ -810,12 +698,7 @@ impl<'a> TypedElements<'a> { if arr.is_null(src_idx) { set_null_bit(buffer, null_bitset_start, i); } else { - let bytes = arr.value(src_idx).as_bytes(); - let len = bytes.len(); - buffer.extend_from_slice(bytes); - let padding = round_up_to_8(len) - len; - buffer.extend(std::iter::repeat_n(0u8, padding)); - + let len = write_bytes_padded(buffer, arr.value(src_idx).as_bytes()); let data_offset = buffer.len() - round_up_to_8(len) - array_start; let offset_and_len = ((data_offset as i64) << 32) | (len as i64); let slot_offset = elements_start + i * 8; @@ -830,12 +713,7 @@ impl<'a> TypedElements<'a> { if arr.is_null(src_idx) { set_null_bit(buffer, null_bitset_start, i); } else { - let bytes = arr.value(src_idx); - let len = bytes.len(); - buffer.extend_from_slice(bytes); - let padding = round_up_to_8(len) - len; - buffer.extend(std::iter::repeat_n(0u8, padding)); - + let len = write_bytes_padded(buffer, arr.value(src_idx)); let data_offset = buffer.len() - round_up_to_8(len) - array_start; let offset_and_len = ((data_offset as i64) << 32) | (len as i64); let slot_offset = elements_start + i * 8; @@ -850,12 +728,7 @@ impl<'a> TypedElements<'a> { if arr.is_null(src_idx) { set_null_bit(buffer, null_bitset_start, i); } else { - let bytes = arr.value(src_idx); - let len = bytes.len(); - buffer.extend_from_slice(bytes); - let padding = round_up_to_8(len) - len; - buffer.extend(std::iter::repeat_n(0u8, padding)); - + let len = write_bytes_padded(buffer, arr.value(src_idx)); let data_offset = buffer.len() - round_up_to_8(len) - array_start; let offset_and_len = ((data_offset as i64) << 32) | (len as i64); let slot_offset = elements_start + i * 8; @@ -872,13 +745,8 @@ impl<'a> TypedElements<'a> { set_null_bit(buffer, null_bitset_start, i); } else { let slot_offset = elements_start + i * element_size; - let var_len = write_nested_variable_to_buffer( - buffer, - element_type, - arr, - src_idx, - array_start, - )?; + let var_len = + write_nested_variable_to_buffer(buffer, element_type, arr, src_idx)?; if var_len > 0 { let padded_len = round_up_to_8(var_len); @@ -1035,6 +903,16 @@ impl ColumnarToRowContext { ))); } + // Unpack any dictionary arrays to their underlying value type + // This is needed because Parquet may return dictionary-encoded arrays + // even when the schema expects a specific type like Decimal128 + let arrays: Vec = arrays + .iter() + .zip(self.schema.iter()) + .map(|(arr, schema_type)| Self::maybe_cast_to_schema_type(arr, schema_type)) + .collect::>>()?; + let arrays = arrays.as_slice(); + // Clear previous data self.buffer.clear(); self.offsets.clear(); @@ -1052,8 +930,7 @@ impl ColumnarToRowContext { // Pre-downcast all arrays to avoid type dispatch in inner loop let typed_arrays: Vec = arrays .iter() - .zip(self.schema.iter()) - .map(|(arr, dt)| TypedArray::from_array(arr, dt)) + .map(TypedArray::from_array) .collect::>>()?; // Pre-compute variable-length column indices (once per batch, not per row) @@ -1079,6 +956,83 @@ impl ColumnarToRowContext { Ok((self.buffer.as_ptr(), &self.offsets, &self.lengths)) } + /// Casts an array to match the expected schema type if needed. + /// This handles cases where: + /// 1. Parquet returns dictionary-encoded arrays but the schema expects a non-dictionary type + /// 2. Parquet returns NullArray when all values are null, but the schema expects a typed array + /// 3. Parquet returns Int32/Int64 for small-precision decimals but schema expects Decimal128 + fn maybe_cast_to_schema_type( + array: &ArrayRef, + schema_type: &DataType, + ) -> CometResult { + let actual_type = array.data_type(); + + // If types already match, no cast needed + if actual_type == schema_type { + return Ok(Arc::clone(array)); + } + + match (actual_type, schema_type) { + (DataType::Dictionary(_, _), schema) + if !matches!(schema, DataType::Dictionary(_, _)) => + { + // Unpack dictionary if the schema type is not a dictionary + let options = CastOptions::default(); + cast_with_options(array, schema_type, &options).map_err(|e| { + CometError::Internal(format!( + "Failed to unpack dictionary array from {:?} to {:?}: {}", + actual_type, schema_type, e + )) + }) + } + (DataType::Null, _) => { + // Cast NullArray to the expected schema type + // This happens when all values in a column are null + let options = CastOptions::default(); + cast_with_options(array, schema_type, &options).map_err(|e| { + CometError::Internal(format!( + "Failed to cast NullArray to {:?}: {}", + schema_type, e + )) + }) + } + (DataType::Int32, DataType::Decimal128(precision, scale)) => { + // Parquet stores small-precision decimals as Int32 for efficiency. + // When COMET_USE_DECIMAL_128 is false, BatchReader produces these types. + // The Int32 value is already scaled (e.g., -1 means -0.01 for scale 2). + // We need to reinterpret (not cast) to Decimal128 preserving the value. + let int_array = array.as_any().downcast_ref::().ok_or_else(|| { + CometError::Internal("Failed to downcast to Int32Array".to_string()) + })?; + let decimal_array: Decimal128Array = int_array + .iter() + .map(|v| v.map(|x| x as i128)) + .collect::() + .with_precision_and_scale(*precision, *scale) + .map_err(|e| { + CometError::Internal(format!("Invalid decimal precision/scale: {}", e)) + })?; + Ok(Arc::new(decimal_array)) + } + (DataType::Int64, DataType::Decimal128(precision, scale)) => { + // Same as Int32 but for medium-precision decimals stored as Int64. + let int_array = array.as_any().downcast_ref::().ok_or_else(|| { + CometError::Internal("Failed to downcast to Int64Array".to_string()) + })?; + let decimal_array: Decimal128Array = int_array + .iter() + .map(|v| v.map(|x| x as i128)) + .collect::() + .with_precision_and_scale(*precision, *scale) + .map_err(|e| { + CometError::Internal(format!("Invalid decimal precision/scale: {}", e)) + })?; + Ok(Arc::new(decimal_array)) + } + _ => Ok(Arc::clone(array)), + } + } + /// Fast path for schemas with only fixed-width columns. /// Pre-allocates entire buffer and processes more efficiently. fn convert_fixed_width( @@ -1157,7 +1111,10 @@ impl ColumnarToRowContext { .as_any() .downcast_ref::() .ok_or_else(|| { - CometError::Internal("Failed to downcast to BooleanArray".to_string()) + CometError::Internal(format!( + "Failed to downcast to BooleanArray, actual type: {:?}", + array.data_type() + )) })?; for row_idx in 0..num_rows { if !arr.is_null(row_idx) { @@ -1168,7 +1125,10 @@ impl ColumnarToRowContext { } DataType::Int8 => { let arr = array.as_any().downcast_ref::().ok_or_else(|| { - CometError::Internal("Failed to downcast to Int8Array".to_string()) + CometError::Internal(format!( + "Failed to downcast to Int8Array, actual type: {:?}", + array.data_type() + )) })?; for row_idx in 0..num_rows { if !arr.is_null(row_idx) { @@ -1180,7 +1140,10 @@ impl ColumnarToRowContext { } DataType::Int16 => { let arr = array.as_any().downcast_ref::().ok_or_else(|| { - CometError::Internal("Failed to downcast to Int16Array".to_string()) + CometError::Internal(format!( + "Failed to downcast to Int16Array, actual type: {:?}", + array.data_type() + )) })?; for row_idx in 0..num_rows { if !arr.is_null(row_idx) { @@ -1192,7 +1155,10 @@ impl ColumnarToRowContext { } DataType::Int32 => { let arr = array.as_any().downcast_ref::().ok_or_else(|| { - CometError::Internal("Failed to downcast to Int32Array".to_string()) + CometError::Internal(format!( + "Failed to downcast to Int32Array, actual type: {:?}", + array.data_type() + )) })?; for row_idx in 0..num_rows { if !arr.is_null(row_idx) { @@ -1204,7 +1170,10 @@ impl ColumnarToRowContext { } DataType::Int64 => { let arr = array.as_any().downcast_ref::().ok_or_else(|| { - CometError::Internal("Failed to downcast to Int64Array".to_string()) + CometError::Internal(format!( + "Failed to downcast to Int64Array, actual type: {:?}", + array.data_type() + )) })?; for row_idx in 0..num_rows { if !arr.is_null(row_idx) { @@ -1219,7 +1188,10 @@ impl ColumnarToRowContext { .as_any() .downcast_ref::() .ok_or_else(|| { - CometError::Internal("Failed to downcast to Float32Array".to_string()) + CometError::Internal(format!( + "Failed to downcast to Float32Array, actual type: {:?}", + array.data_type() + )) })?; for row_idx in 0..num_rows { if !arr.is_null(row_idx) { @@ -1234,7 +1206,10 @@ impl ColumnarToRowContext { .as_any() .downcast_ref::() .ok_or_else(|| { - CometError::Internal("Failed to downcast to Float64Array".to_string()) + CometError::Internal(format!( + "Failed to downcast to Float64Array, actual type: {:?}", + array.data_type() + )) })?; for row_idx in 0..num_rows { if !arr.is_null(row_idx) { @@ -1249,7 +1224,10 @@ impl ColumnarToRowContext { .as_any() .downcast_ref::() .ok_or_else(|| { - CometError::Internal("Failed to downcast to Date32Array".to_string()) + CometError::Internal(format!( + "Failed to downcast to Date32Array, actual type: {:?}", + array.data_type() + )) })?; for row_idx in 0..num_rows { if !arr.is_null(row_idx) { @@ -1264,9 +1242,10 @@ impl ColumnarToRowContext { .as_any() .downcast_ref::() .ok_or_else(|| { - CometError::Internal( - "Failed to downcast to TimestampMicrosecondArray".to_string(), - ) + CometError::Internal(format!( + "Failed to downcast to TimestampMicrosecondArray, actual type: {:?}", + array.data_type() + )) })?; for row_idx in 0..num_rows { if !arr.is_null(row_idx) { @@ -1281,7 +1260,10 @@ impl ColumnarToRowContext { .as_any() .downcast_ref::() .ok_or_else(|| { - CometError::Internal("Failed to downcast to Decimal128Array".to_string()) + CometError::Internal(format!( + "Failed to downcast to Decimal128Array, actual type: {:?}", + array.data_type() + )) })?; for row_idx in 0..num_rows { if !arr.is_null(row_idx) { @@ -1605,72 +1587,26 @@ fn write_dictionary_to_buffer_with_key( match value_type { DataType::Utf8 => { - let string_values = values - .as_any() - .downcast_ref::() - .ok_or_else(|| { - CometError::Internal(format!( - "Failed to downcast dictionary values to StringArray, actual type: {:?}", - values.data_type() - )) - })?; - let bytes = string_values.value(key_idx).as_bytes(); - let len = bytes.len(); - buffer.extend_from_slice(bytes); - let padding = round_up_to_8(len) - len; - buffer.extend(std::iter::repeat_n(0u8, padding)); - Ok(len) + let string_values = downcast_array!(values, StringArray)?; + Ok(write_bytes_padded( + buffer, + string_values.value(key_idx).as_bytes(), + )) } DataType::LargeUtf8 => { - let string_values = values - .as_any() - .downcast_ref::() - .ok_or_else(|| { - CometError::Internal(format!( - "Failed to downcast dictionary values to LargeStringArray, actual type: {:?}", - values.data_type() - )) - })?; - let bytes = string_values.value(key_idx).as_bytes(); - let len = bytes.len(); - buffer.extend_from_slice(bytes); - let padding = round_up_to_8(len) - len; - buffer.extend(std::iter::repeat_n(0u8, padding)); - Ok(len) + let string_values = downcast_array!(values, LargeStringArray)?; + Ok(write_bytes_padded( + buffer, + string_values.value(key_idx).as_bytes(), + )) } DataType::Binary => { - let binary_values = values - .as_any() - .downcast_ref::() - .ok_or_else(|| { - CometError::Internal(format!( - "Failed to downcast dictionary values to BinaryArray, actual type: {:?}", - values.data_type() - )) - })?; - let bytes = binary_values.value(key_idx); - let len = bytes.len(); - buffer.extend_from_slice(bytes); - let padding = round_up_to_8(len) - len; - buffer.extend(std::iter::repeat_n(0u8, padding)); - Ok(len) + let binary_values = downcast_array!(values, BinaryArray)?; + Ok(write_bytes_padded(buffer, binary_values.value(key_idx))) } DataType::LargeBinary => { - let binary_values = values - .as_any() - .downcast_ref::() - .ok_or_else(|| { - CometError::Internal(format!( - "Failed to downcast dictionary values to LargeBinaryArray, actual type: {:?}", - values.data_type() - )) - })?; - let bytes = binary_values.value(key_idx); - let len = bytes.len(); - buffer.extend_from_slice(bytes); - let padding = round_up_to_8(len) - len; - buffer.extend(std::iter::repeat_n(0u8, padding)); - Ok(len) + let binary_values = downcast_array!(values, LargeBinaryArray)?; + Ok(write_bytes_padded(buffer, binary_values.value(key_idx))) } _ => Err(CometError::Internal(format!( "Unsupported dictionary value type for direct buffer write: {:?}", @@ -1781,7 +1717,7 @@ fn write_struct_to_buffer_typed( buffer[field_offset..field_offset + 8].copy_from_slice(&value.to_le_bytes()); } else { // Variable-length field - use pre-downcast writer - let var_len = typed_field.write_variable_value(buffer, row_idx, struct_start)?; + let var_len = typed_field.write_variable_value(buffer, row_idx)?; if var_len > 0 { let padded_len = round_up_to_8(var_len); let data_offset = buffer.len() - padded_len - struct_start; @@ -1829,48 +1765,45 @@ fn write_struct_to_buffer( let field_offset = struct_start + nested_bitset_width + field_idx * 8; // Inline type dispatch for fixed-width types (most common case) - let value = match data_type { + let value: Option = match data_type { DataType::Boolean => { - let arr = column.as_any().downcast_ref::().unwrap(); + let arr = downcast_array!(column, BooleanArray)?; Some(if arr.value(row_idx) { 1i64 } else { 0i64 }) } DataType::Int8 => { - let arr = column.as_any().downcast_ref::().unwrap(); + let arr = downcast_array!(column, Int8Array)?; Some(arr.value(row_idx) as i64) } DataType::Int16 => { - let arr = column.as_any().downcast_ref::().unwrap(); + let arr = downcast_array!(column, Int16Array)?; Some(arr.value(row_idx) as i64) } DataType::Int32 => { - let arr = column.as_any().downcast_ref::().unwrap(); + let arr = downcast_array!(column, Int32Array)?; Some(arr.value(row_idx) as i64) } DataType::Int64 => { - let arr = column.as_any().downcast_ref::().unwrap(); + let arr = downcast_array!(column, Int64Array)?; Some(arr.value(row_idx)) } DataType::Float32 => { - let arr = column.as_any().downcast_ref::().unwrap(); - Some((arr.value(row_idx).to_bits() as i32) as i64) + let arr = downcast_array!(column, Float32Array)?; + Some(arr.value(row_idx).to_bits() as i64) } DataType::Float64 => { - let arr = column.as_any().downcast_ref::().unwrap(); + let arr = downcast_array!(column, Float64Array)?; Some(arr.value(row_idx).to_bits() as i64) } DataType::Date32 => { - let arr = column.as_any().downcast_ref::().unwrap(); + let arr = downcast_array!(column, Date32Array)?; Some(arr.value(row_idx) as i64) } DataType::Timestamp(TimeUnit::Microsecond, _) => { - let arr = column - .as_any() - .downcast_ref::() - .unwrap(); + let arr = downcast_array!(column, TimestampMicrosecondArray)?; Some(arr.value(row_idx)) } DataType::Decimal128(p, _) if *p <= MAX_LONG_DIGITS => { - let arr = column.as_any().downcast_ref::().unwrap(); + let arr = downcast_array!(column, Decimal128Array)?; Some(arr.value(row_idx) as i64) } _ => None, // Variable-length type @@ -1881,13 +1814,7 @@ fn write_struct_to_buffer( buffer[field_offset..field_offset + 8].copy_from_slice(&v.to_le_bytes()); } else { // Variable-length field - let var_len = write_nested_variable_to_buffer( - buffer, - data_type, - column, - row_idx, - struct_start, - )?; + let var_len = write_nested_variable_to_buffer(buffer, data_type, column, row_idx)?; if var_len > 0 { let padded_len = round_up_to_8(var_len); let data_offset = buffer.len() - padded_len - struct_start; @@ -2016,136 +1943,45 @@ fn write_nested_variable_to_buffer( data_type: &DataType, array: &ArrayRef, row_idx: usize, - _base_offset: usize, ) -> CometResult { let actual_type = array.data_type(); match actual_type { DataType::Utf8 => { - let arr = array - .as_any() - .downcast_ref::() - .ok_or_else(|| { - CometError::Internal(format!( - "Failed to downcast to StringArray for type {:?}", - actual_type - )) - })?; - let bytes = arr.value(row_idx).as_bytes(); - let len = bytes.len(); - buffer.extend_from_slice(bytes); - let padding = round_up_to_8(len) - len; - buffer.extend(std::iter::repeat_n(0u8, padding)); - Ok(len) + let arr = downcast_array!(array, StringArray)?; + Ok(write_bytes_padded(buffer, arr.value(row_idx).as_bytes())) } DataType::LargeUtf8 => { - let arr = array - .as_any() - .downcast_ref::() - .ok_or_else(|| { - CometError::Internal(format!( - "Failed to downcast to LargeStringArray for type {:?}", - actual_type - )) - })?; - let bytes = arr.value(row_idx).as_bytes(); - let len = bytes.len(); - buffer.extend_from_slice(bytes); - let padding = round_up_to_8(len) - len; - buffer.extend(std::iter::repeat_n(0u8, padding)); - Ok(len) + let arr = downcast_array!(array, LargeStringArray)?; + Ok(write_bytes_padded(buffer, arr.value(row_idx).as_bytes())) } DataType::Binary => { - let arr = array - .as_any() - .downcast_ref::() - .ok_or_else(|| { - CometError::Internal(format!( - "Failed to downcast to BinaryArray for type {:?}", - actual_type - )) - })?; - let bytes = arr.value(row_idx); - let len = bytes.len(); - buffer.extend_from_slice(bytes); - let padding = round_up_to_8(len) - len; - buffer.extend(std::iter::repeat_n(0u8, padding)); - Ok(len) + let arr = downcast_array!(array, BinaryArray)?; + Ok(write_bytes_padded(buffer, arr.value(row_idx))) } DataType::LargeBinary => { - let arr = array - .as_any() - .downcast_ref::() - .ok_or_else(|| { - CometError::Internal(format!( - "Failed to downcast to LargeBinaryArray for type {:?}", - actual_type - )) - })?; - let bytes = arr.value(row_idx); - let len = bytes.len(); - buffer.extend_from_slice(bytes); - let padding = round_up_to_8(len) - len; - buffer.extend(std::iter::repeat_n(0u8, padding)); - Ok(len) + let arr = downcast_array!(array, LargeBinaryArray)?; + Ok(write_bytes_padded(buffer, arr.value(row_idx))) } DataType::Decimal128(precision, _) if *precision > MAX_LONG_DIGITS => { - let arr = array - .as_any() - .downcast_ref::() - .ok_or_else(|| { - CometError::Internal(format!( - "Failed to downcast to Decimal128Array for type {:?}", - actual_type - )) - })?; + let arr = downcast_array!(array, Decimal128Array)?; let bytes = i128_to_spark_decimal_bytes(arr.value(row_idx)); - let len = bytes.len(); - buffer.extend_from_slice(&bytes); - let padding = round_up_to_8(len) - len; - buffer.extend(std::iter::repeat_n(0u8, padding)); - Ok(len) + Ok(write_bytes_padded(buffer, &bytes)) } DataType::Struct(fields) => { - let struct_array = array - .as_any() - .downcast_ref::() - .ok_or_else(|| { - CometError::Internal(format!( - "Failed to downcast to StructArray for type {:?}", - actual_type - )) - })?; + let struct_array = downcast_array!(array, StructArray)?; write_struct_to_buffer(buffer, struct_array, row_idx, fields) } DataType::List(field) => { - let list_array = array.as_any().downcast_ref::().ok_or_else(|| { - CometError::Internal(format!( - "Failed to downcast to ListArray for type {:?}", - actual_type - )) - })?; + let list_array = downcast_array!(array, ListArray)?; write_list_to_buffer(buffer, list_array, row_idx, field) } DataType::LargeList(field) => { - let list_array = array - .as_any() - .downcast_ref::() - .ok_or_else(|| { - CometError::Internal(format!( - "Failed to downcast to LargeListArray for type {:?}", - actual_type - )) - })?; + let list_array = downcast_array!(array, LargeListArray)?; write_large_list_to_buffer(buffer, list_array, row_idx, field) } DataType::Map(field, _) => { - let map_array = array.as_any().downcast_ref::().ok_or_else(|| { - CometError::Internal(format!( - "Failed to downcast to MapArray for type {:?}", - actual_type - )) - })?; + let map_array = downcast_array!(array, MapArray)?; write_map_to_buffer(buffer, map_array, row_idx, field) } DataType::Dictionary(key_type, value_type) => { @@ -2748,4 +2584,163 @@ mod tests { assert_eq!(value, i as i32, "element {} should be {}", i, i); } } + + #[test] + fn test_convert_fixed_size_binary_array() { + // FixedSizeBinary(3) - each value is exactly 3 bytes + let schema = vec![DataType::FixedSizeBinary(3)]; + let mut ctx = ColumnarToRowContext::new(schema, 100); + + let array: ArrayRef = Arc::new(FixedSizeBinaryArray::from(vec![ + Some(&[1u8, 2, 3][..]), + Some(&[4u8, 5, 6][..]), + None, // Test null handling + ])); + let arrays = vec![array]; + + let (ptr, offsets, lengths) = ctx.convert(&arrays, 3).unwrap(); + + assert!(!ptr.is_null()); + assert_eq!(offsets.len(), 3); + assert_eq!(lengths.len(), 3); + + // Row 0: 8 (bitset) + 8 (field slot) + 8 (aligned 3-byte data) = 24 + // Row 1: 8 (bitset) + 8 (field slot) + 8 (aligned 3-byte data) = 24 + // Row 2: 8 (bitset) + 8 (field slot) = 16 (null, no variable data) + assert_eq!(lengths[0], 24); + assert_eq!(lengths[1], 24); + assert_eq!(lengths[2], 16); + + // Verify the data is correct for non-null rows + unsafe { + let row0 = + std::slice::from_raw_parts(ptr.add(offsets[0] as usize), lengths[0] as usize); + // Variable data starts at offset 16 (8 bitset + 8 field slot) + assert_eq!(&row0[16..19], &[1u8, 2, 3]); + + let row1 = + std::slice::from_raw_parts(ptr.add(offsets[1] as usize), lengths[1] as usize); + assert_eq!(&row1[16..19], &[4u8, 5, 6]); + } + } + + #[test] + fn test_convert_dictionary_decimal_array() { + // Test that dictionary-encoded decimals are correctly unpacked and converted + // This tests the fix for casting to schema_type instead of value_type + use arrow::datatypes::Int8Type; + + // Create a dictionary array with Decimal128 values + // Values: [-0.01, -0.02, -0.03] represented as [-1, -2, -3] with scale 2 + let values = Decimal128Array::from(vec![-1i128, -2, -3]) + .with_precision_and_scale(5, 2) + .unwrap(); + + // Keys: [0, 1, 2, 0, 1, 2] - each value appears twice + let keys = Int8Array::from(vec![0i8, 1, 2, 0, 1, 2]); + + let dict_array: ArrayRef = + Arc::new(DictionaryArray::::try_new(keys, Arc::new(values)).unwrap()); + + // Schema expects Decimal128(5, 2) - not a dictionary type + let schema = vec![DataType::Decimal128(5, 2)]; + let mut ctx = ColumnarToRowContext::new(schema, 100); + + let arrays = vec![dict_array]; + let (ptr, offsets, lengths) = ctx.convert(&arrays, 6).unwrap(); + + assert!(!ptr.is_null()); + assert_eq!(offsets.len(), 6); + assert_eq!(lengths.len(), 6); + + // Verify the decimal values are correct (not doubled or otherwise corrupted) + // Fixed-width decimal is stored directly in the 8-byte field slot + unsafe { + for (i, expected) in [-1i64, -2, -3, -1, -2, -3].iter().enumerate() { + let row = + std::slice::from_raw_parts(ptr.add(offsets[i] as usize), lengths[i] as usize); + // Field value starts at offset 8 (after null bitset) + let value = i64::from_le_bytes(row[8..16].try_into().unwrap()); + assert_eq!( + value, *expected, + "Row {} should have value {}, got {}", + i, expected, value + ); + } + } + } + + #[test] + fn test_convert_int32_to_decimal128() { + // Test that Int32 arrays are correctly cast to Decimal128 when schema expects Decimal128. + // This can happen when COMET_USE_DECIMAL_128 is false and the parquet reader produces + // Int32 for small-precision decimals. + + // Create an Int32 array representing decimals: [-1, -2, -3] which at scale 2 means + // [-0.01, -0.02, -0.03] + let int_array: ArrayRef = Arc::new(Int32Array::from(vec![-1i32, -2, -3])); + + // Schema expects Decimal128(5, 2) + let schema = vec![DataType::Decimal128(5, 2)]; + let mut ctx = ColumnarToRowContext::new(schema, 100); + + let arrays = vec![int_array]; + let (ptr, offsets, lengths) = ctx.convert(&arrays, 3).unwrap(); + + assert!(!ptr.is_null()); + assert_eq!(offsets.len(), 3); + assert_eq!(lengths.len(), 3); + + // Verify the decimal values are correct after casting + // Fixed-width decimal is stored directly in the 8-byte field slot + unsafe { + for (i, expected) in [-1i64, -2, -3].iter().enumerate() { + let row = + std::slice::from_raw_parts(ptr.add(offsets[i] as usize), lengths[i] as usize); + // Field value starts at offset 8 (after null bitset) + let value = i64::from_le_bytes(row[8..16].try_into().unwrap()); + assert_eq!( + value, *expected, + "Row {} should have value {}, got {}", + i, expected, value + ); + } + } + } + + #[test] + fn test_convert_int64_to_decimal128() { + // Test that Int64 arrays are correctly cast to Decimal128 when schema expects Decimal128. + // This can happen when COMET_USE_DECIMAL_128 is false and the parquet reader produces + // Int64 for medium-precision decimals. + + // Create an Int64 array representing decimals + let int_array: ArrayRef = Arc::new(Int64Array::from(vec![-100i64, -200, -300])); + + // Schema expects Decimal128(10, 2) + let schema = vec![DataType::Decimal128(10, 2)]; + let mut ctx = ColumnarToRowContext::new(schema, 100); + + let arrays = vec![int_array]; + let (ptr, offsets, lengths) = ctx.convert(&arrays, 3).unwrap(); + + assert!(!ptr.is_null()); + assert_eq!(offsets.len(), 3); + assert_eq!(lengths.len(), 3); + + // Verify the decimal values are correct after casting + unsafe { + for (i, expected) in [-100i64, -200, -300].iter().enumerate() { + let row = + std::slice::from_raw_parts(ptr.add(offsets[i] as usize), lengths[i] as usize); + // Field value starts at offset 8 (after null bitset) + let value = i64::from_le_bytes(row[8..16].try_into().unwrap()); + assert_eq!( + value, *expected, + "Row {} should have value {}, got {}", + i, expected, value + ); + } + } + } } diff --git a/spark/src/main/scala/org/apache/comet/rules/EliminateRedundantTransitions.scala b/spark/src/main/scala/org/apache/comet/rules/EliminateRedundantTransitions.scala index 7402a83248..d1c3b07677 100644 --- a/spark/src/main/scala/org/apache/comet/rules/EliminateRedundantTransitions.scala +++ b/spark/src/main/scala/org/apache/comet/rules/EliminateRedundantTransitions.scala @@ -22,13 +22,14 @@ package org.apache.comet.rules import org.apache.spark.sql.SparkSession import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.catalyst.util.sideBySide -import org.apache.spark.sql.comet.{CometCollectLimitExec, CometColumnarToRowExec, CometNativeColumnarToRowExec, CometNativeWriteExec, CometPlan, CometSparkToColumnarExec} +import org.apache.spark.sql.comet.{CometBatchScanExec, CometCollectLimitExec, CometColumnarToRowExec, CometNativeColumnarToRowExec, CometNativeWriteExec, CometPlan, CometScanExec, CometSparkToColumnarExec} import org.apache.spark.sql.comet.execution.shuffle.{CometColumnarShuffle, CometShuffleExchangeExec} import org.apache.spark.sql.execution.{ColumnarToRowExec, RowToColumnarExec, SparkPlan} import org.apache.spark.sql.execution.adaptive.QueryStageExec import org.apache.spark.sql.execution.exchange.ReusedExchangeExec import org.apache.comet.CometConf +import org.apache.comet.parquet.CometParquetScan // This rule is responsible for eliminating redundant transitions between row-based and // columnar-based operators for Comet. Currently, three potential redundant transitions are: @@ -139,7 +140,8 @@ case class EliminateRedundantTransitions(session: SparkSession) extends Rule[Spa private def createColumnarToRowExec(child: SparkPlan): SparkPlan = { val schema = child.schema val useNative = CometConf.COMET_NATIVE_COLUMNAR_TO_ROW_ENABLED.get() && - CometNativeColumnarToRowExec.supportsSchema(schema) + CometNativeColumnarToRowExec.supportsSchema(schema) && + !hasScanUsingMutableBuffers(child) if (useNative) { CometNativeColumnarToRowExec(child) @@ -147,4 +149,30 @@ case class EliminateRedundantTransitions(session: SparkSession) extends Rule[Spa CometColumnarToRowExec(child) } } + + /** + * Checks if the plan contains a scan that uses mutable buffers. Native C2R is not compatible + * with such scans because the buffers may be modified after C2R reads them. + * + * This includes: + * - CometScanExec with native_comet scan implementation (V1 path) - uses BatchReader + * - CometScanExec with native_iceberg_compat and partition columns - uses + * ConstantColumnReader + * - CometBatchScanExec with CometParquetScan (V2 Parquet path) - uses BatchReader + */ + private def hasScanUsingMutableBuffers(op: SparkPlan): Boolean = { + op match { + case c: QueryStageExec => hasScanUsingMutableBuffers(c.plan) + case c: ReusedExchangeExec => hasScanUsingMutableBuffers(c.child) + case _ => + op.exists { + case scan: CometScanExec => + scan.scanImpl == CometConf.SCAN_NATIVE_COMET || + (scan.scanImpl == CometConf.SCAN_NATIVE_ICEBERG_COMPAT && + scan.relation.partitionSchema.nonEmpty) + case scan: CometBatchScanExec => scan.scan.isInstanceOf[CometParquetScan] + case _ => false + } + } + } } diff --git a/spark/src/main/scala/org/apache/spark/sql/comet/CometNativeColumnarToRowExec.scala b/spark/src/main/scala/org/apache/spark/sql/comet/CometNativeColumnarToRowExec.scala index 93526573c0..a520098ed1 100644 --- a/spark/src/main/scala/org/apache/spark/sql/comet/CometNativeColumnarToRowExec.scala +++ b/spark/src/main/scala/org/apache/spark/sql/comet/CometNativeColumnarToRowExec.scala @@ -19,15 +19,25 @@ package org.apache.spark.sql.comet -import org.apache.spark.TaskContext +import java.util.UUID +import java.util.concurrent.{Future, TimeoutException, TimeUnit} + +import scala.concurrent.Promise +import scala.util.control.NonFatal + +import org.apache.spark.{broadcast, SparkException, TaskContext} import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.{Attribute, SortOrder} import org.apache.spark.sql.catalyst.plans.physical.Partitioning -import org.apache.spark.sql.execution.{ColumnarToRowTransition, SparkPlan} +import org.apache.spark.sql.comet.util.{Utils => CometUtils} +import org.apache.spark.sql.errors.QueryExecutionErrors +import org.apache.spark.sql.execution.{ColumnarToRowTransition, SparkPlan, SQLExecution} +import org.apache.spark.sql.execution.adaptive.BroadcastQueryStageExec +import org.apache.spark.sql.execution.exchange.ReusedExchangeExec import org.apache.spark.sql.execution.metric.{SQLMetric, SQLMetrics} import org.apache.spark.sql.types.StructType -import org.apache.spark.util.Utils +import org.apache.spark.util.{SparkFatalException, Utils} import org.apache.comet.{CometConf, NativeColumnarToRowConverter} @@ -64,6 +74,116 @@ case class CometNativeColumnarToRowExec(child: SparkPlan) "numInputBatches" -> SQLMetrics.createMetric(sparkContext, "number of input batches"), "convertTime" -> SQLMetrics.createNanoTimingMetric(sparkContext, "time in conversion")) + @transient + private lazy val promise = Promise[broadcast.Broadcast[Any]]() + + @transient + private val timeout: Long = conf.broadcastTimeout + + private val runId: UUID = UUID.randomUUID + + private lazy val cometBroadcastExchange = findCometBroadcastExchange(child) + + @transient + lazy val relationFuture: Future[broadcast.Broadcast[Any]] = { + SQLExecution.withThreadLocalCaptured[broadcast.Broadcast[Any]]( + session, + CometBroadcastExchangeExec.executionContext) { + try { + // Setup a job group here so later it may get cancelled by groupId if necessary. + sparkContext.setJobGroup( + runId.toString, + s"CometNativeColumnarToRow broadcast exchange (runId $runId)", + interruptOnCancel = true) + + val numOutputRows = longMetric("numOutputRows") + val numInputBatches = longMetric("numInputBatches") + val localSchema = this.schema + val batchSize = CometConf.COMET_BATCH_SIZE.get() + val broadcastColumnar = child.executeBroadcast() + val serializedBatches = + broadcastColumnar.value.asInstanceOf[Array[org.apache.spark.util.io.ChunkedByteBuffer]] + + // Use native converter to convert columnar data to rows + val converter = new NativeColumnarToRowConverter(localSchema, batchSize) + try { + val rows = serializedBatches.iterator + .flatMap(CometUtils.decodeBatches(_, this.getClass.getSimpleName)) + .flatMap { batch => + numInputBatches += 1 + numOutputRows += batch.numRows() + val result = converter.convert(batch) + // Wrap iterator to close batch after consumption + new Iterator[InternalRow] { + override def hasNext: Boolean = { + val hasMore = result.hasNext + if (!hasMore) { + batch.close() + } + hasMore + } + override def next(): InternalRow = result.next() + } + } + + val mode = cometBroadcastExchange.get.mode + val relation = mode.transform(rows, Some(numOutputRows.value)) + val broadcasted = sparkContext.broadcastInternal(relation, serializedOnly = true) + val executionId = sparkContext.getLocalProperty(SQLExecution.EXECUTION_ID_KEY) + SQLMetrics.postDriverMetricUpdates(sparkContext, executionId, metrics.values.toSeq) + promise.trySuccess(broadcasted) + broadcasted + } finally { + converter.close() + } + } catch { + // SPARK-24294: To bypass scala bug: https://github.com/scala/bug/issues/9554, we throw + // SparkFatalException, which is a subclass of Exception. ThreadUtils.awaitResult + // will catch this exception and re-throw the wrapped fatal throwable. + case oe: OutOfMemoryError => + val ex = new SparkFatalException(oe) + promise.tryFailure(ex) + throw ex + case e if !NonFatal(e) => + val ex = new SparkFatalException(e) + promise.tryFailure(ex) + throw ex + case e: Throwable => + promise.tryFailure(e) + throw e + } + } + } + + override def doExecuteBroadcast[T](): broadcast.Broadcast[T] = { + if (cometBroadcastExchange.isEmpty) { + throw new SparkException( + "CometNativeColumnarToRowExec only supports doExecuteBroadcast when child contains a " + + "CometBroadcastExchange, but got " + child) + } + + try { + relationFuture.get(timeout, TimeUnit.SECONDS).asInstanceOf[broadcast.Broadcast[T]] + } catch { + case ex: TimeoutException => + logError(s"Could not execute broadcast in $timeout secs.", ex) + if (!relationFuture.isDone) { + sparkContext.cancelJobGroup(runId.toString) + relationFuture.cancel(true) + } + throw QueryExecutionErrors.executeBroadcastTimeoutError(timeout, Some(ex)) + } + } + + private def findCometBroadcastExchange(op: SparkPlan): Option[CometBroadcastExchangeExec] = { + op match { + case b: CometBroadcastExchangeExec => Some(b) + case b: BroadcastQueryStageExec => findCometBroadcastExchange(b.plan) + case b: ReusedExchangeExec => findCometBroadcastExchange(b.child) + case _ => op.children.collectFirst(Function.unlift(findCometBroadcastExchange)) + } + } + override def doExecute(): RDD[InternalRow] = { val numOutputRows = longMetric("numOutputRows") val numInputBatches = longMetric("numInputBatches") @@ -91,7 +211,17 @@ case class CometNativeColumnarToRowExec(child: SparkPlan) val result = converter.convert(batch) convertTime += System.nanoTime() - startTime - result + // Wrap iterator to close batch after consumption + new Iterator[InternalRow] { + override def hasNext: Boolean = { + val hasMore = result.hasNext + if (!hasMore) { + batch.close() + } + hasMore + } + override def next(): InternalRow = result.next() + } } } } diff --git a/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala b/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala index e0a5c43aef..fe5ea77a89 100644 --- a/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala +++ b/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala @@ -30,8 +30,8 @@ import org.apache.hadoop.fs.Path import org.apache.spark.sql.{CometTestBase, DataFrame, Row} import org.apache.spark.sql.catalyst.expressions.{Alias, Cast, FromUnixTime, Literal, TruncDate, TruncTimestamp} import org.apache.spark.sql.catalyst.optimizer.SimplifyExtractValueOps -import org.apache.spark.sql.comet.{CometColumnarToRowExec, CometProjectExec} -import org.apache.spark.sql.execution.{InputAdapter, ProjectExec, SparkPlan, WholeStageCodegenExec} +import org.apache.spark.sql.comet.{CometNativeColumnarToRowExec, CometProjectExec} +import org.apache.spark.sql.execution.{ProjectExec, SparkPlan} import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper import org.apache.spark.sql.functions._ import org.apache.spark.sql.internal.SQLConf @@ -1020,11 +1020,7 @@ class CometExpressionSuite extends CometTestBase with AdaptiveSparkPlanHelper { val query = sql(s"select cast(id as string) from $table") val (_, cometPlan) = checkSparkAnswerAndOperator(query) val project = cometPlan - .asInstanceOf[WholeStageCodegenExec] - .child - .asInstanceOf[CometColumnarToRowExec] - .child - .asInstanceOf[InputAdapter] + .asInstanceOf[CometNativeColumnarToRowExec] .child .asInstanceOf[CometProjectExec] val id = project.expressions.head diff --git a/spark/src/test/scala/org/apache/comet/exec/CometExecSuite.scala b/spark/src/test/scala/org/apache/comet/exec/CometExecSuite.scala index 1b2373ad71..696a12d4a2 100644 --- a/spark/src/test/scala/org/apache/comet/exec/CometExecSuite.scala +++ b/spark/src/test/scala/org/apache/comet/exec/CometExecSuite.scala @@ -35,7 +35,7 @@ import org.apache.spark.sql.catalyst.expressions.{Expression, ExpressionInfo, He import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateMode, BloomFilterAggregate} import org.apache.spark.sql.comet._ import org.apache.spark.sql.comet.execution.shuffle.{CometColumnarShuffle, CometShuffleExchangeExec} -import org.apache.spark.sql.execution.{CollectLimitExec, ProjectExec, SQLExecution, UnionExec} +import org.apache.spark.sql.execution.{CollectLimitExec, ProjectExec, SparkPlan, SQLExecution, UnionExec} import org.apache.spark.sql.execution.adaptive.{AdaptiveSparkPlanExec, BroadcastQueryStageExec} import org.apache.spark.sql.execution.datasources.parquet.ParquetFileFormat import org.apache.spark.sql.execution.exchange.{BroadcastExchangeExec, ReusedExchangeExec, ShuffleExchangeExec} @@ -864,9 +864,11 @@ class CometExecSuite extends CometTestBase { checkSparkAnswerAndOperator(df) // Before AQE: one CometBroadcastExchange, no CometColumnarToRow - var columnarToRowExec = stripAQEPlan(df.queryExecution.executedPlan).collect { - case s: CometColumnarToRowExec => s - } + var columnarToRowExec: Seq[SparkPlan] = + stripAQEPlan(df.queryExecution.executedPlan).collect { + case s: CometColumnarToRowExec => s + case s: CometNativeColumnarToRowExec => s + } assert(columnarToRowExec.isEmpty) // Disable CometExecRule after the initial plan is generated. The CometSortMergeJoin and @@ -880,14 +882,25 @@ class CometExecSuite extends CometTestBase { // After AQE: CometBroadcastExchange has to be converted to rows to conform to Spark // BroadcastHashJoin. val plan = stripAQEPlan(df.queryExecution.executedPlan) - columnarToRowExec = plan.collect { case s: CometColumnarToRowExec => - s + columnarToRowExec = plan.collect { + case s: CometColumnarToRowExec => s + case s: CometNativeColumnarToRowExec => s } assert(columnarToRowExec.length == 1) - // This ColumnarToRowExec should be the immediate child of BroadcastHashJoinExec - val parent = plan.find(_.children.contains(columnarToRowExec.head)) - assert(parent.get.isInstanceOf[BroadcastHashJoinExec]) + // This ColumnarToRowExec should be a descendant of BroadcastHashJoinExec (possibly + // wrapped by InputAdapter for codegen). + val broadcastJoins = plan.collect { case b: BroadcastHashJoinExec => b } + assert(broadcastJoins.nonEmpty, s"Expected BroadcastHashJoinExec in plan:\n$plan") + val hasC2RDescendant = broadcastJoins.exists { join => + join.find { + case _: CometColumnarToRowExec | _: CometNativeColumnarToRowExec => true + case _ => false + }.isDefined + } + assert( + hasC2RDescendant, + "BroadcastHashJoinExec should have a columnar-to-row descendant") // There should be a CometBroadcastExchangeExec under CometColumnarToRowExec val broadcastQueryStage = diff --git a/spark/src/test/scala/org/apache/spark/sql/CometTestBase.scala b/spark/src/test/scala/org/apache/spark/sql/CometTestBase.scala index 81ac72247f..73ffb244f7 100644 --- a/spark/src/test/scala/org/apache/spark/sql/CometTestBase.scala +++ b/spark/src/test/scala/org/apache/spark/sql/CometTestBase.scala @@ -80,6 +80,7 @@ abstract class CometTestBase conf.set(CometConf.COMET_ONHEAP_ENABLED.key, "true") conf.set(CometConf.COMET_EXEC_ENABLED.key, "true") conf.set(CometConf.COMET_EXEC_SHUFFLE_ENABLED.key, "true") + conf.set(CometConf.COMET_NATIVE_COLUMNAR_TO_ROW_ENABLED.key, "true") conf.set(CometConf.COMET_RESPECT_PARQUET_FILTER_PUSHDOWN.key, "true") conf.set(CometConf.COMET_SPARK_TO_ARROW_ENABLED.key, "true") conf.set(CometConf.COMET_NATIVE_SCAN_ENABLED.key, "true") diff --git a/spark/src/test/scala/org/apache/spark/sql/comet/CometPlanChecker.scala b/spark/src/test/scala/org/apache/spark/sql/comet/CometPlanChecker.scala index 7caac71351..c8c4baff4a 100644 --- a/spark/src/test/scala/org/apache/spark/sql/comet/CometPlanChecker.scala +++ b/spark/src/test/scala/org/apache/spark/sql/comet/CometPlanChecker.scala @@ -46,7 +46,7 @@ trait CometPlanChecker { case _: CometNativeScanExec | _: CometScanExec | _: CometBatchScanExec | _: CometIcebergNativeScanExec => case _: CometSinkPlaceHolder | _: CometScanWrapper => - case _: CometColumnarToRowExec => + case _: CometColumnarToRowExec | _: CometNativeColumnarToRowExec => case _: CometSparkToColumnarExec => case _: CometExec | _: CometShuffleExchangeExec => case _: CometBroadcastExchangeExec =>