diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastLeftSemiJoinHash.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastLeftSemiJoinHash.scala index 13837150ac2d1..2750f58b005ac 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastLeftSemiJoinHash.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastLeftSemiJoinHash.scala @@ -34,59 +34,26 @@ case class BroadcastLeftSemiJoinHash( rightKeys: Seq[Expression], left: SparkPlan, right: SparkPlan, - condition: Option[Expression]) extends BinaryNode with HashJoin { - - override val buildSide: BuildSide = BuildRight - - override def output: Seq[Attribute] = left.output - - @transient private lazy val boundCondition = - newPredicate(condition.getOrElse(Literal(true)), left.output ++ right.output) + condition: Option[Expression]) extends BinaryNode with HashSemiJoin { protected override def doExecute(): RDD[InternalRow] = { - val buildIter = buildPlan.execute().map(_.copy()).collect().toIterator - - condition match { - case None => - val hashSet = new java.util.HashSet[InternalRow]() - var currentRow: InternalRow = null - - // Create a Hash set of buildKeys - while (buildIter.hasNext) { - currentRow = buildIter.next() - val rowKey = buildSideKeyGenerator(currentRow) - if (!rowKey.anyNull) { - val keyExists = hashSet.contains(rowKey) - if (!keyExists) { - // rowKey may be not serializable (from codegen) - hashSet.add(rowKey.copy()) - } - } - } - - val broadcastedRelation = sparkContext.broadcast(hashSet) - - streamedPlan.execute().mapPartitions { streamIter => - val joinKeys = streamSideKeyGenerator() - streamIter.filter(current => { - !joinKeys(current).anyNull && broadcastedRelation.value.contains(joinKeys.currentValue) - }) - } - case _ => - val hashRelation = HashedRelation(buildIter, buildSideKeyGenerator) - val broadcastedRelation = sparkContext.broadcast(hashRelation) - - streamedPlan.execute().mapPartitions { streamIter => - val joinKeys = streamSideKeyGenerator() - val joinedRow = new JoinedRow - - streamIter.filter(current => { - val rowBuffer = broadcastedRelation.value.get(joinKeys.currentValue) - !joinKeys(current).anyNull && rowBuffer != null && rowBuffer.exists { - (build: InternalRow) => boundCondition(joinedRow(current, build)) - } - }) - } + val buildIter = right.execute().map(_.copy()).collect().toIterator + + if (condition.isEmpty) { + // rowKey may be not serializable (from codegen) + val hashSet = buildKeyHashSet(buildIter, copy = true) + val broadcastedRelation = sparkContext.broadcast(hashSet) + + left.execute().mapPartitions { streamIter => + hashSemiJoin(streamIter, broadcastedRelation.value) + } + } else { + val hashRelation = HashedRelation(buildIter, rightKeyGenerator) + val broadcastedRelation = sparkContext.broadcast(hashRelation) + + left.execute().mapPartitions { streamIter => + hashSemiJoin(streamIter, broadcastedRelation.value) + } } } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashSemiJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashSemiJoin.scala new file mode 100644 index 0000000000000..1b983bc3a90f9 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashSemiJoin.scala @@ -0,0 +1,91 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.joins + +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.execution.SparkPlan + + +trait HashSemiJoin { + self: SparkPlan => + val leftKeys: Seq[Expression] + val rightKeys: Seq[Expression] + val left: SparkPlan + val right: SparkPlan + val condition: Option[Expression] + + override def output: Seq[Attribute] = left.output + + @transient protected lazy val rightKeyGenerator: Projection = + newProjection(rightKeys, right.output) + + @transient protected lazy val leftKeyGenerator: () => MutableProjection = + newMutableProjection(leftKeys, left.output) + + @transient private lazy val boundCondition = + newPredicate(condition.getOrElse(Literal(true)), left.output ++ right.output) + + protected def buildKeyHashSet( + buildIter: Iterator[InternalRow], + copy: Boolean): java.util.Set[InternalRow] = { + val hashSet = new java.util.HashSet[InternalRow]() + var currentRow: InternalRow = null + + // Create a Hash set of buildKeys + while (buildIter.hasNext) { + currentRow = buildIter.next() + val rowKey = rightKeyGenerator(currentRow) + if (!rowKey.anyNull) { + val keyExists = hashSet.contains(rowKey) + if (!keyExists) { + if (copy) { + hashSet.add(rowKey.copy()) + } else { + // rowKey may be not serializable (from codegen) + hashSet.add(rowKey) + } + } + } + } + hashSet + } + + protected def hashSemiJoin( + streamIter: Iterator[InternalRow], + hashedRelation: HashedRelation): Iterator[InternalRow] = { + val joinKeys = leftKeyGenerator() + val joinedRow = new JoinedRow + streamIter.filter(current => { + lazy val rowBuffer = hashedRelation.get(joinKeys.currentValue) + !joinKeys(current).anyNull && rowBuffer != null && rowBuffer.exists { + (build: InternalRow) => boundCondition(joinedRow(current, build)) + } + }) + } + + protected def hashSemiJoin( + streamIter: Iterator[InternalRow], + hashSet: java.util.Set[InternalRow]): Iterator[InternalRow] = { + val joinKeys = leftKeyGenerator() + val joinedRow = new JoinedRow + streamIter.filter(current => { + !joinKeys(current.copy()).anyNull && hashSet.contains(joinKeys.currentValue) + }) + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/LeftSemiJoinHash.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/LeftSemiJoinHash.scala index 602a6c88a0d58..9eaac817d9268 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/LeftSemiJoinHash.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/LeftSemiJoinHash.scala @@ -35,52 +35,19 @@ case class LeftSemiJoinHash( rightKeys: Seq[Expression], left: SparkPlan, right: SparkPlan, - condition: Option[Expression]) extends BinaryNode with HashJoin { - - override val buildSide: BuildSide = BuildRight + condition: Option[Expression]) extends BinaryNode with HashSemiJoin { override def requiredChildDistribution: Seq[ClusteredDistribution] = ClusteredDistribution(leftKeys) :: ClusteredDistribution(rightKeys) :: Nil - override def output: Seq[Attribute] = left.output - - @transient private lazy val boundCondition = - newPredicate(condition.getOrElse(Literal(true)), left.output ++ right.output) - protected override def doExecute(): RDD[InternalRow] = { - buildPlan.execute().zipPartitions(streamedPlan.execute()) { (buildIter, streamIter) => - val joinKeys = streamSideKeyGenerator() - val joinedRow = new JoinedRow - - condition match { - case None => - val hashSet = new java.util.HashSet[InternalRow]() - var currentRow: InternalRow = null - - // Create a Hash set of buildKeys - while (buildIter.hasNext) { - currentRow = buildIter.next() - val rowKey = buildSideKeyGenerator(currentRow) - if (!rowKey.anyNull) { - val keyExists = hashSet.contains(rowKey) - if (!keyExists) { - hashSet.add(rowKey) - } - } - } - - val joinKeys = streamSideKeyGenerator() - streamIter.filter(current => { - !joinKeys(current).anyNull && hashSet.contains(joinKeys.currentValue) - }) - case _ => - val hashRelation = HashedRelation(buildIter, buildSideKeyGenerator) - streamIter.filter(current => { - val rowBuffer = hashRelation.get(joinKeys.currentValue) - !joinKeys(current).anyNull && rowBuffer != null && rowBuffer.exists { - (build: InternalRow) => boundCondition(joinedRow(current, build)) - } - }) + right.execute().zipPartitions(left.execute()) { (buildIter, streamIter) => + if (condition.isEmpty) { + val hashSet = buildKeyHashSet(buildIter, copy = false) + hashSemiJoin(streamIter, hashSet) + } else { + val hashRelation = HashedRelation(buildIter, rightKeyGenerator) + hashSemiJoin(streamIter, hashRelation) } } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/SemiJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/SemiJoinSuite.scala new file mode 100644 index 0000000000000..927e85a7db3dc --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/SemiJoinSuite.scala @@ -0,0 +1,74 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.joins + +import org.apache.spark.sql.Row +import org.apache.spark.sql.catalyst.dsl.expressions._ +import org.apache.spark.sql.catalyst.expressions.{LessThan, Expression} +import org.apache.spark.sql.execution.{SparkPlan, SparkPlanTest} + + +class SemiJoinSuite extends SparkPlanTest{ + val left = Seq( + (1, 2.0), + (1, 2.0), + (2, 1.0), + (2, 1.0), + (3, 3.0) + ).toDF("a", "b") + + val right = Seq( + (2, 3.0), + (2, 3.0), + (3, 2.0), + (4, 1.0) + ).toDF("c", "d") + + val leftKeys: List[Expression] = 'a :: Nil + val rightKeys: List[Expression] = 'c :: Nil + val condition = Some(LessThan('b, 'd)) + + test("left semi join hash") { + checkAnswer2(left, right, (left: SparkPlan, right: SparkPlan) => + LeftSemiJoinHash(leftKeys, rightKeys, left, right, condition), + Seq( + (2, 1.0), + (2, 1.0) + ).map(Row.fromTuple)) + } + + test("left semi join BNL") { + checkAnswer2(left, right, (left: SparkPlan, right: SparkPlan) => + LeftSemiJoinBNL(left, right, condition), + Seq( + (1, 2.0), + (1, 2.0), + (2, 1.0), + (2, 1.0) + ).map(Row.fromTuple)) + } + + test("broadcast left semi join hash") { + checkAnswer2(left, right, (left: SparkPlan, right: SparkPlan) => + BroadcastLeftSemiJoinHash(leftKeys, rightKeys, left, right, condition), + Seq( + (2, 1.0), + (2, 1.0) + ).map(Row.fromTuple)) + } +}