Skip to content

Commit

Permalink
merge commits for rebase
Browse files Browse the repository at this point in the history
  • Loading branch information
adrian-wang committed Jul 14, 2015
1 parent b7bcbe2 commit 8e0afca
Show file tree
Hide file tree
Showing 5 changed files with 95 additions and 46 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -38,14 +38,12 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
case ExtractEquiJoinKeys(LeftSemi, leftKeys, rightKeys, condition, left, right)
if sqlContext.conf.autoBroadcastJoinThreshold > 0 &&
right.statistics.sizeInBytes <= sqlContext.conf.autoBroadcastJoinThreshold =>
val semiJoin = joins.BroadcastLeftSemiJoinHash(
leftKeys, rightKeys, planLater(left), planLater(right))
condition.map(Filter(_, semiJoin)).getOrElse(semiJoin) :: Nil
joins.BroadcastLeftSemiJoinHash(
leftKeys, rightKeys, planLater(left), planLater(right), condition) :: Nil
// Find left semi joins where at least some predicates can be evaluated by matching join keys
case ExtractEquiJoinKeys(LeftSemi, leftKeys, rightKeys, condition, left, right) =>
val semiJoin = joins.LeftSemiJoinHash(
leftKeys, rightKeys, planLater(left), planLater(right))
condition.map(Filter(_, semiJoin)).getOrElse(semiJoin) :: Nil
joins.LeftSemiJoinHash(
leftKeys, rightKeys, planLater(left), planLater(right), condition) :: Nil
// no predicate can be evaluated by matching hash keys
case logical.Join(left, right, LeftSemi, condition) =>
joins.LeftSemiJoinBNL(planLater(left), planLater(right), condition) :: Nil
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,37 +33,59 @@ case class BroadcastLeftSemiJoinHash(
leftKeys: Seq[Expression],
rightKeys: Seq[Expression],
left: SparkPlan,
right: SparkPlan) extends BinaryNode with HashJoin {
right: SparkPlan,
condition: Option[Expression]) extends BinaryNode with HashJoin {

override val buildSide: BuildSide = BuildRight

override def output: Seq[Attribute] = left.output

@transient private lazy val boundCondition =
newPredicate(condition.getOrElse(Literal(true)), left.output ++ right.output)

protected override def doExecute(): RDD[InternalRow] = {
val buildIter = buildPlan.execute().map(_.copy()).collect().toIterator
val hashSet = new java.util.HashSet[InternalRow]()
var currentRow: InternalRow = null
val buildIter= buildPlan.execute().map(_.copy()).collect().toIterator

condition match {
case None =>
val hashSet = new java.util.HashSet[InternalRow]()
var currentRow: InternalRow = null

// Create a Hash set of buildKeys
while (buildIter.hasNext) {
currentRow = buildIter.next()
val rowKey = buildSideKeyGenerator(currentRow)
if (!rowKey.anyNull) {
val keyExists = hashSet.contains(rowKey)
if (!keyExists) {
// rowKey may be not serializable (from codegen)
hashSet.add(rowKey.copy())
// Create a Hash set of buildKeys
while (buildIter.hasNext) {
currentRow = buildIter.next()
val rowKey = buildSideKeyGenerator(currentRow)
if (!rowKey.anyNull) {
val keyExists = hashSet.contains(rowKey)
if (!keyExists) {
hashSet.add(rowKey)
}
}
}
}
}

val broadcastedRelation = sparkContext.broadcast(hashSet)
val broadcastedRelation = sparkContext.broadcast(hashSet)

streamedPlan.execute().mapPartitions { streamIter =>
val joinKeys = streamSideKeyGenerator()
streamIter.filter(current => {
!joinKeys(current).anyNull && broadcastedRelation.value.contains(joinKeys.currentValue)
})
streamedPlan.execute().mapPartitions { streamIter =>
val joinKeys = streamSideKeyGenerator()
streamIter.filter(current => {
!joinKeys(current).anyNull && broadcastedRelation.value.contains(joinKeys.currentValue)
})
}
case _ =>
val hashRelation = HashedRelation(buildIter, buildSideKeyGenerator)
val broadcastedRelation = sparkContext.broadcast(hashRelation)

streamedPlan.execute().mapPartitions { streamIter =>
val joinKeys = streamSideKeyGenerator()
val joinedRow = new JoinedRow

streamIter.filter(current => {
val rowBuffer = broadcastedRelation.value.get(joinKeys.currentValue)
!joinKeys(current).anyNull && rowBuffer != null && rowBuffer.exists {
(build: InternalRow) => boundCondition(joinedRow(current, build))
}
})
}
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -65,8 +65,7 @@ override def outputPartitioning: Partitioning = joinType match {
@transient private[this] lazy val leftNullRow = new GenericInternalRow(left.output.length)
@transient private[this] lazy val rightNullRow = new GenericInternalRow(right.output.length)
@transient private[this] lazy val boundCondition =
condition.map(
newPredicate(_, left.output ++ right.output)).getOrElse((row: InternalRow) => true)
newPredicate(condition.getOrElse(Literal(true)), left.output ++ right.output)

// TODO we need to rewrite all of the iterators with our own implementation instead of the Scala
// iterator for performance purpose.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,8 @@ case class LeftSemiJoinHash(
leftKeys: Seq[Expression],
rightKeys: Seq[Expression],
left: SparkPlan,
right: SparkPlan) extends BinaryNode with HashJoin {
right: SparkPlan,
condition: Option[Expression]) extends BinaryNode with HashJoin {

override val buildSide: BuildSide = BuildRight

Expand All @@ -43,27 +44,44 @@ case class LeftSemiJoinHash(

override def output: Seq[Attribute] = left.output

@transient private lazy val boundCondition =
newPredicate(condition.getOrElse(Literal(true)), left.output ++ right.output)

protected override def doExecute(): RDD[InternalRow] = {
buildPlan.execute().zipPartitions(streamedPlan.execute()) { (buildIter, streamIter) =>
val hashSet = new java.util.HashSet[InternalRow]()
var currentRow: InternalRow = null
val joinKeys = streamSideKeyGenerator()
val joinedRow = new JoinedRow

condition match {
case None =>
val hashSet = new java.util.HashSet[InternalRow]()
var currentRow: InternalRow = null

// Create a Hash set of buildKeys
while (buildIter.hasNext) {
currentRow = buildIter.next()
val rowKey = buildSideKeyGenerator(currentRow)
if (!rowKey.anyNull) {
val keyExists = hashSet.contains(rowKey)
if (!keyExists) {
hashSet.add(rowKey)
// Create a Hash set of buildKeys
while (buildIter.hasNext) {
currentRow = buildIter.next()
val rowKey = buildSideKeyGenerator(currentRow)
if (!rowKey.anyNull) {
val keyExists = hashSet.contains(rowKey)
if (!keyExists) {
hashSet.add(rowKey)
}
}
}
}
}

val joinKeys = streamSideKeyGenerator()
streamIter.filter(current => {
!joinKeys(current).anyNull && hashSet.contains(joinKeys.currentValue)
})
val joinKeys = streamSideKeyGenerator()
streamIter.filter(current => {
!joinKeys(current).anyNull && hashSet.contains(joinKeys.currentValue)
})
case _ =>
val hashRelation = HashedRelation(buildIter, buildSideKeyGenerator)
streamIter.filter(current => {
val rowBuffer = hashRelation.get(joinKeys.currentValue)
!joinKeys(current).anyNull && rowBuffer != null && rowBuffer.exists {
(build: InternalRow) => boundCondition(joinedRow(current, build))
}
})
}
}
}
}
12 changes: 12 additions & 0 deletions sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -395,6 +395,18 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll with SQLTestUtils {
)
}

test("left semi greater than predicate and equal operator") {
checkAnswer(
sql("SELECT * FROM testData2 x LEFT SEMI JOIN testData2 y ON x.b = y.b and x.a >= y.a + 2"),
Seq(Row(3,1), Row(3,2))
)

checkAnswer(
sql("SELECT * FROM testData2 x LEFT SEMI JOIN testData2 y ON x.b = y.a and x.a >= y.b + 1"),
Seq(Row(2,1), Row(2,2), Row(3,1), Row(3,2))
)
}

test("index into array of arrays") {
checkAnswer(
sql(
Expand Down

0 comments on commit 8e0afca

Please sign in to comment.