Skip to content

Commit

Permalink
[SPARK-28169][SQL] Convert scan predicate condition to CNF
Browse files Browse the repository at this point in the history
### What changes were proposed in this pull request?
Spark can't push down scan predicate condition of **Or**:
Such as if I have a table `default.test`, it's partition col is `dt`,
If we use query :
```
select * from default.test
where dt=20190625 or (dt = 20190626 and id in (1,2,3) )
```

In this case, Spark will resolve **Or** condition as one expression, and since this expr has reference of "id", then it can't been push down.

Base on pr #28733, In my PR ,  for SQL like
`select * from default.test`
 `where  dt = 20190626  or  (dt = 20190627  and xxx="a")   `

For this  condition `dt = 20190626  or  (dt = 20190627  and xxx="a"   )`, it will  been converted  to CNF
```
(dt = 20190626 or dt = 20190627) and (dt = 20190626 or xxx = "a" )
```
then condition `dt = 20190626 or dt = 20190627` will be push down when partition pruning

### Why are the changes needed?
Optimize partition pruning

### Does this PR introduce _any_ user-facing change?
NO

### How was this patch tested?
Added UT

Closes #28805 from AngersZhuuuu/cnf-for-partition-pruning.

Lead-authored-by: angerszhu <[email protected]>
Co-authored-by: AngersZhuuuu <[email protected]>
Signed-off-by: Wenchen Fan <[email protected]>
  • Loading branch information
