Skip to content

Commit

Permalink
add dynamic partition pruning for Spark 2.0
Browse files Browse the repository at this point in the history
## What changes were proposed in this pull request?
This patch was originally authored by davies. The patch implements dynamic partition pruning for branch-2.0.

## How was this patch tested?
An end-to-end test was added in SQLQuerySuite.

Author: Davies Liu <[email protected]>

Closes apache#47 from rxin/rxin-dynamic-partition-pruning.
  • Loading branch information
davies authored and rxin committed Aug 2, 2016
1 parent 2367ac6 commit 765add8
Show file tree
Hide file tree
Showing 15 changed files with 406 additions and 78 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,13 @@ case class PredicateSubquery(
override def nullable: Boolean = nullAware
override def plan: LogicalPlan = SubqueryAlias(toString, query)
override def withNewPlan(plan: LogicalPlan): PredicateSubquery = copy(query = plan)
override def semanticEquals(o: Expression): Boolean = o match {
case p: PredicateSubquery =>
query.sameResult(p.query) && nullAware == p.nullAware &&
children.length == p.children.length &&
children.zip(p.children).forall(p => p._1.semanticEquals(p._2))
case _ => false
}
override def toString: String = s"predicate-subquery#${exprId.id} $conditionString"
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1681,6 +1681,9 @@ object RewritePredicateSubquery extends Rule[LogicalPlan] with PredicateHelper {

// Filter the plan by applying left semi and left anti joins.
withSubquery.foldLeft(newFilter) {
case (p, PredicateSubquery(_, Seq(e: Expression), _, _)) if !e.isInstanceOf[Predicate] =>
// This predicate subquery is inserted by PartitionPruning rule, should not be rewritten.
p
case (p, PredicateSubquery(sub, conditions, _, _)) =>
val (joinCond, outerPlan) = rewriteExistentialExpr(conditions, p)
Join(outerPlan, sub, LeftSemi, joinCond)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -300,7 +300,7 @@ abstract class QueryPlan[PlanType <: QueryPlan[PlanType]] extends TreeNode[PlanT
*/
lazy val allAttributes: AttributeSeq = children.flatMap(_.output)

private def cleanExpression(e: Expression): Expression = e match {
protected def cleanExpression(e: Expression): Expression = e match {
case a: Alias =>
// As the root of the expression, Alias will always take an arbitrary exprId, we need
// to erase that for equality testing.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -534,9 +534,9 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] extends Product {

if (innerChildren.nonEmpty) {
innerChildren.init.foreach(_.generateTreeString(
depth + 2, lastChildren :+ false :+ false, builder, verbose))
depth + 2, lastChildren :+ children.isEmpty :+ false, builder, verbose))
innerChildren.last.generateTreeString(
depth + 2, lastChildren :+ false :+ true, builder, verbose)
depth + 2, lastChildren :+ children.isEmpty :+ true, builder, verbose)
}

if (children.nonEmpty) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,14 +20,14 @@ package org.apache.spark.sql.execution
import org.apache.commons.lang3.StringUtils

import org.apache.spark.rdd.RDD
import org.apache.spark.sql.{AnalysisException, Row, SparkSession, SQLContext}
import org.apache.spark.sql.{Row, SparkSession}
import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow, TableIdentifier}
import org.apache.spark.sql.catalyst.analysis.MultiInstanceRelation
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode}
import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Statistics}
import org.apache.spark.sql.catalyst.plans.physical.{HashPartitioning, Partitioning, UnknownPartitioning}
import org.apache.spark.sql.execution.datasources.HadoopFsRelation
import org.apache.spark.sql.execution.datasources.{FilePartition, FileScanRDD, HadoopFsRelation}
import org.apache.spark.sql.execution.datasources.parquet.{ParquetFileFormat => ParquetSource}
import org.apache.spark.sql.execution.metric.SQLMetrics
import org.apache.spark.sql.internal.SQLConf
Expand Down Expand Up @@ -130,20 +130,57 @@ private[sql] case class RDDScanExec(
}
}

