Skip to content

Commit

Permalink
fix left semi join with equi key and non-equi condition
Browse files Browse the repository at this point in the history
  • Loading branch information
adrian-wang committed Apr 23, 2015
1 parent 04525c0 commit d485fe8
Show file tree
Hide file tree
Showing 4 changed files with 55 additions and 21 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -37,14 +37,12 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
case ExtractEquiJoinKeys(LeftSemi, leftKeys, rightKeys, condition, left, right)
if sqlContext.conf.autoBroadcastJoinThreshold > 0 &&
right.statistics.sizeInBytes <= sqlContext.conf.autoBroadcastJoinThreshold =>
val semiJoin = joins.BroadcastLeftSemiJoinHash(
leftKeys, rightKeys, planLater(left), planLater(right))
condition.map(Filter(_, semiJoin)).getOrElse(semiJoin) :: Nil
joins.BroadcastLeftSemiJoinHash(
leftKeys, rightKeys, planLater(left), planLater(right), condition) :: Nil
// Find left semi joins where at least some predicates can be evaluated by matching join keys
case ExtractEquiJoinKeys(LeftSemi, leftKeys, rightKeys, condition, left, right) =>
val semiJoin = joins.LeftSemiJoinHash(
leftKeys, rightKeys, planLater(left), planLater(right))
condition.map(Filter(_, semiJoin)).getOrElse(semiJoin) :: Nil
joins.LeftSemiJoinHash(
leftKeys, rightKeys, planLater(left), planLater(right), condition) :: Nil
// no predicate can be evaluated by matching hash keys
case logical.Join(left, right, LeftSemi, condition) =>
joins.LeftSemiJoinBNL(planLater(left), planLater(right), condition) :: Nil
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ package org.apache.spark.sql.execution.joins

import org.apache.spark.annotation.DeveloperApi
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.expressions.{Attribute, Expression, Row}
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.execution.{BinaryNode, SparkPlan}

/**
Expand All @@ -32,35 +32,50 @@ case class BroadcastLeftSemiJoinHash(
leftKeys: Seq[Expression],
rightKeys: Seq[Expression],
left: SparkPlan,
right: SparkPlan) extends BinaryNode with HashJoin {
right: SparkPlan,
condition: Option[Expression]) extends BinaryNode with HashJoin {

override val buildSide: BuildSide = BuildRight

override def output: Seq[Attribute] = left.output

@transient private lazy val boundCondition =
InterpretedPredicate(
condition
.map(c => BindReferences.bindReference(c, left.output ++ right.output))
.getOrElse(Literal(true)))

override def execute(): RDD[Row] = {
val buildIter= buildPlan.execute().map(_.copy()).collect().toIterator
val hashSet = new java.util.HashSet[Row]()
val hashMap = new java.util.HashMap[Row, scala.collection.mutable.Set[Row]]()
var currentRow: Row = null

// Create a Hash set of buildKeys
while (buildIter.hasNext) {
currentRow = buildIter.next()
val rowKey = buildSideKeyGenerator(currentRow)
if (!rowKey.anyNull) {
val keyExists = hashSet.contains(rowKey)
if (!keyExists) {
hashSet.add(rowKey)
if (!hashMap.containsKey(rowKey)) {
val rowSet = scala.collection.mutable.Set[Row]()
rowSet.add(currentRow.copy())
hashMap.put(rowKey, rowSet)
} else {
hashMap.get(rowKey).add(currentRow.copy())
}
}
}

val broadcastedRelation = sparkContext.broadcast(hashSet)
val broadcastedRelation = sparkContext.broadcast(hashMap)

streamedPlan.execute().mapPartitions { streamIter =>
val joinKeys = streamSideKeyGenerator()
val joinedRow = new JoinedRow
streamIter.filter(current => {
!joinKeys(current).anyNull && broadcastedRelation.value.contains(joinKeys.currentValue)
!joinKeys(current).anyNull &&
broadcastedRelation.value.containsKey(joinKeys.currentValue) &&
broadcastedRelation.value.get(joinKeys.currentValue).exists {
build: Row => boundCondition(joinedRow(current, build))
}
})
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ package org.apache.spark.sql.execution.joins

import org.apache.spark.annotation.DeveloperApi
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.expressions.{Attribute, Expression, Row}
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.physical.ClusteredDistribution
import org.apache.spark.sql.execution.{BinaryNode, SparkPlan}

Expand All @@ -33,7 +33,8 @@ case class LeftSemiJoinHash(
leftKeys: Seq[Expression],
rightKeys: Seq[Expression],
left: SparkPlan,
right: SparkPlan) extends BinaryNode with HashJoin {
right: SparkPlan,
condition: Option[Expression]) extends BinaryNode with HashJoin {

override val buildSide: BuildSide = BuildRight

Expand All @@ -42,26 +43,39 @@ case class LeftSemiJoinHash(

override def output: Seq[Attribute] = left.output

@transient private lazy val boundCondition =
InterpretedPredicate(
condition
.map(c => BindReferences.bindReference(c, left.output ++ right.output))
.getOrElse(Literal(true)))

override def execute(): RDD[Row] = {
buildPlan.execute().zipPartitions(streamedPlan.execute()) { (buildIter, streamIter) =>
val hashSet = new java.util.HashSet[Row]()
val hashMap = new java.util.HashMap[Row, scala.collection.mutable.Set[Row]]()
var currentRow: Row = null

// Create a Hash set of buildKeys
while (buildIter.hasNext) {
currentRow = buildIter.next()
val rowKey = buildSideKeyGenerator(currentRow)
if (!rowKey.anyNull) {
val keyExists = hashSet.contains(rowKey)
if (!keyExists) {
hashSet.add(rowKey)
if (!hashMap.containsKey(rowKey)) {
val rowSet = scala.collection.mutable.Set[Row]()
rowSet.add(currentRow.copy())
hashMap.put(rowKey, rowSet)
} else {
hashMap.get(rowKey).add(currentRow.copy())
}
}
}

val joinKeys = streamSideKeyGenerator()
val joinedRow = new JoinedRow
streamIter.filter(current => {
!joinKeys(current).anyNull && hashSet.contains(joinKeys.currentValue)
!joinKeys(current).anyNull && hashMap.containsKey(joinKeys.currentValue) &&
hashMap.get(joinKeys.currentValue).exists {
build: Row => boundCondition(joinedRow(current, build))
}
})
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -298,6 +298,13 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll {
)
}

test("left semi greater than predicate and equal operator") {
checkAnswer(
sql("SELECT * FROM testData2 x LEFT SEMI JOIN testData2 y ON x.b = y.b and x.a >= y.a + 2"),
Seq(Row(3,1), Row(3,2))
)
}

test("index into array of arrays") {
checkAnswer(
sql(
Expand Down

0 comments on commit d485fe8

Please sign in to comment.