Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
62 changes: 56 additions & 6 deletions native/spark-expr/src/conversion_funcs/cast.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,12 @@
// under the License.

use crate::utils::array_with_timezone;
use crate::EvalMode::Legacy;
use crate::{timezone, BinaryOutputStyle};
use crate::{EvalMode, SparkError, SparkResult};
use arrow::array::builder::StringBuilder;
use arrow::array::{
BooleanBuilder, Decimal128Builder, DictionaryArray, GenericByteArray, ListArray,
BinaryBuilder, BooleanBuilder, Decimal128Builder, DictionaryArray, GenericByteArray, ListArray,
PrimitiveBuilder, StringArray, StructArray,
};
use arrow::compute::can_cast_types;
Expand Down Expand Up @@ -304,29 +305,32 @@ fn can_cast_from_timestamp(to_type: &DataType, _options: &SparkCastOptions) -> b

fn can_cast_from_boolean(to_type: &DataType, _: &SparkCastOptions) -> bool {
use DataType::*;
matches!(to_type, Int8 | Int16 | Int32 | Int64 | Float32 | Float64)
matches!(
to_type,
Int8 | Int16 | Int32 | Int64 | Float32 | Float64 | Decimal128(_, _)
)
}

fn can_cast_from_byte(to_type: &DataType, _: &SparkCastOptions) -> bool {
use DataType::*;
matches!(
to_type,
Boolean | Int8 | Int16 | Int32 | Int64 | Float32 | Float64 | Decimal128(_, _)
Boolean | Int8 | Int16 | Int32 | Int64 | Float32 | Float64 | Decimal128(_, _) | Binary
)
}

fn can_cast_from_short(to_type: &DataType, _: &SparkCastOptions) -> bool {
use DataType::*;
matches!(
to_type,
Boolean | Int8 | Int16 | Int32 | Int64 | Float32 | Float64 | Decimal128(_, _)
Boolean | Int8 | Int16 | Int32 | Int64 | Float32 | Float64 | Decimal128(_, _) | Binary
)
}

