Skip to content

Commit

Permalink
[SPARK-22944][SQL] improve FoldablePropagation
Browse files Browse the repository at this point in the history
## What changes were proposed in this pull request?

`FoldablePropagation` is a little tricky as it needs to handle attributes that are miss-derived from children, e.g. outer join outputs. This rule does a kind of stop-able tree transform, to skip to apply this rule when hit a node which may have miss-derived attributes.

Logically we should be able to apply this rule above the unsupported nodes, by just treating the unsupported nodes as leaf nodes. This PR improves this rule to not stop the tree transformation, but reduce the foldable expressions that we want to propagate.

## How was this patch tested?

existing tests

Author: Wenchen Fan <[email protected]>

Closes #20139 from cloud-fan/foldable.

(cherry picked from commit 7d045c5)
Signed-off-by: gatorsmile <[email protected]>
  • Loading branch information
cloud-fan authored and gatorsmile committed Jan 4, 2018
1 parent a51212b commit f51c8fd
Show file tree
Hide file tree
Showing 2 changed files with 58 additions and 30 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -506,18 +506,21 @@ object NullPropagation extends Rule[LogicalPlan] {


/**
* Propagate foldable expressions:
* Replace attributes with aliases of the original foldable expressions if possible.
* Other optimizations will take advantage of the propagated foldable expressions.
*
* Other optimizations will take advantage of the propagated foldable expressions. For example,
* this rule can optimize
* {{{
* SELECT 1.0 x, 'abc' y, Now() z ORDER BY x, y, 3
* ==> SELECT 1.0 x, 'abc' y, Now() z ORDER BY 1.0, 'abc', Now()
* }}}
* to
* {{{
* SELECT 1.0 x, 'abc' y, Now() z ORDER BY 1.0, 'abc', Now()
* }}}
* and other rules can further optimize it and remove the ORDER BY operator.
*/
object FoldablePropagation extends Rule[LogicalPlan] {
def apply(plan: LogicalPlan): LogicalPlan = {
val foldableMap = AttributeMap(plan.flatMap {
var foldableMap = AttributeMap(plan.flatMap {
case Project(projectList, _) => projectList.collect {
case a: Alias if a.child.foldable => (a.toAttribute, a)
}
Expand All @@ -530,38 +533,44 @@ object FoldablePropagation extends Rule[LogicalPlan] {
if (foldableMap.isEmpty) {
plan
} else {
var stop = false
CleanupAliases(plan.transformUp {
// A leaf node should not stop the folding process (note that we are traversing up the
// tree, starting at the leaf nodes); so we are allowing it.
case l: LeafNode =>
l

// We can only propagate foldables for a subset of unary nodes.
case u: UnaryNode if !stop && canPropagateFoldables(u) =>
case u: UnaryNode if foldableMap.nonEmpty && canPropagateFoldables(u) =>
u.transformExpressions(replaceFoldable)

// Allow inner joins. We do not allow outer join, although its output attributes are
// derived from its children, they are actually different attributes: the output of outer
// join is not always picked from its children, but can also be null.
// Join derives the output attributes from its child while they are actually not the
// same attributes. For example, the output of outer join is not always picked from its
// children, but can also be null. We should exclude these miss-derived attributes when
// propagating the foldable expressions.
// TODO(cloud-fan): It seems more reasonable to use new attributes as the output attributes
// of outer join.
case j @ Join(_, _, Inner, _) if !stop =>
j.transformExpressions(replaceFoldable)

// We can fold the projections an expand holds. However expand changes the output columns
// and often reuses the underlying attributes; so we cannot assume that a column is still
// foldable after the expand has been applied.
// TODO(hvanhovell): Expand should use new attributes as the output attributes.
case expand: Expand if !stop =>
val newExpand = expand.copy(projections = expand.projections.map { projection =>
case j @ Join(left, right, joinType, _) if foldableMap.nonEmpty =>
val newJoin = j.transformExpressions(replaceFoldable)
val missDerivedAttrsSet: AttributeSet = AttributeSet(joinType match {
case _: InnerLike | LeftExistence(_) => Nil
case LeftOuter => right.output
case RightOuter => left.output
case FullOuter => left.output ++ right.output
})
foldableMap = AttributeMap(foldableMap.baseMap.values.filterNot {
case (attr, _) => missDerivedAttrsSet.contains(attr)
}.toSeq)
newJoin

// We can not replace the attributes in `Expand.output`. If there are other non-leaf
// operators that have the `output` field, we should put them here too.
case expand: Expand if foldableMap.nonEmpty =>
expand.copy(projections = expand.projections.map { projection =>
projection.map(_.transform(replaceFoldable))
})
stop = true
newExpand

case other =>
stop = true
// For other plans, they are not safe to apply foldable propagation, and they should not
// propagate foldable expressions from children.
case other if foldableMap.nonEmpty =>
val childrenOutputSet = AttributeSet(other.children.flatMap(_.output))
foldableMap = AttributeMap(foldableMap.baseMap.values.filterNot {
case (attr, _) => childrenOutputSet.contains(attr)
}.toSeq)
other
})
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -147,8 +147,8 @@ class FoldablePropagationSuite extends PlanTest {
test("Propagate in expand") {
val c1 = Literal(1).as('a)
val c2 = Literal(2).as('b)
val a1 = c1.toAttribute.withNullability(true)
val a2 = c2.toAttribute.withNullability(true)
val a1 = c1.toAttribute.newInstance().withNullability(true)
val a2 = c2.toAttribute.newInstance().withNullability(true)
val expand = Expand(
Seq(Seq(Literal(null), 'b), Seq('a, Literal(null))),
Seq(a1, a2),
Expand All @@ -161,4 +161,23 @@ class FoldablePropagationSuite extends PlanTest {
val correctAnswer = correctExpand.where(a1.isNotNull).select(a1, a2).analyze
comparePlans(optimized, correctAnswer)
}

test("Propagate above outer join") {
val left = LocalRelation('a.int).select('a, Literal(1).as('b))
val right = LocalRelation('c.int).select('c, Literal(1).as('d))

val join = left.join(
right,
joinType = LeftOuter,
condition = Some('a === 'c && 'b === 'd))
val query = join.select(('b + 3).as('res)).analyze
val optimized = Optimize.execute(query)

val correctAnswer = left.join(
right,
joinType = LeftOuter,
condition = Some('a === 'c && Literal(1) === Literal(1)))
.select((Literal(1) + 3).as('res)).analyze
comparePlans(optimized, correctAnswer)
}
}

0 comments on commit f51c8fd

Please sign in to comment.