Skip to content

Commit

Permalink
SPARK-21479 Outer join filter pushdown in null supplying table when c…
Browse files Browse the repository at this point in the history
…ondition is on one of the joined columns
  • Loading branch information
maryannxue committed Mar 13, 2018
1 parent 918fb9b commit ac17976
Show file tree
Hide file tree
Showing 3 changed files with 103 additions and 2 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ abstract class Optimizer(sessionCatalog: SessionCatalog)
LimitPushDown,
ColumnPruning,
InferFiltersFromConstraints,
TransitPredicateInOuterJoin,
// Operator combine
CollapseRepartition,
CollapseProject,
Expand Down Expand Up @@ -1071,6 +1072,66 @@ object PushPredicateThroughJoin extends Rule[LogicalPlan] with PredicateHelper {
}
}

/**
* Infer and transit predicate from the preserved side to the null-supplying side
* of an outer join. The predicate is inferred from the preserved side based on the
* join condition and will be pushed over to the null-supplying side. For example,
* if the preserved side has constraints of the form 'a > 5' and the join condition
* is 'a = b', in which 'b' is an attribute from the null-supplying side, a [[Filter]]
* operator of 'b > 5' will be applied to the null-supplying side.
*
* Applying this rule will not change the constraints of the [[Join]] operator, so
* aside from its child being transformed, there is no side-effect to the [[Join]]
* operator itself.
*/
object TransitPredicateInOuterJoin extends Rule[LogicalPlan] with PredicateHelper {
def apply(plan: LogicalPlan): LogicalPlan =
if (!plan.conf.constraintPropagationEnabled) {
plan
} else plan transform {
case j@Join(left, right, joinType, joinCondition) =>
joinType match {
case RightOuter if joinCondition.isDefined =>
val rightConstraints = right.constraints.union(
splitConjunctivePredicates(joinCondition.get).toSet)
val inferredConstraints = ExpressionSet(
QueryPlanConstraints.inferAdditionalConstraints(rightConstraints))
val leftConditions = inferredConstraints
.filter(_.deterministic)
.filter(_.references.subsetOf(left.outputSet))
if (leftConditions.isEmpty) {
j
} else {
// push the predicate down to left side sub query.
val newLeft = leftConditions.
reduceLeftOption(And).map(Filter(_, left)).getOrElse(left)
val newRight = right

Join(newLeft, newRight, RightOuter, joinCondition)
}
case LeftOuter if joinCondition.isDefined =>
val leftConstraints = left.constraints.union(
splitConjunctivePredicates(joinCondition.get).toSet)
val inferredConstraints = ExpressionSet(
QueryPlanConstraints.inferAdditionalConstraints(leftConstraints))
val rightConditions = inferredConstraints
.filter(_.deterministic)
.filter(_.references.subsetOf(right.outputSet))
if (rightConditions.isEmpty) {
j
} else {
// push the predicate down to right side sub query.
val newRight = rightConditions.
reduceLeftOption(And).map(Filter(_, right)).getOrElse(right)
val newLeft = left

Join(newLeft, newRight, LeftOuter, joinCondition)
}
case _ => j
}
}
}

/**
* Combines two adjacent [[Limit]] operators into one, merging the
* expressions into one single expression.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ trait QueryPlanConstraints { self: LogicalPlan =>
lazy val allConstraints: ExpressionSet = {
if (conf.constraintPropagationEnabled) {
ExpressionSet(validConstraints
.union(inferAdditionalConstraints(validConstraints))
.union(QueryPlanConstraints.inferAdditionalConstraints(validConstraints))
.union(constructIsNotNullConstraints(validConstraints)))
} else {
ExpressionSet(Set.empty)
Expand Down Expand Up @@ -96,13 +96,16 @@ trait QueryPlanConstraints { self: LogicalPlan =>
case _: NullIntolerant => expr.children.flatMap(scanNullIntolerantAttribute)
case _ => Seq.empty[Attribute]
}
}

object QueryPlanConstraints {

/**
* Infers an additional set of constraints from a given set of equality constraints.
* For e.g., if an operator has constraints of the form (`a = 5`, `a = b`), this returns an
* additional constraint of the form `b = 5`.
*/
private def inferAdditionalConstraints(constraints: Set[Expression]): Set[Expression] = {
def inferAdditionalConstraints(constraints: Set[Expression]): Set[Expression] = {
var inferredConstraints = Set.empty[Expression]
constraints.foreach {
case eq @ EqualTo(l: Attribute, r: Attribute) =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ class InferFiltersFromConstraintsSuite extends PlanTest {
PushPredicateThroughJoin,
PushDownPredicate,
InferFiltersFromConstraints,
TransitPredicateInOuterJoin,
CombineFilters,
SimplifyBinaryComparison,
BooleanSimplification) :: Nil
Expand Down Expand Up @@ -204,4 +205,40 @@ class InferFiltersFromConstraintsSuite extends PlanTest {
val optimized = Optimize.execute(originalQuery)
comparePlans(optimized, correctAnswer)
}

test("SPARK-21479: Outer join after-join filters push down to null-supplying side") {
val x = testRelation.subquery('x)
val y = testRelation.subquery('y)
val condition = Some("x.a".attr === "y.a".attr)
val originalQuery = x.join(y, LeftOuter, condition).where("x.a".attr === 2).analyze
val left = x.where(IsNotNull('a) && 'a === 2)
val right = y.where(IsNotNull('a) && 'a === 2)
val correctAnswer = left.join(right, LeftOuter, condition).analyze
val optimized = Optimize.execute(originalQuery)
comparePlans(optimized, correctAnswer)
}

test("SPARK-21479: Outer join pre-existing filters push down to null-supplying side") {
val x = testRelation.subquery('x)
val y = testRelation.subquery('y)
val condition = Some("x.a".attr === "y.a".attr)
val originalQuery = x.join(y.where("y.a".attr > 5), RightOuter, condition).analyze
val left = x.where(IsNotNull('a) && 'a > 5)
val right = y.where(IsNotNull('a) && 'a > 5)
val correctAnswer = left.join(right, RightOuter, condition).analyze
val optimized = Optimize.execute(originalQuery)
comparePlans(optimized, correctAnswer)
}

test("SPARK-21479: Outer join no filter push down to preserved side") {
val x = testRelation.subquery('x)
val y = testRelation.subquery('y)
val condition = Some("x.a".attr === "y.a".attr)
val originalQuery = x.join(y.where("y.a".attr === 1), LeftOuter, condition).analyze
val left = x
val right = y.where(IsNotNull('a) && 'a === 1)
val correctAnswer = left.join(right, LeftOuter, condition).analyze
val optimized = Optimize.execute(originalQuery)
comparePlans(optimized, correctAnswer)
}
}

0 comments on commit ac17976

Please sign in to comment.