From 29fe1c0f1e2a8055039ed09f42c234927987fa29 Mon Sep 17 00:00:00 2001 From: Andy Grove Date: Tue, 20 Jan 2026 19:45:44 -0700 Subject: [PATCH 1/4] feat: [EXPERIMENTAL] direct native shuffle execution optimization MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit This PR introduces an experimental optimization that allows the native shuffle writer to directly execute the child native plan instead of reading intermediate batches via JNI. This avoids the JNI round-trip for single-source native plans. Current flow: Native Plan → ColumnarBatch → JNI → ScanExec → ShuffleWriterExec Optimized flow: Native Plan → (directly in native) → ShuffleWriterExec The optimization is: - Disabled by default (spark.comet.exec.shuffle.directNative.enabled=false) - Only applies to CometNativeShuffle (not columnar JVM shuffle) - Only applies to single-source native scans (CometNativeScanExec) - Does not apply to RangePartitioning (requires sampling) Changes: - CometShuffleDependency: Added childNativePlan field to pass native plan - CometShuffleExchangeExec: Added detection logic for single-source native plans - CometShuffleManager: Pass native plan to shuffle writer - CometNativeShuffleWriter: Use child native plan directly when available - CometConf: Added COMET_SHUFFLE_DIRECT_NATIVE_ENABLED config option - CometDirectNativeShuffleSuite: Comprehensive test suite with 15 tests Co-Authored-By: Claude Opus 4.5 --- .../scala/org/apache/comet/CometConf.scala | 12 + .../shuffle/CometNativeShuffleWriter.scala | 260 ++++++++-------- .../shuffle/CometShuffleDependency.scala | 6 +- .../shuffle/CometShuffleExchangeExec.scala | 105 ++++++- .../shuffle/CometShuffleManager.scala | 3 +- .../exec/CometDirectNativeShuffleSuite.scala | 286 ++++++++++++++++++ 6 files changed, 537 insertions(+), 135 deletions(-) create mode 100644 spark/src/test/scala/org/apache/comet/exec/CometDirectNativeShuffleSuite.scala diff --git a/common/src/main/scala/org/apache/comet/CometConf.scala b/common/src/main/scala/org/apache/comet/CometConf.scala index 89dbb6468d..65061282c9 100644 --- a/common/src/main/scala/org/apache/comet/CometConf.scala +++ b/common/src/main/scala/org/apache/comet/CometConf.scala @@ -319,6 +319,18 @@ object CometConf extends ShimCometConf { .booleanConf .createWithDefault(true) + val COMET_SHUFFLE_DIRECT_NATIVE_ENABLED: ConfigEntry[Boolean] = + conf(s"$COMET_EXEC_CONFIG_PREFIX.shuffle.directNative.enabled") + .category(CATEGORY_SHUFFLE) + .doc( + "When enabled, the native shuffle writer will directly execute the child native plan " + + "instead of reading intermediate batches via JNI. This optimization avoids the " + + "JNI round-trip for single-source native plans (e.g., Scan -> Filter -> Project). " + + "This is an experimental feature and is disabled by default.") + .internal() + .booleanConf + .createWithDefault(false) + val COMET_SHUFFLE_MODE: ConfigEntry[String] = conf(s"$COMET_EXEC_CONFIG_PREFIX.shuffle.mode") .category(CATEGORY_SHUFFLE) .doc( diff --git a/spark/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/CometNativeShuffleWriter.scala b/spark/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/CometNativeShuffleWriter.scala index b5d15b41f4..5655c7e492 100644 --- a/spark/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/CometNativeShuffleWriter.scala +++ b/spark/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/CometNativeShuffleWriter.scala @@ -43,6 +43,11 @@ import org.apache.comet.serde.QueryPlanSerde.serializeDataType /** * A [[ShuffleWriter]] that will delegate shuffle write to native shuffle. + * + * @param childNativePlan + * When provided, the shuffle writer will execute this native plan directly and pipe its output + * to the ShuffleWriter, avoiding the JNI round-trip for intermediate batches. This is used for + * direct native execution optimization when the shuffle's child is a single-source native plan. */ class CometNativeShuffleWriter[K, V]( outputPartitioning: Partitioning, @@ -53,7 +58,8 @@ class CometNativeShuffleWriter[K, V]( mapId: Long, context: TaskContext, metricsReporter: ShuffleWriteMetricsReporter, - rangePartitionBounds: Option[Seq[InternalRow]] = None) + rangePartitionBounds: Option[Seq[InternalRow]] = None, + childNativePlan: Option[Operator] = None) extends ShuffleWriter[K, V] with Logging { @@ -163,150 +169,150 @@ class CometNativeShuffleWriter[K, V]( } private def getNativePlan(dataFile: String, indexFile: String): Operator = { - val scanBuilder = OperatorOuterClass.Scan.newBuilder().setSource("ShuffleWriterInput") - val opBuilder = OperatorOuterClass.Operator.newBuilder() - - val scanTypes = outputAttributes.flatten { attr => - serializeDataType(attr.dataType) - } - - if (scanTypes.length == outputAttributes.length) { + // When childNativePlan is provided, we use it directly as the input to ShuffleWriter. + // Otherwise, we create a Scan operator that reads from JNI input ("ShuffleWriterInput"). + val inputOperator: Operator = childNativePlan.getOrElse { + val scanBuilder = OperatorOuterClass.Scan.newBuilder().setSource("ShuffleWriterInput") + val scanTypes = outputAttributes.flatten { attr => + serializeDataType(attr.dataType) + } + if (scanTypes.length != outputAttributes.length) { + throw new UnsupportedOperationException( + s"$outputAttributes contains unsupported data types for CometShuffleExchangeExec.") + } scanBuilder.addAllFields(scanTypes.asJava) + OperatorOuterClass.Operator.newBuilder().setScan(scanBuilder).build() + } - val shuffleWriterBuilder = OperatorOuterClass.ShuffleWriter.newBuilder() - shuffleWriterBuilder.setOutputDataFile(dataFile) - shuffleWriterBuilder.setOutputIndexFile(indexFile) + val shuffleWriterBuilder = OperatorOuterClass.ShuffleWriter.newBuilder() + shuffleWriterBuilder.setOutputDataFile(dataFile) + shuffleWriterBuilder.setOutputIndexFile(indexFile) - if (SparkEnv.get.conf.getBoolean("spark.shuffle.compress", true)) { - val codec = CometConf.COMET_EXEC_SHUFFLE_COMPRESSION_CODEC.get() match { - case "zstd" => CompressionCodec.Zstd - case "lz4" => CompressionCodec.Lz4 - case "snappy" => CompressionCodec.Snappy - case other => throw new UnsupportedOperationException(s"invalid codec: $other") - } - shuffleWriterBuilder.setCodec(codec) - } else { - shuffleWriterBuilder.setCodec(CompressionCodec.None) + if (SparkEnv.get.conf.getBoolean("spark.shuffle.compress", true)) { + val codec = CometConf.COMET_EXEC_SHUFFLE_COMPRESSION_CODEC.get() match { + case "zstd" => CompressionCodec.Zstd + case "lz4" => CompressionCodec.Lz4 + case "snappy" => CompressionCodec.Snappy + case other => throw new UnsupportedOperationException(s"invalid codec: $other") } - shuffleWriterBuilder.setCompressionLevel( - CometConf.COMET_EXEC_SHUFFLE_COMPRESSION_ZSTD_LEVEL.get) - shuffleWriterBuilder.setWriteBufferSize( - CometConf.COMET_SHUFFLE_WRITE_BUFFER_SIZE.get().max(Int.MaxValue).toInt) + shuffleWriterBuilder.setCodec(codec) + } else { + shuffleWriterBuilder.setCodec(CompressionCodec.None) + } + shuffleWriterBuilder.setCompressionLevel( + CometConf.COMET_EXEC_SHUFFLE_COMPRESSION_ZSTD_LEVEL.get) + shuffleWriterBuilder.setWriteBufferSize( + CometConf.COMET_SHUFFLE_WRITE_BUFFER_SIZE.get().max(Int.MaxValue).toInt) - outputPartitioning match { - case p if isSinglePartitioning(p) => - val partitioning = PartitioningOuterClass.SinglePartition.newBuilder() + outputPartitioning match { + case p if isSinglePartitioning(p) => + val partitioning = PartitioningOuterClass.SinglePartition.newBuilder() - val partitioningBuilder = PartitioningOuterClass.Partitioning.newBuilder() - shuffleWriterBuilder.setPartitioning( - partitioningBuilder.setSinglePartition(partitioning).build()) - case _: HashPartitioning => - val hashPartitioning = outputPartitioning.asInstanceOf[HashPartitioning] + val partitioningBuilder = PartitioningOuterClass.Partitioning.newBuilder() + shuffleWriterBuilder.setPartitioning( + partitioningBuilder.setSinglePartition(partitioning).build()) + case _: HashPartitioning => + val hashPartitioning = outputPartitioning.asInstanceOf[HashPartitioning] - val partitioning = PartitioningOuterClass.HashPartition.newBuilder() - partitioning.setNumPartitions(outputPartitioning.numPartitions) + val partitioning = PartitioningOuterClass.HashPartition.newBuilder() + partitioning.setNumPartitions(outputPartitioning.numPartitions) - val partitionExprs = hashPartitioning.expressions - .flatMap(e => QueryPlanSerde.exprToProto(e, outputAttributes)) + val partitionExprs = hashPartitioning.expressions + .flatMap(e => QueryPlanSerde.exprToProto(e, outputAttributes)) - if (partitionExprs.length != hashPartitioning.expressions.length) { - throw new UnsupportedOperationException( - s"Partitioning $hashPartitioning is not supported.") - } + if (partitionExprs.length != hashPartitioning.expressions.length) { + throw new UnsupportedOperationException( + s"Partitioning $hashPartitioning is not supported.") + } - partitioning.addAllHashExpression(partitionExprs.asJava) - - val partitioningBuilder = PartitioningOuterClass.Partitioning.newBuilder() - shuffleWriterBuilder.setPartitioning( - partitioningBuilder.setHashPartition(partitioning).build()) - case _: RangePartitioning => - val rangePartitioning = outputPartitioning.asInstanceOf[RangePartitioning] - - val partitioning = PartitioningOuterClass.RangePartition.newBuilder() - partitioning.setNumPartitions(outputPartitioning.numPartitions) - - // Detect duplicates by tracking expressions directly, similar to DataFusion's LexOrdering - // DataFusion will deduplicate identical sort expressions in LexOrdering, - // so we need to transform boundary rows to match the deduplicated structure - val seenExprs = mutable.HashSet[Expression]() - val deduplicationMap = mutable.ArrayBuffer[(Int, Boolean)]() // (originalIndex, isKept) - - rangePartitioning.ordering.zipWithIndex.foreach { case (sortOrder, idx) => - if (seenExprs.contains(sortOrder.child)) { - deduplicationMap += (idx -> false) // Will be deduplicated by DataFusion - } else { - seenExprs += sortOrder.child - deduplicationMap += (idx -> true) // Will be kept by DataFusion - } + partitioning.addAllHashExpression(partitionExprs.asJava) + + val partitioningBuilder = PartitioningOuterClass.Partitioning.newBuilder() + shuffleWriterBuilder.setPartitioning( + partitioningBuilder.setHashPartition(partitioning).build()) + case _: RangePartitioning => + val rangePartitioning = outputPartitioning.asInstanceOf[RangePartitioning] + + val partitioning = PartitioningOuterClass.RangePartition.newBuilder() + partitioning.setNumPartitions(outputPartitioning.numPartitions) + + // Detect duplicates by tracking expressions directly, similar to DataFusion's LexOrdering + // DataFusion will deduplicate identical sort expressions in LexOrdering, + // so we need to transform boundary rows to match the deduplicated structure + val seenExprs = mutable.HashSet[Expression]() + val deduplicationMap = mutable.ArrayBuffer[(Int, Boolean)]() // (originalIndex, isKept) + + rangePartitioning.ordering.zipWithIndex.foreach { case (sortOrder, idx) => + if (seenExprs.contains(sortOrder.child)) { + deduplicationMap += (idx -> false) // Will be deduplicated by DataFusion + } else { + seenExprs += sortOrder.child + deduplicationMap += (idx -> true) // Will be kept by DataFusion } + } - { - // Serialize the ordering expressions for comparisons - val orderingExprs = rangePartitioning.ordering - .flatMap(e => QueryPlanSerde.exprToProto(e, outputAttributes)) - if (orderingExprs.length != rangePartitioning.ordering.length) { - throw new UnsupportedOperationException( - s"Partitioning $rangePartitioning is not supported.") - } - partitioning.addAllSortOrders(orderingExprs.asJava) + { + // Serialize the ordering expressions for comparisons + val orderingExprs = rangePartitioning.ordering + .flatMap(e => QueryPlanSerde.exprToProto(e, outputAttributes)) + if (orderingExprs.length != rangePartitioning.ordering.length) { + throw new UnsupportedOperationException( + s"Partitioning $rangePartitioning is not supported.") } + partitioning.addAllSortOrders(orderingExprs.asJava) + } - // Convert Spark's sequence of InternalRows that represent partitioning boundaries to - // sequences of Literals, where each outer entry represents a boundary row, and each - // internal entry is a value in that row. In other words, these are stored in row major - // order, not column major - val boundarySchema = rangePartitioning.ordering.flatMap(e => Some(e.dataType)) - - // Transform boundary rows to match DataFusion's deduplicated structure - val transformedBoundaryExprs: Seq[Seq[Literal]] = - rangePartitionBounds.get.map((row: InternalRow) => { - // For every InternalRow, map its values to Literals - val allLiterals = - row.toSeq(boundarySchema).zip(boundarySchema).map { case (value, valueType) => - Literal(value, valueType) - } - - // Keep only the literals that correspond to non-deduplicated expressions - allLiterals - .zip(deduplicationMap) - .filter(_._2._2) // Keep only where isKept = true - .map(_._1) // Extract the literal + // Convert Spark's sequence of InternalRows that represent partitioning boundaries to + // sequences of Literals, where each outer entry represents a boundary row, and each + // internal entry is a value in that row. In other words, these are stored in row major + // order, not column major + val boundarySchema = rangePartitioning.ordering.flatMap(e => Some(e.dataType)) + + // Transform boundary rows to match DataFusion's deduplicated structure + val transformedBoundaryExprs: Seq[Seq[Literal]] = + rangePartitionBounds.get.map((row: InternalRow) => { + // For every InternalRow, map its values to Literals + val allLiterals = + row.toSeq(boundarySchema).zip(boundarySchema).map { case (value, valueType) => + Literal(value, valueType) + } + + // Keep only the literals that correspond to non-deduplicated expressions + allLiterals + .zip(deduplicationMap) + .filter(_._2._2) // Keep only where isKept = true + .map(_._1) // Extract the literal + }) + + { + // Convert the sequences of Literals to a collection of serialized BoundaryRows + val boundaryRows: Seq[PartitioningOuterClass.BoundaryRow] = transformedBoundaryExprs + .map((rowLiterals: Seq[Literal]) => { + // Serialize each sequence of Literals as a BoundaryRow + val rowBuilder = PartitioningOuterClass.BoundaryRow.newBuilder(); + val serializedExprs = + rowLiterals.map(lit_value => + QueryPlanSerde.exprToProto(lit_value, outputAttributes).get) + rowBuilder.addAllPartitionBounds(serializedExprs.asJava) + rowBuilder.build() }) + partitioning.addAllBoundaryRows(boundaryRows.asJava) + } - { - // Convert the sequences of Literals to a collection of serialized BoundaryRows - val boundaryRows: Seq[PartitioningOuterClass.BoundaryRow] = transformedBoundaryExprs - .map((rowLiterals: Seq[Literal]) => { - // Serialize each sequence of Literals as a BoundaryRow - val rowBuilder = PartitioningOuterClass.BoundaryRow.newBuilder(); - val serializedExprs = - rowLiterals.map(lit_value => - QueryPlanSerde.exprToProto(lit_value, outputAttributes).get) - rowBuilder.addAllPartitionBounds(serializedExprs.asJava) - rowBuilder.build() - }) - partitioning.addAllBoundaryRows(boundaryRows.asJava) - } - - val partitioningBuilder = PartitioningOuterClass.Partitioning.newBuilder() - shuffleWriterBuilder.setPartitioning( - partitioningBuilder.setRangePartition(partitioning).build()) - - case _ => - throw new UnsupportedOperationException( - s"Partitioning $outputPartitioning is not supported.") - } + val partitioningBuilder = PartitioningOuterClass.Partitioning.newBuilder() + shuffleWriterBuilder.setPartitioning( + partitioningBuilder.setRangePartition(partitioning).build()) - val shuffleWriterOpBuilder = OperatorOuterClass.Operator.newBuilder() - shuffleWriterOpBuilder - .setShuffleWriter(shuffleWriterBuilder) - .addChildren(opBuilder.setScan(scanBuilder).build()) - .build() - } else { - // There are unsupported scan type - throw new UnsupportedOperationException( - s"$outputAttributes contains unsupported data types for CometShuffleExchangeExec.") + case _ => + throw new UnsupportedOperationException( + s"Partitioning $outputPartitioning is not supported.") } + + val shuffleWriterOpBuilder = OperatorOuterClass.Operator.newBuilder() + shuffleWriterOpBuilder + .setShuffleWriter(shuffleWriterBuilder) + .addChildren(inputOperator) + .build() } override def stop(success: Boolean): Option[MapStatus] = { diff --git a/spark/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/CometShuffleDependency.scala b/spark/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/CometShuffleDependency.scala index 2b74e5a168..3528b6d2c9 100644 --- a/spark/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/CometShuffleDependency.scala +++ b/spark/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/CometShuffleDependency.scala @@ -31,6 +31,8 @@ import org.apache.spark.sql.catalyst.plans.physical.Partitioning import org.apache.spark.sql.execution.metric.SQLMetric import org.apache.spark.sql.types.StructType +import org.apache.comet.serde.OperatorOuterClass.Operator + /** * A [[ShuffleDependency]] that allows us to identify the shuffle dependency as a Comet shuffle. */ @@ -49,7 +51,9 @@ class CometShuffleDependency[K: ClassTag, V: ClassTag, C: ClassTag]( val outputAttributes: Seq[Attribute] = Seq.empty, val shuffleWriteMetrics: Map[String, SQLMetric] = Map.empty, val numParts: Int = 0, - val rangePartitionBounds: Option[Seq[InternalRow]] = None) + val rangePartitionBounds: Option[Seq[InternalRow]] = None, + // For direct native execution: the child's native plan to compose with ShuffleWriter + val childNativePlan: Option[Operator] = None) extends ShuffleDependency[K, V, C]( _rdd, partitioner, diff --git a/spark/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/CometShuffleExchangeExec.scala b/spark/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/CometShuffleExchangeExec.scala index 1805711d01..0b829aa6ac 100644 --- a/spark/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/CometShuffleExchangeExec.scala +++ b/spark/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/CometShuffleExchangeExec.scala @@ -34,7 +34,7 @@ import org.apache.spark.sql.catalyst.expressions.{Attribute, BoundReference, Uns import org.apache.spark.sql.catalyst.expressions.codegen.LazilyGeneratedOrdering import org.apache.spark.sql.catalyst.plans.logical.Statistics import org.apache.spark.sql.catalyst.plans.physical._ -import org.apache.spark.sql.comet.{CometMetricNode, CometNativeExec, CometPlan, CometSinkPlaceHolder} +import org.apache.spark.sql.comet.{CometBatchScanExec, CometMetricNode, CometNativeExec, CometNativeScanExec, CometPlan, CometScanExec, CometSinkPlaceHolder} import org.apache.spark.sql.execution._ import org.apache.spark.sql.execution.adaptive.ShuffleQueryStageExec import org.apache.spark.sql.execution.exchange.{ENSURE_REQUIREMENTS, ShuffleExchangeExec, ShuffleExchangeLike, ShuffleOrigin} @@ -52,6 +52,7 @@ import org.apache.comet.CometConf import org.apache.comet.CometConf.{COMET_EXEC_SHUFFLE_ENABLED, COMET_SHUFFLE_MODE} import org.apache.comet.CometSparkSessionExtensions.{isCometShuffleManagerEnabled, withInfo} import org.apache.comet.serde.{Compatible, OperatorOuterClass, QueryPlanSerde, SupportLevel, Unsupported} +import org.apache.comet.serde.OperatorOuterClass.Operator import org.apache.comet.serde.operator.CometSink import org.apache.comet.shims.ShimCometShuffleExchangeExec @@ -89,9 +90,85 @@ case class CometShuffleExchangeExec( private lazy val serializer: Serializer = new UnsafeRowSerializer(child.output.size, longMetric("dataSize")) + /** + * Information about direct native execution optimization. When the child is a single-source + * native plan with a fully native scan (CometNativeScanExec), we can pass the child's native + * plan to the shuffle writer and execute: Scan -> Filter -> Project -> ShuffleWriter all in + * native code, avoiding the JNI round-trip for intermediate batches. + * + * Currently only supports CometNativeScanExec (fully native scans that read files directly via + * DataFusion). JVM scan wrappers (CometScanExec, CometBatchScanExec) still require JNI input + * and are not optimized. + */ + @transient private lazy val directNativeExecutionInfo: Option[DirectNativeExecutionInfo] = { + if (!CometConf.COMET_SHUFFLE_DIRECT_NATIVE_ENABLED.get()) { + None + } else if (shuffleType != CometNativeShuffle) { + None + } else { + // Check if direct native execution is possible + outputPartitioning match { + case _: RangePartitioning => + // RangePartitioning requires sampling the data to compute bounds, + // which requires executing the child plan. Fall back to current behavior. + None + case _ => + child match { + case nativeChild: CometNativeExec => + // Find input sources using foreachUntilCometInput + val inputSources = scala.collection.mutable.ArrayBuffer.empty[SparkPlan] + nativeChild.foreachUntilCometInput(nativeChild)(inputSources += _) + + // Only optimize single-source native scan case for now + // JVM scan wrappers (CometScanExec, CometBatchScanExec) still need JNI input, + // so we don't optimize those yet + if (inputSources.size == 1) { + inputSources.head match { + case scan: CometNativeScanExec => + // Fully native scan - no JNI input needed, native code reads files directly + // Get the partition count from the underlying scan + val numPartitions = scan.originalPlan.inputRDD.getNumPartitions + Some(DirectNativeExecutionInfo(nativeChild.nativeOp, numPartitions)) + case _ => + // Other input sources (JVM scans, shuffle, broadcast, etc.) - fall back + None + } + } else { + // Multiple input sources (joins, unions) - fall back for now + None + } + case _ => + None + } + } + } + } + + /** + * Returns true if direct native execution optimization is being used for this shuffle. This is + * primarily intended for testing to verify the optimization is applied correctly. + */ + def isDirectNativeExecution: Boolean = directNativeExecutionInfo.isDefined + + /** + * Creates an RDD that provides empty iterators for each partition. Used when direct native + * execution is enabled - the shuffle writer will execute the full native plan which reads data + * directly (no JNI input needed). + */ + private def createEmptyPartitionRDD(numPartitions: Int): RDD[ColumnarBatch] = { + sparkContext.parallelize(Seq.empty[ColumnarBatch], numPartitions) + } + @transient lazy val inputRDD: RDD[_] = if (shuffleType == CometNativeShuffle) { - // CometNativeShuffle assumes that the input plan is Comet plan. - child.executeColumnar() + directNativeExecutionInfo match { + case Some(info) => + // Direct native execution: create an RDD with empty partitions. + // The shuffle writer will execute the full native plan which reads data directly. + createEmptyPartitionRDD(info.numPartitions) + case None => + // Fall back to current behavior: execute child and pass intermediate batches + child.executeColumnar() + } } else if (shuffleType == CometColumnarShuffle) { // CometColumnarShuffle uses Spark's row-based execute() API. For Spark row-based plans, // rows flow directly. For Comet native plans, their doExecute() wraps with ColumnarToRowExec @@ -142,7 +219,8 @@ case class CometShuffleExchangeExec( child.output, outputPartitioning, serializer, - metrics) + metrics, + directNativeExecutionInfo.map(_.childNativePlan)) metrics("numPartitions").set(dep.partitioner.numPartitions) val executionId = sparkContext.getLocalProperty(SQLExecution.EXECUTION_ID_KEY) SQLMetrics.postDriverMetricUpdates( @@ -538,7 +616,9 @@ object CometShuffleExchangeExec outputAttributes: Seq[Attribute], outputPartitioning: Partitioning, serializer: Serializer, - metrics: Map[String, SQLMetric]): ShuffleDependency[Int, ColumnarBatch, ColumnarBatch] = { + metrics: Map[String, SQLMetric], + childNativePlan: Option[Operator] = None) + : ShuffleDependency[Int, ColumnarBatch, ColumnarBatch] = { val numParts = rdd.getNumPartitions // The code block below is mostly brought over from @@ -605,7 +685,8 @@ object CometShuffleExchangeExec outputAttributes = outputAttributes, shuffleWriteMetrics = metrics, numParts = numParts, - rangePartitionBounds = rangePartitionBounds) + rangePartitionBounds = rangePartitionBounds, + childNativePlan = childNativePlan) dependency } @@ -810,3 +891,15 @@ object CometShuffleExchangeExec dependency } } + +/** + * Information needed for direct native execution optimization. + * + * @param childNativePlan + * The child's native operator plan to compose with ShuffleWriter + * @param numPartitions + * The number of partitions (from the underlying scan) + */ +private[shuffle] case class DirectNativeExecutionInfo( + childNativePlan: Operator, + numPartitions: Int) diff --git a/spark/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/CometShuffleManager.scala b/spark/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/CometShuffleManager.scala index aa47dfa166..367ec4a90e 100644 --- a/spark/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/CometShuffleManager.scala +++ b/spark/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/CometShuffleManager.scala @@ -238,7 +238,8 @@ class CometShuffleManager(conf: SparkConf) extends ShuffleManager with Logging { mapId, context, metrics, - dep.rangePartitionBounds) + dep.rangePartitionBounds, + dep.childNativePlan) case bypassMergeSortHandle: CometBypassMergeSortShuffleHandle[K @unchecked, V @unchecked] => new CometBypassMergeSortShuffleWriter( env.blockManager, diff --git a/spark/src/test/scala/org/apache/comet/exec/CometDirectNativeShuffleSuite.scala b/spark/src/test/scala/org/apache/comet/exec/CometDirectNativeShuffleSuite.scala new file mode 100644 index 0000000000..6ab30b94a3 --- /dev/null +++ b/spark/src/test/scala/org/apache/comet/exec/CometDirectNativeShuffleSuite.scala @@ -0,0 +1,286 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.comet.exec + +import org.scalactic.source.Position +import org.scalatest.Tag + +import org.apache.spark.sql.{CometTestBase, DataFrame} +import org.apache.spark.sql.comet.execution.shuffle.CometShuffleExchangeExec +import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper +import org.apache.spark.sql.functions.col + +import org.apache.comet.CometConf + +/** + * Test suite for the direct native shuffle execution optimization. + * + * This optimization allows the native shuffle writer to directly execute the child native plan + * instead of reading intermediate batches via JNI. This avoids the JNI round-trip for + * single-source native plans (e.g., Scan -> Filter -> Project -> Shuffle). + */ +class CometDirectNativeShuffleSuite extends CometTestBase with AdaptiveSparkPlanHelper { + + override protected def test(testName: String, testTags: Tag*)(testFun: => Any)(implicit + pos: Position): Unit = { + super.test(testName, testTags: _*) { + withSQLConf( + CometConf.COMET_EXEC_ENABLED.key -> "true", + CometConf.COMET_SHUFFLE_MODE.key -> "native", + CometConf.COMET_EXEC_SHUFFLE_ENABLED.key -> "true", + CometConf.COMET_NATIVE_SCAN_IMPL.key -> "native_datafusion", + CometConf.COMET_SHUFFLE_DIRECT_NATIVE_ENABLED.key -> "true") { + testFun + } + } + } + + import testImplicits._ + + test("direct native execution: simple scan with hash partitioning") { + withParquetTable((0 until 100).map(i => (i, (i + 1).toLong)), "tbl") { + val df = sql("SELECT * FROM tbl").repartition(10, $"_1") + + // Verify the optimization is applied + val shuffles = findShuffleExchanges(df) + assert(shuffles.length == 1, "Expected exactly one shuffle") + assert( + shuffles.head.isDirectNativeExecution, + "Direct native execution should be enabled for single-source native scan") + + // Verify correctness + checkSparkAnswer(df) + } + } + + test("direct native execution: scan with filter and project") { + withParquetTable((0 until 100).map(i => (i, (i + 1).toLong, i.toString)), "tbl") { + val df = sql("SELECT _1, _2 * 2 as doubled FROM tbl WHERE _1 > 10") + .repartition(10, $"_1") + + val shuffles = findShuffleExchanges(df) + assert(shuffles.length == 1) + assert( + shuffles.head.isDirectNativeExecution, + "Direct native execution should work with filter and project") + + checkSparkAnswer(df) + } + } + + test("direct native execution: single partition") { + withParquetTable((0 until 50).map(i => (i, (i + 1).toLong)), "tbl") { + val df = sql("SELECT * FROM tbl").repartition(1) + + val shuffles = findShuffleExchanges(df) + assert(shuffles.length == 1) + assert( + shuffles.head.isDirectNativeExecution, + "Direct native execution should work with single partition") + + checkSparkAnswer(df) + } + } + + test("direct native execution: multiple hash columns") { + withParquetTable((0 until 100).map(i => (i, (i + 1).toLong, i.toString)), "tbl") { + val df = sql("SELECT * FROM tbl").repartition(10, $"_1", $"_2") + + val shuffles = findShuffleExchanges(df) + assert(shuffles.length == 1) + assert( + shuffles.head.isDirectNativeExecution, + "Direct native execution should work with multiple hash columns") + + checkSparkAnswer(df) + } + } + + test("direct native execution: aggregation before shuffle") { + withParquetTable((0 until 100).map(i => (i % 10, (i + 1).toLong)), "tbl") { + val df = sql("SELECT _1, SUM(_2) as total FROM tbl GROUP BY _1") + .repartition(5, col("_1")) + + // This involves partial aggregation -> shuffle -> final aggregation + // The direct native execution applies to the shuffle that reads from the partial agg + checkSparkAnswer(df) + } + } + + test("direct native execution disabled: config is false") { + withSQLConf(CometConf.COMET_SHUFFLE_DIRECT_NATIVE_ENABLED.key -> "false") { + withParquetTable((0 until 50).map(i => (i, (i + 1).toLong)), "tbl") { + val df = sql("SELECT * FROM tbl").repartition(10, $"_1") + + val shuffles = findShuffleExchanges(df) + assert(shuffles.length == 1) + assert( + !shuffles.head.isDirectNativeExecution, + "Direct native execution should be disabled when config is false") + + checkSparkAnswer(df) + } + } + } + + test("direct native execution disabled: range partitioning") { + withParquetTable((0 until 100).map(i => (i, (i + 1).toLong)), "tbl") { + val df = sql("SELECT * FROM tbl").repartitionByRange(10, $"_1") + + val shuffles = findShuffleExchanges(df) + assert(shuffles.length == 1) + assert( + !shuffles.head.isDirectNativeExecution, + "Direct native execution should not be used for range partitioning") + + checkSparkAnswer(df) + } + } + + test("direct native execution disabled: JVM columnar shuffle mode") { + withSQLConf(CometConf.COMET_SHUFFLE_MODE.key -> "jvm") { + withParquetTable((0 until 50).map(i => (i, (i + 1).toLong)), "tbl") { + val df = sql("SELECT * FROM tbl").repartition(10, $"_1") + + // JVM shuffle mode uses CometColumnarShuffle, not CometNativeShuffle + val shuffles = findShuffleExchanges(df) + shuffles.foreach { shuffle => + assert( + !shuffle.isDirectNativeExecution, + "Direct native execution should not be used with JVM shuffle mode") + } + + checkSparkAnswer(df) + } + } + } + + test("direct native execution: multiple shuffles in same query") { + withParquetTable((0 until 100).map(i => (i, (i + 1).toLong)), "tbl") { + val df = sql("SELECT * FROM tbl") + .repartition(10, $"_1") + .select($"_1", $"_2" + 1 as "_2_plus") + .repartition(5, $"_2_plus") + + // First shuffle reads from scan, second reads from previous shuffle output + // Only the first shuffle should use direct native execution + val shuffles = findShuffleExchanges(df) + // AQE might combine some shuffles, so just verify results are correct + checkSparkAnswer(df) + } + } + + test("direct native execution: various data types") { + withParquetTable( + (0 until 50).map(i => + (i, i.toLong, i.toFloat, i.toDouble, i.toString, i % 2 == 0, BigDecimal(i))), + "tbl") { + val df = sql("SELECT * FROM tbl").repartition(10, $"_1") + + val shuffles = findShuffleExchanges(df) + assert(shuffles.length == 1) + assert(shuffles.head.isDirectNativeExecution) + + checkSparkAnswer(df) + } + } + + test("direct native execution: complex filter and multiple projections") { + withParquetTable((0 until 100).map(i => (i, (i + 1).toLong, i % 5)), "tbl") { + val df = sql(""" + |SELECT _1 * 2 as doubled, + | _2 + _3 as sum_col, + | _1 + _2 as combined + |FROM tbl + |WHERE _1 > 20 AND _3 < 3 + |""".stripMargin) + .repartition(10, col("doubled")) + + val shuffles = findShuffleExchanges(df) + // Note: Native shuffle might fall back depending on expression support + // Just verify correctness - the optimization is best-effort + checkSparkAnswer(df) + } + } + + test("direct native execution: results match non-optimized path") { + withParquetTable((0 until 100).map(i => (i, (i + 1).toLong, i.toString)), "tbl") { + // Run with optimization enabled + val dfOptimized = sql("SELECT _1, _2 FROM tbl WHERE _1 > 50").repartition(10, $"_1") + val optimizedResult = dfOptimized.collect().sortBy(_.getInt(0)) + + // Run with optimization disabled and collect results + var nonOptimizedResult: Array[org.apache.spark.sql.Row] = Array.empty + withSQLConf(CometConf.COMET_SHUFFLE_DIRECT_NATIVE_ENABLED.key -> "false") { + val dfNonOptimized = sql("SELECT _1, _2 FROM tbl WHERE _1 > 50").repartition(10, $"_1") + nonOptimizedResult = dfNonOptimized.collect().sortBy(_.getInt(0)) + } + + // Results should match + assert(optimizedResult.length == nonOptimizedResult.length, "Row counts should match") + optimizedResult.zip(nonOptimizedResult).foreach { case (opt, nonOpt) => + assert(opt == nonOpt, s"Rows should match: $opt vs $nonOpt") + } + } + } + + test("direct native execution: large number of partitions") { + withParquetTable((0 until 1000).map(i => (i, (i + 1).toLong)), "tbl") { + val df = sql("SELECT * FROM tbl").repartition(201, $"_1") + + val shuffles = findShuffleExchanges(df) + assert(shuffles.length == 1) + assert(shuffles.head.isDirectNativeExecution) + + checkSparkAnswer(df) + } + } + + test("direct native execution: empty table") { + withParquetTable(Seq.empty[(Int, Long)], "tbl") { + val df = sql("SELECT * FROM tbl").repartition(10, $"_1") + + // Should handle empty tables gracefully + val result = df.collect() + assert(result.isEmpty) + } + } + + test("direct native execution: all rows filtered out") { + withParquetTable((0 until 100).map(i => (i, (i + 1).toLong)), "tbl") { + val df = sql("SELECT * FROM tbl WHERE _1 > 1000").repartition(10, $"_1") + + val shuffles = findShuffleExchanges(df) + assert(shuffles.length == 1) + assert(shuffles.head.isDirectNativeExecution) + + val result = df.collect() + assert(result.isEmpty, "Result should be empty when all rows are filtered") + } + } + + /** + * Helper method to find CometShuffleExchangeExec nodes in a DataFrame's execution plan. + */ + private def findShuffleExchanges(df: DataFrame): Seq[CometShuffleExchangeExec] = { + val plan = stripAQEPlan(df.queryExecution.executedPlan) + plan.collect { case s: CometShuffleExchangeExec => s } + } +} From 15c88da4289e9f21ad2f740bee07f1de712df0c2 Mon Sep 17 00:00:00 2001 From: Andy Grove Date: Tue, 20 Jan 2026 21:14:58 -0700 Subject: [PATCH 2/4] format --- .../sql/comet/execution/shuffle/CometShuffleExchangeExec.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/spark/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/CometShuffleExchangeExec.scala b/spark/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/CometShuffleExchangeExec.scala index 0b829aa6ac..6f2bf7d91a 100644 --- a/spark/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/CometShuffleExchangeExec.scala +++ b/spark/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/CometShuffleExchangeExec.scala @@ -34,7 +34,7 @@ import org.apache.spark.sql.catalyst.expressions.{Attribute, BoundReference, Uns import org.apache.spark.sql.catalyst.expressions.codegen.LazilyGeneratedOrdering import org.apache.spark.sql.catalyst.plans.logical.Statistics import org.apache.spark.sql.catalyst.plans.physical._ -import org.apache.spark.sql.comet.{CometBatchScanExec, CometMetricNode, CometNativeExec, CometNativeScanExec, CometPlan, CometScanExec, CometSinkPlaceHolder} +import org.apache.spark.sql.comet.{CometMetricNode, CometNativeExec, CometNativeScanExec, CometPlan, CometSinkPlaceHolder} import org.apache.spark.sql.execution._ import org.apache.spark.sql.execution.adaptive.ShuffleQueryStageExec import org.apache.spark.sql.execution.exchange.{ENSURE_REQUIREMENTS, ShuffleExchangeExec, ShuffleExchangeLike, ShuffleOrigin} From ce361f843dc534754c292937ac9382858b123609 Mon Sep 17 00:00:00 2001 From: Andy Grove Date: Wed, 21 Jan 2026 08:26:09 -0700 Subject: [PATCH 3/4] fi --- .../org/apache/comet/exec/CometDirectNativeShuffleSuite.scala | 2 -- 1 file changed, 2 deletions(-) diff --git a/spark/src/test/scala/org/apache/comet/exec/CometDirectNativeShuffleSuite.scala b/spark/src/test/scala/org/apache/comet/exec/CometDirectNativeShuffleSuite.scala index 6ab30b94a3..05c608310c 100644 --- a/spark/src/test/scala/org/apache/comet/exec/CometDirectNativeShuffleSuite.scala +++ b/spark/src/test/scala/org/apache/comet/exec/CometDirectNativeShuffleSuite.scala @@ -181,7 +181,6 @@ class CometDirectNativeShuffleSuite extends CometTestBase with AdaptiveSparkPlan // First shuffle reads from scan, second reads from previous shuffle output // Only the first shuffle should use direct native execution - val shuffles = findShuffleExchanges(df) // AQE might combine some shuffles, so just verify results are correct checkSparkAnswer(df) } @@ -213,7 +212,6 @@ class CometDirectNativeShuffleSuite extends CometTestBase with AdaptiveSparkPlan |""".stripMargin) .repartition(10, col("doubled")) - val shuffles = findShuffleExchanges(df) // Note: Native shuffle might fall back depending on expression support // Just verify correctness - the optimization is best-effort checkSparkAnswer(df) From 6132098982a107b916472b3a9d51e89b437238ee Mon Sep 17 00:00:00 2001 From: Andy Grove Date: Wed, 21 Jan 2026 08:48:36 -0700 Subject: [PATCH 4/4] fix: disable direct native shuffle when plan contains subqueries Subqueries (e.g., bloom filters with might_contain) are registered with the parent execution context ID. Direct native shuffle creates a new execution context with a different ID, causing subquery lookup to fail with "Subquery X not found for plan Y" errors. This change detects ScalarSubquery expressions in the child plan and falls back to the standard execution path when present. Co-Authored-By: Claude Opus 4.5 --- .../execution/shuffle/CometShuffleExchangeExec.scala | 12 +++++++++++- 1 file changed, 11 insertions(+), 1 deletion(-) diff --git a/spark/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/CometShuffleExchangeExec.scala b/spark/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/CometShuffleExchangeExec.scala index 6f2bf7d91a..96b8bcbacc 100644 --- a/spark/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/CometShuffleExchangeExec.scala +++ b/spark/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/CometShuffleExchangeExec.scala @@ -36,6 +36,7 @@ import org.apache.spark.sql.catalyst.plans.logical.Statistics import org.apache.spark.sql.catalyst.plans.physical._ import org.apache.spark.sql.comet.{CometMetricNode, CometNativeExec, CometNativeScanExec, CometPlan, CometSinkPlaceHolder} import org.apache.spark.sql.execution._ +import org.apache.spark.sql.execution.ScalarSubquery import org.apache.spark.sql.execution.adaptive.ShuffleQueryStageExec import org.apache.spark.sql.execution.exchange.{ENSURE_REQUIREMENTS, ShuffleExchangeExec, ShuffleExchangeLike, ShuffleOrigin} import org.apache.spark.sql.execution.metric.{SQLMetric, SQLMetrics, SQLShuffleReadMetricsReporter, SQLShuffleWriteMetricsReporter} @@ -122,7 +123,16 @@ case class CometShuffleExchangeExec( // Only optimize single-source native scan case for now // JVM scan wrappers (CometScanExec, CometBatchScanExec) still need JNI input, // so we don't optimize those yet - if (inputSources.size == 1) { + // Check if the plan contains subqueries (e.g., bloom filters with might_contain). + // Subqueries are registered with the parent execution context ID, but direct + // native shuffle creates a new execution context, so subquery lookup would fail. + val containsSubquery = nativeChild.exists { p => + p.expressions.exists(_.exists(_.isInstanceOf[ScalarSubquery])) + } + if (containsSubquery) { + // Fall back to avoid subquery lookup failures + None + } else if (inputSources.size == 1) { inputSources.head match { case scan: CometNativeScanExec => // Fully native scan - no JNI input needed, native code reads files directly