Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-20366] [SQL] Fix recursive join reordering: inside joins are not reordered #17668

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ import scala.collection.mutable

import org.apache.spark.internal.Logging
import org.apache.spark.sql.catalyst.expressions.{And, Attribute, AttributeSet, Expression, PredicateHelper}
import org.apache.spark.sql.catalyst.plans.{Inner, InnerLike}
import org.apache.spark.sql.catalyst.plans.{Inner, InnerLike, JoinType}
import org.apache.spark.sql.catalyst.plans.logical.{BinaryNode, Join, LogicalPlan, Project}
import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.internal.SQLConf
Expand All @@ -47,7 +47,7 @@ case class CostBasedJoinReorder(conf: SQLConf) extends Rule[LogicalPlan] with Pr
}
// After reordering is finished, convert OrderedJoin back to Join
result transformDown {
case oj: OrderedJoin => oj.join
case OrderedJoin(left, right, jt, cond) => Join(left, right, jt, cond)
}
}
}
Expand Down Expand Up @@ -87,22 +87,24 @@ case class CostBasedJoinReorder(conf: SQLConf) extends Rule[LogicalPlan] with Pr
}

private def replaceWithOrderedJoin(plan: LogicalPlan): LogicalPlan = plan match {
case j @ Join(left, right, _: InnerLike, Some(cond)) =>
case j @ Join(left, right, jt: InnerLike, Some(cond)) =>
val replacedLeft = replaceWithOrderedJoin(left)
val replacedRight = replaceWithOrderedJoin(right)
OrderedJoin(j.copy(left = replacedLeft, right = replacedRight))
OrderedJoin(replacedLeft, replacedRight, jt, Some(cond))
case p @ Project(projectList, j @ Join(_, _, _: InnerLike, Some(cond))) =>
p.copy(child = replaceWithOrderedJoin(j))
case _ =>
plan
}
}

/** This is a wrapper class for a join node that has been ordered. */
private case class OrderedJoin(join: Join) extends BinaryNode {
override def left: LogicalPlan = join.left
override def right: LogicalPlan = join.right
override def output: Seq[Attribute] = join.output
}
/** This is a mimic class for a join node that has been ordered. */
case class OrderedJoin(
left: LogicalPlan,
right: LogicalPlan,
joinType: JoinType,
condition: Option[Expression]) extends BinaryNode {
override def output: Seq[Attribute] = left.output ++ right.output
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,13 +25,12 @@ import org.apache.spark.sql.catalyst.plans.logical.{ColumnStat, LogicalPlan}
import org.apache.spark.sql.catalyst.rules.RuleExecutor
import org.apache.spark.sql.catalyst.statsEstimation.{StatsEstimationTestBase, StatsTestPlan}
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.internal.SQLConf.{CASE_SENSITIVE, CBO_ENABLED, JOIN_REORDER_ENABLED}
import org.apache.spark.sql.internal.SQLConf.{CBO_ENABLED, JOIN_REORDER_ENABLED}


class JoinReorderSuite extends PlanTest with StatsEstimationTestBase {

override val conf = new SQLConf().copy(
CASE_SENSITIVE -> true, CBO_ENABLED -> true, JOIN_REORDER_ENABLED -> true)
override val conf = new SQLConf().copy(CBO_ENABLED -> true, JOIN_REORDER_ENABLED -> true)

object Optimize extends RuleExecutor[LogicalPlan] {
val batches =
Expand Down Expand Up @@ -212,6 +211,50 @@ class JoinReorderSuite extends PlanTest with StatsEstimationTestBase {
}
}

test("reorder recursively") {
// Original order:
// Join
// / \
// Union t5
// / \
// Join t4
// / \
// Join t3
// / \
// t1 t2
val bottomJoins =
t1.join(t2).join(t3).where((nameToAttr("t1.k-1-2") === nameToAttr("t2.k-1-5")) &&
(nameToAttr("t1.v-1-10") === nameToAttr("t3.v-1-100")))
.select(nameToAttr("t1.v-1-10"))

val originalPlan = bottomJoins
.union(t4.select(nameToAttr("t4.v-1-10")))
.join(t5, Inner, Some(nameToAttr("t1.v-1-10") === nameToAttr("t5.v-1-5")))

// Should be able to reorder the bottom part.
// Best order:
// Join
// / \
// Union t5
// / \
// Join t4
// / \
// Join t2
// / \
// t1 t3
val bestBottomPlan =
t1.join(t3, Inner, Some(nameToAttr("t1.v-1-10") === nameToAttr("t3.v-1-100")))
.select(nameToAttr("t1.k-1-2"), nameToAttr("t1.v-1-10"))
.join(t2, Inner, Some(nameToAttr("t1.k-1-2") === nameToAttr("t2.k-1-5")))
.select(nameToAttr("t1.v-1-10"))

val bestPlan = bestBottomPlan
.union(t4.select(nameToAttr("t4.v-1-10")))
.join(t5, Inner, Some(nameToAttr("t1.v-1-10") === nameToAttr("t5.v-1-5")))

assertEqualPlans(originalPlan, bestPlan)
}

private def assertEqualPlans(
originalPlan: LogicalPlan,
groundTruthBestPlan: LogicalPlan): Unit = {
Expand Down