From 77ee7327e53fbbf06857783f7b5536df12e689c9 Mon Sep 17 00:00:00 2001 From: Kent Yao Date: Fri, 20 Mar 2026 01:25:52 +0800 Subject: [PATCH] [CORE] Fix multi-key DPP support in ColumnarSubqueryBroadcastExec Fix the BuildSideRelation path to handle multiple filtering keys instead of only using the first key (indices(0)). This resolves the TODO/FIXME that caused multi-key DPP to silently drop extra keys. For single-key DPP, behavior is unchanged. For multi-key DPP (SPARK-46946), all keys are now projected via CreateStruct, matching the HashedRelation path's multi-key support. This fixes potential DPP loss in queries like TPC-DS q23a/q23b that have multi-column partition join keys. --- .../ColumnarSubqueryBroadcastExec.scala | 24 +++++++++---- .../GlutenDynamicPartitionPruningSuite.scala | 32 +++++++++++++++++ .../GlutenDynamicPartitionPruningSuite.scala | 35 +++++++++++++++++++ 3 files changed, 84 insertions(+), 7 deletions(-) diff --git a/gluten-substrait/src/main/scala/org/apache/spark/sql/execution/ColumnarSubqueryBroadcastExec.scala b/gluten-substrait/src/main/scala/org/apache/spark/sql/execution/ColumnarSubqueryBroadcastExec.scala index 611433c51b5b..6257d842328d 100644 --- a/gluten-substrait/src/main/scala/org/apache/spark/sql/execution/ColumnarSubqueryBroadcastExec.scala +++ b/gluten-substrait/src/main/scala/org/apache/spark/sql/execution/ColumnarSubqueryBroadcastExec.scala @@ -89,14 +89,24 @@ case class ColumnarSubqueryBroadcastExec( val relation = child.executeBroadcast[Any]().value relation match { case b: BuildSideRelation => - val index = indices(0) // TODO(): fixme - // Transform columnar broadcast value to Array[InternalRow] by key. - if (canRewriteAsLongType(buildKeys)) { - b.transform(HashJoin.extractKeyExprAt(buildKeys, index)).distinct + // Build key expressions for all indices (multi-key DPP support). + val keyExprs = if (canRewriteAsLongType(buildKeys)) { + indices.map(idx => HashJoin.extractKeyExprAt(buildKeys, idx)) } else { - b.transform( - BoundReference(index, buildKeys(index).dataType, buildKeys(index).nullable)) - .distinct + indices.map { + idx => + BoundReference( + idx, + buildKeys(idx).dataType, + buildKeys(idx).nullable): Expression + } + } + if (keyExprs.size == 1) { + b.transform(keyExprs.head).distinct + } else { + // For multi-key DPP, pack all keys into a struct so that + // transform() projects all of them in a single pass. + b.transform(CreateStruct(keyExprs)).distinct } case h: HashedRelation => val (iter, exprs) = if (h.isInstanceOf[LongHashedRelation]) { diff --git a/gluten-ut/spark35/src/test/scala/org/apache/spark/sql/GlutenDynamicPartitionPruningSuite.scala b/gluten-ut/spark35/src/test/scala/org/apache/spark/sql/GlutenDynamicPartitionPruningSuite.scala index 4c1dcd2a0960..4c7df48cc847 100644 --- a/gluten-ut/spark35/src/test/scala/org/apache/spark/sql/GlutenDynamicPartitionPruningSuite.scala +++ b/gluten-ut/spark35/src/test/scala/org/apache/spark/sql/GlutenDynamicPartitionPruningSuite.scala @@ -716,6 +716,38 @@ class GlutenDynamicPartitionPruningV1SuiteAEOn } } } + + testGluten("multi-key DPP with columnar broadcast") { + withSQLConf( + SQLConf.DYNAMIC_PARTITION_PRUNING_REUSE_BROADCAST_ONLY.key -> "true", + SQLConf.ANSI_ENABLED.key -> "false" + ) { + withTable("fact_mk", "dim_mk") { + sql("""CREATE TABLE fact_mk (id BIGINT, value INT, a STRING, b STRING) + |USING parquet PARTITIONED BY (a, b)""".stripMargin) + sql( + "INSERT INTO fact_mk VALUES " + + (0 until 10).map(i => s"($i, 1, '1', '1')").mkString(", ")) + sql( + "INSERT INTO fact_mk VALUES " + + (0 until 10).map(i => s"($i, 2, '2', '2')").mkString(", ")) + sql( + "INSERT INTO fact_mk VALUES " + + (0 until 10).map(i => s"($i, 3, '3', '3')").mkString(", ")) + + sql("CREATE TABLE dim_mk (x STRING, y STRING, z INT) USING parquet") + sql("INSERT INTO dim_mk VALUES ('1', '1', 10), ('2', '2', 20)") + + val df = sql("""SELECT f.id, f.a, f.b FROM fact_mk f + |JOIN dim_mk d ON f.a = d.x AND f.b = d.y + |WHERE d.z < 15""".stripMargin) + + val result = df.collect() + assert(result.length == 10) + checkAnswer(df, result) + } + } + } } abstract class GlutenDynamicPartitionPruningV2Suite extends GlutenDynamicPartitionPruningSuiteBase { diff --git a/gluten-ut/spark40/src/test/scala/org/apache/spark/sql/GlutenDynamicPartitionPruningSuite.scala b/gluten-ut/spark40/src/test/scala/org/apache/spark/sql/GlutenDynamicPartitionPruningSuite.scala index dc96c09bc240..78f16e4cb016 100644 --- a/gluten-ut/spark40/src/test/scala/org/apache/spark/sql/GlutenDynamicPartitionPruningSuite.scala +++ b/gluten-ut/spark40/src/test/scala/org/apache/spark/sql/GlutenDynamicPartitionPruningSuite.scala @@ -657,6 +657,41 @@ class GlutenDynamicPartitionPruningV1SuiteAEOn } } } + + testGluten("multi-key DPP with columnar broadcast") { + withSQLConf( + SQLConf.DYNAMIC_PARTITION_PRUNING_REUSE_BROADCAST_ONLY.key -> "true", + SQLConf.ANSI_ENABLED.key -> "false" + ) { + withTable("fact_mk", "dim_mk") { + sql( + """CREATE TABLE fact_mk (id BIGINT, value INT, + | a STRING, b STRING) USING parquet + |PARTITIONED BY (a, b)""".stripMargin) + sql("INSERT INTO fact_mk VALUES " + + (0 until 10).map(i => s"($i, 1, '1', '1')").mkString(", ")) + sql("INSERT INTO fact_mk VALUES " + + (0 until 10).map(i => s"($i, 2, '2', '2')").mkString(", ")) + sql("INSERT INTO fact_mk VALUES " + + (0 until 10).map(i => s"($i, 3, '3', '3')").mkString(", ")) + + sql( + """CREATE TABLE dim_mk (x STRING, y STRING, z INT) + |USING parquet""".stripMargin) + sql( + "INSERT INTO dim_mk VALUES ('1','1',10), ('2','2',20)") + + val df = sql( + """SELECT f.id, f.a, f.b FROM fact_mk f + |JOIN dim_mk d ON f.a = d.x AND f.b = d.y + |WHERE d.z < 15""".stripMargin) + + val result = df.collect() + assert(result.length == 10) + checkAnswer(df, result) + } + } + } } abstract class GlutenDynamicPartitionPruningV2Suite extends GlutenDynamicPartitionPruningSuiteBase {