Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-31350][SQL] Coalesce bucketed tables for sort merge join if applicable #28123

Closed
wants to merge 27 commits into from
Closed
Show file tree
Hide file tree
Changes from 25 commits
Commits
Show all changes
27 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -2617,6 +2617,26 @@ object SQLConf {
.booleanConf
.createWithDefault(false)

val COALESCE_BUCKETS_IN_SORT_MERGE_JOIN_ENABLED =
buildConf("spark.sql.bucketing.coalesceBucketsInSortMergeJoin.enabled")
.doc("When true, if two bucketed tables with the different number of buckets are joined, " +
"the side with a bigger number of buckets will be coalesced to have the same number " +
"of buckets as the other side. Bucket coalescing is applied only to sort-merge joins " +
"and only when the bigger number of buckets is divisible by the smaller number of buckets.")
.version("3.1.0")
.booleanConf
.createWithDefault(false)

val COALESCE_BUCKETS_IN_SORT_MERGE_JOIN_MAX_BUCKET_RATIO =
buildConf("spark.sql.bucketing.coalesceBucketsInSortMergeJoin.maxBucketRatio")
.doc("The ratio of the number of two buckets being coalesced should be less than or " +
"equal to this value for bucket coalescing to be applied. This configuration only " +
s"has an effect when '${COALESCE_BUCKETS_IN_SORT_MERGE_JOIN_ENABLED.key}' is set to true.")
.version("3.1.0")
.intConf
.checkValue(_ > 0, "The difference must be positive.")
.createWithDefault(4)

/**
* Holds information about keys that have been deprecated.
*
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -155,7 +155,8 @@ case class RowDataSourceScanExec(
* @param output Output attributes of the scan, including data attributes and partition attributes.
* @param requiredSchema Required schema of the underlying relation, excluding partition columns.
* @param partitionFilters Predicates to use for partition pruning.
* @param optionalBucketSet Bucket ids for bucket pruning
* @param optionalBucketSet Bucket ids for bucket pruning.
* @param optionalNumCoalescedBuckets Number of coalesced buckets.
* @param dataFilters Filters on non-partition columns.
* @param tableIdentifier identifier for the table in the metastore.
*/
Expand All @@ -165,6 +166,7 @@ case class FileSourceScanExec(
requiredSchema: StructType,
partitionFilters: Seq[Expression],
optionalBucketSet: Option[BitSet],
optionalNumCoalescedBuckets: Option[Int],
cloud-fan marked this conversation as resolved.
Show resolved Hide resolved
dataFilters: Seq[Expression],
tableIdentifier: Option[TableIdentifier])
extends DataSourceScanExec {
Expand Down Expand Up @@ -291,7 +293,8 @@ case class FileSourceScanExec(
// above
val spec = relation.bucketSpec.get
val bucketColumns = spec.bucketColumnNames.flatMap(n => toAttribute(n))
val partitioning = HashPartitioning(bucketColumns, spec.numBuckets)
val numPartitions = optionalNumCoalescedBuckets.getOrElse(spec.numBuckets)
val partitioning = HashPartitioning(bucketColumns, numPartitions)
val sortColumns =
spec.sortColumnNames.map(x => toAttribute(x)).takeWhile(x => x.isDefined).map(_.get)
val shouldCalculateSortOrder =
Expand All @@ -311,7 +314,8 @@ case class FileSourceScanExec(
files.map(_.getPath.getName).groupBy(file => BucketingUtils.getBucketId(file))
val singleFilePartitions = bucketToFilesGrouping.forall(p => p._2.length <= 1)

if (singleFilePartitions) {
// TODO SPARK-24528 Sort order is currently ignored if buckets are coalesced.
if (singleFilePartitions && optionalNumCoalescedBuckets.isEmpty) {
// TODO Currently Spark does not support writing columns sorting in descending order
// so using Ascending order. This can be fixed in future
sortColumns.map(attribute => SortOrder(attribute, Ascending))
Expand Down Expand Up @@ -356,7 +360,8 @@ case class FileSourceScanExec(
spec.numBuckets
}
metadata + ("SelectedBucketsCount" ->
s"$numSelectedBuckets out of ${spec.numBuckets}")
(s"$numSelectedBuckets out of ${spec.numBuckets}" +
optionalNumCoalescedBuckets.map { b => s" (Coalesced to $b)"}.getOrElse("")))
imback82 marked this conversation as resolved.
Show resolved Hide resolved
} getOrElse {
metadata
}
Expand Down Expand Up @@ -544,8 +549,19 @@ case class FileSourceScanExec(
filesGroupedToBuckets
}

val filePartitions = Seq.tabulate(bucketSpec.numBuckets) { bucketId =>
FilePartition(bucketId, prunedFilesGroupedToBuckets.getOrElse(bucketId, Array.empty))
val filePartitions = optionalNumCoalescedBuckets.map { numCoalescedBuckets =>
logInfo(s"Coalescing to ${numCoalescedBuckets} buckets")
val coalescedBuckets = prunedFilesGroupedToBuckets.groupBy(_._1 % numCoalescedBuckets)
Seq.tabulate(numCoalescedBuckets) { bucketId =>
val partitionedFiles = coalescedBuckets.get(bucketId).map {
_.values.flatten.toArray
}.getOrElse(Array.empty)
FilePartition(bucketId, partitionedFiles)
}
} getOrElse {
imback82 marked this conversation as resolved.
Show resolved Hide resolved
Seq.tabulate(bucketSpec.numBuckets) { bucketId =>
FilePartition(bucketId, prunedFilesGroupedToBuckets.getOrElse(bucketId, Array.empty))
}
}

new FileScanRDD(fsRelation.sparkSession, readFile, filePartitions)
Expand Down Expand Up @@ -599,6 +615,7 @@ case class FileSourceScanExec(
requiredSchema,
QueryPlan.normalizePredicates(partitionFilters, output),
optionalBucketSet,
optionalNumCoalescedBuckets,
QueryPlan.normalizePredicates(dataFilters, output),
None)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.catalyst.util.StringUtils.PlanStringConcat
import org.apache.spark.sql.catalyst.util.truncatedString
import org.apache.spark.sql.execution.adaptive.{AdaptiveExecutionContext, InsertAdaptiveSparkPlan}
import org.apache.spark.sql.execution.bucketing.CoalesceBucketsInSortMergeJoin
import org.apache.spark.sql.execution.dynamicpruning.PlanDynamicPruningFilters
import org.apache.spark.sql.execution.exchange.{EnsureRequirements, ReuseExchange}
import org.apache.spark.sql.execution.streaming.{IncrementalExecution, OffsetSeqMetadata}
Expand Down Expand Up @@ -331,6 +332,7 @@ object QueryExecution {
// as the original plan is hidden behind `AdaptiveSparkPlanExec`.
adaptiveExecutionRule.toSeq ++
Seq(
CoalesceBucketsInSortMergeJoin(sparkSession.sessionState.conf),
PlanDynamicPruningFilters(sparkSession),
PlanSubqueries(sparkSession),
EnsureRequirements(sparkSession.sessionState.conf),
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,132 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.sql.execution.bucketing

import org.apache.spark.sql.catalyst.catalog.BucketSpec
import org.apache.spark.sql.catalyst.expressions.Expression
import org.apache.spark.sql.catalyst.plans.physical.{HashPartitioning, Partitioning}
import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.execution.{FileSourceScanExec, FilterExec, ProjectExec, SparkPlan}
import org.apache.spark.sql.execution.joins.SortMergeJoinExec
import org.apache.spark.sql.internal.SQLConf

/**
* This rule coalesces one side of the `SortMergeJoin` if the following conditions are met:
* - Two bucketed tables are joined.
imback82 marked this conversation as resolved.
Show resolved Hide resolved
* - Join keys match with output partition expressions on their respective sides.
* - The larger bucket number is divisible by the smaller bucket number.
* - COALESCE_BUCKETS_IN_SORT_MERGE_JOIN_ENABLED is set to true.
* - The ratio of the number of buckets is less than the value set in
* COALESCE_BUCKETS_IN_SORT_MERGE_JOIN_MAX_BUCKET_RATIO.
*/
case class CoalesceBucketsInSortMergeJoin(conf: SQLConf) extends Rule[SparkPlan] {
private def mayCoalesce(numBuckets1: Int, numBuckets2: Int, conf: SQLConf): Option[Int] = {
assert(numBuckets1 != numBuckets2)
val (small, large) = (math.min(numBuckets1, numBuckets2), math.max(numBuckets1, numBuckets2))
// A bucket can be coalesced only if the bigger number of buckets is divisible by the smaller
// number of buckets because bucket id is calculated by modding the total number of buckets.
if (large % small == 0 &&
large / small <= conf.getConf(SQLConf.COALESCE_BUCKETS_IN_SORT_MERGE_JOIN_MAX_BUCKET_RATIO)) {
Some(small)
} else {
None
}
}

private def updateNumCoalescedBuckets(plan: SparkPlan, numCoalescedBuckets: Int): SparkPlan = {
plan.transformUp {
case f: FileSourceScanExec =>
f.copy(optionalNumCoalescedBuckets = Some(numCoalescedBuckets))
}
}

def apply(plan: SparkPlan): SparkPlan = {
if (!conf.getConf(SQLConf.COALESCE_BUCKETS_IN_SORT_MERGE_JOIN_ENABLED)) {
return plan
}

plan transform {
case ExtractSortMergeJoinWithBuckets(smj, numLeftBuckets, numRightBuckets)
if numLeftBuckets != numRightBuckets =>
mayCoalesce(numLeftBuckets, numRightBuckets, conf).map { numCoalescedBuckets =>
if (numCoalescedBuckets != numLeftBuckets) {
smj.copy(left = updateNumCoalescedBuckets(smj.left, numCoalescedBuckets))
} else {
smj.copy(right = updateNumCoalescedBuckets(smj.right, numCoalescedBuckets))
}
}.getOrElse(smj)
case other => other
}
}
}

/**
* An extractor that extracts `SortMergeJoinExec` where both sides of the join have the bucketed
* tables and are consisted of only the scan operation.
*/
object ExtractSortMergeJoinWithBuckets {
private def isScanOperation(plan: SparkPlan): Boolean = plan match {
case f: FilterExec => isScanOperation(f.child)
case p: ProjectExec => isScanOperation(p.child)
cloud-fan marked this conversation as resolved.
Show resolved Hide resolved
case _: FileSourceScanExec => true
case _ => false
}

private def getBucketSpec(plan: SparkPlan): Option[BucketSpec] = {
plan.collectFirst {
case f: FileSourceScanExec if f.relation.bucketSpec.nonEmpty &&
f.optionalNumCoalescedBuckets.isEmpty =>
f.relation.bucketSpec.get
}
}

/**
* The join keys should match with expressions for output partitioning. Note that
* the ordering does not matter because it will be handled in `EnsureRequirements`.
cloud-fan marked this conversation as resolved.
Show resolved Hide resolved
*/
private def satisfiesOutputPartitioning(
keys: Seq[Expression],
partitioning: Partitioning): Boolean = {
partitioning match {
case HashPartitioning(exprs, _) if exprs.length == keys.length =>
exprs.forall(e => keys.exists(_.semanticEquals(e)))
case _ => false
}
}

private def isApplicable(s: SortMergeJoinExec): Boolean = {
isScanOperation(s.left) &&
isScanOperation(s.right) &&
satisfiesOutputPartitioning(s.leftKeys, s.left.outputPartitioning) &&
satisfiesOutputPartitioning(s.rightKeys, s.right.outputPartitioning)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

not related to this PR but just an idea: we don't need to do bucket scan at all if it can't save shuffles. This can increase parallelism.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yea good idea.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this a follow-up issue?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I can give it a shot after this PR.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yea, thanks!

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we don't need to do bucket scan at all if it can't save shuffles. This can increase parallelism.

@cloud-fan IMO there's other benefit to do bucket scan even though it can't save shuffle, e.g. bucket filter push down. So we probably need to take that into consideration before disabling bucketing.

}

def unapply(plan: SparkPlan): Option[(SortMergeJoinExec, Int, Int)] = {
plan match {
case s: SortMergeJoinExec if isApplicable(s) =>
val leftBucket = getBucketSpec(s.left)
val rightBucket = getBucketSpec(s.right)
if (leftBucket.isDefined && rightBucket.isDefined) {
Some(s, leftBucket.get.numBuckets, rightBucket.get.numBuckets)
} else {
None
}
case _ => None
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -208,6 +208,7 @@ object FileSourceStrategy extends Strategy with Logging {
outputSchema,
partitionKeyFilters.toSeq,
bucketSet,
None,
dataFilters,
table.map(_.identifier))

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -345,7 +345,7 @@ class DataFrameJoinSuite extends QueryTest
}
assert(broadcastExchanges.size == 1)
val tables = broadcastExchanges.head.collect {
case FileSourceScanExec(_, _, _, _, _, _, Some(tableIdent)) => tableIdent
case FileSourceScanExec(_, _, _, _, _, _, _, Some(tableIdent)) => tableIdent
}
assert(tables.size == 1)
assert(tables.head === TableIdentifier(table1Name, Some(dbName)))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1313,7 +1313,7 @@ class SubquerySuite extends QueryTest with SharedSparkSession with AdaptiveSpark
// need to execute the query before we can examine fs.inputRDDs()
assert(stripAQEPlan(df.queryExecution.executedPlan) match {
case WholeStageCodegenExec(ColumnarToRowExec(InputAdapter(
fs @ FileSourceScanExec(_, _, _, partitionFilters, _, _, _)))) =>
fs @ FileSourceScanExec(_, _, _, partitionFilters, _, _, _, _)))) =>
partitionFilters.exists(ExecSubqueryExpression.hasSubquery) &&
fs.inputRDDs().forall(
_.asInstanceOf[FileScanRDD].filePartitions.forall(
Expand Down
Loading