Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-7026] [SQL] fix left semi join with equi key and non-equi condition #5643

Closed
wants to merge 6 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -38,14 +38,12 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
case ExtractEquiJoinKeys(LeftSemi, leftKeys, rightKeys, condition, left, right)
if sqlContext.conf.autoBroadcastJoinThreshold > 0 &&
right.statistics.sizeInBytes <= sqlContext.conf.autoBroadcastJoinThreshold =>
val semiJoin = joins.BroadcastLeftSemiJoinHash(
leftKeys, rightKeys, planLater(left), planLater(right))
condition.map(Filter(_, semiJoin)).getOrElse(semiJoin) :: Nil
joins.BroadcastLeftSemiJoinHash(
leftKeys, rightKeys, planLater(left), planLater(right), condition) :: Nil
// Find left semi joins where at least some predicates can be evaluated by matching join keys
case ExtractEquiJoinKeys(LeftSemi, leftKeys, rightKeys, condition, left, right) =>
val semiJoin = joins.LeftSemiJoinHash(
leftKeys, rightKeys, planLater(left), planLater(right))
condition.map(Filter(_, semiJoin)).getOrElse(semiJoin) :: Nil
joins.LeftSemiJoinHash(
leftKeys, rightKeys, planLater(left), planLater(right), condition) :: Nil
// no predicate can be evaluated by matching hash keys
case logical.Join(left, right, LeftSemi, condition) =>
joins.LeftSemiJoinBNL(planLater(left), planLater(right), condition) :: Nil
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,37 +33,27 @@ case class BroadcastLeftSemiJoinHash(
leftKeys: Seq[Expression],
rightKeys: Seq[Expression],
left: SparkPlan,
right: SparkPlan) extends BinaryNode with HashJoin {

override val buildSide: BuildSide = BuildRight

override def output: Seq[Attribute] = left.output
right: SparkPlan,
condition: Option[Expression]) extends BinaryNode with HashSemiJoin {

protected override def doExecute(): RDD[InternalRow] = {
val buildIter = buildPlan.execute().map(_.copy()).collect().toIterator
val hashSet = new java.util.HashSet[InternalRow]()
var currentRow: InternalRow = null
val buildIter = right.execute().map(_.copy()).collect().toIterator

// Create a Hash set of buildKeys
while (buildIter.hasNext) {
currentRow = buildIter.next()
val rowKey = buildSideKeyGenerator(currentRow)
if (!rowKey.anyNull) {
val keyExists = hashSet.contains(rowKey)
if (!keyExists) {
// rowKey may be not serializable (from codegen)
hashSet.add(rowKey.copy())
}
}
}
if (condition.isEmpty) {
// rowKey may be not serializable (from codegen)
val hashSet = buildKeyHashSet(buildIter, copy = true)
val broadcastedRelation = sparkContext.broadcast(hashSet)

val broadcastedRelation = sparkContext.broadcast(hashSet)
left.execute().mapPartitions { streamIter =>
hashSemiJoin(streamIter, broadcastedRelation.value)
}
} else {
val hashRelation = HashedRelation(buildIter, rightKeyGenerator)
val broadcastedRelation = sparkContext.broadcast(hashRelation)

streamedPlan.execute().mapPartitions { streamIter =>
val joinKeys = streamSideKeyGenerator()
streamIter.filter(current => {
!joinKeys(current).anyNull && broadcastedRelation.value.contains(joinKeys.currentValue)
})
left.execute().mapPartitions { streamIter =>
hashSemiJoin(streamIter, broadcastedRelation.value)
}
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -65,8 +65,7 @@ override def outputPartitioning: Partitioning = joinType match {
@transient private[this] lazy val leftNullRow = new GenericInternalRow(left.output.length)
@transient private[this] lazy val rightNullRow = new GenericInternalRow(right.output.length)
@transient private[this] lazy val boundCondition =
condition.map(
newPredicate(_, left.output ++ right.output)).getOrElse((row: InternalRow) => true)
newPredicate(condition.getOrElse(Literal(true)), left.output ++ right.output)

// TODO we need to rewrite all of the iterators with our own implementation instead of the Scala
// iterator for performance purpose.
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.sql.execution.joins

import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.execution.SparkPlan


trait HashSemiJoin {
self: SparkPlan =>
val leftKeys: Seq[Expression]
val rightKeys: Seq[Expression]
val left: SparkPlan
val right: SparkPlan
val condition: Option[Expression]

override def output: Seq[Attribute] = left.output

@transient protected lazy val rightKeyGenerator: Projection =
newProjection(rightKeys, right.output)

@transient protected lazy val leftKeyGenerator: () => MutableProjection =
newMutableProjection(leftKeys, left.output)

@transient private lazy val boundCondition =
newPredicate(condition.getOrElse(Literal(true)), left.output ++ right.output)

protected def buildKeyHashSet(
buildIter: Iterator[InternalRow],
copy: Boolean): java.util.Set[InternalRow] = {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Long term I wonder if its actually a win for us to build just a set instead of using hashed relation everywhere. We have done a bunch optimization on HashedRelation to make it serialize faster.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe we need to implement a version of HashedRelation which only stores the keys.

val hashSet = new java.util.HashSet[InternalRow]()
var currentRow: InternalRow = null

// Create a Hash set of buildKeys
while (buildIter.hasNext) {
currentRow = buildIter.next()
val rowKey = rightKeyGenerator(currentRow)
if (!rowKey.anyNull) {
val keyExists = hashSet.contains(rowKey)
if (!keyExists) {
if (copy) {
hashSet.add(rowKey.copy())
} else {
// rowKey may be not serializable (from codegen)
hashSet.add(rowKey)
}
}
}
}
hashSet
}

protected def hashSemiJoin(
streamIter: Iterator[InternalRow],
hashedRelation: HashedRelation): Iterator[InternalRow] = {
val joinKeys = leftKeyGenerator()
val joinedRow = new JoinedRow
streamIter.filter(current => {
lazy val rowBuffer = hashedRelation.get(joinKeys.currentValue)
!joinKeys(current).anyNull && rowBuffer != null && rowBuffer.exists {
(build: InternalRow) => boundCondition(joinedRow(current, build))
}
})
}

protected def hashSemiJoin(
streamIter: Iterator[InternalRow],
hashSet: java.util.Set[InternalRow]): Iterator[InternalRow] = {
val joinKeys = leftKeyGenerator()
val joinedRow = new JoinedRow
streamIter.filter(current => {
!joinKeys(current.copy()).anyNull && hashSet.contains(joinKeys.currentValue)
})
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ package org.apache.spark.sql.execution.joins
import org.apache.spark.annotation.DeveloperApi
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.{Attribute, Expression}
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.physical.ClusteredDistribution
import org.apache.spark.sql.execution.{BinaryNode, SparkPlan}

Expand All @@ -34,36 +34,21 @@ case class LeftSemiJoinHash(
leftKeys: Seq[Expression],
rightKeys: Seq[Expression],
left: SparkPlan,
right: SparkPlan) extends BinaryNode with HashJoin {

override val buildSide: BuildSide = BuildRight
right: SparkPlan,
condition: Option[Expression]) extends BinaryNode with HashSemiJoin {

override def requiredChildDistribution: Seq[ClusteredDistribution] =
ClusteredDistribution(leftKeys) :: ClusteredDistribution(rightKeys) :: Nil

override def output: Seq[Attribute] = left.output

protected override def doExecute(): RDD[InternalRow] = {
buildPlan.execute().zipPartitions(streamedPlan.execute()) { (buildIter, streamIter) =>
val hashSet = new java.util.HashSet[InternalRow]()
var currentRow: InternalRow = null

// Create a Hash set of buildKeys
while (buildIter.hasNext) {
currentRow = buildIter.next()
val rowKey = buildSideKeyGenerator(currentRow)
if (!rowKey.anyNull) {
val keyExists = hashSet.contains(rowKey)
if (!keyExists) {
hashSet.add(rowKey)
}
}
right.execute().zipPartitions(left.execute()) { (buildIter, streamIter) =>
if (condition.isEmpty) {
val hashSet = buildKeyHashSet(buildIter, copy = false)
hashSemiJoin(streamIter, hashSet)
} else {
val hashRelation = HashedRelation(buildIter, rightKeyGenerator)
hashSemiJoin(streamIter, hashRelation)
}

val joinKeys = streamSideKeyGenerator()
streamIter.filter(current => {
!joinKeys(current).anyNull && hashSet.contains(joinKeys.currentValue)
})
}
}
}
12 changes: 12 additions & 0 deletions sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -395,6 +395,18 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll with SQLTestUtils {
)
}

test("left semi greater than predicate and equal operator") {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@adrian-wang i suggest you add the case chenghao described in my PR to the unit test.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

create a pr for your branch

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

closed since you have added the test

checkAnswer(
sql("SELECT * FROM testData2 x LEFT SEMI JOIN testData2 y ON x.b = y.b and x.a >= y.a + 2"),
Seq(Row(3, 1), Row(3, 2))
)

checkAnswer(
sql("SELECT * FROM testData2 x LEFT SEMI JOIN testData2 y ON x.b = y.a and x.a >= y.b + 1"),
Seq(Row(2, 1), Row(2, 2), Row(3, 1), Row(3, 2))
)
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It would be great to also add some tests using the new SparkPlanTest infrastructure. In particular it seems like @Sephiroth-Lin found a bug that indicates test coverage is insufficient.


test("index into array of arrays") {
checkAnswer(
sql(
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.sql.execution.joins

import org.apache.spark.sql.Row
import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.expressions.{LessThan, Expression}
import org.apache.spark.sql.execution.{SparkPlan, SparkPlanTest}


class SemiJoinSuite extends SparkPlanTest{
val left = Seq(
(1, 2.0),
(1, 2.0),
(2, 1.0),
(2, 1.0),
(3, 3.0)
).toDF("a", "b")

val right = Seq(
(2, 3.0),
(2, 3.0),
(3, 2.0),
(4, 1.0)
).toDF("c", "d")

val leftKeys: List[Expression] = 'a :: Nil
val rightKeys: List[Expression] = 'c :: Nil
val condition = Some(LessThan('b, 'd))

test("left semi join hash") {
checkAnswer2(left, right, (left: SparkPlan, right: SparkPlan) =>
LeftSemiJoinHash(leftKeys, rightKeys, left, right, condition),
Seq(
(2, 1.0),
(2, 1.0)
).map(Row.fromTuple))
}

test("left semi join BNL") {
checkAnswer2(left, right, (left: SparkPlan, right: SparkPlan) =>
LeftSemiJoinBNL(left, right, condition),
Seq(
(1, 2.0),
(1, 2.0),
(2, 1.0),
(2, 1.0)
).map(Row.fromTuple))
}

test("broadcast left semi join hash") {
checkAnswer2(left, right, (left: SparkPlan, right: SparkPlan) =>
BroadcastLeftSemiJoinHash(leftKeys, rightKeys, left, right, condition),
Seq(
(2, 1.0),
(2, 1.0)
).map(Row.fromTuple))
}
}