Skip to content

Commit

Permalink
better solution for pushing extra predicates through join
Browse files Browse the repository at this point in the history
  • Loading branch information
gengliangwang committed Jul 16, 2020
1 parent b05f309 commit 16c78d6
Show file tree
Hide file tree
Showing 9 changed files with 99 additions and 378 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -201,126 +201,50 @@ trait PredicateHelper extends Logging {
case e => e.children.forall(canEvaluateWithinJoin)
}

/**
* Convert an expression into conjunctive normal form.
* Definition and algorithm: https://en.wikipedia.org/wiki/Conjunctive_normal_form
* CNF can explode exponentially in the size of the input expression when converting [[Or]]
* clauses. Use a configuration [[SQLConf.MAX_CNF_NODE_COUNT]] to prevent such cases.
*
* @param condition to be converted into CNF.
* @return the CNF result as sequence of disjunctive expressions. If the number of expressions
* exceeds threshold on converting `Or`, `Seq.empty` is returned.
/*
* Returns a filter that it's output is a subset of `outputSet` and it contains all possible
* constraints from `condition`. This is used for predicate pushdown.
* When there is no such convertible filter, `None` is returned.
*/
protected def conjunctiveNormalForm(
condition: Expression,
groupExpsFunc: Seq[Expression] => Seq[Expression]): Seq[Expression] = {
val postOrderNodes = postOrderTraversal(condition)
val resultStack = new mutable.Stack[Seq[Expression]]
val maxCnfNodeCount = SQLConf.get.maxCnfNodeCount
// Bottom up approach to get CNF of sub-expressions
while (postOrderNodes.nonEmpty) {
val cnf = postOrderNodes.pop() match {
case _: And =>
val right = resultStack.pop()
val left = resultStack.pop()
left ++ right
case _: Or =>
// For each side, there is no need to expand predicates of the same references.
// So here we can aggregate predicates of the same qualifier as one single predicate,
// for reducing the size of pushed down predicates and corresponding codegen.
val right = groupExpsFunc(resultStack.pop())
val left = groupExpsFunc(resultStack.pop())
// Stop the loop whenever the result exceeds the `maxCnfNodeCount`
if (left.size * right.size > maxCnfNodeCount) {
logInfo(s"As the result size exceeds the threshold $maxCnfNodeCount. " +
"The CNF conversion is skipped and returning Seq.empty now. To avoid this, you can " +
s"raise the limit ${SQLConf.MAX_CNF_NODE_COUNT.key}.")
return Seq.empty
} else {
for { x <- left; y <- right } yield Or(x, y)
}
case other => other :: Nil
protected def convertibleFilter(
condition: Expression,
outputSet: AttributeSet): Option[Expression] = condition match {
case And(left, right) =>
val leftResultOptional = convertibleFilter(left, outputSet)
val rightResultOptional = convertibleFilter(right, outputSet)
(leftResultOptional, rightResultOptional) match {
case (Some(leftResult), Some(rightResult)) => Some(And(leftResult, rightResult))
case (Some(leftResult), None) => Some(leftResult)
case (None, Some(rightResult)) => Some(rightResult)
case _ => None
}
resultStack.push(cnf)
}
if (resultStack.length != 1) {
logWarning("The length of CNF conversion result stack is supposed to be 1. There might " +
"be something wrong with CNF conversion.")
return Seq.empty
}
resultStack.top
}

/**
* Convert an expression to conjunctive normal form when pushing predicates through Join,
* when expand predicates, we can group by the qualifier avoiding generate unnecessary
* expression to control the length of final result since there are multiple tables.
*
* @param condition condition need to be converted
* @return the CNF result as sequence of disjunctive expressions. If the number of expressions
* exceeds threshold on converting `Or`, `Seq.empty` is returned.
*/
def CNFWithGroupExpressionsByQualifier(condition: Expression): Seq[Expression] = {
conjunctiveNormalForm(condition, (expressions: Seq[Expression]) =>
expressions.groupBy(_.references.map(_.qualifier)).map(_._2.reduceLeft(And)).toSeq)
}

/**
* Convert an expression to conjunctive normal form for predicate pushdown and partition pruning.
* When expanding predicates, this method groups expressions by their references for reducing
* the size of pushed down predicates and corresponding codegen. In partition pruning strategies,
* we split filters by [[splitConjunctivePredicates]] and partition filters by judging if it's
* references is subset of partCols, if we combine expressions group by reference when expand
* predicate of [[Or]], it won't impact final predicate pruning result since
* [[splitConjunctivePredicates]] won't split [[Or]] expression.
*
* @param condition condition need to be converted
* @return the CNF result as sequence of disjunctive expressions. If the number of expressions
* exceeds threshold on converting `Or`, `Seq.empty` is returned.
*/
def CNFWithGroupExpressionsByReference(condition: Expression): Seq[Expression] = {
conjunctiveNormalForm(condition, (expressions: Seq[Expression]) =>
expressions.groupBy(e => AttributeSet(e.references)).map(_._2.reduceLeft(And)).toSeq)
}

/**
* Iterative post order traversal over a binary tree built by And/Or clauses with two stacks.
* For example, a condition `(a And b) Or c`, the postorder traversal is
* (`a`,`b`, `And`, `c`, `Or`).
* Following is the complete algorithm. After step 2, we get the postorder traversal in
* the second stack.
* 1. Push root to first stack.
* 2. Loop while first stack is not empty
* 2.1 Pop a node from first stack and push it to second stack
* 2.2 Push the children of the popped node to first stack
*
* @param condition to be traversed as binary tree
* @return sub-expressions in post order traversal as a stack.
* The first element of result stack is the leftmost node.
*/
private def postOrderTraversal(condition: Expression): mutable.Stack[Expression] = {
val stack = new mutable.Stack[Expression]
val result = new mutable.Stack[Expression]
stack.push(condition)
while (stack.nonEmpty) {
val node = stack.pop()
node match {
case Not(a And b) => stack.push(Or(Not(a), Not(b)))
case Not(a Or b) => stack.push(And(Not(a), Not(b)))
case Not(Not(a)) => stack.push(a)
case a And b =>
result.push(node)
stack.push(a)
stack.push(b)
case a Or b =>
result.push(node)
stack.push(a)
stack.push(b)
case _ =>
result.push(node)
// The Or predicate is convertible when both of its children can be pushed down.
// That is to say, if one/both of the children can be partially pushed down, the Or
// predicate can be partially pushed down as well.
//
// Here is an example used to explain the reason.
// Let's say we have
// (a1 AND a2) OR (b1 AND b2),
// a1 and b1 is convertible, while a2 and b2 is not.
// The predicate can be converted as
// (a1 OR b1) AND (a1 OR b2) AND (a2 OR b1) AND (a2 OR b2)
// As per the logical in And predicate, we can push down (a1 OR b1).
case Or(left, right) =>
for {
lhs <- convertibleFilter(left, outputSet)
rhs <- convertibleFilter(right, outputSet)
} yield Or(lhs, rhs)

// Here we assume all the `Not` operators is already below all the `And` and `Or` operators
// after the optimization rule `BooleanSimplification`, so that we don't need to handle the
// `Not` operators here.
case other =>
if (other.references.subsetOf(outputSet)) {
Some(other)
} else {
None
}
}
result
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -51,8 +51,7 @@ abstract class Optimizer(catalogManager: CatalogManager)
override protected val excludedOnceBatches: Set[String] =
Set(
"PartitionPruning",
"Extract Python UDFs",
"Push CNF predicate through join")
"Extract Python UDFs")

protected def fixedPoint =
FixedPoint(
Expand Down Expand Up @@ -123,8 +122,9 @@ abstract class Optimizer(catalogManager: CatalogManager)
rulesWithoutInferFiltersFromConstraints: _*) ::
// Set strategy to Once to avoid pushing filter every time because we do not change the
// join condition.
Batch("Push CNF predicate through join", Once,
PushCNFPredicateThroughJoin) :: Nil
Batch("Push extra predicate through join", fixedPoint,
PushExtraPredicateThroughJoin,
PushDownPredicates) :: Nil
}

val batches = (Batch("Eliminate Distinct", Once, EliminateDistinct) ::
Expand Down

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -545,19 +545,6 @@ object SQLConf {
.booleanConf
.createWithDefault(true)

val MAX_CNF_NODE_COUNT =
buildConf("spark.sql.optimizer.maxCNFNodeCount")
.internal()
.doc("Specifies the maximum allowable number of conjuncts in the result of CNF " +
"conversion. If the conversion exceeds the threshold, an empty sequence is returned. " +
"For example, CNF conversion of (a && b) || (c && d) generates " +
"four conjuncts (a || c) && (a || d) && (b || c) && (b || d).")
.version("3.1.0")
.intConf
.checkValue(_ >= 0,
"The depth of the maximum rewriting conjunction normal form must be positive.")
.createWithDefault(128)

val ESCAPED_STRING_LITERALS = buildConf("spark.sql.parser.escapedStringLiterals")
.internal()
.doc("When true, string literals (including regex patterns) remain escaped in our SQL " +
Expand Down Expand Up @@ -2948,8 +2935,6 @@ class SQLConf extends Serializable with Logging {

def constraintPropagationEnabled: Boolean = getConf(CONSTRAINT_PROPAGATION_ENABLED)

def maxCnfNodeCount: Int = getConf(MAX_CNF_NODE_COUNT)

def escapedStringLiterals: Boolean = getConf(ESCAPED_STRING_LITERALS)

def fileCompressionFactor: Double = getConf(FILE_COMPRESSION_FACTOR)
Expand Down
Loading

0 comments on commit 16c78d6

Please sign in to comment.