Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-21979][SQL]Improve QueryPlanConstraints framework #19201

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -297,7 +297,6 @@ abstract class UnaryNode extends LogicalPlan {
case expr: Expression if expr.semanticEquals(e) =>
a.toAttribute
})
allConstraints += EqualNullSafe(e, a.toAttribute)
case _ => // Don't change.
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -106,91 +106,48 @@ trait QueryPlanConstraints { self: LogicalPlan =>
* Infers an additional set of constraints from a given set of equality constraints.
* For e.g., if an operator has constraints of the form (`a = 5`, `a = b`), this returns an
* additional constraint of the form `b = 5`.
*
* [SPARK-17733] We explicitly prevent producing recursive constraints of the form `a = f(a, b)`
* as they are often useless and can lead to a non-converging set of constraints.
*/
private def inferAdditionalConstraints(constraints: Set[Expression]): Set[Expression] = {
val constraintClasses = generateEquivalentConstraintClasses(constraints)

val aliasedConstraints = eliminateAliasedExpressionInConstraints(constraints)
var inferredConstraints = Set.empty[Expression]
constraints.foreach {
aliasedConstraints.foreach {
case eq @ EqualTo(l: Attribute, r: Attribute) =>
val candidateConstraints = constraints - eq
inferredConstraints ++= candidateConstraints.map(_ transform {
case a: Attribute if a.semanticEquals(l) &&
!isRecursiveDeduction(r, constraintClasses) => r
})
inferredConstraints ++= candidateConstraints.map(_ transform {
case a: Attribute if a.semanticEquals(r) &&
!isRecursiveDeduction(l, constraintClasses) => l
})
val candidateConstraints = aliasedConstraints - eq
inferredConstraints ++= replaceConstraints(candidateConstraints, l, r)
inferredConstraints ++= replaceConstraints(candidateConstraints, r, l)
case _ => // No inference
}
inferredConstraints -- constraints
}

/**
* Generate a sequence of expression sets from constraints, where each set stores an equivalence
* class of expressions. For example, Set(`a = b`, `b = c`, `e = f`) will generate the following
* expression sets: (Set(a, b, c), Set(e, f)). This will be used to search all expressions equal
* to an selected attribute.
* Replace the aliased expression in [[Alias]] with the alias name if both exist in constraints.
* Thus non-converging inference can be prevented.
* E.g. `Alias(b, f(a)), a = b` infers `f(a) = f(f(a))` without eliminating aliased expressions.
* Also, the size of constraints is reduced without losing any information.
* When the inferred filters are pushed down the operators that generate the alias,
* the alias names used in filters are replaced by the aliased expressions.
*/
private def generateEquivalentConstraintClasses(
constraints: Set[Expression]): Seq[Set[Expression]] = {
var constraintClasses = Seq.empty[Set[Expression]]
constraints.foreach {
case eq @ EqualTo(l: Attribute, r: Attribute) =>
// Transform [[Alias]] to its child.
val left = aliasMap.getOrElse(l, l)
val right = aliasMap.getOrElse(r, r)
// Get the expression set for an equivalence constraint class.
val leftConstraintClass = getConstraintClass(left, constraintClasses)
val rightConstraintClass = getConstraintClass(right, constraintClasses)
if (leftConstraintClass.nonEmpty && rightConstraintClass.nonEmpty) {
// Combine the two sets.
constraintClasses = constraintClasses
.diff(leftConstraintClass :: rightConstraintClass :: Nil) :+
(leftConstraintClass ++ rightConstraintClass)
} else if (leftConstraintClass.nonEmpty) { // && rightConstraintClass.isEmpty
// Update equivalence class of `left` expression.
constraintClasses = constraintClasses
.diff(leftConstraintClass :: Nil) :+ (leftConstraintClass + right)
} else if (rightConstraintClass.nonEmpty) { // && leftConstraintClass.isEmpty
// Update equivalence class of `right` expression.
constraintClasses = constraintClasses
.diff(rightConstraintClass :: Nil) :+ (rightConstraintClass + left)
} else { // leftConstraintClass.isEmpty && rightConstraintClass.isEmpty
// Create new equivalence constraint class since neither expression presents
// in any classes.
constraintClasses = constraintClasses :+ Set(left, right)
}
case _ => // Skip
private def eliminateAliasedExpressionInConstraints(constraints: Set[Expression])
: Set[Expression] = {
val attributesInEqualTo = constraints.flatMap {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

make it a set?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is a set :)

case EqualTo(l: Attribute, r: Attribute) => l :: r :: Nil
case _ => Nil
}

constraintClasses
}

/**
* Get all expressions equivalent to the selected expression.
*/
private def getConstraintClass(
expr: Expression,
constraintClasses: Seq[Set[Expression]]): Set[Expression] =
constraintClasses.find(_.contains(expr)).getOrElse(Set.empty[Expression])

/**
* Check whether replace by an [[Attribute]] will cause a recursive deduction. Generally it
* has the form like: `a -> f(a, b)`, where `a` and `b` are expressions and `f` is a function.
* Here we first get all expressions equal to `attr` and then check whether at least one of them
* is a child of the referenced expression.
*/
private def isRecursiveDeduction(
attr: Attribute,
constraintClasses: Seq[Set[Expression]]): Boolean = {
val expr = aliasMap.getOrElse(attr, attr)
getConstraintClass(expr, constraintClasses).exists { e =>
expr.children.exists(_.semanticEquals(e))
var aliasedConstraints = constraints
attributesInEqualTo.foreach { a =>
if (aliasMap.contains(a)) {
val child = aliasMap.get(a).get
aliasedConstraints = replaceConstraints(aliasedConstraints, child, a)
}
}
aliasedConstraints
}

private def replaceConstraints(
constraints: Set[Expression],
source: Expression,
destination: Attribute): Set[Expression] = constraints.map(_ transform {
case e: Expression if e.semanticEquals(source) => destination
})
}
Original file line number Diff line number Diff line change
Expand Up @@ -151,9 +151,9 @@ class InferFiltersFromConstraintsSuite extends PlanTest {
.join(t2, Inner, Some("t.a".attr === "t2.a".attr && "t.d".attr === "t2.a".attr))
.analyze
val correctAnswer = t1
.where(IsNotNull('a) && IsNotNull('b) && 'a <=> 'a && 'b <=> 'b &&'a === 'b)
.where(IsNotNull('a) && IsNotNull('b) &&'a === 'b)
.select('a, 'b.as('d)).as("t")
.join(t2.where(IsNotNull('a) && 'a <=> 'a), Inner,
.join(t2.where(IsNotNull('a)), Inner,
Some("t.a".attr === "t2.a".attr && "t.d".attr === "t2.a".attr))
.analyze
val optimized = Optimize.execute(originalQuery)
Expand All @@ -176,24 +176,48 @@ class InferFiltersFromConstraintsSuite extends PlanTest {
&& "t.int_col".attr === "t2.a".attr))
.analyze
val correctAnswer = t1
.where(IsNotNull('a) && IsNotNull(Coalesce(Seq('a, 'a)))
&& 'a === Coalesce(Seq('a, 'a)) && 'a <=> Coalesce(Seq('a, 'a))
&& Coalesce(Seq('b, 'b)) <=> 'a && 'a === 'b && IsNotNull(Coalesce(Seq('a, 'b)))
&& 'a === Coalesce(Seq('a, 'b)) && Coalesce(Seq('a, 'b)) === 'b
&& IsNotNull('b) && IsNotNull(Coalesce(Seq('b, 'b)))
&& 'b === Coalesce(Seq('b, 'b)) && 'b <=> Coalesce(Seq('b, 'b)))
.where(IsNotNull('a) && IsNotNull(Coalesce(Seq('a, 'a))) && IsNotNull(Coalesce(Seq('b, 'a)))
&& IsNotNull('b) && IsNotNull(Coalesce(Seq('b, 'b))) && IsNotNull(Coalesce(Seq('a, 'b)))
&& 'a === 'b && 'a === Coalesce(Seq('a, 'a)) && 'a === Coalesce(Seq('a, 'b))
&& 'a === Coalesce(Seq('b, 'a)) && 'b === Coalesce(Seq('a, 'b))
&& 'b === Coalesce(Seq('b, 'a)) && 'b === Coalesce(Seq('b, 'b)))
.select('a, 'b.as('d), Coalesce(Seq('a, 'b)).as('int_col))
.select('int_col, 'd, 'a).as("t")
.join(t2
.where(IsNotNull('a) && IsNotNull(Coalesce(Seq('a, 'a)))
&& 'a <=> Coalesce(Seq('a, 'a)) && 'a === Coalesce(Seq('a, 'a)) && 'a <=> 'a), Inner,
.join(
t2.where(IsNotNull('a) && IsNotNull(Coalesce(Seq('a, 'a))) &&
'a === Coalesce(Seq('a, 'a))),
Inner,
Some("t.a".attr === "t2.a".attr && "t.d".attr === "t2.a".attr
&& "t.int_col".attr === "t2.a".attr))
.analyze
val optimized = Optimize.execute(originalQuery)
comparePlans(optimized, correctAnswer)
}

test("inner join with EqualTo expressions containing part of each other: don't generate " +
"constraints for recursive functions") {
val t1 = testRelation.subquery('t1)
val t2 = testRelation.subquery('t2)

// We should prevent `c = Coalese(a, b)` and `a = Coalese(b, c)` from recursively creating
// complicated constraints through the constraint inference procedure.
val originalQuery = t1
.select('a, 'b, 'c, Coalesce(Seq('b, 'c)).as('d), Coalesce(Seq('a, 'b)).as('e))
.where('a === 'd && 'c === 'e)
.join(t2, Inner, Some("t1.a".attr === "t2.a".attr && "t1.c".attr === "t2.c".attr))
.analyze
val correctAnswer = t1
.where(IsNotNull('a) && IsNotNull('c) && 'a === Coalesce(Seq('b, 'c)) &&
'c === Coalesce(Seq('a, 'b)))
.select('a, 'b, 'c, Coalesce(Seq('b, 'c)).as('d), Coalesce(Seq('a, 'b)).as('e))
.join(t2.where(IsNotNull('a) && IsNotNull('c)),
Inner,
Some("t1.a".attr === "t2.a".attr && "t1.c".attr === "t2.c".attr))
.analyze
val optimized = Optimize.execute(originalQuery)
comparePlans(optimized, correctAnswer)
}

test("generate correct filters for alias that don't produce recursive constraints") {
val t1 = testRelation.subquery('t1)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -134,8 +134,6 @@ class ConstraintPropagationSuite extends SparkFunSuite with PlanTest {
verifyConstraints(aliasedRelation.analyze.constraints,
ExpressionSet(Seq(resolveColumn(aliasedRelation.analyze, "x") > 10,
IsNotNull(resolveColumn(aliasedRelation.analyze, "x")),
resolveColumn(aliasedRelation.analyze, "b") <=> resolveColumn(aliasedRelation.analyze, "y"),
resolveColumn(aliasedRelation.analyze, "z") <=> resolveColumn(aliasedRelation.analyze, "x"),
resolveColumn(aliasedRelation.analyze, "z") > 10,
IsNotNull(resolveColumn(aliasedRelation.analyze, "z")))))

Expand Down