Skip to content

Commit

Permalink
[SPARK-23301][SQL] data source column pruning should work for arbitra…
Browse files Browse the repository at this point in the history
…ry expressions

## What changes were proposed in this pull request?

This PR fixes a mistake in the `PushDownOperatorsToDataSource` rule, the column pruning logic is incorrect about `Project`.

## How was this patch tested?

a new test case for column pruning with arbitrary expressions, and improve the existing tests to make sure the `PushDownOperatorsToDataSource` really works.

Author: Wenchen Fan <[email protected]>

Closes #20476 from cloud-fan/push-down.
  • Loading branch information
cloud-fan authored and gatorsmile committed Feb 2, 2018
1 parent b3a0428 commit 19c7c7e
Show file tree
Hide file tree
Showing 3 changed files with 155 additions and 40 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

package org.apache.spark.sql.execution.datasources.v2

import org.apache.spark.sql.catalyst.expressions.{And, Attribute, AttributeMap, Expression, NamedExpression, PredicateHelper}
import org.apache.spark.sql.catalyst.expressions.{And, Attribute, AttributeMap, AttributeSet, Expression, NamedExpression, PredicateHelper}
import org.apache.spark.sql.catalyst.optimizer.RemoveRedundantProject
import org.apache.spark.sql.catalyst.plans.logical.{Filter, LogicalPlan, Project}
import org.apache.spark.sql.catalyst.rules.Rule
Expand Down Expand Up @@ -81,35 +81,34 @@ object PushDownOperatorsToDataSource extends Rule[LogicalPlan] with PredicateHel

// TODO: add more push down rules.

