Skip to content

Commit

Permalink
refactor semijoin and add plan test
Browse files Browse the repository at this point in the history
  • Loading branch information
adrian-wang committed Jul 15, 2015
1 parent 575a7c8 commit cc09809
Show file tree
Hide file tree
Showing 4 changed files with 191 additions and 92 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -34,59 +34,26 @@ case class BroadcastLeftSemiJoinHash(
rightKeys: Seq[Expression],
left: SparkPlan,
right: SparkPlan,
condition: Option[Expression]) extends BinaryNode with HashJoin {

override val buildSide: BuildSide = BuildRight

override def output: Seq[Attribute] = left.output

@transient private lazy val boundCondition =
newPredicate(condition.getOrElse(Literal(true)), left.output ++ right.output)
condition: Option[Expression]) extends BinaryNode with HashSemiJoin {

protected override def doExecute(): RDD[InternalRow] = {
val buildIter = buildPlan.execute().map(_.copy()).collect().toIterator

condition match {
case None =>
val hashSet = new java.util.HashSet[InternalRow]()
var currentRow: InternalRow = null

// Create a Hash set of buildKeys
while (buildIter.hasNext) {
currentRow = buildIter.next()
val rowKey = buildSideKeyGenerator(currentRow)
if (!rowKey.anyNull) {
val keyExists = hashSet.contains(rowKey)
if (!keyExists) {
// rowKey may be not serializable (from codegen)
hashSet.add(rowKey.copy())
}
}
}

val broadcastedRelation = sparkContext.broadcast(hashSet)

streamedPlan.execute().mapPartitions { streamIter =>
val joinKeys = streamSideKeyGenerator()
streamIter.filter(current => {
!joinKeys(current).anyNull && broadcastedRelation.value.contains(joinKeys.currentValue)
})
}
case _ =>
val hashRelation = HashedRelation(buildIter, buildSideKeyGenerator)
val broadcastedRelation = sparkContext.broadcast(hashRelation)

streamedPlan.execute().mapPartitions { streamIter =>
val joinKeys = streamSideKeyGenerator()
val joinedRow = new JoinedRow

streamIter.filter(current => {
val rowBuffer = broadcastedRelation.value.get(joinKeys.currentValue)
!joinKeys(current).anyNull && rowBuffer != null && rowBuffer.exists {
(build: InternalRow) => boundCondition(joinedRow(current, build))
}
})
}
val buildIter = right.execute().map(_.copy()).collect().toIterator

if (condition.isEmpty) {
// rowKey may be not serializable (from codegen)
val hashSet = buildKeyHashSet(buildIter, copy = true)
val broadcastedRelation = sparkContext.broadcast(hashSet)

left.execute().mapPartitions { streamIter =>
hashSemiJoin(streamIter, broadcastedRelation.value)
}
} else {
val hashRelation = HashedRelation(buildIter, rightKeyGenerator)
val broadcastedRelation = sparkContext.broadcast(hashRelation)

left.execute().mapPartitions { streamIter =>
hashSemiJoin(streamIter, broadcastedRelation.value)
}
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.sql.execution.joins

import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.execution.SparkPlan


trait HashSemiJoin {
self: SparkPlan =>
val leftKeys: Seq[Expression]
val rightKeys: Seq[Expression]
val left: SparkPlan
val right: SparkPlan
val condition: Option[Expression]

override def output: Seq[Attribute] = left.output

@transient protected lazy val rightKeyGenerator: Projection =
newProjection(rightKeys, right.output)

@transient protected lazy val leftKeyGenerator: () => MutableProjection =
newMutableProjection(leftKeys, left.output)

@transient private lazy val boundCondition =
newPredicate(condition.getOrElse(Literal(true)), left.output ++ right.output)

protected def buildKeyHashSet(
buildIter: Iterator[InternalRow],
copy: Boolean): java.util.Set[InternalRow] = {
val hashSet = new java.util.HashSet[InternalRow]()
var currentRow: InternalRow = null

// Create a Hash set of buildKeys
while (buildIter.hasNext) {
currentRow = buildIter.next()
val rowKey = rightKeyGenerator(currentRow)
if (!rowKey.anyNull) {
val keyExists = hashSet.contains(rowKey)
if (!keyExists) {
if (copy) {
hashSet.add(rowKey.copy())
} else {
// rowKey may be not serializable (from codegen)
hashSet.add(rowKey)
}
}
}
}
hashSet
}

protected def hashSemiJoin(
streamIter: Iterator[InternalRow],
hashedRelation: HashedRelation): Iterator[InternalRow] = {
val joinKeys = leftKeyGenerator()
val joinedRow = new JoinedRow
streamIter.filter(current => {
lazy val rowBuffer = hashedRelation.get(joinKeys.currentValue)
!joinKeys(current).anyNull && rowBuffer != null && rowBuffer.exists {
(build: InternalRow) => boundCondition(joinedRow(current, build))
}
})
}

protected def hashSemiJoin(
streamIter: Iterator[InternalRow],
hashSet: java.util.Set[InternalRow]): Iterator[InternalRow] = {
val joinKeys = leftKeyGenerator()
val joinedRow = new JoinedRow
streamIter.filter(current => {
!joinKeys(current.copy()).anyNull && hashSet.contains(joinKeys.currentValue)
})
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -35,52 +35,19 @@ case class LeftSemiJoinHash(
rightKeys: Seq[Expression],
left: SparkPlan,
right: SparkPlan,
condition: Option[Expression]) extends BinaryNode with HashJoin {

override val buildSide: BuildSide = BuildRight
condition: Option[Expression]) extends BinaryNode with HashSemiJoin {

override def requiredChildDistribution: Seq[ClusteredDistribution] =
ClusteredDistribution(leftKeys) :: ClusteredDistribution(rightKeys) :: Nil

override def output: Seq[Attribute] = left.output

@transient private lazy val boundCondition =
newPredicate(condition.getOrElse(Literal(true)), left.output ++ right.output)

protected override def doExecute(): RDD[InternalRow] = {
buildPlan.execute().zipPartitions(streamedPlan.execute()) { (buildIter, streamIter) =>
val joinKeys = streamSideKeyGenerator()
val joinedRow = new JoinedRow

condition match {
case None =>
val hashSet = new java.util.HashSet[InternalRow]()
var currentRow: InternalRow = null

// Create a Hash set of buildKeys
while (buildIter.hasNext) {
currentRow = buildIter.next()
val rowKey = buildSideKeyGenerator(currentRow)
if (!rowKey.anyNull) {
val keyExists = hashSet.contains(rowKey)
if (!keyExists) {
hashSet.add(rowKey)
}
}
}

val joinKeys = streamSideKeyGenerator()
streamIter.filter(current => {
!joinKeys(current).anyNull && hashSet.contains(joinKeys.currentValue)
})
case _ =>
val hashRelation = HashedRelation(buildIter, buildSideKeyGenerator)
streamIter.filter(current => {
val rowBuffer = hashRelation.get(joinKeys.currentValue)
!joinKeys(current).anyNull && rowBuffer != null && rowBuffer.exists {
(build: InternalRow) => boundCondition(joinedRow(current, build))
}
})
right.execute().zipPartitions(left.execute()) { (buildIter, streamIter) =>
if (condition.isEmpty) {
val hashSet = buildKeyHashSet(buildIter, copy = false)
hashSemiJoin(streamIter, hashSet)
} else {
val hashRelation = HashedRelation(buildIter, rightKeyGenerator)
hashSemiJoin(streamIter, hashRelation)
}
}
}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.sql.execution.joins

import org.apache.spark.sql.Row
import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.expressions.{LessThan, Expression}
import org.apache.spark.sql.execution.{SparkPlan, SparkPlanTest}


class SemiJoinSuite extends SparkPlanTest{
val left = Seq(
(1, 2.0),
(1, 2.0),
(2, 1.0),
(2, 1.0),
(3, 3.0)
).toDF("a", "b")

val right = Seq(
(2, 3.0),
(2, 3.0),
(3, 2.0),
(4, 1.0)
).toDF("c", "d")

val leftKeys: List[Expression] = 'a :: Nil
val rightKeys: List[Expression] = 'c :: Nil
val condition = Some(LessThan('b, 'd))

test("left semi join hash") {
checkAnswer2(left, right, (left: SparkPlan, right: SparkPlan) =>
LeftSemiJoinHash(leftKeys, rightKeys, left, right, condition),
Seq(
(2, 1.0),
(2, 1.0)
).map(Row.fromTuple))
}

test("left semi join BNL") {
checkAnswer2(left, right, (left: SparkPlan, right: SparkPlan) =>
LeftSemiJoinBNL(left, right, condition),
Seq(
(1, 2.0),
(1, 2.0),
(2, 1.0),
(2, 1.0)
).map(Row.fromTuple))
}

test("broadcast left semi join hash") {
checkAnswer2(left, right, (left: SparkPlan, right: SparkPlan) =>
BroadcastLeftSemiJoinHash(leftKeys, rightKeys, left, right, condition),
Seq(
(2, 1.0),
(2, 1.0)
).map(Row.fromTuple))
}
}

0 comments on commit cc09809

Please sign in to comment.