diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index e59e3b999aa7f..f510d68b3b7bf 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -118,7 +118,9 @@ abstract class Optimizer(catalogManager: CatalogManager) Batch("Infer Filters", Once, InferFiltersFromConstraints) :: Batch("Operator Optimization after Inferring Filters", fixedPoint, - rulesWithoutInferFiltersFromConstraints: _*) :: Nil + rulesWithoutInferFiltersFromConstraints: _*) :: + Batch("Push predicate through join by conjunctive normal form", Once, + PushPredicateThroughJoinByCNF) :: Nil } val batches = (Batch("Eliminate Distinct", Once, EliminateDistinct) :: @@ -1372,6 +1374,80 @@ object PushPredicateThroughJoin extends Rule[LogicalPlan] with PredicateHelper { } } +/** + * Rewriting join condition to conjunctive normal form expression so that we can push + * more predicate. + */ +object PushPredicateThroughJoinByCNF extends Rule[LogicalPlan] with PredicateHelper { + + /** + * Rewrite pattern: + * 1. (a && b) || c --> (a || c) && (b || c) + * 2. a || (b && c) --> (a || b) && (a || c) + * 3. !(a || b) --> !a && !b + */ + private def rewriteToCNF(condition: Expression, depth: Int = 0): Expression = { + if (depth < SQLConf.get.maxRewritingCNFDepth) { + val nextDepth = depth + 1 + condition match { + case Or(And(a, b), c) => + And(rewriteToCNF(Or(rewriteToCNF(a, nextDepth), rewriteToCNF(c, nextDepth)), nextDepth), + rewriteToCNF(Or(rewriteToCNF(b, nextDepth), rewriteToCNF(c, nextDepth)), nextDepth)) + case Or(a, And(b, c)) => + And(rewriteToCNF(Or(rewriteToCNF(a, nextDepth), rewriteToCNF(b, nextDepth)), nextDepth), + rewriteToCNF(Or(rewriteToCNF(a, nextDepth), rewriteToCNF(c, nextDepth)), nextDepth)) + case Not(Or(a, b)) => + And(rewriteToCNF(Not(rewriteToCNF(a, nextDepth)), nextDepth), + rewriteToCNF(Not(rewriteToCNF(b, nextDepth)), nextDepth)) + case And(a, b) => + And(rewriteToCNF(a, nextDepth), rewriteToCNF(b, nextDepth)) + case other => other + } + } else { + condition + } + } + + private def maybeWithFilter(joinCondition: Seq[Expression], plan: LogicalPlan) = { + (joinCondition.reduceLeftOption(And).reduceLeftOption(And), plan) match { + case (Some(condition), filter: Filter) if condition.semanticEquals(filter.condition) => + plan + case (Some(condition), _) => + Filter(condition, plan) + case _ => + plan + } + } + + def apply(plan: LogicalPlan): LogicalPlan = plan transform applyLocally + + val applyLocally: PartialFunction[LogicalPlan, LogicalPlan] = { + case j @ Join(left, right, joinType, Some(joinCondition), hint) => + + val pushDownCandidates = splitConjunctivePredicates(rewriteToCNF(joinCondition)) + .filter(_.deterministic) + val (leftEvaluateCondition, rest) = + pushDownCandidates.partition(_.references.subsetOf(left.outputSet)) + val (rightEvaluateCondition, _) = + rest.partition(expr => expr.references.subsetOf(right.outputSet)) + + val newLeft = maybeWithFilter(leftEvaluateCondition, left) + val newRight = maybeWithFilter(rightEvaluateCondition, right) + + joinType match { + case _: InnerLike | LeftSemi => + Join(newLeft, newRight, joinType, Some(joinCondition), hint) + case RightOuter => + Join(newLeft, right, RightOuter, Some(joinCondition), hint) + case LeftOuter | LeftAnti | ExistenceJoin(_) => + Join(left, newRight, joinType, Some(joinCondition), hint) + case FullOuter => j + case NaturalJoin(_) => sys.error("Untransformed NaturalJoin node") + case UsingJoin(_, _) => sys.error("Untransformed Using join node") + } + } +} + /** * Combines two adjacent [[Limit]] operators into one, merging the * expressions into one single expression. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index c739fa516f0c1..aed5d23d197b1 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -544,6 +544,18 @@ object SQLConf { .booleanConf .createWithDefault(true) + val MAX_REWRITING_CNF_DEPTH = + buildConf("spark.sql.maxRewritingCNFDepth") + .internal() + .doc("The maximum depth of rewriting a join condition to conjunctive normal form " + + "expression. The deeper, the more predicate may be found, but the optimization time " + + "will increase. The default is 6. By setting this value to 0 this feature can be disabled.") + .version("3.1.0") + .intConf + .checkValue(_ >= 0, + "The depth of the maximum rewriting conjunction normal form must be positive.") + .createWithDefault(6) + val ESCAPED_STRING_LITERALS = buildConf("spark.sql.parser.escapedStringLiterals") .internal() .doc("When true, string literals (including regex patterns) remain escaped in our SQL " + @@ -2845,6 +2857,8 @@ class SQLConf extends Serializable with Logging { def constraintPropagationEnabled: Boolean = getConf(CONSTRAINT_PROPAGATION_ENABLED) + def maxRewritingCNFDepth: Int = getConf(MAX_REWRITING_CNF_DEPTH) + def escapedStringLiterals: Boolean = getConf(ESCAPED_STRING_LITERALS) def fileCompressionFactor: Double = getConf(FILE_COMPRESSION_FACTOR) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala index 70e29dca46e9e..d8a05279a5df0 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala @@ -39,7 +39,9 @@ class FilterPushdownSuite extends PlanTest { PushPredicateThroughNonJoin, BooleanSimplification, PushPredicateThroughJoin, - CollapseProject) :: Nil + CollapseProject) :: + Batch("PushPredicateThroughJoinByCNF", Once, + PushPredicateThroughJoinByCNF) :: Nil } val attrA = 'a.int @@ -1230,4 +1232,106 @@ class FilterPushdownSuite extends PlanTest { comparePlans(Optimize.execute(query.analyze), expected) } + + test("inner join: rewrite filter predicates to conjunctive normal form") { + val x = testRelation.subquery('x) + val y = testRelation.subquery('y) + + val originalQuery = { + x.join(y) + .where(("x.b".attr === "y.b".attr) + && (("x.a".attr > 3) && ("y.a".attr > 13) || ("x.a".attr > 1) && ("y.a".attr > 11))) + } + + val optimized = Optimize.execute(originalQuery.analyze) + val left = testRelation.where(('a > 3 || 'a > 1)).subquery('x) + val right = testRelation.where('a > 13 || 'a > 11).subquery('y) + val correctAnswer = + left.join(right, condition = Some("x.b".attr === "y.b".attr + && (("x.a".attr > 3) && ("y.a".attr > 13) || ("x.a".attr > 1) && ("y.a".attr > 11)))) + .analyze + + comparePlans(optimized, correctAnswer) + } + + test("inner join: rewrite join predicates to conjunctive normal form") { + val x = testRelation.subquery('x) + val y = testRelation.subquery('y) + + val originalQuery = { + x.join(y, condition = Some(("x.b".attr === "y.b".attr) + && (("x.a".attr > 3) && ("y.a".attr > 13) || ("x.a".attr > 1) && ("y.a".attr > 11)))) + } + + val optimized = Optimize.execute(originalQuery.analyze) + val left = testRelation.where('a > 3 || 'a > 1).subquery('x) + val right = testRelation.where('a > 13 || 'a > 11).subquery('y) + val correctAnswer = + left.join(right, condition = Some("x.b".attr === "y.b".attr + && (("x.a".attr > 3) && ("y.a".attr > 13) || ("x.a".attr > 1) && ("y.a".attr > 11)))) + .analyze + + comparePlans(optimized, correctAnswer) + } + + test("inner join: rewrite join predicates(with NOT predicate) to conjunctive normal form") { + val x = testRelation.subquery('x) + val y = testRelation.subquery('y) + + val originalQuery = { + x.join(y, condition = Some(("x.b".attr === "y.b".attr) + && Not(("x.a".attr > 3) + && ("x.a".attr < 2 || ("y.a".attr > 13)) || ("x.a".attr > 1) && ("y.a".attr > 11)))) + } + + val optimized = Optimize.execute(originalQuery.analyze) + val left = testRelation.where('a <= 3 || 'a >= 2).subquery('x) + val right = testRelation.subquery('y) + val correctAnswer = + left.join(right, condition = Some("x.b".attr === "y.b".attr + && (("x.a".attr <= 3) || (("x.a".attr >= 2) && ("y.a".attr <= 13))) + && (("x.a".attr <= 1) || ("y.a".attr <= 11)))) + .analyze + comparePlans(optimized, correctAnswer) + } + + test("left join: rewrite join predicates to conjunctive normal form") { + val x = testRelation.subquery('x) + val y = testRelation.subquery('y) + + val originalQuery = { + x.join(y, joinType = LeftOuter, condition = Some(("x.b".attr === "y.b".attr) + && (("x.a".attr > 3) && ("y.a".attr > 13) || ("x.a".attr > 1) && ("y.a".attr > 11)))) + } + + val optimized = Optimize.execute(originalQuery.analyze) + val left = testRelation.subquery('x) + val right = testRelation.where('a > 13 || 'a > 11).subquery('y) + val correctAnswer = + left.join(right, joinType = LeftOuter, condition = Some("x.b".attr === "y.b".attr + && (("x.a".attr > 3) && ("y.a".attr > 13) || ("x.a".attr > 1) && ("y.a".attr > 11)))) + .analyze + + comparePlans(optimized, correctAnswer) + } + + test("right join: rewrite join predicates to conjunctive normal form") { + val x = testRelation.subquery('x) + val y = testRelation.subquery('y) + + val originalQuery = { + x.join(y, joinType = RightOuter, condition = Some(("x.b".attr === "y.b".attr) + && (("x.a".attr > 3) && ("y.a".attr > 13) || ("x.a".attr > 1) && ("y.a".attr > 11)))) + } + + val optimized = Optimize.execute(originalQuery.analyze) + val left = testRelation.where('a > 3 || 'a > 1).subquery('x) + val right = testRelation.subquery('y) + val correctAnswer = + left.join(right, joinType = RightOuter, condition = Some("x.b".attr === "y.b".attr + && (("x.a".attr > 3) && ("y.a".attr > 13) || ("x.a".attr > 1) && ("y.a".attr > 11)))) + .analyze + + comparePlans(optimized, correctAnswer) + } }