Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -269,8 +269,14 @@ object Utils extends CometTypeShim {

case c =>
throw new SparkException(
"Comet execution only takes Arrow Arrays, but got " +
s"${c.getClass}")
s"Comet execution only takes Arrow Arrays, but got ${c.getClass}. " +
"This typically happens when a Comet scan falls back to Spark due to unsupported " +
"data types (e.g., complex types like structs, arrays, or maps with native_comet). " +
"To resolve this, you can: " +
"(1) enable spark.comet.scan.allowIncompatible=true to use a compatible native " +
"scan variant, or " +
"(2) enable spark.comet.convert.parquet.enabled=true to convert Spark Parquet " +
"data to Arrow format automatically.")
}
}
(fieldVectors, provider)
Expand Down
15 changes: 15 additions & 0 deletions spark/src/main/scala/org/apache/comet/rules/CometExecRule.scala
Original file line number Diff line number Diff line change
Expand Up @@ -482,6 +482,21 @@ case class CometExecRule(session: SparkSession) extends Rule[SparkPlan] {
private def convertToComet(op: SparkPlan, handler: CometOperatorSerde[_]): Option[SparkPlan] = {
val serde = handler.asInstanceOf[CometOperatorSerde[SparkPlan]]
if (isOperatorEnabled(serde, op)) {
// For operators that require native children (like writes), check if all data-producing
// children are CometNativeExec. This prevents runtime failures when the native operator
// expects Arrow arrays but receives non-Arrow data (e.g., OnHeapColumnVector).
if (serde.requiresNativeChildren && op.children.nonEmpty) {
// Get the actual data-producing children (unwrap WriteFilesExec if present)
val dataProducingChildren = op.children.flatMap {
case writeFiles: WriteFilesExec => Seq(writeFiles.child)
case other => Seq(other)
}
if (!dataProducingChildren.forall(_.isInstanceOf[CometNativeExec])) {
withInfo(op, "Cannot perform native operation because input is not in Arrow format")
return None
}
}

val builder = OperatorOuterClass.Operator.newBuilder().setPlanId(op.id)
if (op.children.nonEmpty && op.children.forall(_.isInstanceOf[CometNativeExec])) {
val childOp = op.children.map(_.asInstanceOf[CometNativeExec].nativeOp)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,13 @@ trait CometOperatorSerde[T <: SparkPlan] {
*/
def enabledConfig: Option[ConfigEntry[Boolean]]

/**
* Indicates whether this operator requires all of its children to be CometNativeExec. If true
* and any child is not a native exec, conversion will be skipped and the operator will fall
* back to Spark. This is useful for operators like writes that require Arrow-formatted input.
*/
def requiresNativeChildren: Boolean = false

/**
* Determine the support level of the operator based on its attributes.
*
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,10 @@ object CometDataWritingCommand extends CometOperatorSerde[DataWritingCommandExec
override def enabledConfig: Option[ConfigEntry[Boolean]] =
Some(CometConf.COMET_NATIVE_PARQUET_WRITE_ENABLED)

// Native writes require Arrow-formatted input data. If the scan falls back to Spark
// (e.g., due to unsupported complex types), the write must also fall back.
override def requiresNativeChildren: Boolean = true

override def getSupportLevel(op: DataWritingCommandExec): SupportLevel = {
op.cmd match {
case cmd: InsertIntoHadoopFsRelationCommand =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -75,16 +75,15 @@ class CometParquetWriterSuite extends CometTestBase {

withSQLConf(CometConf.COMET_NATIVE_SCAN_IMPL.key -> "native_datafusion") {
val capturedPlan = writeWithCometNativeWriteExec(inputPath, outputPath)
capturedPlan.foreach { qe =>
val executedPlan = qe.executedPlan
val hasNativeScan = executedPlan.exists {
capturedPlan.foreach { plan =>
val hasNativeScan = plan.exists {
case _: CometNativeScanExec => true
case _ => false
}

assert(
hasNativeScan,
s"Expected CometNativeScanExec in the plan, but got:\n${executedPlan.treeString}")
s"Expected CometNativeScanExec in the plan, but got:\n${plan.treeString}")
}

verifyWrittenFile(outputPath)
Expand Down Expand Up @@ -311,6 +310,54 @@ class CometParquetWriterSuite extends CometTestBase {
}
}

test("native write falls back when scan produces non-Arrow data") {
// This test verifies that when a native scan (like native_comet) doesn't support
// certain data types (complex types), the native write correctly falls back to Spark
// instead of failing at runtime with "Comet execution only takes Arrow Arrays" error.
withTempPath { dir =>
val inputPath = new File(dir, "input.parquet").getAbsolutePath
val outputPath = new File(dir, "output.parquet").getAbsolutePath

// Create data with complex types and write without Comet
withSQLConf(CometConf.COMET_ENABLED.key -> "false") {
val df = Seq((1, Seq(1, 2, 3)), (2, Seq(4, 5)), (3, Seq(6, 7, 8, 9)))
.toDF("id", "values")
df.write.parquet(inputPath)
}

// With native Parquet write enabled but using native_comet scan which doesn't
// support complex types, the scan falls back to Spark. The native write should
// detect this and also fall back to Spark instead of failing at runtime.
withSQLConf(
CometConf.COMET_NATIVE_PARQUET_WRITE_ENABLED.key -> "true",
CometConf.COMET_EXEC_ENABLED.key -> "true",
CometConf.getOperatorAllowIncompatConfigKey(classOf[DataWritingCommandExec]) -> "true",
// Use native_comet which doesn't support complex types
CometConf.COMET_NATIVE_SCAN_IMPL.key -> "native_comet") {

val plan =
captureWritePlan(path => spark.read.parquet(inputPath).write.parquet(path), outputPath)

// Verify NO CometNativeWriteExec in the plan (should have fallen back to Spark)
val hasNativeWrite = plan.exists {
case _: CometNativeWriteExec => true
case d: DataWritingCommandExec =>
d.child.exists(_.isInstanceOf[CometNativeWriteExec])
case _ => false
}

assert(
!hasNativeWrite,
"Expected fallback to Spark write (no CometNativeWriteExec), but found native write " +
s"in plan:\n${plan.treeString}")

// Verify the data was written correctly
val result = spark.read.parquet(outputPath).collect()
assert(result.length == 3, "Expected 3 rows to be written")
}
}
}

test("parquet write complex types fuzz test") {
withTempPath { dir =>
val outputPath = new File(dir, "output.parquet").getAbsolutePath
Expand Down Expand Up @@ -347,18 +394,21 @@ class CometParquetWriterSuite extends CometTestBase {
inputPath
}

private def writeWithCometNativeWriteExec(
inputPath: String,
outputPath: String,
num_partitions: Option[Int] = None): Option[QueryExecution] = {
val df = spark.read.parquet(inputPath)

// Use a listener to capture the execution plan during write
/**
* Captures the execution plan during a write operation.
*
* @param writeOp
* The write operation to execute (takes output path as parameter)
* @param outputPath
* The path to write to
* @return
* The captured execution plan
*/
private def captureWritePlan(writeOp: String => Unit, outputPath: String): SparkPlan = {
var capturedPlan: Option[QueryExecution] = None

val listener = new org.apache.spark.sql.util.QueryExecutionListener {
override def onSuccess(funcName: String, qe: QueryExecution, durationNs: Long): Unit = {
// Capture plans from write operations
if (funcName == "save" || funcName.contains("command")) {
capturedPlan = Some(qe)
}
Expand All @@ -373,8 +423,7 @@ class CometParquetWriterSuite extends CometTestBase {
spark.listenerManager.register(listener)

try {
// Perform native write with optional partitioning
num_partitions.fold(df)(n => df.repartition(n)).write.parquet(outputPath)
writeOp(outputPath)

// Wait for listener to be called with timeout
val maxWaitTimeMs = 15000
Expand All @@ -387,36 +436,45 @@ class CometParquetWriterSuite extends CometTestBase {
iterations += 1
}

// Verify that CometNativeWriteExec was used
assert(
capturedPlan.isDefined,
s"Listener was not called within ${maxWaitTimeMs}ms - no execution plan captured")

capturedPlan.foreach { qe =>
val executedPlan = stripAQEPlan(qe.executedPlan)
stripAQEPlan(capturedPlan.get.executedPlan)
} finally {
spark.listenerManager.unregister(listener)
}
}

// Count CometNativeWriteExec instances in the plan
var nativeWriteCount = 0
executedPlan.foreach {
private def writeWithCometNativeWriteExec(
inputPath: String,
outputPath: String,
num_partitions: Option[Int] = None): Option[SparkPlan] = {
val df = spark.read.parquet(inputPath)

val plan = captureWritePlan(
path => num_partitions.fold(df)(n => df.repartition(n)).write.parquet(path),
outputPath)

// Count CometNativeWriteExec instances in the plan
var nativeWriteCount = 0
plan.foreach {
case _: CometNativeWriteExec =>
nativeWriteCount += 1
case d: DataWritingCommandExec =>
d.child.foreach {
case _: CometNativeWriteExec =>
nativeWriteCount += 1
case d: DataWritingCommandExec =>
d.child.foreach {
case _: CometNativeWriteExec =>
nativeWriteCount += 1
case _ =>
}
case _ =>
}

assert(
nativeWriteCount == 1,
s"Expected exactly one CometNativeWriteExec in the plan, but found $nativeWriteCount:\n${executedPlan.treeString}")
}
} finally {
spark.listenerManager.unregister(listener)
case _ =>
}
capturedPlan

assert(
nativeWriteCount == 1,
s"Expected exactly one CometNativeWriteExec in the plan, but found $nativeWriteCount:\n${plan.treeString}")

Some(plan)
}

private def verifyWrittenFile(outputPath: String): Unit = {
Expand Down
Loading