Skip to content

Commit

Permalink
[SPARK-1495][SQL]add support for left semi join
Browse files Browse the repository at this point in the history
Just submit another solution for apache#395

Author: Daoyuan <[email protected]>
Author: Michael Armbrust <[email protected]>
Author: Daoyuan Wang <[email protected]>

Closes apache#837 from adrian-wang/left-semi-join-support and squashes the following commits:

d39cd12 [Daoyuan Wang] Merge pull request #1 from marmbrus/pr/837
6713c09 [Michael Armbrust] Better debugging for failed query tests.
035b73e [Michael Armbrust] Add test for left semi that can't be done with a hash join.
5ec6fa4 [Michael Armbrust] Add left semi to SQL Parser.
4c726e5 [Daoyuan] improvement according to Michael
8d4a121 [Daoyuan] add golden files for leftsemijoin
83a3c8a [Daoyuan] scala style fix
14cff80 [Daoyuan] add support for left semi join

(cherry picked from commit 0cf6002)
Signed-off-by: Michael Armbrust <[email protected]>
  • Loading branch information
adrian-wang authored and marmbrus committed Jun 9, 2014
1 parent a5848d3 commit 73cd1f8
Show file tree
Hide file tree
Showing 37 changed files with 216 additions and 3 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,7 @@ class SqlParser extends StandardTokenParsers with PackratParsers {
protected val OUTER = Keyword("OUTER")
protected val RIGHT = Keyword("RIGHT")
protected val SELECT = Keyword("SELECT")
protected val SEMI = Keyword("SEMI")
protected val STRING = Keyword("STRING")
protected val SUM = Keyword("SUM")
protected val TRUE = Keyword("TRUE")
Expand Down Expand Up @@ -241,6 +242,7 @@ class SqlParser extends StandardTokenParsers with PackratParsers {

protected lazy val joinType: Parser[JoinType] =
INNER ^^^ Inner |
LEFT ~ SEMI ^^^ LeftSemi |
LEFT ~ opt(OUTER) ^^^ LeftOuter |
RIGHT ~ opt(OUTER) ^^^ RightOuter |
FULL ~ opt(OUTER) ^^^ FullOuter
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,11 @@ object HashFilteredJoin extends Logging with PredicateHelper {
case FilteredOperation(predicates, join @ Join(left, right, Inner, condition)) =>
logger.debug(s"Considering hash inner join on: ${predicates ++ condition}")
splitPredicates(predicates ++ condition, join)
// All predicates can be evaluated for left semi join (those that are in the WHERE
// clause can only from left table, so they can all be pushed down.)
case FilteredOperation(predicates, join @ Join(left, right, LeftSemi, condition)) =>
logger.debug(s"Considering hash left semi join on: ${predicates ++ condition}")
splitPredicates(predicates ++ condition, join)
case join @ Join(left, right, joinType, condition) =>
logger.debug(s"Considering hash join on: $condition")
splitPredicates(condition.toSeq, join)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,3 +22,4 @@ case object Inner extends JoinType
case object LeftOuter extends JoinType
case object RightOuter extends JoinType
case object FullOuter extends JoinType
case object LeftSemi extends JoinType
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
package org.apache.spark.sql.catalyst.plans.logical

import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.JoinType
import org.apache.spark.sql.catalyst.plans.{LeftSemi, JoinType}
import org.apache.spark.sql.catalyst.types._

case class Project(projectList: Seq[NamedExpression], child: LogicalPlan) extends UnaryNode {
Expand Down Expand Up @@ -81,7 +81,12 @@ case class Join(
condition: Option[Expression]) extends BinaryNode {

def references = condition.map(_.references).getOrElse(Set.empty)
def output = left.output ++ right.output
def output = joinType match {
case LeftSemi =>
left.output
case _ =>
left.output ++ right.output
}
}

case class InsertIntoTable(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -193,6 +193,7 @@ class SQLContext(@transient val sparkContext: SparkContext)
val strategies: Seq[Strategy] =
TakeOrdered ::
PartialAggregation ::
LeftSemiJoin ::
HashJoin ::
ParquetOperations ::
BasicOperators ::
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,22 @@ import org.apache.spark.sql.parquet._
private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
self: SQLContext#SparkPlanner =>

object LeftSemiJoin extends Strategy with PredicateHelper {
def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match {
// Find left semi joins where at least some predicates can be evaluated by matching hash
// keys using the HashFilteredJoin pattern.
case HashFilteredJoin(LeftSemi, leftKeys, rightKeys, condition, left, right) =>
val semiJoin = execution.LeftSemiJoinHash(
leftKeys, rightKeys, planLater(left), planLater(right))
condition.map(Filter(_, semiJoin)).getOrElse(semiJoin) :: Nil
// no predicate can be evaluated by matching hash keys
case logical.Join(left, right, LeftSemi, condition) =>
execution.LeftSemiJoinBNL(
planLater(left), planLater(right), condition)(sparkContext) :: Nil
case _ => Nil
}
}

object HashJoin extends Strategy with PredicateHelper {
def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match {
// Find inner joins where at least some predicates can be evaluated by matching hash keys
Expand Down
131 changes: 131 additions & 0 deletions sql/core/src/main/scala/org/apache/spark/sql/execution/joins.scala
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,137 @@ case class HashJoin(
}
}

/**
* :: DeveloperApi ::
* Build the right table's join keys into a HashSet, and iteratively go through the left
* table, to find the if join keys are in the Hash set.
*/
@DeveloperApi
case class LeftSemiJoinHash(
leftKeys: Seq[Expression],
rightKeys: Seq[Expression],
left: SparkPlan,
right: SparkPlan) extends BinaryNode {

override def outputPartitioning: Partitioning = left.outputPartitioning

override def requiredChildDistribution =
ClusteredDistribution(leftKeys) :: ClusteredDistribution(rightKeys) :: Nil

val (buildPlan, streamedPlan) = (right, left)
val (buildKeys, streamedKeys) = (rightKeys, leftKeys)

def output = left.output

@transient lazy val buildSideKeyGenerator = new Projection(buildKeys, buildPlan.output)
@transient lazy val streamSideKeyGenerator =
() => new MutableProjection(streamedKeys, streamedPlan.output)

def execute() = {

buildPlan.execute().zipPartitions(streamedPlan.execute()) { (buildIter, streamIter) =>
val hashTable = new java.util.HashSet[Row]()
var currentRow: Row = null

// Create a Hash set of buildKeys
while (buildIter.hasNext) {
currentRow = buildIter.next()
val rowKey = buildSideKeyGenerator(currentRow)
if(!rowKey.anyNull) {
val keyExists = hashTable.contains(rowKey)
if (!keyExists) {
hashTable.add(rowKey)
}
}
}

new Iterator[Row] {
private[this] var currentStreamedRow: Row = _
private[this] var currentHashMatched: Boolean = false

private[this] val joinKeys = streamSideKeyGenerator()

override final def hasNext: Boolean =
streamIter.hasNext && fetchNext()

override final def next() = {
currentStreamedRow
}

/**
* Searches the streamed iterator for the next row that has at least one match in hashtable.
*
* @return true if the search is successful, and false the streamed iterator runs out of
* tuples.
*/
private final def fetchNext(): Boolean = {
currentHashMatched = false
while (!currentHashMatched && streamIter.hasNext) {
currentStreamedRow = streamIter.next()
if (!joinKeys(currentStreamedRow).anyNull) {
currentHashMatched = hashTable.contains(joinKeys.currentValue)
}
}
currentHashMatched
}
}
}
}
}

/**
* :: DeveloperApi ::
* Using BroadcastNestedLoopJoin to calculate left semi join result when there's no join keys
* for hash join.
*/
@DeveloperApi
case class LeftSemiJoinBNL(
streamed: SparkPlan, broadcast: SparkPlan, condition: Option[Expression])
(@transient sc: SparkContext)
extends BinaryNode {
// TODO: Override requiredChildDistribution.

override def outputPartitioning: Partitioning = streamed.outputPartitioning

override def otherCopyArgs = sc :: Nil

def output = left.output

/** The Streamed Relation */
def left = streamed
/** The Broadcast relation */
def right = broadcast

@transient lazy val boundCondition =
InterpretedPredicate(
condition
.map(c => BindReferences.bindReference(c, left.output ++ right.output))
.getOrElse(Literal(true)))


def execute() = {
val broadcastedRelation = sc.broadcast(broadcast.execute().map(_.copy()).collect().toIndexedSeq)

streamed.execute().mapPartitions { streamedIter =>
val joinedRow = new JoinedRow

streamedIter.filter(streamedRow => {
var i = 0
var matched = false

while (i < broadcastedRelation.value.size && !matched) {
val broadcastedRow = broadcastedRelation.value(i)
if (boundCondition(joinedRow(streamedRow, broadcastedRow))) {
matched = true
}
i += 1
}
matched
})
}
}
}

/**
* :: DeveloperApi ::
*/
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ class QueryTest extends FunSuite {
fail(
s"""
|Exception thrown while executing query:
|${rdd.logicalPlan}
|${rdd.queryExecution}
|== Exception ==
|$e
""".stripMargin)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,13 @@ class SQLQuerySuite extends QueryTest {
arrayData.map(d => (d.data, d.data(0), d.data(0) + d.data(1), d.data(1))).collect().toSeq)
}

test("left semi greater than predicate") {
checkAnswer(
sql("SELECT * FROM testData2 x LEFT SEMI JOIN testData2 y ON x.a >= y.a + 2"),
Seq((3,1), (3,2))
)
}

test("index into array of arrays") {
checkAnswer(
sql(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -224,6 +224,7 @@ class HiveContext(sc: SparkContext) extends SQLContext(sc) {
DataSinks,
Scripts,
PartialAggregation,
LeftSemiJoin,
HashJoin,
BasicOperators,
CartesianProduct,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -685,6 +685,7 @@ private[hive] object HiveQl {
case "TOK_RIGHTOUTERJOIN" => RightOuter
case "TOK_LEFTOUTERJOIN" => LeftOuter
case "TOK_FULLOUTERJOIN" => FullOuter
case "TOK_LEFTSEMIJOIN" => LeftSemi
}
assert(other.size <= 1, "Unhandled join clauses.")
Join(nodeToRelation(relation1),
Expand Down
Empty file.
Empty file.
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
Hank 2
Hank 2
Joe 2
Joe 2
Empty file.
Empty file.
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
0
Empty file.
Empty file.
Empty file.
Empty file.
Empty file.
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
Hank 2
Hank 2
Joe 2
Joe 2
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
2 Tie
2 Tie
Empty file.
Empty file.
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
0
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
1
1
Empty file.
Empty file.
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
1
1
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
0
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
0
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
0
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
1
1
Original file line number Diff line number Diff line change
Expand Up @@ -480,6 +480,8 @@ class HiveCompatibilitySuite extends HiveQueryFileTest with BeforeAndAfter {
"lateral_view",
"lateral_view_cp",
"lateral_view_ppd",
"leftsemijoin",
"leftsemijoin_mr",
"lineage1",
"literal_double",
"literal_ints",
Expand Down

0 comments on commit 73cd1f8

Please sign in to comment.