diff --git a/native/spark-expr/src/conversion_funcs/cast.rs b/native/spark-expr/src/conversion_funcs/cast.rs index 9ccfc3e6af..086af2d312 100644 --- a/native/spark-expr/src/conversion_funcs/cast.rs +++ b/native/spark-expr/src/conversion_funcs/cast.rs @@ -16,11 +16,12 @@ // under the License. use crate::utils::array_with_timezone; +use crate::EvalMode::Legacy; use crate::{timezone, BinaryOutputStyle}; use crate::{EvalMode, SparkError, SparkResult}; use arrow::array::builder::StringBuilder; use arrow::array::{ - BooleanBuilder, Decimal128Builder, DictionaryArray, GenericByteArray, ListArray, + BinaryBuilder, BooleanBuilder, Decimal128Builder, DictionaryArray, GenericByteArray, ListArray, PrimitiveBuilder, StringArray, StructArray, }; use arrow::compute::can_cast_types; @@ -304,14 +305,17 @@ fn can_cast_from_timestamp(to_type: &DataType, _options: &SparkCastOptions) -> b fn can_cast_from_boolean(to_type: &DataType, _: &SparkCastOptions) -> bool { use DataType::*; - matches!(to_type, Int8 | Int16 | Int32 | Int64 | Float32 | Float64) + matches!( + to_type, + Int8 | Int16 | Int32 | Int64 | Float32 | Float64 | Decimal128(_, _) + ) } fn can_cast_from_byte(to_type: &DataType, _: &SparkCastOptions) -> bool { use DataType::*; matches!( to_type, - Boolean | Int8 | Int16 | Int32 | Int64 | Float32 | Float64 | Decimal128(_, _) + Boolean | Int8 | Int16 | Int32 | Int64 | Float32 | Float64 | Decimal128(_, _) | Binary ) } @@ -319,14 +323,14 @@ fn can_cast_from_short(to_type: &DataType, _: &SparkCastOptions) -> bool { use DataType::*; matches!( to_type, - Boolean | Int8 | Int16 | Int32 | Int64 | Float32 | Float64 | Decimal128(_, _) + Boolean | Int8 | Int16 | Int32 | Int64 | Float32 | Float64 | Decimal128(_, _) | Binary ) } fn can_cast_from_int(to_type: &DataType, options: &SparkCastOptions) -> bool { use DataType::*; match to_type { - Boolean | Int8 | Int16 | Int32 | Int64 | Float32 | Float64 | Utf8 => true, + Boolean | Int8 | Int16 | Int32 | Int64 | Float32 | Float64 | Utf8 | Binary => true, Decimal128(_, _) => { // incompatible: no overflow check options.allow_incompat @@ -338,7 +342,7 @@ fn can_cast_from_int(to_type: &DataType, options: &SparkCastOptions) -> bool { fn can_cast_from_long(to_type: &DataType, options: &SparkCastOptions) -> bool { use DataType::*; match to_type { - Boolean | Int8 | Int16 | Int32 | Int64 | Float32 | Float64 => true, + Boolean | Int8 | Int16 | Int32 | Int64 | Float32 | Float64 | Binary => true, Decimal128(_, _) => { // incompatible: no overflow check options.allow_incompat @@ -501,6 +505,29 @@ macro_rules! cast_float_to_string { }}; } +// eval mode is not needed since all ints can be implemented in binary format +macro_rules! cast_whole_num_to_binary { + ($array:expr, $primitive_type:ty, $byte_size:expr) => {{ + let input_arr = $array + .as_any() + .downcast_ref::<$primitive_type>() + .ok_or_else(|| SparkError::Internal("Expected numeric array".to_string()))?; + + let len = input_arr.len(); + let mut builder = BinaryBuilder::with_capacity(len, len * $byte_size); + + for i in 0..input_arr.len() { + if input_arr.is_null(i) { + builder.append_null(); + } else { + builder.append_value(input_arr.value(i).to_be_bytes()); + } + } + + Ok(Arc::new(builder.finish()) as ArrayRef) + }}; +} + macro_rules! cast_int_to_int_macro { ( $array: expr, @@ -1100,6 +1127,19 @@ fn cast_array( Ok(cast_with_options(&array, to_type, &CAST_OPTIONS)?) } (Binary, Utf8) => Ok(cast_binary_to_string::(&array, cast_options)?), + (Int8, Binary) if (eval_mode == Legacy) => cast_whole_num_to_binary!(&array, Int8Array, 1), + (Int16, Binary) if (eval_mode == Legacy) => { + cast_whole_num_to_binary!(&array, Int16Array, 2) + } + (Int32, Binary) if (eval_mode == Legacy) => { + cast_whole_num_to_binary!(&array, Int32Array, 4) + } + (Int64, Binary) if (eval_mode == Legacy) => { + cast_whole_num_to_binary!(&array, Int64Array, 8) + } + (Boolean, Decimal128(precision, scale)) => { + cast_boolean_to_decimal(&array, *precision, *scale) + } _ if cast_options.is_adapting_schema || is_datafusion_spark_compatible(from_type, to_type) => { @@ -1118,6 +1158,16 @@ fn cast_array( Ok(spark_cast_postprocess(cast_result?, from_type, to_type)) } +fn cast_boolean_to_decimal(array: &ArrayRef, precision: u8, scale: i8) -> SparkResult { + let bool_array = array.as_boolean(); + let scaled_value = 10_i128.pow(scale as u32); + let result: Decimal128Array = bool_array + .iter() + .map(|v| v.map(|b| if b { scaled_value } else { 0 })) + .collect(); + Ok(Arc::new(result.with_precision_and_scale(precision, scale)?)) +} + fn cast_string_to_float( array: &ArrayRef, to_type: &DataType, diff --git a/spark/src/main/scala/org/apache/comet/expressions/CometCast.scala b/spark/src/main/scala/org/apache/comet/expressions/CometCast.scala index 9fc4b3afdf..853a288eb7 100644 --- a/spark/src/main/scala/org/apache/comet/expressions/CometCast.scala +++ b/spark/src/main/scala/org/apache/comet/expressions/CometCast.scala @@ -147,13 +147,13 @@ object CometCast extends CometExpressionSerde[Cast] with CometExprShim { case (DataTypes.BooleanType, _) => canCastFromBoolean(toType) case (DataTypes.ByteType, _) => - canCastFromByte(toType) + canCastFromByte(toType, evalMode) case (DataTypes.ShortType, _) => - canCastFromShort(toType) + canCastFromShort(toType, evalMode) case (DataTypes.IntegerType, _) => - canCastFromInt(toType) + canCastFromInt(toType, evalMode) case (DataTypes.LongType, _) => - canCastFromLong(toType) + canCastFromLong(toType, evalMode) case (DataTypes.FloatType, _) => canCastFromFloat(toType) case (DataTypes.DoubleType, _) => @@ -263,58 +263,68 @@ object CometCast extends CometExpressionSerde[Cast] with CometExprShim { private def canCastFromBoolean(toType: DataType): SupportLevel = toType match { case DataTypes.ByteType | DataTypes.ShortType | DataTypes.IntegerType | DataTypes.LongType | - DataTypes.FloatType | DataTypes.DoubleType => + DataTypes.FloatType | DataTypes.DoubleType | _: DecimalType => Compatible() case _ => unsupported(DataTypes.BooleanType, toType) } - private def canCastFromByte(toType: DataType): SupportLevel = toType match { - case DataTypes.BooleanType => - Compatible() - case DataTypes.ShortType | DataTypes.IntegerType | DataTypes.LongType => - Compatible() - case DataTypes.FloatType | DataTypes.DoubleType | _: DecimalType => - Compatible() - case _ => - unsupported(DataTypes.ByteType, toType) - } + private def canCastFromByte(toType: DataType, evalMode: CometEvalMode.Value): SupportLevel = + toType match { + case DataTypes.BooleanType => + Compatible() + case DataTypes.ShortType | DataTypes.IntegerType | DataTypes.LongType => + Compatible() + case DataTypes.FloatType | DataTypes.DoubleType | _: DecimalType => + Compatible() + case DataTypes.BinaryType if (evalMode == CometEvalMode.LEGACY) => + Compatible() + case _ => + unsupported(DataTypes.ByteType, toType) + } - private def canCastFromShort(toType: DataType): SupportLevel = toType match { - case DataTypes.BooleanType => - Compatible() - case DataTypes.ByteType | DataTypes.IntegerType | DataTypes.LongType => - Compatible() - case DataTypes.FloatType | DataTypes.DoubleType | _: DecimalType => - Compatible() - case _ => - unsupported(DataTypes.ShortType, toType) - } + private def canCastFromShort(toType: DataType, evalMode: CometEvalMode.Value): SupportLevel = + toType match { + case DataTypes.BooleanType => + Compatible() + case DataTypes.ByteType | DataTypes.IntegerType | DataTypes.LongType => + Compatible() + case DataTypes.FloatType | DataTypes.DoubleType | _: DecimalType => + Compatible() + case DataTypes.BinaryType if (evalMode == CometEvalMode.LEGACY) => + Compatible() + case _ => + unsupported(DataTypes.ShortType, toType) + } - private def canCastFromInt(toType: DataType): SupportLevel = toType match { - case DataTypes.BooleanType => - Compatible() - case DataTypes.ByteType | DataTypes.ShortType | DataTypes.LongType => - Compatible() - case DataTypes.FloatType | DataTypes.DoubleType => - Compatible() - case _: DecimalType => - Compatible() - case _ => - unsupported(DataTypes.IntegerType, toType) - } + private def canCastFromInt(toType: DataType, evalMode: CometEvalMode.Value): SupportLevel = + toType match { + case DataTypes.BooleanType => + Compatible() + case DataTypes.ByteType | DataTypes.ShortType | DataTypes.LongType => + Compatible() + case DataTypes.FloatType | DataTypes.DoubleType => + Compatible() + case _: DecimalType => + Compatible() + case DataTypes.BinaryType if (evalMode == CometEvalMode.LEGACY) => Compatible() + case _ => + unsupported(DataTypes.IntegerType, toType) + } - private def canCastFromLong(toType: DataType): SupportLevel = toType match { - case DataTypes.BooleanType => - Compatible() - case DataTypes.ByteType | DataTypes.ShortType | DataTypes.IntegerType => - Compatible() - case DataTypes.FloatType | DataTypes.DoubleType => - Compatible() - case _: DecimalType => - Compatible() - case _ => - unsupported(DataTypes.LongType, toType) - } + private def canCastFromLong(toType: DataType, evalMode: CometEvalMode.Value): SupportLevel = + toType match { + case DataTypes.BooleanType => + Compatible() + case DataTypes.ByteType | DataTypes.ShortType | DataTypes.IntegerType => + Compatible() + case DataTypes.FloatType | DataTypes.DoubleType => + Compatible() + case _: DecimalType => + Compatible() + case DataTypes.BinaryType if (evalMode == CometEvalMode.LEGACY) => Compatible() + case _ => + unsupported(DataTypes.LongType, toType) + } private def canCastFromFloat(toType: DataType): SupportLevel = toType match { case DataTypes.BooleanType | DataTypes.DoubleType | DataTypes.ByteType | DataTypes.ShortType | diff --git a/spark/src/test/scala/org/apache/comet/CometCastSuite.scala b/spark/src/test/scala/org/apache/comet/CometCastSuite.scala index 8a68df3820..9127d85476 100644 --- a/spark/src/test/scala/org/apache/comet/CometCastSuite.scala +++ b/spark/src/test/scala/org/apache/comet/CometCastSuite.scala @@ -135,11 +135,18 @@ class CometCastSuite extends CometTestBase with AdaptiveSparkPlanHelper { castTest(generateBools(), DataTypes.DoubleType) } - ignore("cast BooleanType to DecimalType(10,2)") { - // Arrow error: Cast error: Casting from Boolean to Decimal128(10, 2) not supported + test("cast BooleanType to DecimalType(10,2)") { castTest(generateBools(), DataTypes.createDecimalType(10, 2)) } + test("cast BooleanType to DecimalType(14,4)") { + castTest(generateBools(), DataTypes.createDecimalType(14, 4)) + } + + test("cast BooleanType to DecimalType(30,0)") { + castTest(generateBools(), DataTypes.createDecimalType(30, 0)) + } + test("cast BooleanType to StringType") { castTest(generateBools(), DataTypes.StringType) } @@ -207,11 +214,14 @@ class CometCastSuite extends CometTestBase with AdaptiveSparkPlanHelper { hasIncompatibleType = usingParquetExecWithIncompatTypes) } - ignore("cast ByteType to BinaryType") { + test("cast ByteType to BinaryType") { + // Spark does not support ANSI or Try mode castTest( generateBytes(), DataTypes.BinaryType, - hasIncompatibleType = usingParquetExecWithIncompatTypes) + hasIncompatibleType = usingParquetExecWithIncompatTypes, + testAnsi = false, + testTry = false) } ignore("cast ByteType to TimestampType") { @@ -281,11 +291,14 @@ class CometCastSuite extends CometTestBase with AdaptiveSparkPlanHelper { hasIncompatibleType = usingParquetExecWithIncompatTypes) } - ignore("cast ShortType to BinaryType") { + test("cast ShortType to BinaryType") { +// Spark does not support ANSI or Try mode castTest( generateShorts(), DataTypes.BinaryType, - hasIncompatibleType = usingParquetExecWithIncompatTypes) + hasIncompatibleType = usingParquetExecWithIncompatTypes, + testAnsi = false, + testTry = false) } ignore("cast ShortType to TimestampType") { @@ -346,8 +359,9 @@ class CometCastSuite extends CometTestBase with AdaptiveSparkPlanHelper { castTest(generateInts(), DataTypes.StringType) } - ignore("cast IntegerType to BinaryType") { - castTest(generateInts(), DataTypes.BinaryType) + test("cast IntegerType to BinaryType") { + // Spark does not support ANSI or Try mode + castTest(generateInts(), DataTypes.BinaryType, testAnsi = false, testTry = false) } ignore("cast IntegerType to TimestampType") { @@ -392,8 +406,9 @@ class CometCastSuite extends CometTestBase with AdaptiveSparkPlanHelper { castTest(generateLongs(), DataTypes.StringType) } - ignore("cast LongType to BinaryType") { - castTest(generateLongs(), DataTypes.BinaryType) + test("cast LongType to BinaryType") { + // Spark does not support ANSI or Try mode + castTest(generateLongs(), DataTypes.BinaryType, testAnsi = false, testTry = false) } ignore("cast LongType to TimestampType") { @@ -1329,28 +1344,32 @@ class CometCastSuite extends CometTestBase with AdaptiveSparkPlanHelper { input: DataFrame, toType: DataType, hasIncompatibleType: Boolean = false, - testAnsi: Boolean = true): Unit = { + testAnsi: Boolean = true, + testTry: Boolean = true): Unit = { withTempPath { dir => val data = roundtripParquet(input, dir).coalesce(1) - data.createOrReplaceTempView("t") withSQLConf((SQLConf.ANSI_ENABLED.key, "false")) { // cast() should return null for invalid inputs when ansi mode is disabled - val df = spark.sql(s"select a, cast(a as ${toType.sql}) from t order by a") + val df = data.select(col("a"), col("a").cast(toType)).orderBy(col("a")) if (hasIncompatibleType) { checkSparkAnswer(df) } else { checkSparkAnswerAndOperator(df) } - // try_cast() should always return null for invalid inputs - val df2 = - spark.sql(s"select a, try_cast(a as ${toType.sql}) from t order by a") - if (hasIncompatibleType) { - checkSparkAnswer(df2) - } else { - checkSparkAnswerAndOperator(df2) + if (testTry) { + data.createOrReplaceTempView("t") + // try_cast() should always return null for invalid inputs +// not using spark DSL since it `try_cast` is only available from Spark 4x + val df2 = + spark.sql(s"select a, try_cast(a as ${toType.sql}) from t order by a") + if (hasIncompatibleType) { + checkSparkAnswer(df2) + } else { + checkSparkAnswerAndOperator(df2) + } } } @@ -1408,14 +1427,16 @@ class CometCastSuite extends CometTestBase with AdaptiveSparkPlanHelper { } // try_cast() should always return null for invalid inputs - val df2 = - spark.sql(s"select a, try_cast(a as ${toType.sql}) from t order by a") - if (hasIncompatibleType) { - checkSparkAnswer(df2) - } else { - checkSparkAnswerAndOperator(df2) + if (testTry) { + data.createOrReplaceTempView("t") + val df2 = + spark.sql(s"select a, try_cast(a as ${toType.sql}) from t order by a") + if (hasIncompatibleType) { + checkSparkAnswer(df2) + } else { + checkSparkAnswerAndOperator(df2) + } } - } } }