Skip to content

Commit

Permalink
[SPARK-31705][SQL] Push more possible predicates through Join via CNF…
Browse files Browse the repository at this point in the history
… conversion

### What changes were proposed in this pull request?

This PR add a new rule to support push predicate through join by rewriting join condition to CNF(conjunctive normal form). The following example is the steps of this rule:

1. Prepare Table:

```sql
CREATE TABLE x(a INT);
CREATE TABLE y(b INT);
...
SELECT * FROM x JOIN y ON ((a < 0 and a > b) or a > 10);
```

2. Convert the join condition to CNF:
```
(a < 0 or a > 10) and (a > b or a > 10)
```

3. Split conjunctive predicates

Predicates
---|
(a < 0 or a > 10)
(a > b or a > 10)

4. Push predicate

Table | Predicate
--- | ---
x | (a < 0 or a > 10)

### Why are the changes needed?
Improve query performance. PostgreSQL, [Impala](https://issues.apache.org/jira/browse/IMPALA-9183) and Hive support this feature.

### Does this PR introduce _any_ user-facing change?
No.

### How was this patch tested?
Unit test and benchmark test.

SQL | Before this PR | After this PR
--- | --- | ---
TPCDS 5T Q13 | 84s | 21s
TPCDS 5T q85 | 66s | 34s
TPCH 1T q19 | 37s | 32s

Closes #28733 from gengliangwang/cnf.

Lead-authored-by: Gengliang Wang <[email protected]>
Co-authored-by: Yuming Wang <[email protected]>
Signed-off-by: Gengliang Wang <[email protected]>
  • Loading branch information
gengliangwang and wangyum committed Jun 11, 2020
1 parent 91cd06b commit 11d3a74
Show file tree
Hide file tree
Showing 6 changed files with 468 additions and 4 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,9 @@
package org.apache.spark.sql.catalyst.expressions

import scala.collection.immutable.TreeSet
import scala.collection.mutable

import org.apache.spark.internal.Logging
import org.apache.spark.sql.catalyst.CatalystTypeConverters.convertToScala
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
Expand Down Expand Up @@ -95,7 +97,7 @@ object Predicate extends CodeGeneratorWithInterpretedFallback[Expression, BasePr
}
}