AngersZhuuuu authored and cloud-fan committed Jul 1, 2020
1 parent a4ba344 commit 15fb5d7
Show file tree
Hide file tree
Showing 9 changed files with 152 additions and 27 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -211,7 +211,9 @@ trait PredicateHelper extends Logging {
* @return the CNF result as sequence of disjunctive expressions. If the number of expressions
* exceeds threshold on converting `Or`, `Seq.empty` is returned.
*/
protected def conjunctiveNormalForm(condition: Expression): Seq[Expression] = {
protected def conjunctiveNormalForm(
condition: Expression,
groupExpsFunc: Seq[Expression] => Seq[Expression]): Seq[Expression] = {
val postOrderNodes = postOrderTraversal(condition)
val resultStack = new mutable.Stack[Seq[Expression]]
val maxCnfNodeCount = SQLConf.get.maxCnfNodeCount
Expand All @@ -226,8 +228,8 @@ trait PredicateHelper extends Logging {
// For each side, there is no need to expand predicates of the same references.
// So here we can aggregate predicates of the same qualifier as one single predicate,
// for reducing the size of pushed down predicates and corresponding codegen.
val right = groupExpressionsByQualifier(resultStack.pop())
val left = groupExpressionsByQualifier(resultStack.pop())
val right = groupExpsFunc(resultStack.pop())
val left = groupExpsFunc(resultStack.pop())
// Stop the loop whenever the result exceeds the `maxCnfNodeCount`
if (left.size * right.size > maxCnfNodeCount) {
logInfo(s"As the result size exceeds the threshold $maxCnfNodeCount. " +
Expand All @@ -249,8 +251,36 @@ trait PredicateHelper extends Logging {
resultStack.top
}

private def groupExpressionsByQualifier(expressions: Seq[Expression]): Seq[Expression] = {
expressions.groupBy(_.references.map(_.qualifier)).map(_._2.reduceLeft(And)).toSeq
/**
* Convert an expression to conjunctive normal form when pushing predicates through Join,
* when expand predicates, we can group by the qualifier avoiding generate unnecessary
* expression to control the length of final result since there are multiple tables.
*
* @param condition condition need to be converted
* @return the CNF result as sequence of disjunctive expressions. If the number of expressions
* exceeds threshold on converting `Or`, `Seq.empty` is returned.
*/
def CNFWithGroupExpressionsByQualifier(condition: Expression): Seq[Expression] = {
conjunctiveNormalForm(condition, (expressions: Seq[Expression]) =>
expressions.groupBy(_.references.map(_.qualifier)).map(_._2.reduceLeft(And)).toSeq)
}

/**
* Convert an expression to conjunctive normal form for predicate pushdown and partition pruning.
* When expanding predicates, this method groups expressions by their references for reducing
* the size of pushed down predicates and corresponding codegen. In partition pruning strategies,
* we split filters by [[splitConjunctivePredicates]] and partition filters by judging if it's
* references is subset of partCols, if we combine expressions group by reference when expand
* predicate of [[Or]], it won't impact final predicate pruning result since
* [[splitConjunctivePredicates]] won't split [[Or]] expression.
*
* @param condition condition need to be converted
* @return the CNF result as sequence of disjunctive expressions. If the number of expressions
* exceeds threshold on converting `Or`, `Seq.empty` is returned.
*/
def CNFWithGroupExpressionsByReference(condition: Expression): Seq[Expression] = {
conjunctiveNormalForm(condition, (expressions: Seq[Expression]) =>
expressions.groupBy(e => AttributeSet(e.references)).map(_._2.reduceLeft(And)).toSeq)
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ object PushCNFPredicateThroughJoin extends Rule[LogicalPlan] with PredicateHelpe
def apply(plan: LogicalPlan): LogicalPlan = plan transform {
case j @ Join(left, right, joinType, Some(joinCondition), hint)
if canPushThrough(joinType) =>
val predicates = conjunctiveNormalForm(joinCondition)
val predicates = CNFWithGroupExpressionsByQualifier(joinCondition)
if (predicates.isEmpty) {
j
} else {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ class ConjunctiveNormalFormPredicateSuite extends SparkFunSuite with PredicateHe

// Check CNF conversion with expected expression, assuming the input has non-empty result.
private def checkCondition(input: Expression, expected: Expression): Unit = {
val cnf = conjunctiveNormalForm(input)
val cnf = CNFWithGroupExpressionsByQualifier(input)
assert(cnf.nonEmpty)
val result = cnf.reduceLeft(And)
assert(result.semanticEquals(expected))
Expand Down Expand Up @@ -113,14 +113,14 @@ class ConjunctiveNormalFormPredicateSuite extends SparkFunSuite with PredicateHe
Seq(8, 9, 10, 35, 36, 37).foreach { maxCount =>
withSQLConf(SQLConf.MAX_CNF_NODE_COUNT.key -> maxCount.toString) {
if (maxCount < 36) {
assert(conjunctiveNormalForm(input).isEmpty)
assert(CNFWithGroupExpressionsByQualifier(input).isEmpty)
} else {
assert(conjunctiveNormalForm(input).nonEmpty)
assert(CNFWithGroupExpressionsByQualifier(input).nonEmpty)
}
if (maxCount < 9) {
assert(conjunctiveNormalForm(input2).isEmpty)
assert(CNFWithGroupExpressionsByQualifier(input2).isEmpty)
} else {
assert(conjunctiveNormalForm(input2).nonEmpty)
assert(CNFWithGroupExpressionsByQualifier(input2).nonEmpty)
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -207,7 +207,7 @@ case class FileSourceScanExec(
private def isDynamicPruningFilter(e: Expression): Boolean =
e.find(_.isInstanceOf[PlanExpression[_]]).isDefined

@transient private lazy val selectedPartitions: Array[PartitionDirectory] = {
@transient lazy val selectedPartitions: Array[PartitionDirectory] = {
val optimizerMetadataTimeNs = relation.location.metadataOpsTimeNs.getOrElse(0L)
val startTime = System.nanoTime()
val ret =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,8 @@ import org.apache.spark.sql.types.StructType
* its underlying [[FileScan]]. And the partition filters will be removed in the filters of
* returned logical plan.
*/
private[sql] object PruneFileSourcePartitions extends Rule[LogicalPlan] {
private[sql] object PruneFileSourcePartitions
extends Rule[LogicalPlan] with PredicateHelper {

private def getPartitionKeyFiltersAndDataFilters(
sparkSession: SparkSession,
Expand Down Expand Up @@ -87,8 +88,12 @@ private[sql] object PruneFileSourcePartitions extends Rule[LogicalPlan] {
_,
_))
if filters.nonEmpty && fsRelation.partitionSchemaOption.isDefined =>
val predicates = CNFWithGroupExpressionsByReference(filters.reduceLeft(And))
val finalPredicates = if (predicates.nonEmpty) predicates else filters
val (partitionKeyFilters, _) = getPartitionKeyFiltersAndDataFilters(
fsRelation.sparkSession, logicalRelation, partitionSchema, filters, logicalRelation.output)
fsRelation.sparkSession, logicalRelation, partitionSchema, finalPredicates,
logicalRelation.output)

if (partitionKeyFilters.nonEmpty) {
val prunedFileIndex = catalogFileIndex.filterPartitions(partitionKeyFilters.toSeq)
val prunedFsRelation =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,8 @@ import org.apache.hadoop.hive.common.StatsSetupConst

import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.catalyst.analysis.CastSupport
import org.apache.spark.sql.catalyst.catalog.{CatalogStatistics, CatalogTable, CatalogTablePartition, ExternalCatalogUtils, HiveTableRelation}
import org.apache.spark.sql.catalyst.expressions.{And, AttributeSet, Expression, ExpressionSet, SubqueryExpression}
import org.apache.spark.sql.catalyst.catalog._
import org.apache.spark.sql.catalyst.expressions.{And, AttributeSet, Expression, ExpressionSet, PredicateHelper, SubqueryExpression}
import org.apache.spark.sql.catalyst.planning.PhysicalOperation
import org.apache.spark.sql.catalyst.plans.logical.{Filter, LogicalPlan, Project}
import org.apache.spark.sql.catalyst.rules.Rule
Expand All @@ -41,7 +41,7 @@ import org.apache.spark.sql.internal.SQLConf
* TODO: merge this with PruneFileSourcePartitions after we completely make hive as a data source.
*/
private[sql] class PruneHiveTablePartitions(session: SparkSession)
extends Rule[LogicalPlan] with CastSupport {
extends Rule[LogicalPlan] with CastSupport with PredicateHelper {

override val conf: SQLConf = session.sessionState.conf

Expand Down Expand Up @@ -103,7 +103,9 @@ private[sql] class PruneHiveTablePartitions(session: SparkSession)
override def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators {
case op @ PhysicalOperation(projections, filters, relation: HiveTableRelation)
if filters.nonEmpty && relation.isPartitioned && relation.prunedPartitions.isEmpty =>
val partitionKeyFilters = getPartitionKeyFilters(filters, relation)
val predicates = CNFWithGroupExpressionsByReference(filters.reduceLeft(And))
val finalPredicates = if (predicates.nonEmpty) predicates else filters
val partitionKeyFilters = getPartitionKeyFilters(finalPredicates, relation)
if (partitionKeyFilters.nonEmpty) {
val newPartitions = prunePartitions(relation, partitionKeyFilters)
val newTableMeta = updateTableMeta(relation.tableMeta, newPartitions)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,22 +19,22 @@ package org.apache.spark.sql.hive.execution

import org.scalatest.Matchers._

import org.apache.spark.sql.QueryTest
import org.apache.spark.sql.catalyst.TableIdentifier
import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.dsl.plans._
import org.apache.spark.sql.catalyst.plans.logical.{Filter, LogicalPlan, Project}
import org.apache.spark.sql.catalyst.rules.RuleExecutor
import org.apache.spark.sql.execution.{FileSourceScanExec, SparkPlan}
import org.apache.spark.sql.execution.datasources.{CatalogFileIndex, HadoopFsRelation, LogicalRelation, PruneFileSourcePartitions}
import org.apache.spark.sql.execution.datasources.parquet.ParquetFileFormat
import org.apache.spark.sql.execution.joins.BroadcastHashJoinExec
import org.apache.spark.sql.functions.broadcast
import org.apache.spark.sql.hive.test.TestHiveSingleton
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.test.SQLTestUtils
import org.apache.spark.sql.types.StructType

class PruneFileSourcePartitionsSuite extends QueryTest with SQLTestUtils with TestHiveSingleton {
class PruneFileSourcePartitionsSuite extends PrunePartitionSuiteBase {

override def format: String = "parquet"

object Optimize extends RuleExecutor[LogicalPlan] {
val batches = Batch("PruneFileSourcePartitions", Once, PruneFileSourcePartitions) :: Nil
Expand Down Expand Up @@ -108,4 +108,10 @@ class PruneFileSourcePartitionsSuite extends QueryTest with SQLTestUtils with Te
}
}
}

override def getScanExecPartitionSize(plan: SparkPlan): Long = {
plan.collectFirst {
case p: FileSourceScanExec => p
}.get.selectedPartitions.length
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -17,22 +17,22 @@

package org.apache.spark.sql.hive.execution

import org.apache.spark.sql.QueryTest
import org.apache.spark.sql.catalyst.analysis.EliminateSubqueryAliases
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.sql.catalyst.rules.RuleExecutor
import org.apache.spark.sql.hive.test.TestHiveSingleton
import org.apache.spark.sql.test.SQLTestUtils
import org.apache.spark.sql.execution.SparkPlan

class PruneHiveTablePartitionsSuite extends QueryTest with SQLTestUtils with TestHiveSingleton {
class PruneHiveTablePartitionsSuite extends PrunePartitionSuiteBase {

override def format(): String = "hive"

object Optimize extends RuleExecutor[LogicalPlan] {
val batches =
Batch("PruneHiveTablePartitions", Once,
EliminateSubqueryAliases, new PruneHiveTablePartitions(spark)) :: Nil
}

test("SPARK-15616 statistics pruned after going throuhg PruneHiveTablePartitions") {
test("SPARK-15616: statistics pruned after going through PruneHiveTablePartitions") {
withTable("test", "temp") {
sql(
s"""
Expand All @@ -54,4 +54,10 @@ class PruneHiveTablePartitionsSuite extends QueryTest with SQLTestUtils with Tes
Optimize.execute(analyzed2).stats.sizeInBytes)
}
}

override def getScanExecPartitionSize(plan: SparkPlan): Long = {
plan.collectFirst {
case p: HiveTableScanExec => p
}.get.prunedPartitions.size
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.sql.hive.execution

import org.apache.spark.sql.QueryTest
import org.apache.spark.sql.execution.SparkPlan
import org.apache.spark.sql.hive.test.TestHiveSingleton
import org.apache.spark.sql.test.SQLTestUtils

abstract class PrunePartitionSuiteBase extends QueryTest with SQLTestUtils with TestHiveSingleton {

protected def format: String

test("SPARK-28169: Convert scan predicate condition to CNF") {
withTempView("temp") {
withTable("t") {
sql(
s"""
|CREATE TABLE t(i INT, p STRING)
|USING $format
|PARTITIONED BY (p)""".stripMargin)

spark.range(0, 1000, 1).selectExpr("id as col")
.createOrReplaceTempView("temp")

for (part <- Seq(1, 2, 3, 4)) {
sql(
s"""
|INSERT OVERWRITE TABLE t PARTITION (p='$part')
|SELECT col FROM temp""".stripMargin)
}

assertPrunedPartitions(
"SELECT * FROM t WHERE p = '1' OR (p = '2' AND i = 1)", 2)

This comment has been minimized.

Copy link
@wangyum

wangyum Jul 6, 2020

Member

@AngersZhuuuu It seems the pushed partition filters are not in PartitionFilters:

== Physical Plan ==
*(1) Filter ((p#21 = 1) OR ((p#21 = 2) AND (i#20 = 1)))
+- *(1) ColumnarToRow
   +- FileScan parquet default.t[i#20,p#21] Batched: true, DataFilters: [], Format: Parquet, Location: InMemoryFileIndex[file:/private/var/folders/tg/f5mz46090wg7swzgdc69f8q03965_0/T/warehouse-be43db8..., PartitionFilters: [], PushedFilters: [], ReadSchema: struct<i:int>

This comment has been minimized.

Copy link
@AngersZhuuuu

AngersZhuuuu Jul 6, 2020

Author Contributor

@AngersZhuuuu It seems the pushed partition filters are not in PartitionFilters:

== Physical Plan ==
*(1) Filter ((p#21 = 1) OR ((p#21 = 2) AND (i#20 = 1)))
+- *(1) ColumnarToRow
   +- FileScan parquet default.t[i#20,p#21] Batched: true, DataFilters: [], Format: Parquet, Location: InMemoryFileIndex[file:/private/var/folders/tg/f5mz46090wg7swzgdc69f8q03965_0/T/warehouse-be43db8..., PartitionFilters: [], PushedFilters: [], ReadSchema: struct<i:int>

Got it, will fix this problem quickly

assertPrunedPartitions(
"SELECT * FROM t WHERE (p = '1' AND i = 2) OR (i = 1 OR p = '2')", 4)
assertPrunedPartitions(
"SELECT * FROM t WHERE (p = '1' AND i = 2) OR (p = '3' AND i = 3 )", 2)
assertPrunedPartitions(
"SELECT * FROM t WHERE (p = '1' AND i = 2) OR (p = '2' OR p = '3')", 3)
assertPrunedPartitions(
"SELECT * FROM t", 4)
assertPrunedPartitions(
"SELECT * FROM t WHERE p = '1' AND i = 2", 1)
assertPrunedPartitions(
"""
|SELECT i, COUNT(1) FROM (
|SELECT * FROM t WHERE p = '1' OR (p = '2' AND i = 1)
|) tmp GROUP BY i
""".stripMargin, 2)
}
}
}

protected def assertPrunedPartitions(query: String, expected: Long): Unit = {
val plan = sql(query).queryExecution.sparkPlan
assert(getScanExecPartitionSize(plan) == expected)
}

protected def getScanExecPartitionSize(plan: SparkPlan): Long
}

0 comments on commit 15fb5d7

Please sign in to comment.