Skip to content

Commit

Permalink
[SPARK-7026] [SQL] fix left semi join with equi key and non-equi cond…
Browse files Browse the repository at this point in the history
…ition

When the `condition` extracted by `ExtractEquiJoinKeys` contain join Predicate for left semi join, we can not plan it as semiJoin. Such as

    SELECT * FROM testData2 x
    LEFT SEMI JOIN testData2 y
    ON x.b = y.b
    AND x.a >= y.a + 2
Condition `x.a >= y.a + 2` can not evaluate on table `x`, so it throw errors

Author: Daoyuan Wang <[email protected]>

Closes apache#5643 from adrian-wang/spark7026 and squashes the following commits:

cc09809 [Daoyuan Wang] refactor semijoin and add plan test
575a7c8 [Daoyuan Wang] fix notserializable
27841de [Daoyuan Wang] fix rebase
10bf124 [Daoyuan Wang] fix style
72baa02 [Daoyuan Wang] fix style
8e0afca [Daoyuan Wang] merge commits for rebase
  • Loading branch information
adrian-wang authored and marmbrus committed Jul 17, 2015
1 parent b13ef77 commit 1707238
Show file tree
Hide file tree
Showing 7 changed files with 208 additions and 59 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -38,14 +38,12 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
case ExtractEquiJoinKeys(LeftSemi, leftKeys, rightKeys, condition, left, right)
if sqlContext.conf.autoBroadcastJoinThreshold > 0 &&
right.statistics.sizeInBytes <= sqlContext.conf.autoBroadcastJoinThreshold =>
val semiJoin = joins.BroadcastLeftSemiJoinHash(
leftKeys, rightKeys, planLater(left), planLater(right))
condition.map(Filter(_, semiJoin)).getOrElse(semiJoin) :: Nil
joins.BroadcastLeftSemiJoinHash(
leftKeys, rightKeys, planLater(left), planLater(right), condition) :: Nil
// Find left semi joins where at least some predicates can be evaluated by matching join keys
case ExtractEquiJoinKeys(LeftSemi, leftKeys, rightKeys, condition, left, right) =>
val semiJoin = joins.LeftSemiJoinHash(
leftKeys, rightKeys, planLater(left), planLater(right))
condition.map(Filter(_, semiJoin)).getOrElse(semiJoin) :: Nil
joins.LeftSemiJoinHash(
leftKeys, rightKeys, planLater(left), planLater(right), condition) :: Nil
// no predicate can be evaluated by matching hash keys
case logical.Join(left, right, LeftSemi, condition) =>
joins.LeftSemiJoinBNL(planLater(left), planLater(right), condition) :: Nil
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,37 +33,27 @@ case class BroadcastLeftSemiJoinHash(
leftKeys: Seq[Expression],
rightKeys: Seq[Expression],
left: SparkPlan,
right: SparkPlan) extends BinaryNode with HashJoin {

override val buildSide: BuildSide = BuildRight

override def output: Seq[Attribute] = left.output
right: SparkPlan,
condition: Option[Expression]) extends BinaryNode with HashSemiJoin {

protected override def doExecute(): RDD[InternalRow] = {
val buildIter = buildPlan.execute().map(_.copy()).collect().toIterator
val hashSet = new java.util.HashSet[InternalRow]()
var currentRow: InternalRow = null
val buildIter = right.execute().map(_.copy()).collect().toIterator

// Create a Hash set of buildKeys
while (buildIter.hasNext) {
currentRow = buildIter.next()
val rowKey = buildSideKeyGenerator(currentRow)
if (!rowKey.anyNull) {
val keyExists = hashSet.contains(rowKey)
if (!keyExists) {
// rowKey may be not serializable (from codegen)
hashSet.add(rowKey.copy())
}
}
}
if (condition.isEmpty) {
// rowKey may be not serializable (from codegen)
val hashSet = buildKeyHashSet(buildIter, copy = true)
val broadcastedRelation = sparkContext.broadcast(hashSet)

val broadcastedRelation = sparkContext.broadcast(hashSet)
left.execute().mapPartitions { streamIter =>
hashSemiJoin(streamIter, broadcastedRelation.value)
}
} else {
val hashRelation = HashedRelation(buildIter, rightKeyGenerator)
val broadcastedRelation = sparkContext.broadcast(hashRelation)

streamedPlan.execute().mapPartitions { streamIter =>
val joinKeys = streamSideKeyGenerator()
streamIter.filter(current => {
!joinKeys(current).anyNull && broadcastedRelation.value.contains(joinKeys.currentValue)
})
left.execute().mapPartitions { streamIter =>
hashSemiJoin(streamIter, broadcastedRelation.value)
}
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -65,8 +65,7 @@ override def outputPartitioning: Partitioning = joinType match {
@transient private[this] lazy val leftNullRow = new GenericInternalRow(left.output.length)
@transient private[this] lazy val rightNullRow = new GenericInternalRow(right.output.length)
@transient private[this] lazy val boundCondition =
condition.map(
newPredicate(_, left.output ++ right.output)).getOrElse((row: InternalRow) => true)
newPredicate(condition.getOrElse(Literal(true)), left.output ++ right.output)

// TODO we need to rewrite all of the iterators with our own implementation instead of the Scala
// iterator for performance purpose.
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.sql.execution.joins

import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.execution.SparkPlan


trait HashSemiJoin {
self: SparkPlan =>
val leftKeys: Seq[Expression]
val rightKeys: Seq[Expression]
val left: SparkPlan
val right: SparkPlan
val condition: Option[Expression]

override def output: Seq[Attribute] = left.output

@transient protected lazy val rightKeyGenerator: Projection =
newProjection(rightKeys, right.output)

@transient protected lazy val leftKeyGenerator: () => MutableProjection =
newMutableProjection(leftKeys, left.output)

@transient private lazy val boundCondition =
newPredicate(condition.getOrElse(Literal(true)), left.output ++ right.output)

protected def buildKeyHashSet(
buildIter: Iterator[InternalRow],
copy: Boolean): java.util.Set[InternalRow] = {
val hashSet = new java.util.HashSet[InternalRow]()
var currentRow: InternalRow = null

// Create a Hash set of buildKeys
while (buildIter.hasNext) {
currentRow = buildIter.next()
val rowKey = rightKeyGenerator(currentRow)
if (!rowKey.anyNull) {
val keyExists = hashSet.contains(rowKey)
if (!keyExists) {
if (copy) {
hashSet.add(rowKey.copy())
} else {
// rowKey may be not serializable (from codegen)
hashSet.add(rowKey)
}
}
}
}
hashSet
}

protected def hashSemiJoin(
streamIter: Iterator[InternalRow],
hashedRelation: HashedRelation): Iterator[InternalRow] = {
val joinKeys = leftKeyGenerator()
val joinedRow = new JoinedRow
streamIter.filter(current => {
lazy val rowBuffer = hashedRelation.get(joinKeys.currentValue)
!joinKeys(current).anyNull && rowBuffer != null && rowBuffer.exists {
(build: InternalRow) => boundCondition(joinedRow(current, build))
}
})
}

protected def hashSemiJoin(
streamIter: Iterator[InternalRow],
hashSet: java.util.Set[InternalRow]): Iterator[InternalRow] = {
val joinKeys = leftKeyGenerator()
val joinedRow = new JoinedRow
streamIter.filter(current => {
!joinKeys(current.copy()).anyNull && hashSet.contains(joinKeys.currentValue)
})
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ package org.apache.spark.sql.execution.joins
import org.apache.spark.annotation.DeveloperApi
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.{Attribute, Expression}
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.physical.ClusteredDistribution
import org.apache.spark.sql.execution.{BinaryNode, SparkPlan}

Expand All @@ -34,36 +34,21 @@ case class LeftSemiJoinHash(
leftKeys: Seq[Expression],
rightKeys: Seq[Expression],
left: SparkPlan,
right: SparkPlan) extends BinaryNode with HashJoin {

override val buildSide: BuildSide = BuildRight
right: SparkPlan,
condition: Option[Expression]) extends BinaryNode with HashSemiJoin {

override def requiredChildDistribution: Seq[ClusteredDistribution] =
ClusteredDistribution(leftKeys) :: ClusteredDistribution(rightKeys) :: Nil

override def output: Seq[Attribute] = left.output

protected override def doExecute(): RDD[InternalRow] = {
buildPlan.execute().zipPartitions(streamedPlan.execute()) { (buildIter, streamIter) =>
val hashSet = new java.util.HashSet[InternalRow]()
var currentRow: InternalRow = null

// Create a Hash set of buildKeys
while (buildIter.hasNext) {
currentRow = buildIter.next()
val rowKey = buildSideKeyGenerator(currentRow)
if (!rowKey.anyNull) {
val keyExists = hashSet.contains(rowKey)
if (!keyExists) {
hashSet.add(rowKey)
}
}
right.execute().zipPartitions(left.execute()) { (buildIter, streamIter) =>
if (condition.isEmpty) {
val hashSet = buildKeyHashSet(buildIter, copy = false)
hashSemiJoin(streamIter, hashSet)
} else {
val hashRelation = HashedRelation(buildIter, rightKeyGenerator)
hashSemiJoin(streamIter, hashRelation)
}

val joinKeys = streamSideKeyGenerator()
streamIter.filter(current => {
!joinKeys(current).anyNull && hashSet.contains(joinKeys.currentValue)
})
}
}
}
12 changes: 12 additions & 0 deletions sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -395,6 +395,18 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll with SQLTestUtils {
)
}

test("left semi greater than predicate and equal operator") {
checkAnswer(
sql("SELECT * FROM testData2 x LEFT SEMI JOIN testData2 y ON x.b = y.b and x.a >= y.a + 2"),
Seq(Row(3, 1), Row(3, 2))
)

checkAnswer(
sql("SELECT * FROM testData2 x LEFT SEMI JOIN testData2 y ON x.b = y.a and x.a >= y.b + 1"),
Seq(Row(2, 1), Row(2, 2), Row(3, 1), Row(3, 2))
)
}

test("index into array of arrays") {
checkAnswer(
sql(
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.sql.execution.joins

import org.apache.spark.sql.Row
import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.expressions.{LessThan, Expression}
import org.apache.spark.sql.execution.{SparkPlan, SparkPlanTest}


class SemiJoinSuite extends SparkPlanTest{
val left = Seq(
(1, 2.0),
(1, 2.0),
(2, 1.0),
(2, 1.0),
(3, 3.0)
).toDF("a", "b")

val right = Seq(
(2, 3.0),
(2, 3.0),
(3, 2.0),
(4, 1.0)
).toDF("c", "d")

val leftKeys: List[Expression] = 'a :: Nil
val rightKeys: List[Expression] = 'c :: Nil
val condition = Some(LessThan('b, 'd))

test("left semi join hash") {
checkAnswer2(left, right, (left: SparkPlan, right: SparkPlan) =>
LeftSemiJoinHash(leftKeys, rightKeys, left, right, condition),
Seq(
(2, 1.0),
(2, 1.0)
).map(Row.fromTuple))
}

test("left semi join BNL") {
checkAnswer2(left, right, (left: SparkPlan, right: SparkPlan) =>
LeftSemiJoinBNL(left, right, condition),
Seq(
(1, 2.0),
(1, 2.0),
(2, 1.0),
(2, 1.0)
).map(Row.fromTuple))
}

test("broadcast left semi join hash") {
checkAnswer2(left, right, (left: SparkPlan, right: SparkPlan) =>
BroadcastLeftSemiJoinHash(leftKeys, rightKeys, left, right, condition),
Seq(
(2, 1.0),
(2, 1.0)
).map(Row.fromTuple))
}
}

0 comments on commit 1707238

Please sign in to comment.