trait PredicateHelper {
trait PredicateHelper extends Logging {
protected def splitConjunctivePredicates(condition: Expression): Seq[Expression] = {
condition match {
case And(cond1, cond2) =>
Expand Down Expand Up @@ -198,6 +200,98 @@ trait PredicateHelper {
case e: Unevaluable => false
case e => e.children.forall(canEvaluateWithinJoin)
}

/**
* Convert an expression into conjunctive normal form.
* Definition and algorithm: https://en.wikipedia.org/wiki/Conjunctive_normal_form
* CNF can explode exponentially in the size of the input expression when converting [[Or]]
* clauses. Use a configuration [[SQLConf.MAX_CNF_NODE_COUNT]] to prevent such cases.
*
* @param condition to be converted into CNF.
* @return the CNF result as sequence of disjunctive expressions. If the number of expressions
* exceeds threshold on converting `Or`, `Seq.empty` is returned.
*/
def conjunctiveNormalForm(condition: Expression): Seq[Expression] = {
val postOrderNodes = postOrderTraversal(condition)
val resultStack = new mutable.Stack[Seq[Expression]]
val maxCnfNodeCount = SQLConf.get.maxCnfNodeCount
// Bottom up approach to get CNF of sub-expressions
while (postOrderNodes.nonEmpty) {
val cnf = postOrderNodes.pop() match {
case _: And =>
val right = resultStack.pop()
val left = resultStack.pop()
left ++ right
case _: Or =>
// For each side, there is no need to expand predicates of the same references.
// So here we can aggregate predicates of the same qualifier as one single predicate,
// for reducing the size of pushed down predicates and corresponding codegen.
val right = groupExpressionsByQualifier(resultStack.pop())
val left = groupExpressionsByQualifier(resultStack.pop())
// Stop the loop whenever the result exceeds the `maxCnfNodeCount`
if (left.size * right.size > maxCnfNodeCount) {
logInfo(s"As the result size exceeds the threshold $maxCnfNodeCount. " +
"The CNF conversion is skipped and returning Seq.empty now. To avoid this, you can " +
s"raise the limit ${SQLConf.MAX_CNF_NODE_COUNT.key}.")
return Seq.empty
} else {
for { x <- left; y <- right } yield Or(x, y)
}
case other => other :: Nil
}
resultStack.push(cnf)
}
if (resultStack.length != 1) {
logWarning("The length of CNF conversion result stack is supposed to be 1. There might " +
"be something wrong with CNF conversion.")
return Seq.empty
}
resultStack.top
}

private def groupExpressionsByQualifier(expressions: Seq[Expression]): Seq[Expression] = {
expressions.groupBy(_.references.map(_.qualifier)).map(_._2.reduceLeft(And)).toSeq
}

/**
* Iterative post order traversal over a binary tree built by And/Or clauses with two stacks.
* For example, a condition `(a And b) Or c`, the postorder traversal is
* (`a`,`b`, `And`, `c`, `Or`).
* Following is the complete algorithm. After step 2, we get the postorder traversal in
* the second stack.
* 1. Push root to first stack.
* 2. Loop while first stack is not empty
* 2.1 Pop a node from first stack and push it to second stack
* 2.2 Push the children of the popped node to first stack
*
* @param condition to be traversed as binary tree
* @return sub-expressions in post order traversal as a stack.
* The first element of result stack is the leftmost node.
*/
private def postOrderTraversal(condition: Expression): mutable.Stack[Expression] = {
val stack = new mutable.Stack[Expression]
val result = new mutable.Stack[Expression]
stack.push(condition)
while (stack.nonEmpty) {
val node = stack.pop()
node match {
case Not(a And b) => stack.push(Or(Not(a), Not(b)))
case Not(a Or b) => stack.push(And(Not(a), Not(b)))
case Not(Not(a)) => stack.push(a)
case a And b =>
result.push(node)
stack.push(a)
stack.push(b)
case a Or b =>
result.push(node)
stack.push(a)
stack.push(b)
case _ =>
result.push(node)
}
}
result
}
}

@ExpressionDescription(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,8 @@ abstract class Optimizer(catalogManager: CatalogManager)
override protected val blacklistedOnceBatches: Set[String] =
Set(
"PartitionPruning",
"Extract Python UDFs")
"Extract Python UDFs",
"Push CNF predicate through join")

protected def fixedPoint =
FixedPoint(
Expand Down Expand Up @@ -118,7 +119,11 @@ abstract class Optimizer(catalogManager: CatalogManager)
Batch("Infer Filters", Once,
InferFiltersFromConstraints) ::
Batch("Operator Optimization after Inferring Filters", fixedPoint,
rulesWithoutInferFiltersFromConstraints: _*) :: Nil
rulesWithoutInferFiltersFromConstraints: _*) ::
// Set strategy to Once to avoid pushing filter every time because we do not change the
// join condition.
Batch("Push CNF predicate through join", Once,
PushCNFPredicateThroughJoin) :: Nil
}

val batches = (Batch("Eliminate Distinct", Once, EliminateDistinct) ::
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.sql.catalyst.optimizer

import org.apache.spark.sql.catalyst.expressions.{And, PredicateHelper}
import org.apache.spark.sql.catalyst.plans._
import org.apache.spark.sql.catalyst.plans.logical.{Filter, Join, LogicalPlan}
import org.apache.spark.sql.catalyst.rules.Rule

/**
* Try converting join condition to conjunctive normal form expression so that more predicates may
* be able to be pushed down.
* To avoid expanding the join condition, the join condition will be kept in the original form even
* when predicate pushdown happens.
*/
object PushCNFPredicateThroughJoin extends Rule[LogicalPlan] with PredicateHelper {
def apply(plan: LogicalPlan): LogicalPlan = plan transform {
case j @ Join(left, right, joinType, Some(joinCondition), hint) =>
val predicates = conjunctiveNormalForm(joinCondition)
if (predicates.isEmpty) {
j
} else {
val pushDownCandidates = predicates.filter(_.deterministic)
lazy val leftFilterConditions =
pushDownCandidates.filter(_.references.subsetOf(left.outputSet))
lazy val rightFilterConditions =
pushDownCandidates.filter(_.references.subsetOf(right.outputSet))

lazy val newLeft =
leftFilterConditions.reduceLeftOption(And).map(Filter(_, left)).getOrElse(left)
lazy val newRight =
rightFilterConditions.reduceLeftOption(And).map(Filter(_, right)).getOrElse(right)

joinType match {
case _: InnerLike | LeftSemi =>
Join(newLeft, newRight, joinType, Some(joinCondition), hint)
case RightOuter =>
Join(newLeft, right, RightOuter, Some(joinCondition), hint)
case LeftOuter | LeftAnti | ExistenceJoin(_) =>
Join(left, newRight, joinType, Some(joinCondition), hint)
case FullOuter => j
case NaturalJoin(_) => sys.error("Untransformed NaturalJoin node")
case UsingJoin(_, _) => sys.error("Untransformed Using join node")
}
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -545,6 +545,19 @@ object SQLConf {
.booleanConf
.createWithDefault(true)

val MAX_CNF_NODE_COUNT =
buildConf("spark.sql.optimizer.maxCNFNodeCount")
.internal()
.doc("Specifies the maximum allowable number of conjuncts in the result of CNF " +
"conversion. If the conversion exceeds the threshold, an empty sequence is returned. " +
"For example, CNF conversion of (a && b) || (c && d) generates " +
"four conjuncts (a || c) && (a || d) && (b || c) && (b || d).")
.version("3.1.0")
.intConf
.checkValue(_ >= 0,
"The depth of the maximum rewriting conjunction normal form must be positive.")
.createWithDefault(128)

val ESCAPED_STRING_LITERALS = buildConf("spark.sql.parser.escapedStringLiterals")
.internal()
.doc("When true, string literals (including regex patterns) remain escaped in our SQL " +
Expand Down Expand Up @@ -2874,6 +2887,8 @@ class SQLConf extends Serializable with Logging {

def constraintPropagationEnabled: Boolean = getConf(CONSTRAINT_PROPAGATION_ENABLED)

def maxCnfNodeCount: Int = getConf(MAX_CNF_NODE_COUNT)

def escapedStringLiterals: Boolean = getConf(ESCAPED_STRING_LITERALS)

def fileCompressionFactor: Double = getConf(FILE_COMPRESSION_FACTOR)
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,128 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.sql.catalyst.expressions

import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.plans.PlanTest
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types.BooleanType

class ConjunctiveNormalFormPredicateSuite extends SparkFunSuite with PredicateHelper with PlanTest {
private val a = AttributeReference("A", BooleanType)(exprId = ExprId(1)).withQualifier(Seq("ta"))
private val b = AttributeReference("B", BooleanType)(exprId = ExprId(2)).withQualifier(Seq("tb"))
private val c = AttributeReference("C", BooleanType)(exprId = ExprId(3)).withQualifier(Seq("tc"))
private val d = AttributeReference("D", BooleanType)(exprId = ExprId(4)).withQualifier(Seq("td"))
private val e = AttributeReference("E", BooleanType)(exprId = ExprId(5)).withQualifier(Seq("te"))
private val f = AttributeReference("F", BooleanType)(exprId = ExprId(6)).withQualifier(Seq("tf"))
private val g = AttributeReference("G", BooleanType)(exprId = ExprId(7)).withQualifier(Seq("tg"))
private val h = AttributeReference("H", BooleanType)(exprId = ExprId(8)).withQualifier(Seq("th"))
private val i = AttributeReference("I", BooleanType)(exprId = ExprId(9)).withQualifier(Seq("ti"))
private val j = AttributeReference("J", BooleanType)(exprId = ExprId(10)).withQualifier(Seq("tj"))
private val a1 =
AttributeReference("a1", BooleanType)(exprId = ExprId(11)).withQualifier(Seq("ta"))
private val a2 =
AttributeReference("a2", BooleanType)(exprId = ExprId(12)).withQualifier(Seq("ta"))
private val b1 =
AttributeReference("b1", BooleanType)(exprId = ExprId(12)).withQualifier(Seq("tb"))

// Check CNF conversion with expected expression, assuming the input has non-empty result.
private def checkCondition(input: Expression, expected: Expression): Unit = {
val cnf = conjunctiveNormalForm(input)
assert(cnf.nonEmpty)
val result = cnf.reduceLeft(And)
assert(result.semanticEquals(expected))
}

test("Keep non-predicated expressions") {
checkCondition(a, a)
checkCondition(Literal(1), Literal(1))
}

test("Conversion of Not") {
checkCondition(!a, !a)
checkCondition(!(!a), a)
checkCondition(!(!(a && b)), a && b)
checkCondition(!(!(a || b)), a || b)
checkCondition(!(a || b), !a && !b)
checkCondition(!(a && b), !a || !b)
}

test("Conversion of And") {
checkCondition(a && b, a && b)
checkCondition(a && b && c, a && b && c)
checkCondition(a && (b || c), a && (b || c))
checkCondition((a || b) && c, (a || b) && c)
checkCondition(a && b && c && d, a && b && c && d)
}

test("Conversion of Or") {
checkCondition(a || b, a || b)
checkCondition(a || b || c, a || b || c)
checkCondition(a || b || c || d, a || b || c || d)
checkCondition((a && b) || c, (a || c) && (b || c))
checkCondition((a && b) || (c && d), (a || c) && (a || d) && (b || c) && (b || d))
}

test("More complex cases") {
checkCondition(a && !(b || c), a && !b && !c)
checkCondition((a && b) || !(c && d), (a || !c || !d) && (b || !c || !d))
checkCondition(a || b || c && d, (a || b || c) && (a || b || d))
checkCondition(a || (b && c || d), (a || b || d) && (a || c || d))
checkCondition(a && !(b && c || d && e), a && (!b || !c) && (!d || !e))
checkCondition(((a && b) || c) || (d || e), (a || c || d || e) && (b || c || d || e))

checkCondition(
(a && b && c) || (d && e && f),
(a || d) && (a || e) && (a || f) && (b || d) && (b || e) && (b || f) &&
(c || d) && (c || e) && (c || f)
)
}

test("Aggregate predicate of same qualifiers to avoid expanding") {
checkCondition(((a && b && a1) || c), ((a && a1) || c) && (b ||c))
checkCondition(((a && a1 && b) || c), ((a && a1) || c) && (b ||c))
checkCondition(((b && d && a && a1) || c), ((a && a1) || c) && (b ||c) && (d || c))
checkCondition(((b && a2 && d && a && a1) || c), ((a2 && a && a1) || c) && (b ||c) && (d || c))
checkCondition(((b && d && a && a1 && b1) || c),
((a && a1) || c) && ((b && b1) ||c) && (d || c))
checkCondition((a && a1) || (b && b1), (a && a1) || (b && b1))
checkCondition((a && a1 && c) || (b && b1), ((a && a1) || (b && b1)) && (c || (b && b1)))
}

test("Return Seq.empty when exceeding MAX_CNF_NODE_COUNT") {
// The following expression contains 36 conjunctive sub-expressions in CNF
val input = (a && b && c) || (d && e && f) || (g && h && i && j)
// The following expression contains 9 conjunctive sub-expressions in CNF
val input2 = (a && b && c) || (d && e && f)
Seq(8, 9, 10, 35, 36, 37).foreach { maxCount =>
withSQLConf(SQLConf.MAX_CNF_NODE_COUNT.key -> maxCount.toString) {
if (maxCount < 36) {
assert(conjunctiveNormalForm(input).isEmpty)
} else {
assert(conjunctiveNormalForm(input).nonEmpty)
}
if (maxCount < 9) {
assert(conjunctiveNormalForm(input2).isEmpty)
} else {
assert(conjunctiveNormalForm(input2).nonEmpty)
}
}
}
}
}
Loading

0 comments on commit 11d3a74

Please sign in to comment.