diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/AdaptiveSparkPlanExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/AdaptiveSparkPlanExec.scala index a760991ab5180..6d3a9b4ab991f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/AdaptiveSparkPlanExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/AdaptiveSparkPlanExec.scala @@ -40,7 +40,7 @@ import org.apache.spark.sql.catalyst.util.sideBySide import org.apache.spark.sql.errors.QueryExecutionErrors import org.apache.spark.sql.execution._ import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanExec._ -import org.apache.spark.sql.execution.bucketing.DisableUnnecessaryBucketedScan +import org.apache.spark.sql.execution.bucketing.{CoalesceBucketsInJoin, DisableUnnecessaryBucketedScan} import org.apache.spark.sql.execution.columnar.InMemoryTableScanExec import org.apache.spark.sql.execution.exchange._ import org.apache.spark.sql.execution.ui.{SparkListenerSQLAdaptiveExecutionUpdate, SparkListenerSQLAdaptiveSQLMetricUpdates, SQLPlanMetric} @@ -118,6 +118,7 @@ case class AdaptiveSparkPlanExec( val ensureRequirements = EnsureRequirements(requiredDistribution.isDefined, requiredDistribution) Seq( + CoalesceBucketsInJoin, RemoveRedundantProjects, ensureRequirements, AdjustShuffleExchangePosition, diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/BucketedReadSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/BucketedReadSuite.scala index fc7c4e5761be1..e9563093a05e7 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/BucketedReadSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/BucketedReadSuite.scala @@ -25,7 +25,7 @@ import org.apache.spark.sql.catalyst.expressions import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.physical.HashPartitioning import org.apache.spark.sql.execution.{FileSourceScanExec, SortExec, SparkPlan} -import org.apache.spark.sql.execution.adaptive.{AdaptiveSparkPlanExec, AdaptiveSparkPlanHelper, DisableAdaptiveExecution} +import org.apache.spark.sql.execution.adaptive.{AdaptiveSparkPlanExec, AdaptiveSparkPlanHelper} import org.apache.spark.sql.execution.datasources.BucketingUtils import org.apache.spark.sql.execution.exchange.ShuffleExchangeExec import org.apache.spark.sql.execution.joins.SortMergeJoinExec @@ -1010,8 +1010,7 @@ abstract class BucketedReadSuite extends QueryTest with SQLTestUtils with Adapti } } - test("bucket coalescing is applied when join expressions match with partitioning expressions", - DisableAdaptiveExecution("Expected shuffle num mismatched")) { + test("bucket coalescing is applied when join expressions match with partitioning expressions") { withTable("t1", "t2") { df1.write.format("parquet").bucketBy(8, "i", "j").saveAsTable("t1") df2.write.format("parquet").bucketBy(4, "i", "j").saveAsTable("t2") @@ -1024,10 +1023,10 @@ abstract class BucketedReadSuite extends QueryTest with SQLTestUtils with Adapti expectedNumShuffles: Int, expectedCoalescedNumBuckets: Option[Int]): Unit = { val plan = sql(query).queryExecution.executedPlan - val shuffles = plan.collect { case s: ShuffleExchangeExec => s } + val shuffles = collect(plan) { case s: ShuffleExchangeExec => s } assert(shuffles.length == expectedNumShuffles) - val scans = plan.collect { + val scans = collect(plan) { case f: FileSourceScanExec if f.optionalNumCoalescedBuckets.isDefined => f } if (expectedCoalescedNumBuckets.isDefined) {