diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/AdaptiveSparkPlanExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/AdaptiveSparkPlanExec.scala index a760991ab5180..fceb9db411200 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/AdaptiveSparkPlanExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/AdaptiveSparkPlanExec.scala @@ -40,7 +40,7 @@ import org.apache.spark.sql.catalyst.util.sideBySide import org.apache.spark.sql.errors.QueryExecutionErrors import org.apache.spark.sql.execution._ import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanExec._ -import org.apache.spark.sql.execution.bucketing.DisableUnnecessaryBucketedScan +import org.apache.spark.sql.execution.bucketing.{CoalesceBucketsInJoin, DisableUnnecessaryBucketedScan} import org.apache.spark.sql.execution.columnar.InMemoryTableScanExec import org.apache.spark.sql.execution.exchange._ import org.apache.spark.sql.execution.ui.{SparkListenerSQLAdaptiveExecutionUpdate, SparkListenerSQLAdaptiveSQLMetricUpdates, SQLPlanMetric} @@ -117,7 +117,10 @@ case class AdaptiveSparkPlanExec( // around this case. val ensureRequirements = EnsureRequirements(requiredDistribution.isDefined, requiredDistribution) + // CoalesceBucketsInJoin can help eliminate shuffles and must be run before + // EnsureRequirements Seq( + CoalesceBucketsInJoin, RemoveRedundantProjects, ensureRequirements, AdjustShuffleExchangePosition, diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/BucketedReadSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/BucketedReadSuite.scala index fc7c4e5761be1..a18c681e0fe42 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/BucketedReadSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/BucketedReadSuite.scala @@ -25,7 +25,7 @@ import org.apache.spark.sql.catalyst.expressions import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.physical.HashPartitioning import org.apache.spark.sql.execution.{FileSourceScanExec, SortExec, SparkPlan} -import org.apache.spark.sql.execution.adaptive.{AdaptiveSparkPlanExec, AdaptiveSparkPlanHelper, DisableAdaptiveExecution} +import org.apache.spark.sql.execution.adaptive.{AdaptiveSparkPlanExec, AdaptiveSparkPlanHelper} import org.apache.spark.sql.execution.datasources.BucketingUtils import org.apache.spark.sql.execution.exchange.ShuffleExchangeExec import org.apache.spark.sql.execution.joins.SortMergeJoinExec @@ -1010,8 +1010,7 @@ abstract class BucketedReadSuite extends QueryTest with SQLTestUtils with Adapti } } - test("bucket coalescing is applied when join expressions match with partitioning expressions", - DisableAdaptiveExecution("Expected shuffle num mismatched")) { + test("bucket coalescing is applied when join expressions match with partitioning expressions") { withTable("t1", "t2") { df1.write.format("parquet").bucketBy(8, "i", "j").saveAsTable("t1") df2.write.format("parquet").bucketBy(4, "i", "j").saveAsTable("t2") @@ -1023,18 +1022,22 @@ abstract class BucketedReadSuite extends QueryTest with SQLTestUtils with Adapti query: String, expectedNumShuffles: Int, expectedCoalescedNumBuckets: Option[Int]): Unit = { - val plan = sql(query).queryExecution.executedPlan - val shuffles = plan.collect { case s: ShuffleExchangeExec => s } - assert(shuffles.length == expectedNumShuffles) - - val scans = plan.collect { - case f: FileSourceScanExec if f.optionalNumCoalescedBuckets.isDefined => f - } - if (expectedCoalescedNumBuckets.isDefined) { - assert(scans.length == 1) - assert(scans.head.optionalNumCoalescedBuckets == expectedCoalescedNumBuckets) - } else { - assert(scans.isEmpty) + Seq(true, false).foreach { aqeEnabled => + withSQLConf(SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> aqeEnabled.toString) { + val plan = sql(query).queryExecution.executedPlan + val shuffles = collect(plan) { case s: ShuffleExchangeExec => s } + assert(shuffles.length == expectedNumShuffles) + + val scans = collect(plan) { + case f: FileSourceScanExec if f.optionalNumCoalescedBuckets.isDefined => f + } + if (expectedCoalescedNumBuckets.isDefined) { + assert(scans.length == 1) + assert(scans.head.optionalNumCoalescedBuckets == expectedCoalescedNumBuckets) + } else { + assert(scans.isEmpty) + } + } } }