diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala index c7c3e1672f034..ae7637951f6d1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala @@ -206,7 +206,7 @@ trait HashJoin extends BaseJoinExec { existenceJoin(streamedIter, hashed) case x => throw new IllegalArgumentException( - s"BroadcastHashJoin should not take $x as the JoinType") + s"HashJoin should not take $x as the JoinType") } val resultProj = createResultProjection diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashJoinExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashJoinExec.scala index 2b7cd65e7d96f..05991bbb9862b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashJoinExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashJoinExec.scala @@ -40,15 +40,14 @@ case class ShuffledHashJoinExec( condition: Option[Expression], left: SparkPlan, right: SparkPlan) - extends HashJoin { + extends HashJoin with ShuffledJoin { override lazy val metrics = Map( "numOutputRows" -> SQLMetrics.createMetric(sparkContext, "number of output rows"), "buildDataSize" -> SQLMetrics.createSizeMetric(sparkContext, "data size of build side"), "buildTime" -> SQLMetrics.createTimingMetric(sparkContext, "time to build hash map")) - override def requiredChildDistribution: Seq[Distribution] = - HashClusteredDistribution(leftKeys) :: HashClusteredDistribution(rightKeys) :: Nil + override def outputPartitioning: Partitioning = super[ShuffledJoin].outputPartitioning private def buildHashedRelation(iter: Iterator[InternalRow]): HashedRelation = { val buildDataSize = longMetric("buildDataSize") diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledJoin.scala new file mode 100644 index 0000000000000..7035ddc35be9c --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledJoin.scala @@ -0,0 +1,43 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.joins + +import org.apache.spark.sql.catalyst.plans.{FullOuter, InnerLike, LeftExistence, LeftOuter, RightOuter} +import org.apache.spark.sql.catalyst.plans.physical.{Distribution, HashClusteredDistribution, Partitioning, PartitioningCollection, UnknownPartitioning} + +/** + * Holds common logic for join operators by shuffling two child relations + * using the join keys. + */ +trait ShuffledJoin extends BaseJoinExec { + override def requiredChildDistribution: Seq[Distribution] = { + HashClusteredDistribution(leftKeys) :: HashClusteredDistribution(rightKeys) :: Nil + } + + override def outputPartitioning: Partitioning = joinType match { + case _: InnerLike => + PartitioningCollection(Seq(left.outputPartitioning, right.outputPartitioning)) + case LeftOuter => left.outputPartitioning + case RightOuter => right.outputPartitioning + case FullOuter => UnknownPartitioning(left.outputPartitioning.numPartitions) + case LeftExistence(_) => left.outputPartitioning + case x => + throw new IllegalArgumentException( + s"ShuffledJoin should not take $x as the JoinType") + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala index 2c57956de5bca..b9f6684447dd8 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala @@ -41,7 +41,7 @@ case class SortMergeJoinExec( condition: Option[Expression], left: SparkPlan, right: SparkPlan, - isSkewJoin: Boolean = false) extends BaseJoinExec with CodegenSupport { + isSkewJoin: Boolean = false) extends ShuffledJoin with CodegenSupport { override lazy val metrics = Map( "numOutputRows" -> SQLMetrics.createMetric(sparkContext, "number of output rows")) @@ -72,26 +72,13 @@ case class SortMergeJoinExec( } } - override def outputPartitioning: Partitioning = joinType match { - case _: InnerLike => - PartitioningCollection(Seq(left.outputPartitioning, right.outputPartitioning)) - // For left and right outer joins, the output is partitioned by the streamed input's join keys. - case LeftOuter => left.outputPartitioning - case RightOuter => right.outputPartitioning - case FullOuter => UnknownPartitioning(left.outputPartitioning.numPartitions) - case LeftExistence(_) => left.outputPartitioning - case x => - throw new IllegalArgumentException( - s"${getClass.getSimpleName} should not take $x as the JoinType") - } - override def requiredChildDistribution: Seq[Distribution] = { if (isSkewJoin) { // We re-arrange the shuffle partitions to deal with skew join, and the new children // partitioning doesn't satisfy `HashClusteredDistribution`. UnspecifiedDistribution :: UnspecifiedDistribution :: Nil } else { - HashClusteredDistribution(leftKeys) :: HashClusteredDistribution(rightKeys) :: Nil + super.requiredChildDistribution } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala index f24da6df67ca0..b4f626270cfc9 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala @@ -31,6 +31,7 @@ import org.apache.spark.sql.catalyst.expressions.{Ascending, GenericRow, SortOrd import org.apache.spark.sql.catalyst.plans.logical.Filter import org.apache.spark.sql.execution.{BinaryExecNode, FilterExec, SortExec, SparkPlan} import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper +import org.apache.spark.sql.execution.exchange.ShuffleExchangeExec import org.apache.spark.sql.execution.joins._ import org.apache.spark.sql.execution.python.BatchEvalPythonExec import org.apache.spark.sql.internal.SQLConf @@ -1086,4 +1087,21 @@ class JoinSuite extends QueryTest with SharedSparkSession with AdaptiveSparkPlan assert(df2.join(df1, "id").collect().isEmpty) } } + + test("SPARK-32330: Preserve shuffled hash join build side partitioning") { + withSQLConf( + SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "50", + SQLConf.SHUFFLE_PARTITIONS.key -> "2", + SQLConf.PREFER_SORTMERGEJOIN.key -> "false") { + val df1 = spark.range(10).select($"id".as("k1")) + val df2 = spark.range(30).select($"id".as("k2")) + Seq("inner", "cross").foreach(joinType => { + val plan = df1.join(df2, $"k1" === $"k2", joinType).groupBy($"k1").count() + .queryExecution.executedPlan + assert(plan.collect { case _: ShuffledHashJoinExec => true }.size === 1) + // No extra shuffle before aggregate + assert(plan.collect { case _: ShuffleExchangeExec => true }.size === 2) + }) + } + } }