Skip to content

Commit

Permalink
Address all comments beside the separate configs discussion
Browse files Browse the repository at this point in the history
  • Loading branch information
c21 committed Jul 19, 2020
1 parent 2bbc8f8 commit 6aa17dd
Show file tree
Hide file tree
Showing 2 changed files with 65 additions and 40 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -40,25 +40,30 @@ import org.apache.spark.sql.internal.SQLConf
* COALESCE_BUCKETS_IN_SHUFFLED_HASH_JOIN_MAX_BUCKET_RATIO (`ShuffledHashJoin`).
*/
case class CoalesceBucketsInJoin(conf: SQLConf) extends Rule[SparkPlan] {
private def updateNumCoalescedBucketsInScan(
plan: SparkPlan,
numCoalescedBuckets: Int): SparkPlan = {
plan transformUp {
case f: FileSourceScanExec =>
f.copy(optionalNumCoalescedBuckets = Some(numCoalescedBuckets))
}
}

private def updateNumCoalescedBuckets(
join: BaseJoinExec,
numLeftBuckets: Int,
numRightBucket: Int,
numCoalescedBuckets: Int): BaseJoinExec = {
if (numCoalescedBuckets != numLeftBuckets) {
val leftCoalescedChild = join.left transformUp {
case f: FileSourceScanExec =>
f.copy(optionalNumCoalescedBuckets = Some(numCoalescedBuckets))
}
val leftCoalescedChild =
updateNumCoalescedBucketsInScan(join.left, numCoalescedBuckets)
join match {
case j: SortMergeJoinExec => j.copy(left = leftCoalescedChild)
case j: ShuffledHashJoinExec => j.copy(left = leftCoalescedChild)
}
} else {
val rightCoalescedChild = join.right transformUp {
case f: FileSourceScanExec =>
f.copy(optionalNumCoalescedBuckets = Some(numCoalescedBuckets))
}
val rightCoalescedChild =
updateNumCoalescedBucketsInScan(join.right, numCoalescedBuckets)
join match {
case j: SortMergeJoinExec => j.copy(right = rightCoalescedChild)
case j: ShuffledHashJoinExec => j.copy(right = rightCoalescedChild)
Expand Down Expand Up @@ -160,12 +165,12 @@ object ExtractJoinWithBuckets {

def unapply(plan: SparkPlan): Option[(BaseJoinExec, Int, Int)] = {
plan match {
case s: BaseJoinExec if isApplicable(s) =>
val leftBucket = getBucketSpec(s.left)
val rightBucket = getBucketSpec(s.right)
case j: BaseJoinExec if isApplicable(j) =>
val leftBucket = getBucketSpec(j.left)
val rightBucket = getBucketSpec(j.right)
if (leftBucket.isDefined && rightBucket.isDefined &&
isDivisible(leftBucket.get.numBuckets, rightBucket.get.numBuckets)) {
Some(s, leftBucket.get.numBuckets, rightBucket.get.numBuckets)
Some(j, leftBucket.get.numBuckets, rightBucket.get.numBuckets)
} else {
None
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,9 +30,9 @@ import org.apache.spark.sql.test.{SharedSparkSession, SQLTestUtils}
import org.apache.spark.sql.types.{IntegerType, StructType}

class CoalesceBucketsInJoinSuite extends SQLTestUtils with SharedSparkSession {
private val sortMergeJoin = "sortMergeJoin"
private val shuffledHashJoin = "shuffledHashJoin"
private val broadcastHashJoin = "broadcastHashJoin"
private val SORT_MERGE_JOIN = "sortMergeJoin"
private val SHUFFLED_HASH_JOIN = "shuffledHashJoin"
private val BROADCAST_HASH_JOIN = "broadcastHashJoin"

case class RelationSetting(
cols: Seq[Attribute],
Expand All @@ -58,7 +58,7 @@ class CoalesceBucketsInJoinSuite extends SQLTestUtils with SharedSparkSession {
def apply(
l: RelationSetting,
r: RelationSetting,
joinOperator: String = sortMergeJoin,
joinOperator: String = SORT_MERGE_JOIN,
shjBuildSide: Option[BuildSide] = None): JoinSetting = {
JoinSetting(l.cols, r.cols, l, r, joinOperator, shjBuildSide)
}
Expand All @@ -82,17 +82,17 @@ class CoalesceBucketsInJoinSuite extends SQLTestUtils with SharedSparkSession {
leftRelation = setting.rightRelation,
rightRelation = setting.leftRelation)

val settings = if (setting.joinOperator != shuffledHashJoin) {
val settings = if (setting.joinOperator != SHUFFLED_HASH_JOIN) {
Seq(setting, swappedSetting)
} else {
Seq(setting)
}
settings.foreach { s =>
val lScan = newFileSourceScanExec(s.leftRelation)
val rScan = newFileSourceScanExec(s.rightRelation)
val join = if (s.joinOperator == sortMergeJoin) {
val join = if (s.joinOperator == SORT_MERGE_JOIN) {
SortMergeJoinExec(s.leftKeys, s.rightKeys, Inner, None, lScan, rScan)
} else if (s.joinOperator == shuffledHashJoin) {
} else if (s.joinOperator == SHUFFLED_HASH_JOIN) {
ShuffledHashJoinExec(s.leftKeys, s.rightKeys, Inner, s.shjBuildSide.get, None, lScan, rScan)
} else {
BroadcastHashJoinExec(
Expand Down Expand Up @@ -121,47 +121,59 @@ class CoalesceBucketsInJoinSuite extends SQLTestUtils with SharedSparkSession {
test("bucket coalescing - basic") {
withSQLConf(SQLConf.COALESCE_BUCKETS_IN_JOIN_ENABLED.key -> "true") {
run(JoinSetting(
RelationSetting(4, None), RelationSetting(8, Some(4)), joinOperator = sortMergeJoin))
RelationSetting(4, None), RelationSetting(8, Some(4)), joinOperator = SORT_MERGE_JOIN))
run(JoinSetting(
RelationSetting(4, None), RelationSetting(8, Some(4)), joinOperator = shuffledHashJoin,
RelationSetting(4, None), RelationSetting(8, Some(4)), joinOperator = SHUFFLED_HASH_JOIN,
shjBuildSide = Some(BuildLeft)))
// Coalescing bucket should not happen when the target is on shuffled hash join
// build side.
run(JoinSetting(
RelationSetting(4, None), RelationSetting(8, None), joinOperator = shuffledHashJoin,
shjBuildSide = Some(BuildRight)))
}

withSQLConf(SQLConf.COALESCE_BUCKETS_IN_JOIN_ENABLED.key -> "false") {
run(JoinSetting(
RelationSetting(4, None), RelationSetting(8, None), joinOperator = broadcastHashJoin))
RelationSetting(4, None), RelationSetting(8, None), joinOperator = BROADCAST_HASH_JOIN))
run(JoinSetting(
RelationSetting(4, None), RelationSetting(8, None), joinOperator = SORT_MERGE_JOIN))
run(JoinSetting(
RelationSetting(4, None), RelationSetting(8, None), joinOperator = SHUFFLED_HASH_JOIN,
shjBuildSide = Some(BuildLeft)))
run(JoinSetting(
RelationSetting(4, None), RelationSetting(8, None), joinOperator = SHUFFLED_HASH_JOIN,
shjBuildSide = Some(BuildRight)))
}
}

test("bucket coalescing should work only for sort merge join and shuffled hash join") {
Seq(true, false).foreach { enabled =>
withSQLConf(SQLConf.COALESCE_BUCKETS_IN_JOIN_ENABLED.key -> enabled.toString) {
run(JoinSetting(
RelationSetting(4, None), RelationSetting(8, None), joinOperator = broadcastHashJoin))
RelationSetting(4, None), RelationSetting(8, None), joinOperator = BROADCAST_HASH_JOIN))
}
}
}

test("bucket coalescing shouldn't be applied to shuffled hash join build side") {
withSQLConf(SQLConf.COALESCE_BUCKETS_IN_JOIN_ENABLED.key -> "true") {
run(JoinSetting(
RelationSetting(4, None), RelationSetting(8, None), joinOperator = SHUFFLED_HASH_JOIN,
shjBuildSide = Some(BuildRight)))
}
}

test("bucket coalescing shouldn't be applied when the number of buckets are the same") {
withSQLConf(SQLConf.COALESCE_BUCKETS_IN_JOIN_ENABLED.key -> "true") {
run(JoinSetting(
RelationSetting(8, None), RelationSetting(8, None), joinOperator = sortMergeJoin))
RelationSetting(8, None), RelationSetting(8, None), joinOperator = SORT_MERGE_JOIN))
run(JoinSetting(
RelationSetting(8, None), RelationSetting(8, None), joinOperator = shuffledHashJoin,
RelationSetting(8, None), RelationSetting(8, None), joinOperator = SHUFFLED_HASH_JOIN,
shjBuildSide = Some(BuildLeft)))
}
}

test("number of bucket is not divisible by other number of bucket") {
withSQLConf(SQLConf.COALESCE_BUCKETS_IN_JOIN_ENABLED.key -> "true") {
run(JoinSetting(
RelationSetting(3, None), RelationSetting(8, None), joinOperator = sortMergeJoin))
RelationSetting(3, None), RelationSetting(8, None), joinOperator = SORT_MERGE_JOIN))
run(JoinSetting(
RelationSetting(3, None), RelationSetting(8, None), joinOperator = shuffledHashJoin,
RelationSetting(3, None), RelationSetting(8, None), joinOperator = SHUFFLED_HASH_JOIN,
shjBuildSide = Some(BuildLeft)))
}
}
Expand All @@ -170,11 +182,11 @@ class CoalesceBucketsInJoinSuite extends SQLTestUtils with SharedSparkSession {
withSQLConf(SQLConf.COALESCE_BUCKETS_IN_JOIN_ENABLED.key -> "true") {
withSQLConf(SQLConf.COALESCE_BUCKETS_IN_SORT_MERGE_JOIN_MAX_BUCKET_RATIO.key -> "2") {
run(JoinSetting(
RelationSetting(4, None), RelationSetting(16, None), joinOperator = sortMergeJoin))
RelationSetting(4, None), RelationSetting(16, None), joinOperator = SORT_MERGE_JOIN))
}
withSQLConf(SQLConf.COALESCE_BUCKETS_IN_SHUFFLED_HASH_JOIN_MAX_BUCKET_RATIO.key -> "2") {
run(JoinSetting(
RelationSetting(4, None), RelationSetting(16, None), joinOperator = shuffledHashJoin,
RelationSetting(4, None), RelationSetting(16, None), joinOperator = SHUFFLED_HASH_JOIN,
shjBuildSide = Some(BuildLeft)))
}
}
Expand All @@ -199,15 +211,15 @@ class CoalesceBucketsInJoinSuite extends SQLTestUtils with SharedSparkSession {
rightKeys = Seq(rCols.head),
leftRelation = lRel,
rightRelation = rRel,
joinOperator = sortMergeJoin,
joinOperator = SORT_MERGE_JOIN,
shjBuildSide = None))

run(JoinSetting(
leftKeys = Seq(lCols.head),
rightKeys = Seq(rCols.head),
leftRelation = lRel,
rightRelation = rRel,
joinOperator = shuffledHashJoin,
joinOperator = SHUFFLED_HASH_JOIN,
shjBuildSide = Some(BuildLeft)))

// The following should not be coalesced because join keys do not match with output
Expand All @@ -217,15 +229,15 @@ class CoalesceBucketsInJoinSuite extends SQLTestUtils with SharedSparkSession {
rightKeys = rCols :+ AttributeReference("r3", IntegerType)(),
leftRelation = lRel,
rightRelation = rRel,
joinOperator = sortMergeJoin,
joinOperator = SORT_MERGE_JOIN,
shjBuildSide = None))

run(JoinSetting(
leftKeys = lCols :+ AttributeReference("l3", IntegerType)(),
rightKeys = rCols :+ AttributeReference("r3", IntegerType)(),
leftRelation = lRel,
rightRelation = rRel,
joinOperator = shuffledHashJoin,
joinOperator = SHUFFLED_HASH_JOIN,
shjBuildSide = Some(BuildLeft)))

// The following will be coalesced since ordering should not matter because it will be
Expand All @@ -235,16 +247,24 @@ class CoalesceBucketsInJoinSuite extends SQLTestUtils with SharedSparkSession {
rightKeys = rCols.reverse,
leftRelation = lRel,
rightRelation = RelationSetting(rCols, 8, Some(4)),
joinOperator = sortMergeJoin,
joinOperator = SORT_MERGE_JOIN,
shjBuildSide = None))

run(JoinSetting(
leftKeys = lCols.reverse,
rightKeys = rCols.reverse,
leftRelation = lRel,
rightRelation = RelationSetting(rCols, 8, Some(4)),
joinOperator = shuffledHashJoin,
joinOperator = SHUFFLED_HASH_JOIN,
shjBuildSide = Some(BuildLeft)))

run(JoinSetting(
leftKeys = rCols.reverse,
rightKeys = lCols.reverse,
leftRelation = RelationSetting(rCols, 8, Some(4)),
rightRelation = lRel,
joinOperator = SHUFFLED_HASH_JOIN,
shjBuildSide = Some(BuildRight)))
}
}

Expand Down

0 comments on commit 6aa17dd

Please sign in to comment.