fn can_cast_from_int(to_type: &DataType, options: &SparkCastOptions) -> bool {
use DataType::*;
match to_type {
Boolean | Int8 | Int16 | Int32 | Int64 | Float32 | Float64 | Utf8 => true,
Boolean | Int8 | Int16 | Int32 | Int64 | Float32 | Float64 | Utf8 | Binary => true,
Decimal128(_, _) => {
// incompatible: no overflow check
options.allow_incompat
Expand All @@ -338,7 +342,7 @@ fn can_cast_from_int(to_type: &DataType, options: &SparkCastOptions) -> bool {
fn can_cast_from_long(to_type: &DataType, options: &SparkCastOptions) -> bool {
use DataType::*;
match to_type {
Boolean | Int8 | Int16 | Int32 | Int64 | Float32 | Float64 => true,
Boolean | Int8 | Int16 | Int32 | Int64 | Float32 | Float64 | Binary => true,
Decimal128(_, _) => {
// incompatible: no overflow check
options.allow_incompat
Expand Down Expand Up @@ -501,6 +505,29 @@ macro_rules! cast_float_to_string {
}};
}

// eval mode is not needed since all ints can be implemented in binary format
macro_rules! cast_whole_num_to_binary {
($array:expr, $primitive_type:ty, $byte_size:expr) => {{
let input_arr = $array
.as_any()
.downcast_ref::<$primitive_type>()
.ok_or_else(|| SparkError::Internal("Expected numeric array".to_string()))?;

let len = input_arr.len();
let mut builder = BinaryBuilder::with_capacity(len, len * $byte_size);

for i in 0..input_arr.len() {
if input_arr.is_null(i) {
builder.append_null();
} else {
builder.append_value(input_arr.value(i).to_be_bytes());
}
}

Ok(Arc::new(builder.finish()) as ArrayRef)
}};
}

macro_rules! cast_int_to_int_macro {
(
$array: expr,
Expand Down Expand Up @@ -1100,6 +1127,19 @@ fn cast_array(
Ok(cast_with_options(&array, to_type, &CAST_OPTIONS)?)
}
(Binary, Utf8) => Ok(cast_binary_to_string::<i32>(&array, cast_options)?),
(Int8, Binary) if (eval_mode == Legacy) => cast_whole_num_to_binary!(&array, Int8Array, 1),
(Int16, Binary) if (eval_mode == Legacy) => {
cast_whole_num_to_binary!(&array, Int16Array, 2)
}
(Int32, Binary) if (eval_mode == Legacy) => {
cast_whole_num_to_binary!(&array, Int32Array, 4)
}
(Int64, Binary) if (eval_mode == Legacy) => {
cast_whole_num_to_binary!(&array, Int64Array, 8)
}
(Boolean, Decimal128(precision, scale)) => {
cast_boolean_to_decimal(&array, *precision, *scale)
}
_ if cast_options.is_adapting_schema
|| is_datafusion_spark_compatible(from_type, to_type) =>
{
Expand All @@ -1118,6 +1158,16 @@ fn cast_array(
Ok(spark_cast_postprocess(cast_result?, from_type, to_type))
}

fn cast_boolean_to_decimal(array: &ArrayRef, precision: u8, scale: i8) -> SparkResult<ArrayRef> {
let bool_array = array.as_boolean();
let scaled_value = 10_i128.pow(scale as u32);
let result: Decimal128Array = bool_array
.iter()
.map(|v| v.map(|b| if b { scaled_value } else { 0 }))
.collect();
Ok(Arc::new(result.with_precision_and_scale(precision, scale)?))
}

fn cast_string_to_float(
array: &ArrayRef,
to_type: &DataType,
Expand Down
108 changes: 59 additions & 49 deletions spark/src/main/scala/org/apache/comet/expressions/CometCast.scala
Original file line number Diff line number Diff line change
Expand Up @@ -147,13 +147,13 @@ object CometCast extends CometExpressionSerde[Cast] with CometExprShim {
case (DataTypes.BooleanType, _) =>
canCastFromBoolean(toType)
case (DataTypes.ByteType, _) =>
canCastFromByte(toType)
canCastFromByte(toType, evalMode)
case (DataTypes.ShortType, _) =>
canCastFromShort(toType)
canCastFromShort(toType, evalMode)
case (DataTypes.IntegerType, _) =>
canCastFromInt(toType)
canCastFromInt(toType, evalMode)
case (DataTypes.LongType, _) =>
canCastFromLong(toType)
canCastFromLong(toType, evalMode)
case (DataTypes.FloatType, _) =>
canCastFromFloat(toType)
case (DataTypes.DoubleType, _) =>
Expand Down Expand Up @@ -263,58 +263,68 @@ object CometCast extends CometExpressionSerde[Cast] with CometExprShim {

private def canCastFromBoolean(toType: DataType): SupportLevel = toType match {
case DataTypes.ByteType | DataTypes.ShortType | DataTypes.IntegerType | DataTypes.LongType |
DataTypes.FloatType | DataTypes.DoubleType =>
DataTypes.FloatType | DataTypes.DoubleType | _: DecimalType =>
Compatible()
case _ => unsupported(DataTypes.BooleanType, toType)
}

private def canCastFromByte(toType: DataType): SupportLevel = toType match {
case DataTypes.BooleanType =>
Compatible()
case DataTypes.ShortType | DataTypes.IntegerType | DataTypes.LongType =>
Compatible()
case DataTypes.FloatType | DataTypes.DoubleType | _: DecimalType =>
Compatible()
case _ =>
unsupported(DataTypes.ByteType, toType)
}
private def canCastFromByte(toType: DataType, evalMode: CometEvalMode.Value): SupportLevel =
toType match {
case DataTypes.BooleanType =>
Compatible()
case DataTypes.ShortType | DataTypes.IntegerType | DataTypes.LongType =>
Compatible()
case DataTypes.FloatType | DataTypes.DoubleType | _: DecimalType =>
Compatible()
case DataTypes.BinaryType if (evalMode == CometEvalMode.LEGACY) =>
Compatible()
case _ =>
unsupported(DataTypes.ByteType, toType)
}

private def canCastFromShort(toType: DataType): SupportLevel = toType match {
case DataTypes.BooleanType =>
Compatible()
case DataTypes.ByteType | DataTypes.IntegerType | DataTypes.LongType =>
Compatible()
case DataTypes.FloatType | DataTypes.DoubleType | _: DecimalType =>
Compatible()
case _ =>
unsupported(DataTypes.ShortType, toType)
}
private def canCastFromShort(toType: DataType, evalMode: CometEvalMode.Value): SupportLevel =
toType match {
case DataTypes.BooleanType =>
Compatible()
case DataTypes.ByteType | DataTypes.IntegerType | DataTypes.LongType =>
Compatible()
case DataTypes.FloatType | DataTypes.DoubleType | _: DecimalType =>
Compatible()
case DataTypes.BinaryType if (evalMode == CometEvalMode.LEGACY) =>
Compatible()
case _ =>
unsupported(DataTypes.ShortType, toType)
}

private def canCastFromInt(toType: DataType): SupportLevel = toType match {
case DataTypes.BooleanType =>
Compatible()
case DataTypes.ByteType | DataTypes.ShortType | DataTypes.LongType =>
Compatible()
case DataTypes.FloatType | DataTypes.DoubleType =>
Compatible()
case _: DecimalType =>
Compatible()
case _ =>
unsupported(DataTypes.IntegerType, toType)
}
private def canCastFromInt(toType: DataType, evalMode: CometEvalMode.Value): SupportLevel =
toType match {
case DataTypes.BooleanType =>
Compatible()
case DataTypes.ByteType | DataTypes.ShortType | DataTypes.LongType =>
Compatible()
case DataTypes.FloatType | DataTypes.DoubleType =>
Compatible()
case _: DecimalType =>
Compatible()
case DataTypes.BinaryType if (evalMode == CometEvalMode.LEGACY) => Compatible()
case _ =>
unsupported(DataTypes.IntegerType, toType)
}

private def canCastFromLong(toType: DataType): SupportLevel = toType match {
case DataTypes.BooleanType =>
Compatible()
case DataTypes.ByteType | DataTypes.ShortType | DataTypes.IntegerType =>
Compatible()
case DataTypes.FloatType | DataTypes.DoubleType =>
Compatible()
case _: DecimalType =>
Compatible()
case _ =>
unsupported(DataTypes.LongType, toType)
}
private def canCastFromLong(toType: DataType, evalMode: CometEvalMode.Value): SupportLevel =
toType match {
case DataTypes.BooleanType =>
Compatible()
case DataTypes.ByteType | DataTypes.ShortType | DataTypes.IntegerType =>
Compatible()
case DataTypes.FloatType | DataTypes.DoubleType =>
Compatible()
case _: DecimalType =>
Compatible()
case DataTypes.BinaryType if (evalMode == CometEvalMode.LEGACY) => Compatible()
case _ =>
unsupported(DataTypes.LongType, toType)
}

private def canCastFromFloat(toType: DataType): SupportLevel = toType match {
case DataTypes.BooleanType | DataTypes.DoubleType | DataTypes.ByteType | DataTypes.ShortType |
Expand Down
Loading