diff --git a/common/src/main/scala/org/apache/spark/sql/comet/util/Utils.scala b/common/src/main/scala/org/apache/spark/sql/comet/util/Utils.scala index 036c5c9aaf..cb0a519442 100644 --- a/common/src/main/scala/org/apache/spark/sql/comet/util/Utils.scala +++ b/common/src/main/scala/org/apache/spark/sql/comet/util/Utils.scala @@ -269,8 +269,14 @@ object Utils extends CometTypeShim { case c => throw new SparkException( - "Comet execution only takes Arrow Arrays, but got " + - s"${c.getClass}") + s"Comet execution only takes Arrow Arrays, but got ${c.getClass}. " + + "This typically happens when a Comet scan falls back to Spark due to unsupported " + + "data types (e.g., complex types like structs, arrays, or maps with native_comet). " + + "To resolve this, you can: " + + "(1) enable spark.comet.scan.allowIncompatible=true to use a compatible native " + + "scan variant, or " + + "(2) enable spark.comet.convert.parquet.enabled=true to convert Spark Parquet " + + "data to Arrow format automatically.") } } (fieldVectors, provider) diff --git a/spark/src/main/scala/org/apache/comet/rules/CometExecRule.scala b/spark/src/main/scala/org/apache/comet/rules/CometExecRule.scala index 9c23b1be68..76e741e3bf 100644 --- a/spark/src/main/scala/org/apache/comet/rules/CometExecRule.scala +++ b/spark/src/main/scala/org/apache/comet/rules/CometExecRule.scala @@ -482,6 +482,21 @@ case class CometExecRule(session: SparkSession) extends Rule[SparkPlan] { private def convertToComet(op: SparkPlan, handler: CometOperatorSerde[_]): Option[SparkPlan] = { val serde = handler.asInstanceOf[CometOperatorSerde[SparkPlan]] if (isOperatorEnabled(serde, op)) { + // For operators that require native children (like writes), check if all data-producing + // children are CometNativeExec. This prevents runtime failures when the native operator + // expects Arrow arrays but receives non-Arrow data (e.g., OnHeapColumnVector). + if (serde.requiresNativeChildren && op.children.nonEmpty) { + // Get the actual data-producing children (unwrap WriteFilesExec if present) + val dataProducingChildren = op.children.flatMap { + case writeFiles: WriteFilesExec => Seq(writeFiles.child) + case other => Seq(other) + } + if (!dataProducingChildren.forall(_.isInstanceOf[CometNativeExec])) { + withInfo(op, "Cannot perform native operation because input is not in Arrow format") + return None + } + } + val builder = OperatorOuterClass.Operator.newBuilder().setPlanId(op.id) if (op.children.nonEmpty && op.children.forall(_.isInstanceOf[CometNativeExec])) { val childOp = op.children.map(_.asInstanceOf[CometNativeExec].nativeOp) diff --git a/spark/src/main/scala/org/apache/comet/serde/CometOperatorSerde.scala b/spark/src/main/scala/org/apache/comet/serde/CometOperatorSerde.scala index 3a2494591e..5698919398 100644 --- a/spark/src/main/scala/org/apache/comet/serde/CometOperatorSerde.scala +++ b/spark/src/main/scala/org/apache/comet/serde/CometOperatorSerde.scala @@ -36,6 +36,13 @@ trait CometOperatorSerde[T <: SparkPlan] { */ def enabledConfig: Option[ConfigEntry[Boolean]] + /** + * Indicates whether this operator requires all of its children to be CometNativeExec. If true + * and any child is not a native exec, conversion will be skipped and the operator will fall + * back to Spark. This is useful for operators like writes that require Arrow-formatted input. + */ + def requiresNativeChildren: Boolean = false + /** * Determine the support level of the operator based on its attributes. * diff --git a/spark/src/main/scala/org/apache/comet/serde/operator/CometDataWritingCommand.scala b/spark/src/main/scala/org/apache/comet/serde/operator/CometDataWritingCommand.scala index c98f8314a7..31575138f8 100644 --- a/spark/src/main/scala/org/apache/comet/serde/operator/CometDataWritingCommand.scala +++ b/spark/src/main/scala/org/apache/comet/serde/operator/CometDataWritingCommand.scala @@ -49,6 +49,10 @@ object CometDataWritingCommand extends CometOperatorSerde[DataWritingCommandExec override def enabledConfig: Option[ConfigEntry[Boolean]] = Some(CometConf.COMET_NATIVE_PARQUET_WRITE_ENABLED) + // Native writes require Arrow-formatted input data. If the scan falls back to Spark + // (e.g., due to unsupported complex types), the write must also fall back. + override def requiresNativeChildren: Boolean = true + override def getSupportLevel(op: DataWritingCommandExec): SupportLevel = { op.cmd match { case cmd: InsertIntoHadoopFsRelationCommand => diff --git a/spark/src/test/scala/org/apache/comet/parquet/CometParquetWriterSuite.scala b/spark/src/test/scala/org/apache/comet/parquet/CometParquetWriterSuite.scala index c4856c3cc2..eecc16d712 100644 --- a/spark/src/test/scala/org/apache/comet/parquet/CometParquetWriterSuite.scala +++ b/spark/src/test/scala/org/apache/comet/parquet/CometParquetWriterSuite.scala @@ -75,16 +75,15 @@ class CometParquetWriterSuite extends CometTestBase { withSQLConf(CometConf.COMET_NATIVE_SCAN_IMPL.key -> "native_datafusion") { val capturedPlan = writeWithCometNativeWriteExec(inputPath, outputPath) - capturedPlan.foreach { qe => - val executedPlan = qe.executedPlan - val hasNativeScan = executedPlan.exists { + capturedPlan.foreach { plan => + val hasNativeScan = plan.exists { case _: CometNativeScanExec => true case _ => false } assert( hasNativeScan, - s"Expected CometNativeScanExec in the plan, but got:\n${executedPlan.treeString}") + s"Expected CometNativeScanExec in the plan, but got:\n${plan.treeString}") } verifyWrittenFile(outputPath) @@ -311,6 +310,54 @@ class CometParquetWriterSuite extends CometTestBase { } } + test("native write falls back when scan produces non-Arrow data") { + // This test verifies that when a native scan (like native_comet) doesn't support + // certain data types (complex types), the native write correctly falls back to Spark + // instead of failing at runtime with "Comet execution only takes Arrow Arrays" error. + withTempPath { dir => + val inputPath = new File(dir, "input.parquet").getAbsolutePath + val outputPath = new File(dir, "output.parquet").getAbsolutePath + + // Create data with complex types and write without Comet + withSQLConf(CometConf.COMET_ENABLED.key -> "false") { + val df = Seq((1, Seq(1, 2, 3)), (2, Seq(4, 5)), (3, Seq(6, 7, 8, 9))) + .toDF("id", "values") + df.write.parquet(inputPath) + } + + // With native Parquet write enabled but using native_comet scan which doesn't + // support complex types, the scan falls back to Spark. The native write should + // detect this and also fall back to Spark instead of failing at runtime. + withSQLConf( + CometConf.COMET_NATIVE_PARQUET_WRITE_ENABLED.key -> "true", + CometConf.COMET_EXEC_ENABLED.key -> "true", + CometConf.getOperatorAllowIncompatConfigKey(classOf[DataWritingCommandExec]) -> "true", + // Use native_comet which doesn't support complex types + CometConf.COMET_NATIVE_SCAN_IMPL.key -> "native_comet") { + + val plan = + captureWritePlan(path => spark.read.parquet(inputPath).write.parquet(path), outputPath) + + // Verify NO CometNativeWriteExec in the plan (should have fallen back to Spark) + val hasNativeWrite = plan.exists { + case _: CometNativeWriteExec => true + case d: DataWritingCommandExec => + d.child.exists(_.isInstanceOf[CometNativeWriteExec]) + case _ => false + } + + assert( + !hasNativeWrite, + "Expected fallback to Spark write (no CometNativeWriteExec), but found native write " + + s"in plan:\n${plan.treeString}") + + // Verify the data was written correctly + val result = spark.read.parquet(outputPath).collect() + assert(result.length == 3, "Expected 3 rows to be written") + } + } + } + test("parquet write complex types fuzz test") { withTempPath { dir => val outputPath = new File(dir, "output.parquet").getAbsolutePath @@ -347,18 +394,21 @@ class CometParquetWriterSuite extends CometTestBase { inputPath } - private def writeWithCometNativeWriteExec( - inputPath: String, - outputPath: String, - num_partitions: Option[Int] = None): Option[QueryExecution] = { - val df = spark.read.parquet(inputPath) - - // Use a listener to capture the execution plan during write + /** + * Captures the execution plan during a write operation. + * + * @param writeOp + * The write operation to execute (takes output path as parameter) + * @param outputPath + * The path to write to + * @return + * The captured execution plan + */ + private def captureWritePlan(writeOp: String => Unit, outputPath: String): SparkPlan = { var capturedPlan: Option[QueryExecution] = None val listener = new org.apache.spark.sql.util.QueryExecutionListener { override def onSuccess(funcName: String, qe: QueryExecution, durationNs: Long): Unit = { - // Capture plans from write operations if (funcName == "save" || funcName.contains("command")) { capturedPlan = Some(qe) } @@ -373,8 +423,7 @@ class CometParquetWriterSuite extends CometTestBase { spark.listenerManager.register(listener) try { - // Perform native write with optional partitioning - num_partitions.fold(df)(n => df.repartition(n)).write.parquet(outputPath) + writeOp(outputPath) // Wait for listener to be called with timeout val maxWaitTimeMs = 15000 @@ -387,36 +436,45 @@ class CometParquetWriterSuite extends CometTestBase { iterations += 1 } - // Verify that CometNativeWriteExec was used assert( capturedPlan.isDefined, s"Listener was not called within ${maxWaitTimeMs}ms - no execution plan captured") - capturedPlan.foreach { qe => - val executedPlan = stripAQEPlan(qe.executedPlan) + stripAQEPlan(capturedPlan.get.executedPlan) + } finally { + spark.listenerManager.unregister(listener) + } + } - // Count CometNativeWriteExec instances in the plan - var nativeWriteCount = 0 - executedPlan.foreach { + private def writeWithCometNativeWriteExec( + inputPath: String, + outputPath: String, + num_partitions: Option[Int] = None): Option[SparkPlan] = { + val df = spark.read.parquet(inputPath) + + val plan = captureWritePlan( + path => num_partitions.fold(df)(n => df.repartition(n)).write.parquet(path), + outputPath) + + // Count CometNativeWriteExec instances in the plan + var nativeWriteCount = 0 + plan.foreach { + case _: CometNativeWriteExec => + nativeWriteCount += 1 + case d: DataWritingCommandExec => + d.child.foreach { case _: CometNativeWriteExec => nativeWriteCount += 1 - case d: DataWritingCommandExec => - d.child.foreach { - case _: CometNativeWriteExec => - nativeWriteCount += 1 - case _ => - } case _ => } - - assert( - nativeWriteCount == 1, - s"Expected exactly one CometNativeWriteExec in the plan, but found $nativeWriteCount:\n${executedPlan.treeString}") - } - } finally { - spark.listenerManager.unregister(listener) + case _ => } - capturedPlan + + assert( + nativeWriteCount == 1, + s"Expected exactly one CometNativeWriteExec in the plan, but found $nativeWriteCount:\n${plan.treeString}") + + Some(plan) } private def verifyWrittenFile(outputPath: String): Unit = {