diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/bucketing/CoalesceBucketsInJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/bucketing/CoalesceBucketsInJoin.scala index 7acad9e7b7187..607eb6fd5661f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/bucketing/CoalesceBucketsInJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/bucketing/CoalesceBucketsInJoin.scala @@ -40,25 +40,30 @@ import org.apache.spark.sql.internal.SQLConf * COALESCE_BUCKETS_IN_SHUFFLED_HASH_JOIN_MAX_BUCKET_RATIO (`ShuffledHashJoin`). */ case class CoalesceBucketsInJoin(conf: SQLConf) extends Rule[SparkPlan] { + private def updateNumCoalescedBucketsInScan( + plan: SparkPlan, + numCoalescedBuckets: Int): SparkPlan = { + plan transformUp { + case f: FileSourceScanExec => + f.copy(optionalNumCoalescedBuckets = Some(numCoalescedBuckets)) + } + } + private def updateNumCoalescedBuckets( join: BaseJoinExec, numLeftBuckets: Int, numRightBucket: Int, numCoalescedBuckets: Int): BaseJoinExec = { if (numCoalescedBuckets != numLeftBuckets) { - val leftCoalescedChild = join.left transformUp { - case f: FileSourceScanExec => - f.copy(optionalNumCoalescedBuckets = Some(numCoalescedBuckets)) - } + val leftCoalescedChild = + updateNumCoalescedBucketsInScan(join.left, numCoalescedBuckets) join match { case j: SortMergeJoinExec => j.copy(left = leftCoalescedChild) case j: ShuffledHashJoinExec => j.copy(left = leftCoalescedChild) } } else { - val rightCoalescedChild = join.right transformUp { - case f: FileSourceScanExec => - f.copy(optionalNumCoalescedBuckets = Some(numCoalescedBuckets)) - } + val rightCoalescedChild = + updateNumCoalescedBucketsInScan(join.right, numCoalescedBuckets) join match { case j: SortMergeJoinExec => j.copy(right = rightCoalescedChild) case j: ShuffledHashJoinExec => j.copy(right = rightCoalescedChild) @@ -160,12 +165,12 @@ object ExtractJoinWithBuckets { def unapply(plan: SparkPlan): Option[(BaseJoinExec, Int, Int)] = { plan match { - case s: BaseJoinExec if isApplicable(s) => - val leftBucket = getBucketSpec(s.left) - val rightBucket = getBucketSpec(s.right) + case j: BaseJoinExec if isApplicable(j) => + val leftBucket = getBucketSpec(j.left) + val rightBucket = getBucketSpec(j.right) if (leftBucket.isDefined && rightBucket.isDefined && isDivisible(leftBucket.get.numBuckets, rightBucket.get.numBuckets)) { - Some(s, leftBucket.get.numBuckets, rightBucket.get.numBuckets) + Some(j, leftBucket.get.numBuckets, rightBucket.get.numBuckets) } else { None } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/bucketing/CoalesceBucketsInJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/bucketing/CoalesceBucketsInJoinSuite.scala index 588e74ff3adeb..317a34e5157c6 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/bucketing/CoalesceBucketsInJoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/bucketing/CoalesceBucketsInJoinSuite.scala @@ -30,9 +30,9 @@ import org.apache.spark.sql.test.{SharedSparkSession, SQLTestUtils} import org.apache.spark.sql.types.{IntegerType, StructType} class CoalesceBucketsInJoinSuite extends SQLTestUtils with SharedSparkSession { - private val sortMergeJoin = "sortMergeJoin" - private val shuffledHashJoin = "shuffledHashJoin" - private val broadcastHashJoin = "broadcastHashJoin" + private val SORT_MERGE_JOIN = "sortMergeJoin" + private val SHUFFLED_HASH_JOIN = "shuffledHashJoin" + private val BROADCAST_HASH_JOIN = "broadcastHashJoin" case class RelationSetting( cols: Seq[Attribute], @@ -58,7 +58,7 @@ class CoalesceBucketsInJoinSuite extends SQLTestUtils with SharedSparkSession { def apply( l: RelationSetting, r: RelationSetting, - joinOperator: String = sortMergeJoin, + joinOperator: String = SORT_MERGE_JOIN, shjBuildSide: Option[BuildSide] = None): JoinSetting = { JoinSetting(l.cols, r.cols, l, r, joinOperator, shjBuildSide) } @@ -82,7 +82,7 @@ class CoalesceBucketsInJoinSuite extends SQLTestUtils with SharedSparkSession { leftRelation = setting.rightRelation, rightRelation = setting.leftRelation) - val settings = if (setting.joinOperator != shuffledHashJoin) { + val settings = if (setting.joinOperator != SHUFFLED_HASH_JOIN) { Seq(setting, swappedSetting) } else { Seq(setting) @@ -90,9 +90,9 @@ class CoalesceBucketsInJoinSuite extends SQLTestUtils with SharedSparkSession { settings.foreach { s => val lScan = newFileSourceScanExec(s.leftRelation) val rScan = newFileSourceScanExec(s.rightRelation) - val join = if (s.joinOperator == sortMergeJoin) { + val join = if (s.joinOperator == SORT_MERGE_JOIN) { SortMergeJoinExec(s.leftKeys, s.rightKeys, Inner, None, lScan, rScan) - } else if (s.joinOperator == shuffledHashJoin) { + } else if (s.joinOperator == SHUFFLED_HASH_JOIN) { ShuffledHashJoinExec(s.leftKeys, s.rightKeys, Inner, s.shjBuildSide.get, None, lScan, rScan) } else { BroadcastHashJoinExec( @@ -121,19 +121,23 @@ class CoalesceBucketsInJoinSuite extends SQLTestUtils with SharedSparkSession { test("bucket coalescing - basic") { withSQLConf(SQLConf.COALESCE_BUCKETS_IN_JOIN_ENABLED.key -> "true") { run(JoinSetting( - RelationSetting(4, None), RelationSetting(8, Some(4)), joinOperator = sortMergeJoin)) + RelationSetting(4, None), RelationSetting(8, Some(4)), joinOperator = SORT_MERGE_JOIN)) run(JoinSetting( - RelationSetting(4, None), RelationSetting(8, Some(4)), joinOperator = shuffledHashJoin, + RelationSetting(4, None), RelationSetting(8, Some(4)), joinOperator = SHUFFLED_HASH_JOIN, shjBuildSide = Some(BuildLeft))) - // Coalescing bucket should not happen when the target is on shuffled hash join - // build side. - run(JoinSetting( - RelationSetting(4, None), RelationSetting(8, None), joinOperator = shuffledHashJoin, - shjBuildSide = Some(BuildRight))) } + withSQLConf(SQLConf.COALESCE_BUCKETS_IN_JOIN_ENABLED.key -> "false") { run(JoinSetting( - RelationSetting(4, None), RelationSetting(8, None), joinOperator = broadcastHashJoin)) + RelationSetting(4, None), RelationSetting(8, None), joinOperator = BROADCAST_HASH_JOIN)) + run(JoinSetting( + RelationSetting(4, None), RelationSetting(8, None), joinOperator = SORT_MERGE_JOIN)) + run(JoinSetting( + RelationSetting(4, None), RelationSetting(8, None), joinOperator = SHUFFLED_HASH_JOIN, + shjBuildSide = Some(BuildLeft))) + run(JoinSetting( + RelationSetting(4, None), RelationSetting(8, None), joinOperator = SHUFFLED_HASH_JOIN, + shjBuildSide = Some(BuildRight))) } } @@ -141,17 +145,25 @@ class CoalesceBucketsInJoinSuite extends SQLTestUtils with SharedSparkSession { Seq(true, false).foreach { enabled => withSQLConf(SQLConf.COALESCE_BUCKETS_IN_JOIN_ENABLED.key -> enabled.toString) { run(JoinSetting( - RelationSetting(4, None), RelationSetting(8, None), joinOperator = broadcastHashJoin)) + RelationSetting(4, None), RelationSetting(8, None), joinOperator = BROADCAST_HASH_JOIN)) } } } + test("bucket coalescing shouldn't be applied to shuffled hash join build side") { + withSQLConf(SQLConf.COALESCE_BUCKETS_IN_JOIN_ENABLED.key -> "true") { + run(JoinSetting( + RelationSetting(4, None), RelationSetting(8, None), joinOperator = SHUFFLED_HASH_JOIN, + shjBuildSide = Some(BuildRight))) + } + } + test("bucket coalescing shouldn't be applied when the number of buckets are the same") { withSQLConf(SQLConf.COALESCE_BUCKETS_IN_JOIN_ENABLED.key -> "true") { run(JoinSetting( - RelationSetting(8, None), RelationSetting(8, None), joinOperator = sortMergeJoin)) + RelationSetting(8, None), RelationSetting(8, None), joinOperator = SORT_MERGE_JOIN)) run(JoinSetting( - RelationSetting(8, None), RelationSetting(8, None), joinOperator = shuffledHashJoin, + RelationSetting(8, None), RelationSetting(8, None), joinOperator = SHUFFLED_HASH_JOIN, shjBuildSide = Some(BuildLeft))) } } @@ -159,9 +171,9 @@ class CoalesceBucketsInJoinSuite extends SQLTestUtils with SharedSparkSession { test("number of bucket is not divisible by other number of bucket") { withSQLConf(SQLConf.COALESCE_BUCKETS_IN_JOIN_ENABLED.key -> "true") { run(JoinSetting( - RelationSetting(3, None), RelationSetting(8, None), joinOperator = sortMergeJoin)) + RelationSetting(3, None), RelationSetting(8, None), joinOperator = SORT_MERGE_JOIN)) run(JoinSetting( - RelationSetting(3, None), RelationSetting(8, None), joinOperator = shuffledHashJoin, + RelationSetting(3, None), RelationSetting(8, None), joinOperator = SHUFFLED_HASH_JOIN, shjBuildSide = Some(BuildLeft))) } } @@ -170,11 +182,11 @@ class CoalesceBucketsInJoinSuite extends SQLTestUtils with SharedSparkSession { withSQLConf(SQLConf.COALESCE_BUCKETS_IN_JOIN_ENABLED.key -> "true") { withSQLConf(SQLConf.COALESCE_BUCKETS_IN_SORT_MERGE_JOIN_MAX_BUCKET_RATIO.key -> "2") { run(JoinSetting( - RelationSetting(4, None), RelationSetting(16, None), joinOperator = sortMergeJoin)) + RelationSetting(4, None), RelationSetting(16, None), joinOperator = SORT_MERGE_JOIN)) } withSQLConf(SQLConf.COALESCE_BUCKETS_IN_SHUFFLED_HASH_JOIN_MAX_BUCKET_RATIO.key -> "2") { run(JoinSetting( - RelationSetting(4, None), RelationSetting(16, None), joinOperator = shuffledHashJoin, + RelationSetting(4, None), RelationSetting(16, None), joinOperator = SHUFFLED_HASH_JOIN, shjBuildSide = Some(BuildLeft))) } } @@ -199,7 +211,7 @@ class CoalesceBucketsInJoinSuite extends SQLTestUtils with SharedSparkSession { rightKeys = Seq(rCols.head), leftRelation = lRel, rightRelation = rRel, - joinOperator = sortMergeJoin, + joinOperator = SORT_MERGE_JOIN, shjBuildSide = None)) run(JoinSetting( @@ -207,7 +219,7 @@ class CoalesceBucketsInJoinSuite extends SQLTestUtils with SharedSparkSession { rightKeys = Seq(rCols.head), leftRelation = lRel, rightRelation = rRel, - joinOperator = shuffledHashJoin, + joinOperator = SHUFFLED_HASH_JOIN, shjBuildSide = Some(BuildLeft))) // The following should not be coalesced because join keys do not match with output @@ -217,7 +229,7 @@ class CoalesceBucketsInJoinSuite extends SQLTestUtils with SharedSparkSession { rightKeys = rCols :+ AttributeReference("r3", IntegerType)(), leftRelation = lRel, rightRelation = rRel, - joinOperator = sortMergeJoin, + joinOperator = SORT_MERGE_JOIN, shjBuildSide = None)) run(JoinSetting( @@ -225,7 +237,7 @@ class CoalesceBucketsInJoinSuite extends SQLTestUtils with SharedSparkSession { rightKeys = rCols :+ AttributeReference("r3", IntegerType)(), leftRelation = lRel, rightRelation = rRel, - joinOperator = shuffledHashJoin, + joinOperator = SHUFFLED_HASH_JOIN, shjBuildSide = Some(BuildLeft))) // The following will be coalesced since ordering should not matter because it will be @@ -235,7 +247,7 @@ class CoalesceBucketsInJoinSuite extends SQLTestUtils with SharedSparkSession { rightKeys = rCols.reverse, leftRelation = lRel, rightRelation = RelationSetting(rCols, 8, Some(4)), - joinOperator = sortMergeJoin, + joinOperator = SORT_MERGE_JOIN, shjBuildSide = None)) run(JoinSetting( @@ -243,8 +255,16 @@ class CoalesceBucketsInJoinSuite extends SQLTestUtils with SharedSparkSession { rightKeys = rCols.reverse, leftRelation = lRel, rightRelation = RelationSetting(rCols, 8, Some(4)), - joinOperator = shuffledHashJoin, + joinOperator = SHUFFLED_HASH_JOIN, shjBuildSide = Some(BuildLeft))) + + run(JoinSetting( + leftKeys = rCols.reverse, + rightKeys = lCols.reverse, + leftRelation = RelationSetting(rCols, 8, Some(4)), + rightRelation = lRel, + joinOperator = SHUFFLED_HASH_JOIN, + shjBuildSide = Some(BuildRight))) } }