private[sql] trait DataSourceScanExec extends LeafExecNode {
private[sql] trait DataSourceScanExec extends LeafExecNode with CodegenSupport {
val rdd: RDD[InternalRow]
val relation: BaseRelation
val metastoreTableIdentifier: Option[TableIdentifier]
val partitionPredicate: Option[Expression]

override val nodeName: String = {
s"Scan $relation ${metastoreTableIdentifier.map(_.unquotedString).getOrElse("")}"
val pred = if (partitionPredicate.isDefined) {
s"PartitionFilter: ${partitionPredicate.get} "
} else {
""
}
s"Scan $relation $pred${metastoreTableIdentifier.map(_.unquotedString).getOrElse("")}"
}

// Ignore rdd when checking results
override def sameResult(plan: SparkPlan): Boolean = plan match {
case other: DataSourceScanExec => relation == other.relation && metadata == other.metadata
case other: DataSourceScanExec =>
val thisPredicates = partitionPredicate.map(cleanExpression)
val otherPredicates = other.partitionPredicate.map(cleanExpression)
val result = relation == other.relation && metadata == other.metadata &&
thisPredicates.isDefined == otherPredicates.isDefined &&
thisPredicates.zip(otherPredicates).forall(p => p._1.semanticEquals(p._2))
result
case _ => false
}

protected def prunedRdd: RDD[InternalRow] = rdd match {
case scanRDD: FileScanRDD if partitionPredicate.nonEmpty =>
// Only HadoopFsRelation support dynamic partition pruning
val files = relation.asInstanceOf[HadoopFsRelation]
// The most right columns are partition columns.
val partitionOutput = output.takeRight(files.partitionSchema.length)
val predicate = newPredicate(partitionPredicate.get, partitionOutput)
var currIndex = 0
val partitions = scanRDD.filePartitions.flatMap { p =>
val pruned = p.files.filter(f => predicate(f.partitionValues))
if (pruned.nonEmpty) {
currIndex += 1
Seq(FilePartition(currIndex - 1, pruned))
} else {
Seq.empty
}
}
new FileScanRDD(scanRDD.sparkSession, scanRDD.readFunction, partitions)
case o => rdd
}

override def inputRDDs(): Seq[RDD[InternalRow]] = {
prunedRdd :: Nil
}
}

/** Physical plan node for scanning data from a relation. */
Expand All @@ -153,8 +190,9 @@ private[sql] case class RowDataSourceScanExec(
@transient relation: BaseRelation,
override val outputPartitioning: Partitioning,
override val metadata: Map[String, String],
override val metastoreTableIdentifier: Option[TableIdentifier])
extends DataSourceScanExec with CodegenSupport {
override val metastoreTableIdentifier: Option[TableIdentifier],
override val partitionPredicate: Option[Expression] = None)
extends DataSourceScanExec {

private[sql] override lazy val metrics =
Map("numOutputRows" -> SQLMetrics.createMetric(sparkContext, "number of output rows"))
Expand All @@ -169,9 +207,9 @@ private[sql] case class RowDataSourceScanExec(

protected override def doExecute(): RDD[InternalRow] = {
val unsafeRow = if (outputUnsafeRows) {
rdd
prunedRdd
} else {
rdd.mapPartitionsInternal { iter =>
prunedRdd.mapPartitionsInternal { iter =>
val proj = UnsafeProjection.create(schema)
iter.map(proj)
}
Expand All @@ -193,10 +231,6 @@ private[sql] case class RowDataSourceScanExec(
s"${Utils.truncatedString(metadataEntries, " ", ", ", "")}"
}

override def inputRDDs(): Seq[RDD[InternalRow]] = {
rdd :: Nil
}

override protected def doProduce(ctx: CodegenContext): String = {
val numOutputRows = metricTerm(ctx, "numOutputRows")
// PhysicalRDD always just has one input
Expand Down Expand Up @@ -228,7 +262,8 @@ private[sql] case class BatchedDataSourceScanExec(
@transient relation: BaseRelation,
override val outputPartitioning: Partitioning,
override val metadata: Map[String, String],
override val metastoreTableIdentifier: Option[TableIdentifier])
override val metastoreTableIdentifier: Option[TableIdentifier],
partitionPredicate: Option[Expression] = None)
extends DataSourceScanExec with CodegenSupport {

private[sql] override lazy val metrics =
Expand All @@ -250,10 +285,6 @@ private[sql] case class BatchedDataSourceScanExec(
s"Batched$nodeName${Utils.truncatedString(output, "[", ",", "]")}$metadataStr"
}

override def inputRDDs(): Seq[RDD[InternalRow]] = {
rdd :: Nil
}

private def genCodeColumnVector(ctx: CodegenContext, columnVar: String, ordinal: String,
dataType: DataType, nullable: Boolean): ExprCode = {
val javaType = ctx.javaType(dataType)
Expand Down Expand Up @@ -347,7 +378,8 @@ private[sql] object DataSourceScanExec {
rdd: RDD[InternalRow],
relation: BaseRelation,
metadata: Map[String, String] = Map.empty,
metastoreTableIdentifier: Option[TableIdentifier] = None): DataSourceScanExec = {
metastoreTableIdentifier: Option[TableIdentifier] = None,
partitionPredicate: Option[Expression] = None): DataSourceScanExec = {
val outputPartitioning = {
val bucketSpec = relation match {
// TODO: this should be closer to bucket planning.
Expand All @@ -373,10 +405,12 @@ private[sql] object DataSourceScanExec {
case r: HadoopFsRelation
if r.fileFormat.supportBatch(r.sparkSession, StructType.fromAttributes(output)) =>
BatchedDataSourceScanExec(
output, rdd, relation, outputPartitioning, metadata, metastoreTableIdentifier)
output, rdd, relation, outputPartitioning, metadata, metastoreTableIdentifier,
partitionPredicate)
case _ =>
RowDataSourceScanExec(
output, rdd, relation, outputPartitioning, metadata, metastoreTableIdentifier)
output, rdd, relation, outputPartitioning, metadata, metastoreTableIdentifier,
partitionPredicate)
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,8 @@ class QueryExecution(val sparkSession: SparkSession, val logical: LogicalPlan) {
PlanSubqueries(sparkSession),
EnsureRequirements(sparkSession.sessionState.conf),
CollapseCodegenStages(sparkSession.sessionState.conf),
ReuseExchange(sparkSession.sessionState.conf))
ReuseExchange(sparkSession.sessionState.conf),
ReuseSubquery(sparkSession.sessionState.conf))

protected def stringOrError[A](f: => A): String =
try f.toString catch { case e: Throwable => e.toString }
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,12 @@ package org.apache.spark.sql.execution

import org.apache.spark.sql.ExperimentalMethods
import org.apache.spark.sql.catalyst.catalog.SessionCatalog
import org.apache.spark.sql.catalyst.optimizer.Optimizer
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.optimizer.{CombineFilters, Optimizer, PushDownPredicate, PushPredicateThroughJoin}
import org.apache.spark.sql.catalyst.plans.{Inner, LeftOuter, LeftSemi, RightOuter}
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.execution.datasources.{HadoopFsRelation, LogicalRelation}
import org.apache.spark.sql.execution.python.ExtractPythonUDFFromAggregate
import org.apache.spark.sql.internal.SQLConf

Expand All @@ -30,6 +35,105 @@ class SparkOptimizer(
extends Optimizer(catalog, conf) {

override def batches: Seq[Batch] = super.batches :+
Batch("PartitionPruning", Once,
PartitionPruning(conf),
OptimizeSubqueries) :+
Batch("Pushdown pruning subquery", fixedPoint,
PushPredicateThroughJoin,
PushDownPredicate,
CombineFilters) :+
Batch("Extract Python UDF from Aggregate", Once, ExtractPythonUDFFromAggregate) :+
Batch("User Provided Optimizers", fixedPoint, experimentalMethods.extraOptimizations: _*)
}

/**
* Inserts a predicate for partitioned table when partition column is used as join key.
*/
case class PartitionPruning(conf: SQLConf) extends Rule[LogicalPlan] with PredicateHelper {

/**
* Returns whether an attribute is a partition column or not.
*/
private def isPartitioned(a: Expression, plan: LogicalPlan): Boolean = {
plan.foreach {
case l: LogicalRelation if a.references.subsetOf(l.outputSet) =>
l.relation match {
case fs: HadoopFsRelation =>
val partitionColumns = AttributeSet(
l.resolve(fs.partitionSchema, fs.sparkSession.sessionState.analyzer.resolver))
if (a.references.subsetOf(partitionColumns)) {
return true
}
case _ =>
}
case _ =>
}
false
}

private def insertPredicate(
partitionedPlan: LogicalPlan,
partitioned: Expression,
otherPlan: LogicalPlan,
value: Expression): LogicalPlan = {
val alias = value match {
case a: Attribute => a
case o => Alias(o, o.toString)()
}
Filter(
PredicateSubquery(Aggregate(Seq(alias), Seq(alias), otherPlan), Seq(partitioned)),
partitionedPlan)
}

def apply(plan: LogicalPlan): LogicalPlan = {
if (!conf.partitionPruning) {
return plan
}
plan transformUp {
case join @ Join(left, right, joinType, Some(condition)) =>
var newLeft = left
var newRight = right
splitConjunctivePredicates(condition).foreach {
case e @ EqualTo(a: Expression, b: Expression) =>
// they should come from different sides, otherwise should be pushed down
val (l, r) = if (a.references.subsetOf(left.outputSet) &&
b.references.subsetOf(right.outputSet)) {
a -> b
} else {
b -> a
}
if (isPartitioned(l, left) && hasHighlySelectivePredicate(right) &&
(joinType == Inner || joinType == LeftSemi || joinType == RightOuter) &&
r.references.subsetOf(right.outputSet)) {
newLeft = insertPredicate(newLeft, l, right, r)
} else if (isPartitioned(r, right) && hasHighlySelectivePredicate(left) &&
(joinType == Inner || joinType == LeftOuter) &&
l.references.subsetOf(left.outputSet)) {
newRight = insertPredicate(newRight, r, left, l)
}
case _ =>
}
Join(newLeft, newRight, joinType, Some(condition))
}
}

/**
* Returns whether an expression is highly selective or not.
*/
def isHighlySelective(e: Expression): Boolean = e match {
case Not(expr) => isHighlySelective(expr)
case And(l, r) => isHighlySelective(l) || isHighlySelective(r)
case Or(l, r) => isHighlySelective(l) && isHighlySelective(r)
case _: BinaryComparison => true
case _: In | _: InSet => true
case _: StringPredicate => true
case _ => false
}

def hasHighlySelectivePredicate(plan: LogicalPlan): Boolean = {
plan.find {
case f: Filter => isHighlySelective(f.condition)
case _ => false
}.isDefined
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -142,21 +142,17 @@ abstract class SparkPlan extends QueryPlan[SparkPlan] with Logging with Serializ
* This list is populated by [[prepareSubqueries]], which is called in [[prepare]].
*/
@transient
private val subqueryResults = new ArrayBuffer[(ScalarSubquery, Future[Array[InternalRow]])]
private val runningSubqueries = new ArrayBuffer[ExecSubqueryExpression]

/**
* Finds scalar subquery expressions in this plan node and starts evaluating them.
* The list of subqueries are added to [[subqueryResults]].
*/
protected def prepareSubqueries(): Unit = {
val allSubqueries = expressions.flatMap(_.collect {case e: ScalarSubquery => e})
allSubqueries.asInstanceOf[Seq[ScalarSubquery]].foreach { e =>
val futureResult = Future {
// Each subquery should return only one row (and one column). We take two here and throws
// an exception later if the number of rows is greater than one.
e.executedPlan.executeTake(2)
}(SparkPlan.subqueryExecutionContext)
subqueryResults += e -> futureResult
val allSubqueries = expressions.flatMap(_.collect { case e: ExecSubqueryExpression => e })
allSubqueries.foreach {
case e: ExecSubqueryExpression =>
e.plan.prepare()
runningSubqueries += e
}
}

Expand All @@ -165,21 +161,10 @@ abstract class SparkPlan extends QueryPlan[SparkPlan] with Logging with Serializ
*/
protected def waitForSubqueries(): Unit = synchronized {
// fill in the result of subqueries
subqueryResults.foreach { case (e, futureResult) =>
val rows = ThreadUtils.awaitResult(futureResult, Duration.Inf)
if (rows.length > 1) {
sys.error(s"more than one row returned by a subquery used as an expression:\n${e.plan}")
}
if (rows.length == 1) {
assert(rows(0).numFields == 1,
s"Expects 1 field, but got ${rows(0).numFields}; something went wrong in analysis")
e.updateResult(rows(0).get(0, e.dataType))
} else {
// If there is no rows returned, the result should be null.
e.updateResult(null)
}
runningSubqueries.foreach { sub =>
sub.updateResult(sub.plan.executeCollect())
}
subqueryResults.clear()
runningSubqueries.clear()
}

/**
Expand Down
Loading

0 comments on commit 765add8

Please sign in to comment.