diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index 91208479be03b..b63468ac24a68 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -58,6 +58,7 @@ abstract class Optimizer(sessionCatalog: SessionCatalog) LimitPushDown, ColumnPruning, InferFiltersFromConstraints, + TransitPredicateInOuterJoin, // Operator combine CollapseRepartition, CollapseProject, @@ -1071,6 +1072,66 @@ object PushPredicateThroughJoin extends Rule[LogicalPlan] with PredicateHelper { } } +/** + * Infer and transit predicate from the preserved side to the null-supplying side + * of an outer join. The predicate is inferred from the preserved side based on the + * join condition and will be pushed over to the null-supplying side. For example, + * if the preserved side has constraints of the form 'a > 5' and the join condition + * is 'a = b', in which 'b' is an attribute from the null-supplying side, a [[Filter]] + * operator of 'b > 5' will be applied to the null-supplying side. + * + * Applying this rule will not change the constraints of the [[Join]] operator, so + * aside from its child being transformed, there is no side-effect to the [[Join]] + * operator itself. + */ +object TransitPredicateInOuterJoin extends Rule[LogicalPlan] with PredicateHelper { + def apply(plan: LogicalPlan): LogicalPlan = + if (!plan.conf.constraintPropagationEnabled) { + plan + } else plan transform { + case j@Join(left, right, joinType, joinCondition) => + joinType match { + case RightOuter if joinCondition.isDefined => + val rightConstraints = right.constraints.union( + splitConjunctivePredicates(joinCondition.get).toSet) + val inferredConstraints = ExpressionSet( + QueryPlanConstraints.inferAdditionalConstraints(rightConstraints)) + val leftConditions = inferredConstraints + .filter(_.deterministic) + .filter(_.references.subsetOf(left.outputSet)) + if (leftConditions.isEmpty) { + j + } else { + // push the predicate down to left side sub query. + val newLeft = leftConditions. + reduceLeftOption(And).map(Filter(_, left)).getOrElse(left) + val newRight = right + + Join(newLeft, newRight, RightOuter, joinCondition) + } + case LeftOuter if joinCondition.isDefined => + val leftConstraints = left.constraints.union( + splitConjunctivePredicates(joinCondition.get).toSet) + val inferredConstraints = ExpressionSet( + QueryPlanConstraints.inferAdditionalConstraints(leftConstraints)) + val rightConditions = inferredConstraints + .filter(_.deterministic) + .filter(_.references.subsetOf(right.outputSet)) + if (rightConditions.isEmpty) { + j + } else { + // push the predicate down to right side sub query. + val newRight = rightConditions. + reduceLeftOption(And).map(Filter(_, right)).getOrElse(right) + val newLeft = left + + Join(newLeft, newRight, LeftOuter, joinCondition) + } + case _ => j + } + } +} + /** * Combines two adjacent [[Limit]] operators into one, merging the * expressions into one single expression. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/QueryPlanConstraints.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/QueryPlanConstraints.scala index 046848875548b..de49d749d86e1 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/QueryPlanConstraints.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/QueryPlanConstraints.scala @@ -29,7 +29,7 @@ trait QueryPlanConstraints { self: LogicalPlan => lazy val allConstraints: ExpressionSet = { if (conf.constraintPropagationEnabled) { ExpressionSet(validConstraints - .union(inferAdditionalConstraints(validConstraints)) + .union(QueryPlanConstraints.inferAdditionalConstraints(validConstraints)) .union(constructIsNotNullConstraints(validConstraints))) } else { ExpressionSet(Set.empty) @@ -96,13 +96,16 @@ trait QueryPlanConstraints { self: LogicalPlan => case _: NullIntolerant => expr.children.flatMap(scanNullIntolerantAttribute) case _ => Seq.empty[Attribute] } +} + +object QueryPlanConstraints { /** * Infers an additional set of constraints from a given set of equality constraints. * For e.g., if an operator has constraints of the form (`a = 5`, `a = b`), this returns an * additional constraint of the form `b = 5`. */ - private def inferAdditionalConstraints(constraints: Set[Expression]): Set[Expression] = { + def inferAdditionalConstraints(constraints: Set[Expression]): Set[Expression] = { var inferredConstraints = Set.empty[Expression] constraints.foreach { case eq @ EqualTo(l: Attribute, r: Attribute) => diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/InferFiltersFromConstraintsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/InferFiltersFromConstraintsSuite.scala index f78c2356e35a5..c877fffdec57d 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/InferFiltersFromConstraintsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/InferFiltersFromConstraintsSuite.scala @@ -33,6 +33,7 @@ class InferFiltersFromConstraintsSuite extends PlanTest { PushPredicateThroughJoin, PushDownPredicate, InferFiltersFromConstraints, + TransitPredicateInOuterJoin, CombineFilters, SimplifyBinaryComparison, BooleanSimplification) :: Nil @@ -204,4 +205,40 @@ class InferFiltersFromConstraintsSuite extends PlanTest { val optimized = Optimize.execute(originalQuery) comparePlans(optimized, correctAnswer) } + + test("SPARK-21479: Outer join after-join filters push down to null-supplying side") { + val x = testRelation.subquery('x) + val y = testRelation.subquery('y) + val condition = Some("x.a".attr === "y.a".attr) + val originalQuery = x.join(y, LeftOuter, condition).where("x.a".attr === 2).analyze + val left = x.where(IsNotNull('a) && 'a === 2) + val right = y.where(IsNotNull('a) && 'a === 2) + val correctAnswer = left.join(right, LeftOuter, condition).analyze + val optimized = Optimize.execute(originalQuery) + comparePlans(optimized, correctAnswer) + } + + test("SPARK-21479: Outer join pre-existing filters push down to null-supplying side") { + val x = testRelation.subquery('x) + val y = testRelation.subquery('y) + val condition = Some("x.a".attr === "y.a".attr) + val originalQuery = x.join(y.where("y.a".attr > 5), RightOuter, condition).analyze + val left = x.where(IsNotNull('a) && 'a > 5) + val right = y.where(IsNotNull('a) && 'a > 5) + val correctAnswer = left.join(right, RightOuter, condition).analyze + val optimized = Optimize.execute(originalQuery) + comparePlans(optimized, correctAnswer) + } + + test("SPARK-21479: Outer join no filter push down to preserved side") { + val x = testRelation.subquery('x) + val y = testRelation.subquery('y) + val condition = Some("x.a".attr === "y.a".attr) + val originalQuery = x.join(y.where("y.a".attr === 1), LeftOuter, condition).analyze + val left = x + val right = y.where(IsNotNull('a) && 'a === 1) + val correctAnswer = left.join(right, LeftOuter, condition).analyze + val optimized = Optimize.execute(originalQuery) + comparePlans(optimized, correctAnswer) + } }