diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala index 0e3a8a6bd30a8..f09a9fa91bf05 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala @@ -119,6 +119,11 @@ object HashFilteredJoin extends Logging with PredicateHelper { case FilteredOperation(predicates, join @ Join(left, right, Inner, condition)) => logger.debug(s"Considering hash inner join on: ${predicates ++ condition}") splitPredicates(predicates ++ condition, join) + // All predicates can be evaluated for left semi join (those that are in the WHERE + // clause can only from left table, so they can all be pushed down.) + case FilteredOperation(predicates, join @ Join(left, right, LeftSemi, condition)) => + logger.debug(s"Considering hash left semi join on: ${predicates ++ condition}") + splitPredicates(predicates ++ condition, join) case join @ Join(left, right, joinType, condition) => logger.debug(s"Considering hash join on: $condition") splitPredicates(condition.toSeq, join) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index 1ed2ea709d361..21a41f266c1ab 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -30,11 +30,11 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { object LeftSemiJoin extends Strategy with PredicateHelper { def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match { - // Find leftsemi joins where at least some predicates can be evaluated by matching hash keys - // using the HashFilteredJoin pattern. + // Find left semi joins where at least some predicates can be evaluated by matching hash + // keys using the HashFilteredJoin pattern. case HashFilteredJoin(LeftSemi, leftKeys, rightKeys, condition, left, right) => val semiJoin = execution.LeftSemiJoinHash( - leftKeys, rightKeys, BuildRight, planLater(left), planLater(right)) + leftKeys, rightKeys, planLater(left), planLater(right)) condition.map(Filter(_, semiJoin)).getOrElse(semiJoin) :: Nil // no predicate can be evaluated by matching hash keys case logical.Join(left, right, LeftSemi, condition) => diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins.scala index a503875418674..88ff3d49a79b3 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins.scala @@ -142,29 +142,23 @@ case class HashJoin( /** * :: DeveloperApi :: + * Build the right table's join keys into a HashSet, and iteratively go through the left + * table, to find the if join keys are in the Hash set. */ @DeveloperApi case class LeftSemiJoinHash( - leftKeys: Seq[Expression], - rightKeys: Seq[Expression], - buildSide: BuildSide, - left: SparkPlan, - right: SparkPlan) extends BinaryNode { + leftKeys: Seq[Expression], + rightKeys: Seq[Expression], + left: SparkPlan, + right: SparkPlan) extends BinaryNode { override def outputPartitioning: Partitioning = left.outputPartitioning override def requiredChildDistribution = ClusteredDistribution(leftKeys) :: ClusteredDistribution(rightKeys) :: Nil - val (buildPlan, streamedPlan) = buildSide match { - case BuildLeft => (left, right) - case BuildRight => (right, left) - } - - val (buildKeys, streamedKeys) = buildSide match { - case BuildLeft => (leftKeys, rightKeys) - case BuildRight => (rightKeys, leftKeys) - } + val (buildPlan, streamedPlan) = (right, left) + val (buildKeys, streamedKeys) = (rightKeys, leftKeys) def output = left.output @@ -175,24 +169,18 @@ case class LeftSemiJoinHash( def execute() = { buildPlan.execute().zipPartitions(streamedPlan.execute()) { (buildIter, streamIter) => - // TODO: Use Spark's HashMap implementation. - val hashTable = new java.util.HashMap[Row, ArrayBuffer[Row]]() + val hashTable = new java.util.HashSet[Row]() var currentRow: Row = null - // Create a mapping of buildKeys -> rows + // Create a Hash set of buildKeys while (buildIter.hasNext) { currentRow = buildIter.next() val rowKey = buildSideKeyGenerator(currentRow) if(!rowKey.anyNull) { - val existingMatchList = hashTable.get(rowKey) - val matchList = if (existingMatchList == null) { - val newMatchList = new ArrayBuffer[Row]() - hashTable.put(rowKey, newMatchList) - newMatchList - } else { - existingMatchList + val keyExists = hashTable.contains(rowKey) + if (!keyExists) { + hashTable.add(rowKey) } - matchList += currentRow.copy() } } @@ -220,7 +208,7 @@ case class LeftSemiJoinHash( while (!currentHashMatched && streamIter.hasNext) { currentStreamedRow = streamIter.next() if (!joinKeys(currentStreamedRow).anyNull) { - currentHashMatched = true + currentHashMatched = hashTable.contains(joinKeys.currentValue) } } currentHashMatched @@ -232,6 +220,8 @@ case class LeftSemiJoinHash( /** * :: DeveloperApi :: + * Using BroadcastNestedLoopJoin to calculate left semi join result when there's no join keys + * for hash join. */ @DeveloperApi case class LeftSemiJoinBNL( @@ -261,7 +251,7 @@ case class LeftSemiJoinBNL( def execute() = { val broadcastedRelation = sc.broadcast(broadcast.execute().map(_.copy()).collect().toIndexedSeq) - val streamedPlusMatches = streamed.execute().mapPartitions { streamedIter => + streamed.execute().mapPartitions { streamedIter => val joinedRow = new JoinedRow streamedIter.filter(streamedRow => { @@ -269,7 +259,6 @@ case class LeftSemiJoinBNL( var matched = false while (i < broadcastedRelation.value.size && !matched) { - // TODO: One bitset per partition instead of per row. val broadcastedRow = broadcastedRelation.value(i) if (boundCondition(joinedRow(streamedRow, broadcastedRow))) { matched = true @@ -277,10 +266,8 @@ case class LeftSemiJoinBNL( i += 1 } matched - }).map(streamedRow => (streamedRow, null)) + }) } - - streamedPlusMatches.map(_._1) } }