Skip to content

Commit

Permalink
[SPARK-23564][SQL] infer additional filters from constraints for join…
Browse files Browse the repository at this point in the history
…'s children

## What changes were proposed in this pull request?

The existing query constraints framework has 2 steps:
1. propagate constraints bottom up.
2. use constraints to infer additional filters for better data pruning.

For step 2, it mostly helps with Join, because we can connect the constraints from children to the join condition and infer powerful filters to prune the data of the join sides. e.g., the left side has constraints `a = 1`, the join condition is `left.a = right.a`, then we can infer `right.a = 1` to the right side and prune the right side a lot.

However, the current logic of inferring filters from constraints for Join is pretty weak. It infers the filters from Join's constraints. Some joins like left semi/anti exclude output from right side and the right side constraints will be lost here.

This PR propose to check the left and right constraints individually, expand the constraints with join condition and add filters to children of join directly, instead of adding to the join condition.

This reverts #20670 , covers #20717 and #20816

This is inspired by the original PRs and the tests are all from these PRs. Thanks to the authors mgaido91 maryannxue KaiXinXiaoLei !

## How was this patch tested?

new tests

Author: Wenchen Fan <[email protected]>

Closes #21083 from cloud-fan/join.
  • Loading branch information
cloud-fan committed Apr 23, 2018
1 parent f70f46d commit d87d30e
Show file tree
Hide file tree
Showing 3 changed files with 124 additions and 121 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -637,13 +637,11 @@ object CollapseWindow extends Rule[LogicalPlan] {
* constraints. These filters are currently inserted to the existing conditions in the Filter
* operators and on either side of Join operators.
*
* In addition, for left/right outer joins, infer predicate from the preserved side of the Join
* operator and push the inferred filter over to the null-supplying side. For example, if the
* preserved side has constraints of the form 'a > 5' and the join condition is 'a = b', in
* which 'b' is an attribute from the null-supplying side, a [[Filter]] operator of 'b > 5' will
* be applied to the null-supplying side.
* Note: While this optimization is applicable to a lot of types of join, it primarily benefits
* Inner and LeftSemi joins.
*/
object InferFiltersFromConstraints extends Rule[LogicalPlan] with PredicateHelper {
object InferFiltersFromConstraints extends Rule[LogicalPlan]
with PredicateHelper with ConstraintHelper {

def apply(plan: LogicalPlan): LogicalPlan = {
if (SQLConf.get.constraintPropagationEnabled) {
Expand All @@ -664,53 +662,52 @@ object InferFiltersFromConstraints extends Rule[LogicalPlan] with PredicateHelpe
}

case join @ Join(left, right, joinType, conditionOpt) =>
// Only consider constraints that can be pushed down completely to either the left or the
// right child
val constraints = join.allConstraints.filter { c =>
c.references.subsetOf(left.outputSet) || c.references.subsetOf(right.outputSet)
}
// Remove those constraints that are already enforced by either the left or the right child
val additionalConstraints = constraints -- (left.constraints ++ right.constraints)
val newConditionOpt = conditionOpt match {
case Some(condition) =>
val newFilters = additionalConstraints -- splitConjunctivePredicates(condition)
if (newFilters.nonEmpty) Option(And(newFilters.reduce(And), condition)) else conditionOpt
case None =>
additionalConstraints.reduceOption(And)
}
// Infer filter for left/right outer joins
val newLeftOpt = joinType match {
case RightOuter if newConditionOpt.isDefined =>
val inferredConstraints = left.getRelevantConstraints(
left.constraints
.union(right.constraints)
.union(splitConjunctivePredicates(newConditionOpt.get).toSet))
val newFilters = inferredConstraints
.filterNot(left.constraints.contains)
.reduceLeftOption(And)
newFilters.map(Filter(_, left))
case _ => None
}
val newRightOpt = joinType match {
case LeftOuter if newConditionOpt.isDefined =>
val inferredConstraints = right.getRelevantConstraints(
right.constraints
.union(left.constraints)
.union(splitConjunctivePredicates(newConditionOpt.get).toSet))
val newFilters = inferredConstraints
.filterNot(right.constraints.contains)
.reduceLeftOption(And)
newFilters.map(Filter(_, right))
case _ => None
}
joinType match {
// For inner join, we can infer additional filters for both sides. LeftSemi is kind of an
// inner join, it just drops the right side in the final output.
case _: InnerLike | LeftSemi =>
val allConstraints = getAllConstraints(left, right, conditionOpt)
val newLeft = inferNewFilter(left, allConstraints)
val newRight = inferNewFilter(right, allConstraints)
join.copy(left = newLeft, right = newRight)

if ((newConditionOpt.isDefined && (newConditionOpt ne conditionOpt))
|| newLeftOpt.isDefined || newRightOpt.isDefined) {
Join(newLeftOpt.getOrElse(left), newRightOpt.getOrElse(right), joinType, newConditionOpt)
} else {
join
// For right outer join, we can only infer additional filters for left side.
case RightOuter =>
val allConstraints = getAllConstraints(left, right, conditionOpt)
val newLeft = inferNewFilter(left, allConstraints)
join.copy(left = newLeft)

// For left join, we can only infer additional filters for right side.
case LeftOuter | LeftAnti =>
val allConstraints = getAllConstraints(left, right, conditionOpt)
val newRight = inferNewFilter(right, allConstraints)
join.copy(right = newRight)

case _ => join
}
}

private def getAllConstraints(
left: LogicalPlan,
right: LogicalPlan,
conditionOpt: Option[Expression]): Set[Expression] = {
val baseConstraints = left.constraints.union(right.constraints)
.union(conditionOpt.map(splitConjunctivePredicates).getOrElse(Nil).toSet)
baseConstraints.union(inferAdditionalConstraints(baseConstraints))
}

private def inferNewFilter(plan: LogicalPlan, constraints: Set[Expression]): LogicalPlan = {
val newPredicates = constraints
.union(constructIsNotNullConstraints(constraints, plan.output))
.filter { c =>
c.references.nonEmpty && c.references.subsetOf(plan.outputSet) && c.deterministic
} -- plan.constraints
if (newPredicates.isEmpty) {
plan
} else {
Filter(newPredicates.reduce(And), plan)
}
}
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,29 +20,28 @@ package org.apache.spark.sql.catalyst.plans.logical
import org.apache.spark.sql.catalyst.expressions._


trait QueryPlanConstraints { self: LogicalPlan =>
trait QueryPlanConstraints extends ConstraintHelper { self: LogicalPlan =>

/**
* An [[ExpressionSet]] that contains an additional set of constraints, such as equality
* constraints and `isNotNull` constraints, etc.
* An [[ExpressionSet]] that contains invariants about the rows output by this operator. For
* example, if this set contains the expression `a = 2` then that expression is guaranteed to
* evaluate to `true` for all rows produced.
*/
lazy val allConstraints: ExpressionSet = {
lazy val constraints: ExpressionSet = {
if (conf.constraintPropagationEnabled) {
ExpressionSet(validConstraints
.union(inferAdditionalConstraints(validConstraints))
.union(constructIsNotNullConstraints(validConstraints)))
ExpressionSet(
validConstraints
.union(inferAdditionalConstraints(validConstraints))
.union(constructIsNotNullConstraints(validConstraints, output))
.filter { c =>
c.references.nonEmpty && c.references.subsetOf(outputSet) && c.deterministic
}
)
} else {
ExpressionSet(Set.empty)
}
}

/**
* An [[ExpressionSet]] that contains invariants about the rows output by this operator. For
* example, if this set contains the expression `a = 2` then that expression is guaranteed to
* evaluate to `true` for all rows produced.
*/
lazy val constraints: ExpressionSet = ExpressionSet(allConstraints.filter(selfReferenceOnly))

/**
* This method can be overridden by any child class of QueryPlan to specify a set of constraints
* based on the given operator's constraint propagation logic. These constraints are then
Expand All @@ -52,30 +51,42 @@ trait QueryPlanConstraints { self: LogicalPlan =>
* See [[Canonicalize]] for more details.
*/
protected def validConstraints: Set[Expression] = Set.empty
}

trait ConstraintHelper {

/**
* Returns an [[ExpressionSet]] that contains an additional set of constraints, such as
* equality constraints and `isNotNull` constraints, etc., and that only contains references
* to this [[LogicalPlan]] node.
* Infers an additional set of constraints from a given set of equality constraints.
* For e.g., if an operator has constraints of the form (`a = 5`, `a = b`), this returns an
* additional constraint of the form `b = 5`.
*/
def getRelevantConstraints(constraints: Set[Expression]): ExpressionSet = {
val allRelevantConstraints =
if (conf.constraintPropagationEnabled) {
constraints
.union(inferAdditionalConstraints(constraints))
.union(constructIsNotNullConstraints(constraints))
} else {
constraints
}
ExpressionSet(allRelevantConstraints.filter(selfReferenceOnly))
def inferAdditionalConstraints(constraints: Set[Expression]): Set[Expression] = {
var inferredConstraints = Set.empty[Expression]
constraints.foreach {
case eq @ EqualTo(l: Attribute, r: Attribute) =>
val candidateConstraints = constraints - eq
inferredConstraints ++= replaceConstraints(candidateConstraints, l, r)
inferredConstraints ++= replaceConstraints(candidateConstraints, r, l)
case _ => // No inference
}
inferredConstraints -- constraints
}

private def replaceConstraints(
constraints: Set[Expression],
source: Expression,
destination: Attribute): Set[Expression] = constraints.map(_ transform {
case e: Expression if e.semanticEquals(source) => destination
})

/**
* Infers a set of `isNotNull` constraints from null intolerant expressions as well as
* non-nullable attributes. For e.g., if an expression is of the form (`a > 5`), this
* returns a constraint of the form `isNotNull(a)`
*/
private def constructIsNotNullConstraints(constraints: Set[Expression]): Set[Expression] = {
def constructIsNotNullConstraints(
constraints: Set[Expression],
output: Seq[Attribute]): Set[Expression] = {
// First, we propagate constraints from the null intolerant expressions.
var isNotNullConstraints: Set[Expression] = constraints.flatMap(inferIsNotNullConstraints)

Expand Down Expand Up @@ -111,32 +122,4 @@ trait QueryPlanConstraints { self: LogicalPlan =>
case _: NullIntolerant => expr.children.flatMap(scanNullIntolerantAttribute)
case _ => Seq.empty[Attribute]
}

/**
* Infers an additional set of constraints from a given set of equality constraints.
* For e.g., if an operator has constraints of the form (`a = 5`, `a = b`), this returns an
* additional constraint of the form `b = 5`.
*/
private def inferAdditionalConstraints(constraints: Set[Expression]): Set[Expression] = {
var inferredConstraints = Set.empty[Expression]
constraints.foreach {
case eq @ EqualTo(l: Attribute, r: Attribute) =>
val candidateConstraints = constraints - eq
inferredConstraints ++= replaceConstraints(candidateConstraints, l, r)
inferredConstraints ++= replaceConstraints(candidateConstraints, r, l)
case _ => // No inference
}
inferredConstraints -- constraints
}

private def replaceConstraints(
constraints: Set[Expression],
source: Expression,
destination: Attribute): Set[Expression] = constraints.map(_ transform {
case e: Expression if e.semanticEquals(source) => destination
})

private def selfReferenceOnly(e: Expression): Boolean = {
e.references.nonEmpty && e.references.subsetOf(outputSet) && e.deterministic
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -35,11 +35,25 @@ class InferFiltersFromConstraintsSuite extends PlanTest {
InferFiltersFromConstraints,
CombineFilters,
SimplifyBinaryComparison,
BooleanSimplification) :: Nil
BooleanSimplification,
PruneFilters) :: Nil
}

val testRelation = LocalRelation('a.int, 'b.int, 'c.int)

private def testConstraintsAfterJoin(
x: LogicalPlan,
y: LogicalPlan,
expectedLeft: LogicalPlan,
expectedRight: LogicalPlan,
joinType: JoinType) = {
val condition = Some("x.a".attr === "y.a".attr)
val originalQuery = x.join(y, joinType, condition).analyze
val correctAnswer = expectedLeft.join(expectedRight, joinType, condition).analyze
val optimized = Optimize.execute(originalQuery)
comparePlans(optimized, correctAnswer)
}

test("filter: filter out constraints in condition") {
val originalQuery = testRelation.where('a === 1 && 'a === 'b).analyze
val correctAnswer = testRelation
Expand Down Expand Up @@ -196,13 +210,7 @@ class InferFiltersFromConstraintsSuite extends PlanTest {
test("SPARK-23405: left-semi equal-join should filter out null join keys on both sides") {
val x = testRelation.subquery('x)
val y = testRelation.subquery('y)
val condition = Some("x.a".attr === "y.a".attr)
val originalQuery = x.join(y, LeftSemi, condition).analyze
val left = x.where(IsNotNull('a))
val right = y.where(IsNotNull('a))
val correctAnswer = left.join(right, LeftSemi, condition).analyze
val optimized = Optimize.execute(originalQuery)
comparePlans(optimized, correctAnswer)
testConstraintsAfterJoin(x, y, x.where(IsNotNull('a)), y.where(IsNotNull('a)), LeftSemi)
}

test("SPARK-21479: Outer join after-join filters push down to null-supplying side") {
Expand Down Expand Up @@ -232,12 +240,27 @@ class InferFiltersFromConstraintsSuite extends PlanTest {
test("SPARK-21479: Outer join no filter push down to preserved side") {
val x = testRelation.subquery('x)
val y = testRelation.subquery('y)
val condition = Some("x.a".attr === "y.a".attr)
val originalQuery = x.join(y.where("y.a".attr === 1), LeftOuter, condition).analyze
val left = x
val right = y.where(IsNotNull('a) && 'a === 1)
val correctAnswer = left.join(right, LeftOuter, condition).analyze
val optimized = Optimize.execute(originalQuery)
comparePlans(optimized, correctAnswer)
testConstraintsAfterJoin(
x, y.where("a".attr === 1),
x, y.where(IsNotNull('a) && 'a === 1),
LeftOuter)
}

test("SPARK-23564: left anti join should filter out null join keys on right side") {
val x = testRelation.subquery('x)
val y = testRelation.subquery('y)
testConstraintsAfterJoin(x, y, x, y.where(IsNotNull('a)), LeftAnti)
}

test("SPARK-23564: left outer join should filter out null join keys on right side") {
val x = testRelation.subquery('x)
val y = testRelation.subquery('y)
testConstraintsAfterJoin(x, y, x, y.where(IsNotNull('a)), LeftOuter)
}

test("SPARK-23564: right outer join should filter out null join keys on left side") {
val x = testRelation.subquery('x)
val y = testRelation.subquery('y)
testConstraintsAfterJoin(x, y, x.where(IsNotNull('a)), y, RightOuter)
}
}

0 comments on commit d87d30e

Please sign in to comment.