Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-24172][SQL]: Push projection and filters once when converting to physical plan. #21262

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@ import org.apache.spark.sql.ExperimentalMethods
import org.apache.spark.sql.catalyst.catalog.SessionCatalog
import org.apache.spark.sql.catalyst.optimizer.Optimizer
import org.apache.spark.sql.execution.datasources.PruneFileSourcePartitions
import org.apache.spark.sql.execution.datasources.v2.PushDownOperatorsToDataSource
import org.apache.spark.sql.execution.python.ExtractPythonUDFFromAggregate

class SparkOptimizer(
Expand All @@ -32,8 +31,7 @@ class SparkOptimizer(
override def batches: Seq[Batch] = (preOptimizationBatches ++ super.batches :+
Batch("Optimize Metadata Only Query", Once, OptimizeMetadataOnlyQuery(catalog)) :+
Batch("Extract Python UDF from Aggregate", Once, ExtractPythonUDFFromAggregate) :+
Batch("Prune File Source Table Partitions", Once, PruneFileSourcePartitions) :+
Batch("Push down operators to data source scan", Once, PushDownOperatorsToDataSource)) ++
Batch("Prune File Source Table Partitions", Once, PruneFileSourcePartitions)) ++
postHocOptimizationBatches :+
Batch("User Provided Optimizers", fixedPoint, experimentalMethods.extraOptimizations: _*)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@ import scala.collection.JavaConverters._
import org.apache.spark.sql.AnalysisException
import org.apache.spark.sql.catalyst.analysis.MultiInstanceRelation
import org.apache.spark.sql.catalyst.expressions.{AttributeReference, Expression}
import org.apache.spark.sql.catalyst.plans.QueryPlan
import org.apache.spark.sql.catalyst.plans.logical.{LeafNode, LogicalPlan, Statistics}
import org.apache.spark.sql.execution.datasources.DataSourceStrategy
import org.apache.spark.sql.sources.{DataSourceRegister, Filter}
Expand All @@ -32,79 +31,35 @@ import org.apache.spark.sql.types.StructType

case class DataSourceV2Relation(
source: DataSourceV2,
output: Seq[AttributeReference],
options: Map[String, String],
projection: Seq[AttributeReference],
filters: Option[Seq[Expression]] = None,
userSpecifiedSchema: Option[StructType] = None)
extends LeafNode with MultiInstanceRelation with DataSourceV2StringFormat {

import DataSourceV2Relation._

override def simpleString: String = "RelationV2 " + metadataString

override lazy val schema: StructType = reader.readSchema()

override lazy val output: Seq[AttributeReference] = {
// use the projection attributes to avoid assigning new ids. fields that are not projected
// will be assigned new ids, which is okay because they are not projected.
val attrMap = projection.map(a => a.name -> a).toMap
schema.map(f => attrMap.getOrElse(f.name,
AttributeReference(f.name, f.dataType, f.nullable, f.metadata)()))
}

private lazy val v2Options: DataSourceOptions = makeV2Options(options)
override def pushedFilters: Seq[Expression] = Seq.empty

// postScanFilters: filters that need to be evaluated after the scan.
// pushedFilters: filters that will be pushed down and evaluated in the underlying data sources.
// Note: postScanFilters and pushedFilters can overlap, e.g. the parquet row group filter.
lazy val (
reader: DataSourceReader,
postScanFilters: Seq[Expression],
pushedFilters: Seq[Expression]) = {
val newReader = userSpecifiedSchema match {
case Some(s) =>
source.asReadSupportWithSchema.createReader(s, v2Options)
case _ =>
source.asReadSupport.createReader(v2Options)
}

DataSourceV2Relation.pushRequiredColumns(newReader, projection.toStructType)

val (postScanFilters, pushedFilters) = filters match {
case Some(filterSeq) =>
DataSourceV2Relation.pushFilters(newReader, filterSeq)
case _ =>
(Nil, Nil)
}
logInfo(s"Post-Scan Filters: ${postScanFilters.mkString(",")}")
logInfo(s"Pushed Filters: ${pushedFilters.mkString(", ")}")

(newReader, postScanFilters, pushedFilters)
}

override def doCanonicalize(): LogicalPlan = {
val c = super.doCanonicalize().asInstanceOf[DataSourceV2Relation]
override def simpleString: String = "RelationV2 " + metadataString

// override output with canonicalized output to avoid attempting to configure a reader
val canonicalOutput: Seq[AttributeReference] = this.output
.map(a => QueryPlan.normalizeExprId(a, projection))
lazy val v2Options: DataSourceOptions = makeV2Options(options)

new DataSourceV2Relation(c.source, c.options, c.projection) {
override lazy val output: Seq[AttributeReference] = canonicalOutput
}
def newReader: DataSourceReader = userSpecifiedSchema match {
case Some(userSchema) =>
source.asReadSupportWithSchema.createReader(userSchema, v2Options)
case None =>
source.asReadSupport.createReader(v2Options)
}

override def computeStats(): Statistics = reader match {
override def computeStats(): Statistics = newReader match {
case r: SupportsReportStatistics =>
Statistics(sizeInBytes = r.getStatistics.sizeInBytes().orElse(conf.defaultSizeInBytes))
case _ =>
Statistics(sizeInBytes = conf.defaultSizeInBytes)
}

override def newInstance(): DataSourceV2Relation = {
// projection is used to maintain id assignment.
// if projection is not set, use output so the copy is not equal to the original
copy(projection = projection.map(_.newInstance()))
copy(output = output.map(_.newInstance()))
}
}

Expand Down Expand Up @@ -206,21 +161,27 @@ object DataSourceV2Relation {
def create(
source: DataSourceV2,
options: Map[String, String],
filters: Option[Seq[Expression]] = None,
userSpecifiedSchema: Option[StructType] = None): DataSourceV2Relation = {
val projection = schema(source, makeV2Options(options), userSpecifiedSchema).toAttributes
DataSourceV2Relation(source, options, projection, filters, userSpecifiedSchema)
val output = schema(source, makeV2Options(options), userSpecifiedSchema).toAttributes
DataSourceV2Relation(source, output, options, userSpecifiedSchema)
}

private def pushRequiredColumns(reader: DataSourceReader, struct: StructType): Unit = {
def pushRequiredColumns(
relation: DataSourceV2Relation,
reader: DataSourceReader,
struct: StructType): Seq[AttributeReference] = {
reader match {
case projectionSupport: SupportsPushDownRequiredColumns =>
projectionSupport.pruneColumns(struct)
// return the output columns from the relation that were projected
val attrMap = relation.output.map(a => a.name -> a).toMap
projectionSupport.readSchema().map(f => attrMap(f.name))
case _ =>
relation.output
}
}

private def pushFilters(
def pushFilters(
reader: DataSourceReader,
filters: Seq[Expression]): (Seq[Expression], Seq[Expression]) = {
reader match {
Expand Down Expand Up @@ -248,7 +209,7 @@ object DataSourceV2Relation {
// the data source cannot guarantee the rows returned can pass these filters.
// As a result we must return it so Spark can plan an extra filter operator.
val postScanFilters =
r.pushFilters(translatedFilterToExpr.keys.toArray).map(translatedFilterToExpr)
r.pushFilters(translatedFilterToExpr.keys.toArray).map(translatedFilterToExpr)
// The filters which are marked as pushed to this data source
val pushedFilters = r.pushedFilters().map(translatedFilterToExpr)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,15 +17,56 @@

package org.apache.spark.sql.execution.datasources.v2

import org.apache.spark.sql.Strategy
import org.apache.spark.sql.{execution, Strategy}
import org.apache.spark.sql.catalyst.expressions.{And, AttributeReference, AttributeSet}
import org.apache.spark.sql.catalyst.planning.PhysicalOperation
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.sql.execution.SparkPlan
import org.apache.spark.sql.execution.streaming.continuous.{WriteToContinuousDataSource, WriteToContinuousDataSourceExec}

object DataSourceV2Strategy extends Strategy {
override def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match {
case r: DataSourceV2Relation =>
DataSourceV2ScanExec(r.output, r.source, r.options, r.pushedFilters, r.reader) :: Nil
case PhysicalOperation(project, filters, relation: DataSourceV2Relation) =>
val projectSet = AttributeSet(project.flatMap(_.references))
val filterSet = AttributeSet(filters.flatMap(_.references))

val projection = if (filterSet.subsetOf(projectSet) &&
AttributeSet(relation.output) == projectSet) {
// When the required projection contains all of the filter columns and column pruning alone
// can produce the required projection, push the required projection.
// A final projection may still be needed if the data source produces a different column
// order or if it cannot prune all of the nested columns.
relation.output
} else {
// When there are filter columns not already in the required projection or when the required
// projection is more complicated than column pruning, base column pruning on the set of
// all columns needed by both.
(projectSet ++ filterSet).toSeq
}

val reader = relation.newReader

val output = DataSourceV2Relation.pushRequiredColumns(relation, reader,
projection.asInstanceOf[Seq[AttributeReference]].toStructType)

val (postScanFilters, pushedFilters) = DataSourceV2Relation.pushFilters(reader, filters)

logInfo(s"Post-Scan Filters: ${postScanFilters.mkString(",")}")
logInfo(s"Pushed Filters: ${pushedFilters.mkString(", ")}")

val scan = DataSourceV2ScanExec(
output, relation.source, relation.options, pushedFilters, reader)

val filter = postScanFilters.reduceLeftOption(And)
val withFilter = filter.map(execution.FilterExec(_, scan)).getOrElse(scan)

val withProjection = if (withFilter.output != project) {
execution.ProjectExec(project, withFilter)
} else {
withFilter
}

withProjection :: Nil

case r: StreamingDataSourceV2Relation =>
DataSourceV2ScanExec(r.output, r.source, r.options, r.pushedFilters, r.reader) :: Nil
Expand Down

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -323,21 +323,22 @@ class DataSourceV2Suite extends QueryTest with SharedSQLContext {
}

test("SPARK-23315: get output from canonicalized data source v2 related plans") {
def checkCanonicalizedOutput(df: DataFrame, numOutput: Int): Unit = {
def checkCanonicalizedOutput(
df: DataFrame, logicalNumOutput: Int, physicalNumOutput: Int): Unit = {
val logical = df.queryExecution.optimizedPlan.collect {
case d: DataSourceV2Relation => d
}.head
assert(logical.canonicalized.output.length == numOutput)
assert(logical.canonicalized.output.length == logicalNumOutput)

val physical = df.queryExecution.executedPlan.collect {
case d: DataSourceV2ScanExec => d
}.head
assert(physical.canonicalized.output.length == numOutput)
assert(physical.canonicalized.output.length == physicalNumOutput)
}

val df = spark.read.format(classOf[AdvancedDataSourceV2].getName).load()
checkCanonicalizedOutput(df, 2)
checkCanonicalizedOutput(df.select('i), 1)
checkCanonicalizedOutput(df, 2, 2)
checkCanonicalizedOutput(df.select('i), 2, 1)
}
}

Expand Down