Skip to content

Commit

Permalink
Rewriting join condition to conjunctive normal form expression
Browse files Browse the repository at this point in the history
  • Loading branch information
wangyum committed May 18, 2020
1 parent f0e2fc3 commit 21fb7c5
Show file tree
Hide file tree
Showing 3 changed files with 196 additions and 2 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,9 @@ abstract class Optimizer(catalogManager: CatalogManager)
Batch("Infer Filters", Once,
InferFiltersFromConstraints) ::
Batch("Operator Optimization after Inferring Filters", fixedPoint,
rulesWithoutInferFiltersFromConstraints: _*) :: Nil
rulesWithoutInferFiltersFromConstraints: _*) ::
Batch("Push predicate through join by conjunctive normal form", Once,
PushPredicateThroughJoinByCNF) :: Nil
}

val batches = (Batch("Eliminate Distinct", Once, EliminateDistinct) ::
Expand Down Expand Up @@ -1372,6 +1374,80 @@ object PushPredicateThroughJoin extends Rule[LogicalPlan] with PredicateHelper {
}
}

/**
* Rewriting join condition to conjunctive normal form expression so that we can push
* more predicate.
*/
object PushPredicateThroughJoinByCNF extends Rule[LogicalPlan] with PredicateHelper {

/**
* Rewrite pattern:
* 1. (a && b) || c --> (a || c) && (b || c)
* 2. a || (b && c) --> (a || b) && (a || c)
* 3. !(a || b) --> !a && !b
*/
private def rewriteToCNF(condition: Expression, depth: Int = 0): Expression = {
if (depth < SQLConf.get.maxRewritingCNFDepth) {
val nextDepth = depth + 1
condition match {
case Or(And(a, b), c) =>
And(rewriteToCNF(Or(rewriteToCNF(a, nextDepth), rewriteToCNF(c, nextDepth)), nextDepth),
rewriteToCNF(Or(rewriteToCNF(b, nextDepth), rewriteToCNF(c, nextDepth)), nextDepth))
case Or(a, And(b, c)) =>
And(rewriteToCNF(Or(rewriteToCNF(a, nextDepth), rewriteToCNF(b, nextDepth)), nextDepth),
rewriteToCNF(Or(rewriteToCNF(a, nextDepth), rewriteToCNF(c, nextDepth)), nextDepth))
case Not(Or(a, b)) =>
And(rewriteToCNF(Not(rewriteToCNF(a, nextDepth)), nextDepth),
rewriteToCNF(Not(rewriteToCNF(b, nextDepth)), nextDepth))
case And(a, b) =>
And(rewriteToCNF(a, nextDepth), rewriteToCNF(b, nextDepth))
case other => other
}
} else {
condition
}
}

private def maybeWithFilter(joinCondition: Seq[Expression], plan: LogicalPlan) = {
(joinCondition.reduceLeftOption(And).reduceLeftOption(And), plan) match {
case (Some(condition), filter: Filter) if condition.semanticEquals(filter.condition) =>
plan
case (Some(condition), _) =>
Filter(condition, plan)
case _ =>
plan
}
}

def apply(plan: LogicalPlan): LogicalPlan = plan transform applyLocally

val applyLocally: PartialFunction[LogicalPlan, LogicalPlan] = {
case j @ Join(left, right, joinType, Some(joinCondition), hint) =>

val pushDownCandidates = splitConjunctivePredicates(rewriteToCNF(joinCondition))
.filter(_.deterministic)
val (leftEvaluateCondition, rest) =
pushDownCandidates.partition(_.references.subsetOf(left.outputSet))
val (rightEvaluateCondition, _) =
rest.partition(expr => expr.references.subsetOf(right.outputSet))

val newLeft = maybeWithFilter(leftEvaluateCondition, left)
val newRight = maybeWithFilter(rightEvaluateCondition, right)

joinType match {
case _: InnerLike | LeftSemi =>
Join(newLeft, newRight, joinType, Some(joinCondition), hint)
case RightOuter =>
Join(newLeft, right, RightOuter, Some(joinCondition), hint)
case LeftOuter | LeftAnti | ExistenceJoin(_) =>
Join(left, newRight, joinType, Some(joinCondition), hint)
case FullOuter => j
case NaturalJoin(_) => sys.error("Untransformed NaturalJoin node")
case UsingJoin(_, _) => sys.error("Untransformed Using join node")
}
}
}

