diff --git a/native/core/src/execution/jni_api.rs b/native/core/src/execution/jni_api.rs index 680cf80c75..c06488253c 100644 --- a/native/core/src/execution/jni_api.rs +++ b/native/core/src/execution/jni_api.rs @@ -68,6 +68,7 @@ use std::collections::HashMap; use std::path::PathBuf; use std::time::{Duration, Instant}; use std::{sync::Arc, task::Poll}; +use datafusion::functions::datetime::to_date::ToDateFunc; use tokio::runtime::Runtime; use crate::execution::memory_pools::{ @@ -351,6 +352,7 @@ fn register_datafusion_spark_function(session_ctx: &SessionContext) { session_ctx.register_udf(ScalarUDF::new_from_impl(SparkConcat::default())); session_ctx.register_udf(ScalarUDF::new_from_impl(SparkBitwiseNot::default())); session_ctx.register_udf(ScalarUDF::new_from_impl(SparkHex::default())); + session_ctx.register_udf(ScalarUDF::new_from_impl(ToDateFunc::default())); } /// Prepares arrow arrays for output. diff --git a/native/spark-expr/src/comet_scalar_funcs.rs b/native/spark-expr/src/comet_scalar_funcs.rs index 760dc3570f..f9268e62a3 100644 --- a/native/spark-expr/src/comet_scalar_funcs.rs +++ b/native/spark-expr/src/comet_scalar_funcs.rs @@ -36,6 +36,7 @@ use datafusion::physical_plan::ColumnarValue; use std::any::Any; use std::fmt::Debug; use std::sync::Arc; +use datafusion::functions::datetime::to_date::ToDateFunc; macro_rules! make_comet_scalar_udf { ($name:expr, $func:ident, $data_type:ident) => {{ @@ -196,6 +197,7 @@ fn all_scalar_functions() -> Vec> { Arc::new(ScalarUDF::new_from_impl(SparkDateTrunc::default())), Arc::new(ScalarUDF::new_from_impl(SparkStringSpace::default())), Arc::new(ScalarUDF::new_from_impl(SparkSizeFunc::default())), + Arc::new(ScalarUDF::new_from_impl(ToDateFunc::default())), ] } diff --git a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala index 066680456e..d2c2c6e723 100644 --- a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala +++ b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala @@ -205,7 +205,9 @@ object QueryPlanSerde extends Logging with CometExprShim { classOf[WeekDay] -> CometWeekDay, classOf[DayOfYear] -> CometDayOfYear, classOf[WeekOfYear] -> CometWeekOfYear, - classOf[Quarter] -> CometQuarter) + classOf[Quarter] -> CometQuarter, + classOf[ParseToDate] -> CometParseToDate + ) private val conversionExpressions: Map[Class[_ <: Expression], CometExpressionSerde[_]] = Map( classOf[Cast] -> CometCast) diff --git a/spark/src/main/scala/org/apache/comet/serde/datetime.scala b/spark/src/main/scala/org/apache/comet/serde/datetime.scala index a623146916..835c17e37a 100644 --- a/spark/src/main/scala/org/apache/comet/serde/datetime.scala +++ b/spark/src/main/scala/org/apache/comet/serde/datetime.scala @@ -20,11 +20,9 @@ package org.apache.comet.serde import java.util.Locale - -import org.apache.spark.sql.catalyst.expressions.{Attribute, DateAdd, DateDiff, DateFormatClass, DateSub, DayOfMonth, DayOfWeek, DayOfYear, GetDateField, Hour, LastDay, Literal, Minute, Month, Quarter, Second, TruncDate, TruncTimestamp, UnixDate, UnixTimestamp, WeekDay, WeekOfYear, Year} +import org.apache.spark.sql.catalyst.expressions.{Attribute, DateAdd, DateDiff, DateFormatClass, DateSub, DayOfMonth, DayOfWeek, DayOfYear, GetDateField, Hour, LastDay, Literal, Minute, Month, ParseToDate, Quarter, Second, TruncDate, TruncTimestamp, UnixDate, UnixTimestamp, WeekDay, WeekOfYear, Year} import org.apache.spark.sql.types.{DateType, IntegerType, StringType, TimestampType} import org.apache.spark.unsafe.types.UTF8String - import org.apache.comet.CometSparkSessionExtensions.withInfo import org.apache.comet.serde.CometGetDateField.CometGetDateField import org.apache.comet.serde.ExprOuterClass.Expr @@ -176,6 +174,29 @@ object CometQuarter extends CometExpressionSerde[Quarter] with CometExprGetDateF } } +object CometParseToDate extends CometExpressionSerde[ParseToDate] { + /** + * Convert a Spark expression into a protocol buffer representation that can be passed into + * native code. + * + * @param expr + * The Spark expression. + * @param inputs + * The input attributes. + * @param binding + * Whether the attributes are bound (this is only relevant in aggregate expressions). + * @return + * Protocol buffer representation, or None if the expression could not be converted. In this + * case it is expected that the input expression will have been tagged with reasons why it + * could not be converted. + */ + override def convert(expr: ParseToDate, inputs: Seq[Attribute], binding: Boolean): Option[Expr] = { + val childExpr: Option[Expr] = exprToProtoInternal(expr.left, inputs, binding) + val failOnErrorExpr: Option[Expr] = exprToProtoInternal(Literal(expr.ansiEnabled), inputs, binding) + ??? + } +} + object CometHour extends CometExpressionSerde[Hour] { override def convert( expr: Hour, diff --git a/spark/src/test/scala/org/apache/comet/CometTemporalExpressionSuite.scala b/spark/src/test/scala/org/apache/comet/CometTemporalExpressionSuite.scala index 1ae6926e05..cb489e8f16 100644 --- a/spark/src/test/scala/org/apache/comet/CometTemporalExpressionSuite.scala +++ b/spark/src/test/scala/org/apache/comet/CometTemporalExpressionSuite.scala @@ -173,6 +173,26 @@ class CometTemporalExpressionSuite extends CometTestBase with AdaptiveSparkPlanH } } + test("to_date - string input") { + withTempView("string_tbl") { + // Create test data with timestamp strings + val schema = StructType(Seq(StructField("dt_str", DataTypes.StringType, true))) + val data = Seq( + Row("2020-01-01"), + Row("2021-06-15"), + Row("2022-12-31"), + Row(null)) + spark + .createDataFrame(spark.sparkContext.parallelize(data), schema) + .createOrReplaceTempView("string_tbl") + + // String input with custom format should also fall back + checkSparkAnswerAndFallbackReason( + "SELECT dt_str, to_date(dt_str, 'yyyy-MM-dd') from string_tbl", + "to_date does not support input type: StringType") + } + } + private def createTimestampTestData = { val r = new Random(42) val schema = StructType(