// TODO: nested fields pruning
def pushDownRequiredColumns(plan: LogicalPlan, requiredByParent: Seq[Attribute]): Unit = {
plan match {
case Project(projectList, child) =>
val required = projectList.filter(requiredByParent.contains).flatMap(_.references)
pushDownRequiredColumns(child, required)

case Filter(condition, child) =>
val required = requiredByParent ++ condition.references
pushDownRequiredColumns(child, required)

case DataSourceV2Relation(fullOutput, reader) => reader match {
case r: SupportsPushDownRequiredColumns =>
// Match original case of attributes.
val attrMap = AttributeMap(fullOutput.zip(fullOutput))
val requiredColumns = requiredByParent.map(attrMap)
r.pruneColumns(requiredColumns.toStructType)
case _ =>
}
pushDownRequiredColumns(filterPushed, filterPushed.outputSet)
// After column pruning, we may have redundant PROJECT nodes in the query plan, remove them.
RemoveRedundantProject(filterPushed)
}

// TODO: nested fields pruning
private def pushDownRequiredColumns(plan: LogicalPlan, requiredByParent: AttributeSet): Unit = {
plan match {
case Project(projectList, child) =>
val required = projectList.flatMap(_.references)
pushDownRequiredColumns(child, AttributeSet(required))

case Filter(condition, child) =>
val required = requiredByParent ++ condition.references
pushDownRequiredColumns(child, required)

// TODO: there may be more operators can be used to calculate required columns, we can add
// more and more in the future.
case _ => plan.children.foreach(child => pushDownRequiredColumns(child, child.output))
case relation: DataSourceV2Relation => relation.reader match {
case reader: SupportsPushDownRequiredColumns =>
val requiredColumns = relation.output.filter(requiredByParent.contains)
reader.pruneColumns(requiredColumns.toStructType)

case _ =>
}
}

pushDownRequiredColumns(filterPushed, filterPushed.output)
// After column pruning, we may have redundant PROJECT nodes in the query plan, remove them.
RemoveRedundantProject(filterPushed)
// TODO: there may be more operators that can be used to calculate the required columns. We
// can add more and more in the future.
case _ => plan.children.foreach(child => pushDownRequiredColumns(child, child.outputSet))
}
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,11 +32,12 @@

public class JavaAdvancedDataSourceV2 implements DataSourceV2, ReadSupport {

class Reader implements DataSourceReader, SupportsPushDownRequiredColumns,
public class Reader implements DataSourceReader, SupportsPushDownRequiredColumns,
SupportsPushDownFilters {

private StructType requiredSchema = new StructType().add("i", "int").add("j", "int");
private Filter[] filters = new Filter[0];
// Exposed for testing.
public StructType requiredSchema = new StructType().add("i", "int").add("j", "int");
public Filter[] filters = new Filter[0];

@Override
public StructType readSchema() {
Expand All @@ -50,8 +51,26 @@ public void pruneColumns(StructType requiredSchema) {

@Override
public Filter[] pushFilters(Filter[] filters) {
this.filters = filters;
return new Filter[0];
Filter[] supported = Arrays.stream(filters).filter(f -> {
if (f instanceof GreaterThan) {
GreaterThan gt = (GreaterThan) f;
return gt.attribute().equals("i") && gt.value() instanceof Integer;
} else {
return false;
}
}).toArray(Filter[]::new);

Filter[] unsupported = Arrays.stream(filters).filter(f -> {
if (f instanceof GreaterThan) {
GreaterThan gt = (GreaterThan) f;
return !gt.attribute().equals("i") || !(gt.value() instanceof Integer);
} else {
return true;
}
}).toArray(Filter[]::new);

this.filters = supported;
return unsupported;
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,11 +21,13 @@ import java.util.{ArrayList, List => JList}

import test.org.apache.spark.sql.sources.v2._

import org.apache.spark.{SparkConf, SparkException}
import org.apache.spark.sql.{AnalysisException, QueryTest, Row}
import org.apache.spark.SparkException
import org.apache.spark.sql.{AnalysisException, DataFrame, QueryTest, Row}
import org.apache.spark.sql.catalyst.expressions.UnsafeRow
import org.apache.spark.sql.execution.datasources.v2.DataSourceV2ScanExec
import org.apache.spark.sql.execution.exchange.ShuffleExchangeExec
import org.apache.spark.sql.execution.vectorized.OnHeapColumnVector
import org.apache.spark.sql.functions._
import org.apache.spark.sql.sources.{Filter, GreaterThan}
import org.apache.spark.sql.sources.v2.reader._
import org.apache.spark.sql.sources.v2.reader.partitioning.{ClusteredDistribution, Distribution, Partitioning}
Expand All @@ -48,14 +50,72 @@ class DataSourceV2Suite extends QueryTest with SharedSQLContext {
}

test("advanced implementation") {
def getReader(query: DataFrame): AdvancedDataSourceV2#Reader = {
query.queryExecution.executedPlan.collect {
case d: DataSourceV2ScanExec => d.reader.asInstanceOf[AdvancedDataSourceV2#Reader]
}.head
}

def getJavaReader(query: DataFrame): JavaAdvancedDataSourceV2#Reader = {
query.queryExecution.executedPlan.collect {
case d: DataSourceV2ScanExec => d.reader.asInstanceOf[JavaAdvancedDataSourceV2#Reader]
}.head
}

Seq(classOf[AdvancedDataSourceV2], classOf[JavaAdvancedDataSourceV2]).foreach { cls =>
withClue(cls.getName) {
val df = spark.read.format(cls.getName).load()
checkAnswer(df, (0 until 10).map(i => Row(i, -i)))
checkAnswer(df.select('j), (0 until 10).map(i => Row(-i)))
checkAnswer(df.filter('i > 3), (4 until 10).map(i => Row(i, -i)))
checkAnswer(df.select('j).filter('i > 6), (7 until 10).map(i => Row(-i)))
checkAnswer(df.select('i).filter('i > 10), Nil)

val q1 = df.select('j)
checkAnswer(q1, (0 until 10).map(i => Row(-i)))
if (cls == classOf[AdvancedDataSourceV2]) {
val reader = getReader(q1)
assert(reader.filters.isEmpty)
assert(reader.requiredSchema.fieldNames === Seq("j"))
} else {
val reader = getJavaReader(q1)
assert(reader.filters.isEmpty)
assert(reader.requiredSchema.fieldNames === Seq("j"))
}

val q2 = df.filter('i > 3)
checkAnswer(q2, (4 until 10).map(i => Row(i, -i)))
if (cls == classOf[AdvancedDataSourceV2]) {
val reader = getReader(q2)
assert(reader.filters.flatMap(_.references).toSet == Set("i"))
assert(reader.requiredSchema.fieldNames === Seq("i", "j"))
} else {
val reader = getJavaReader(q2)
assert(reader.filters.flatMap(_.references).toSet == Set("i"))
assert(reader.requiredSchema.fieldNames === Seq("i", "j"))
}

val q3 = df.select('i).filter('i > 6)
checkAnswer(q3, (7 until 10).map(i => Row(i)))
if (cls == classOf[AdvancedDataSourceV2]) {
val reader = getReader(q3)
assert(reader.filters.flatMap(_.references).toSet == Set("i"))
assert(reader.requiredSchema.fieldNames === Seq("i"))
} else {
val reader = getJavaReader(q3)
assert(reader.filters.flatMap(_.references).toSet == Set("i"))
assert(reader.requiredSchema.fieldNames === Seq("i"))
}

val q4 = df.select('j).filter('j < -10)
checkAnswer(q4, Nil)
if (cls == classOf[AdvancedDataSourceV2]) {
val reader = getReader(q4)
// 'j < 10 is not supported by the testing data source.
assert(reader.filters.isEmpty)
assert(reader.requiredSchema.fieldNames === Seq("j"))
} else {
val reader = getJavaReader(q4)
// 'j < 10 is not supported by the testing data source.
assert(reader.filters.isEmpty)
assert(reader.requiredSchema.fieldNames === Seq("j"))
}
}
}
}
Expand Down Expand Up @@ -223,6 +283,39 @@ class DataSourceV2Suite extends QueryTest with SharedSQLContext {
val df2 = df.select(($"i" + 1).as("k"), $"j")
checkAnswer(df.join(df2, "j"), (0 until 10).map(i => Row(-i, i, i + 1)))
}

test("SPARK-23301: column pruning with arbitrary expressions") {
def getReader(query: DataFrame): AdvancedDataSourceV2#Reader = {
query.queryExecution.executedPlan.collect {
case d: DataSourceV2ScanExec => d.reader.asInstanceOf[AdvancedDataSourceV2#Reader]
}.head
}

val df = spark.read.format(classOf[AdvancedDataSourceV2].getName).load()

val q1 = df.select('i + 1)
checkAnswer(q1, (1 until 11).map(i => Row(i)))
val reader1 = getReader(q1)
assert(reader1.requiredSchema.fieldNames === Seq("i"))

val q2 = df.select(lit(1))
checkAnswer(q2, (0 until 10).map(i => Row(1)))
val reader2 = getReader(q2)
assert(reader2.requiredSchema.isEmpty)

// 'j === 1 can't be pushed down, but we should still be able do column pruning
val q3 = df.filter('j === -1).select('j * 2)
checkAnswer(q3, Row(-2))
val reader3 = getReader(q3)
assert(reader3.filters.isEmpty)
assert(reader3.requiredSchema.fieldNames === Seq("j"))

// column pruning should work with other operators.
val q4 = df.sort('i).limit(1).select('i + 1)
checkAnswer(q4, Row(1))
val reader4 = getReader(q4)
assert(reader4.requiredSchema.fieldNames === Seq("i"))
}
}

class SimpleDataSourceV2 extends DataSourceV2 with ReadSupport {
Expand Down Expand Up @@ -270,8 +363,12 @@ class AdvancedDataSourceV2 extends DataSourceV2 with ReadSupport {
}

override def pushFilters(filters: Array[Filter]): Array[Filter] = {
this.filters = filters
Array.empty
val (supported, unsupported) = filters.partition {
case GreaterThan("i", _: Int) => true
case _ => false
}
this.filters = supported
unsupported
}

override def pushedFilters(): Array[Filter] = filters
Expand Down

0 comments on commit 19c7c7e

Please sign in to comment.