/**
* Combines two adjacent [[Limit]] operators into one, merging the
* expressions into one single expression.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -544,6 +544,18 @@ object SQLConf {
.booleanConf
.createWithDefault(true)

val MAX_REWRITING_CNF_DEPTH =
buildConf("spark.sql.maxRewritingCNFDepth")
.internal()
.doc("The maximum depth of rewriting a join condition to conjunctive normal form " +
"expression. The deeper, the more predicate may be found, but the optimization time " +
"will increase. The default is 6. By setting this value to 0 this feature can be disabled.")
.version("3.1.0")
.intConf
.checkValue(_ >= 0,
"The depth of the maximum rewriting conjunction normal form must be positive.")
.createWithDefault(6)

val ESCAPED_STRING_LITERALS = buildConf("spark.sql.parser.escapedStringLiterals")
.internal()
.doc("When true, string literals (including regex patterns) remain escaped in our SQL " +
Expand Down Expand Up @@ -2845,6 +2857,8 @@ class SQLConf extends Serializable with Logging {

def constraintPropagationEnabled: Boolean = getConf(CONSTRAINT_PROPAGATION_ENABLED)

def maxRewritingCNFDepth: Int = getConf(MAX_REWRITING_CNF_DEPTH)

def escapedStringLiterals: Boolean = getConf(ESCAPED_STRING_LITERALS)

def fileCompressionFactor: Double = getConf(FILE_COMPRESSION_FACTOR)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,9 @@ class FilterPushdownSuite extends PlanTest {
PushPredicateThroughNonJoin,
BooleanSimplification,
PushPredicateThroughJoin,
CollapseProject) :: Nil
CollapseProject) ::
Batch("PushPredicateThroughJoinByCNF", Once,
PushPredicateThroughJoinByCNF) :: Nil
}

val attrA = 'a.int
Expand Down Expand Up @@ -1230,4 +1232,106 @@ class FilterPushdownSuite extends PlanTest {

comparePlans(Optimize.execute(query.analyze), expected)
}

test("inner join: rewrite filter predicates to conjunctive normal form") {
val x = testRelation.subquery('x)
val y = testRelation.subquery('y)

val originalQuery = {
x.join(y)
.where(("x.b".attr === "y.b".attr)
&& (("x.a".attr > 3) && ("y.a".attr > 13) || ("x.a".attr > 1) && ("y.a".attr > 11)))
}

val optimized = Optimize.execute(originalQuery.analyze)
val left = testRelation.where(('a > 3 || 'a > 1)).subquery('x)
val right = testRelation.where('a > 13 || 'a > 11).subquery('y)
val correctAnswer =
left.join(right, condition = Some("x.b".attr === "y.b".attr
&& (("x.a".attr > 3) && ("y.a".attr > 13) || ("x.a".attr > 1) && ("y.a".attr > 11))))
.analyze

comparePlans(optimized, correctAnswer)
}

test("inner join: rewrite join predicates to conjunctive normal form") {
val x = testRelation.subquery('x)
val y = testRelation.subquery('y)

val originalQuery = {
x.join(y, condition = Some(("x.b".attr === "y.b".attr)
&& (("x.a".attr > 3) && ("y.a".attr > 13) || ("x.a".attr > 1) && ("y.a".attr > 11))))
}

val optimized = Optimize.execute(originalQuery.analyze)
val left = testRelation.where('a > 3 || 'a > 1).subquery('x)
val right = testRelation.where('a > 13 || 'a > 11).subquery('y)
val correctAnswer =
left.join(right, condition = Some("x.b".attr === "y.b".attr
&& (("x.a".attr > 3) && ("y.a".attr > 13) || ("x.a".attr > 1) && ("y.a".attr > 11))))
.analyze

comparePlans(optimized, correctAnswer)
}

test("inner join: rewrite join predicates(with NOT predicate) to conjunctive normal form") {
val x = testRelation.subquery('x)
val y = testRelation.subquery('y)

val originalQuery = {
x.join(y, condition = Some(("x.b".attr === "y.b".attr)
&& Not(("x.a".attr > 3)
&& ("x.a".attr < 2 || ("y.a".attr > 13)) || ("x.a".attr > 1) && ("y.a".attr > 11))))
}

val optimized = Optimize.execute(originalQuery.analyze)
val left = testRelation.where('a <= 3 || 'a >= 2).subquery('x)
val right = testRelation.subquery('y)
val correctAnswer =
left.join(right, condition = Some("x.b".attr === "y.b".attr
&& (("x.a".attr <= 3) || (("x.a".attr >= 2) && ("y.a".attr <= 13)))
&& (("x.a".attr <= 1) || ("y.a".attr <= 11))))
.analyze
comparePlans(optimized, correctAnswer)
}

test("left join: rewrite join predicates to conjunctive normal form") {
val x = testRelation.subquery('x)
val y = testRelation.subquery('y)

val originalQuery = {
x.join(y, joinType = LeftOuter, condition = Some(("x.b".attr === "y.b".attr)
&& (("x.a".attr > 3) && ("y.a".attr > 13) || ("x.a".attr > 1) && ("y.a".attr > 11))))
}

val optimized = Optimize.execute(originalQuery.analyze)
val left = testRelation.subquery('x)
val right = testRelation.where('a > 13 || 'a > 11).subquery('y)
val correctAnswer =
left.join(right, joinType = LeftOuter, condition = Some("x.b".attr === "y.b".attr
&& (("x.a".attr > 3) && ("y.a".attr > 13) || ("x.a".attr > 1) && ("y.a".attr > 11))))
.analyze

comparePlans(optimized, correctAnswer)
}

test("right join: rewrite join predicates to conjunctive normal form") {
val x = testRelation.subquery('x)
val y = testRelation.subquery('y)

val originalQuery = {
x.join(y, joinType = RightOuter, condition = Some(("x.b".attr === "y.b".attr)
&& (("x.a".attr > 3) && ("y.a".attr > 13) || ("x.a".attr > 1) && ("y.a".attr > 11))))
}

val optimized = Optimize.execute(originalQuery.analyze)
val left = testRelation.where('a > 3 || 'a > 1).subquery('x)
val right = testRelation.subquery('y)
val correctAnswer =
left.join(right, joinType = RightOuter, condition = Some("x.b".attr === "y.b".attr
&& (("x.a".attr > 3) && ("y.a".attr > 13) || ("x.a".attr > 1) && ("y.a".attr > 11))))
.analyze

comparePlans(optimized, correctAnswer)
}
}

0 comments on commit 21fb7c5

Please sign in to comment.