Skip to content

Commit

Permalink
Followup of comments over PR #19257
Browse files Browse the repository at this point in the history
  • Loading branch information
tejasapatil committed Dec 21, 2017
1 parent 0114c89 commit 8c37d46
Show file tree
Hide file tree
Showing 2 changed files with 44 additions and 42 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -252,54 +252,56 @@ case class EnsureRequirements(conf: SQLConf) extends Rule[SparkPlan] {
operator.withNewChildren(children)
}

/**
* When the physical operators are created for JOIN, the ordering of join keys is based on order
* in which the join keys appear in the user query. That might not match with the output
* partitioning of the join node's children (thus leading to extra sort / shuffle being
* introduced). This rule will change the ordering of the join keys to match with the
* partitioning of the join nodes' children.
*/
def reorderJoinPredicates(plan: SparkPlan): SparkPlan = {
def reorderJoinKeys(
leftKeys: Seq[Expression],
rightKeys: Seq[Expression],
leftPartitioning: Partitioning,
rightPartitioning: Partitioning): (Seq[Expression], Seq[Expression]) = {

def reorder(expectedOrderOfKeys: Seq[Expression],
currentOrderOfKeys: Seq[Expression]): (Seq[Expression], Seq[Expression]) = {
val leftKeysBuffer = ArrayBuffer[Expression]()
val rightKeysBuffer = ArrayBuffer[Expression]()
private def reorder(
leftKeys: Seq[Expression],
rightKeys: Seq[Expression],
expectedOrderOfKeys: Seq[Expression],
currentOrderOfKeys: Seq[Expression]): (Seq[Expression], Seq[Expression]) = {
val leftKeysBuffer = ArrayBuffer[Expression]()
val rightKeysBuffer = ArrayBuffer[Expression]()

expectedOrderOfKeys.foreach(expression => {
val index = currentOrderOfKeys.indexWhere(e => e.semanticEquals(expression))
leftKeysBuffer.append(leftKeys(index))
rightKeysBuffer.append(rightKeys(index))
})
(leftKeysBuffer, rightKeysBuffer)
}
expectedOrderOfKeys.foreach(expression => {
val index = currentOrderOfKeys.indexWhere(e => e.semanticEquals(expression))
leftKeysBuffer.append(leftKeys(index))
rightKeysBuffer.append(rightKeys(index))
})
(leftKeysBuffer, rightKeysBuffer)
}

if (leftKeys.forall(_.deterministic) && rightKeys.forall(_.deterministic)) {
leftPartitioning match {
case HashPartitioning(leftExpressions, _)
if leftExpressions.length == leftKeys.length &&
leftKeys.forall(x => leftExpressions.exists(_.semanticEquals(x))) =>
reorder(leftExpressions, leftKeys)
private def reorderJoinKeys(
leftKeys: Seq[Expression],
rightKeys: Seq[Expression],
leftPartitioning: Partitioning,
rightPartitioning: Partitioning): (Seq[Expression], Seq[Expression]) = {
if (leftKeys.forall(_.deterministic) && rightKeys.forall(_.deterministic)) {
leftPartitioning match {
case HashPartitioning(leftExpressions, _)
if leftExpressions.length == leftKeys.length &&
leftKeys.forall(x => leftExpressions.exists(_.semanticEquals(x))) =>
reorder(leftKeys, rightKeys, leftExpressions, leftKeys)

case _ => rightPartitioning match {
case HashPartitioning(rightExpressions, _)
if rightExpressions.length == rightKeys.length &&
rightKeys.forall(x => rightExpressions.exists(_.semanticEquals(x))) =>
reorder(rightExpressions, rightKeys)
case _ => rightPartitioning match {
case HashPartitioning(rightExpressions, _)
if rightExpressions.length == rightKeys.length &&
rightKeys.forall(x => rightExpressions.exists(_.semanticEquals(x))) =>
reorder(leftKeys, rightKeys, rightExpressions, rightKeys)

case _ => (leftKeys, rightKeys)
}
case _ => (leftKeys, rightKeys)
}
} else {
(leftKeys, rightKeys)
}
} else {
(leftKeys, rightKeys)
}
}

/**
* When the physical operators are created for JOIN, the ordering of join keys is based on order
* in which the join keys appear in the user query. That might not match with the output
* partitioning of the join node's children (thus leading to extra sort / shuffle being
* introduced). This rule will change the ordering of the join keys to match with the
* partitioning of the join nodes' children.
*/
private def reorderJoinPredicates(plan: SparkPlan): SparkPlan = {
plan.transformUp {
case BroadcastHashJoinExec(leftKeys, rightKeys, joinType, buildSide, condition, left,
right) =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -620,15 +620,15 @@ abstract class BucketedReadSuite extends QueryTest with SQLTestUtils {
|) ab
|JOIN table2 c
|ON ab.i = c.i
|""".stripMargin),
""".stripMargin),
sql("""
|SELECT a.i, a.j, a.k, c.i, c.j, c.k
|FROM bucketed_table a
|JOIN table1 b
|ON a.i = b.i
|JOIN table2 c
|ON a.i = c.i
|""".stripMargin))
""".stripMargin))
}
}
}
Expand Down

0 comments on commit 8c37d46

Please sign in to comment.