Skip to content

Commit

Permalink
Review comments.
Browse files Browse the repository at this point in the history
  • Loading branch information
marmbrus committed Mar 31, 2014
1 parent 1e9fb63 commit 8e6f2a2
Show file tree
Hide file tree
Showing 3 changed files with 46 additions and 39 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -48,8 +48,8 @@ trait Row extends Seq[Any] with Serializable {
/** Returns true if there are any NULL values in this row. */
def anyNull: Boolean = {
var i = 0
while(i < length) {
if(isNullAt(i)) return true
while (i < length) {
if (isNullAt(i)) { return true }
i += 1
}
false
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,12 @@ import org.apache.spark.sql.catalyst.trees
import org.apache.spark.sql.catalyst.analysis.UnresolvedException
import org.apache.spark.sql.catalyst.types.{BooleanType, StringType}

object InterpretedPredicate {
def apply(expression: Expression): (Row => Boolean) = {
(r: Row) => expression.apply(r).asInstanceOf[Boolean]
}
}

trait Predicate extends Expression {
self: Product =>

Expand Down
75 changes: 38 additions & 37 deletions sql/core/src/main/scala/org/apache/spark/sql/execution/joins.scala
Original file line number Diff line number Diff line change
Expand Up @@ -15,29 +15,20 @@
* limitations under the License.
*/

package org.apache.spark.sql
package execution
package org.apache.spark.sql.execution

import scala.collection.mutable.{ArrayBuffer, BitSet}

import org.apache.spark.rdd.RDD
import org.apache.spark.SparkContext

import catalyst.errors._
import catalyst.expressions._
import catalyst.plans._
import catalyst.plans.physical.{ClusteredDistribution, Partitioning}
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans._
import org.apache.spark.sql.catalyst.plans.physical.{ClusteredDistribution, Partitioning}

sealed abstract class BuildSide
case object BuildLeft extends BuildSide
case object BuildRight extends BuildSide

object InterpretCondition {
def apply(expression: Expression): (Row => Boolean) = {
(r: Row) => expression.apply(r).asInstanceOf[Boolean]
}
}

case class HashJoin(
leftKeys: Seq[Expression],
rightKeys: Seq[Expression],
Expand Down Expand Up @@ -69,11 +60,12 @@ case class HashJoin(
def execute() = {

buildPlan.execute().zipPartitions(streamedPlan.execute()) { (buildIter, streamIter) =>
// TODO: Use Spark's HashMap implementation.
val hashTable = new java.util.HashMap[Row, ArrayBuffer[Row]]()
var currentRow: Row = null

// Create a mapping of buildKeys -> rows
while(buildIter.hasNext) {
while (buildIter.hasNext) {
currentRow = buildIter.next()
val rowKey = buildSideKeyGenerator(currentRow)
if(!rowKey.anyNull) {
Expand All @@ -90,40 +82,49 @@ case class HashJoin(
}

new Iterator[Row] {
private[this] var currentRow: Row = _
private[this] var currentMatches: ArrayBuffer[Row] = _
private[this] var currentPosition: Int = -1
private[this] var currentStreamedRow: Row = _
private[this] var currentHashMatches: ArrayBuffer[Row] = _
private[this] var currentMatchPosition: Int = -1

// Mutable per row objects.
private[this] val joinRow = new JoinedRow

@transient private val joinKeys = streamSideKeyGenerator()
private[this] val joinKeys = streamSideKeyGenerator()

def hasNext: Boolean =
(currentPosition != -1 && currentPosition < currentMatches.size) ||
(streamIter.hasNext && fetchNext())
override final def hasNext: Boolean =
if (currentMatchPosition != -1) {
currentMatchPosition < currentHashMatches.size
} else {
fetchNext()
}

def next() = {
val ret = joinRow(currentRow, currentMatches(currentPosition))
currentPosition += 1
override final def next() = {
val ret = joinRow(currentStreamedRow, currentHashMatches(currentMatchPosition))
currentMatchPosition += 1
ret
}

private def fetchNext(): Boolean = {
currentMatches = null
currentPosition = -1

while (currentMatches == null && streamIter.hasNext) {
currentRow = streamIter.next()
if(!joinKeys(currentRow).anyNull) {
currentMatches = hashTable.get(joinKeys.currentValue)
/**
* Searches the streamed iterator for the next row that has at least one match in hashtable.
*
* @return true if the search is successful, and false the streamed iterator runs out of
* tuples.
*/
private final def fetchNext(): Boolean = {
currentHashMatches = null
currentMatchPosition = -1

while (currentHashMatches == null && streamIter.hasNext) {
currentStreamedRow = streamIter.next()
if (!joinKeys(currentStreamedRow).anyNull) {
currentHashMatches = hashTable.get(joinKeys.currentValue)
}
}

if (currentMatches == null) {
if (currentHashMatches == null) {
false
} else {
currentPosition = 0
currentMatchPosition = 0
true
}
}
Expand Down Expand Up @@ -158,7 +159,7 @@ case class BroadcastNestedLoopJoin(
def right = broadcast

@transient lazy val boundCondition =
InterpretCondition(
InterpretedPredicate(
condition
.map(c => BindReferences.bindReference(c, left.output ++ right.output))
.getOrElse(Literal(true)))
Expand All @@ -169,8 +170,8 @@ case class BroadcastNestedLoopJoin(

val streamedPlusMatches = streamed.execute().mapPartitions { streamedIter =>
val matchedRows = new ArrayBuffer[Row]
val includedBroadcastTuples =
new scala.collection.mutable.BitSet(broadcastedRelation.value.size)
// TODO: Use Spark's BitSet.
val includedBroadcastTuples = new BitSet(broadcastedRelation.value.size)
val joinedRow = new JoinedRow

streamedIter.foreach { streamedRow =>
Expand Down

0 comments on commit 8e6f2a2

Please sign in to comment.