Skip to content

Commit

Permalink
[SPARK-12850][SQL] Support Bucket Pruning (Predicate Pushdown for Buc…
Browse files Browse the repository at this point in the history
…keted Tables)

JIRA: https://issues.apache.org/jira/browse/SPARK-12850

This PR is to support bucket pruning when the predicates are `EqualTo`, `EqualNullSafe`, `IsNull`, `In`, and `InSet`.

Like HIVE, in this PR, the bucket pruning works when the bucketing key has one and only one column.

So far, I do not find a way to verify how many buckets are actually scanned. However, I did verify it when doing the debug. Could you provide a suggestion how to do it properly? Thank you! cloud-fan yhuai rxin marmbrus

BTW, we can add more cases to support complex predicate including `Or` and `And`. Please let me know if I should do it in this PR.

Maybe we also need to add test cases to verify if bucket pruning works well for each data type.

Author: gatorsmile <[email protected]>

Closes #10942 from gatorsmile/pruningBuckets.
  • Loading branch information
gatorsmile authored and rxin committed Feb 5, 2016
1 parent 6dbfc40 commit e3c75c6
Show file tree
Hide file tree
Showing 3 changed files with 245 additions and 10 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -29,12 +29,14 @@ import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.planning.PhysicalOperation
import org.apache.spark.sql.catalyst.plans.logical
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.sql.catalyst.plans.physical.HashPartitioning
import org.apache.spark.sql.execution.PhysicalRDD.{INPUT_PATHS, PUSHED_FILTERS}
import org.apache.spark.sql.execution.SparkPlan
import org.apache.spark.sql.sources._
import org.apache.spark.sql.types.{StringType, StructType}
import org.apache.spark.unsafe.types.UTF8String
import org.apache.spark.util.{SerializableConfiguration, Utils}
import org.apache.spark.util.collection.BitSet

/**
* A Strategy for planning scans over data sources defined using the sources API.
Expand Down Expand Up @@ -97,10 +99,15 @@ private[sql] object DataSourceStrategy extends Strategy with Logging {
(partitionAndNormalColumnAttrs ++ projects).toSeq
}

// Prune the buckets based on the pushed filters that do not contain partitioning key
// since the bucketing key is not allowed to use the columns in partitioning key
val bucketSet = getBuckets(pushedFilters, t.getBucketSpec)

val scan = buildPartitionedTableScan(
l,
partitionAndNormalColumnProjs,
pushedFilters,
bucketSet,
t.partitionSpec.partitionColumns,
selectedPartitions)

Expand All @@ -124,11 +131,14 @@ private[sql] object DataSourceStrategy extends Strategy with Logging {
val sharedHadoopConf = SparkHadoopUtil.get.conf
val confBroadcast =
t.sqlContext.sparkContext.broadcast(new SerializableConfiguration(sharedHadoopConf))
// Prune the buckets based on the filters
val bucketSet = getBuckets(filters, t.getBucketSpec)
pruneFilterProject(
l,
projects,
filters,
(a, f) => t.buildInternalScan(a.map(_.name).toArray, f, t.paths, confBroadcast)) :: Nil
(a, f) =>
t.buildInternalScan(a.map(_.name).toArray, f, bucketSet, t.paths, confBroadcast)) :: Nil

case l @ LogicalRelation(baseRelation: TableScan, _, _) =>
execution.PhysicalRDD.createFromDataSource(
Expand All @@ -150,6 +160,7 @@ private[sql] object DataSourceStrategy extends Strategy with Logging {
logicalRelation: LogicalRelation,
projections: Seq[NamedExpression],
filters: Seq[Expression],
buckets: Option[BitSet],
partitionColumns: StructType,
partitions: Array[Partition]): SparkPlan = {
val relation = logicalRelation.relation.asInstanceOf[HadoopFsRelation]
Expand All @@ -174,7 +185,7 @@ private[sql] object DataSourceStrategy extends Strategy with Logging {
// assuming partition columns data stored in data files are always consistent with those
// partition values encoded in partition directory paths.
val dataRows = relation.buildInternalScan(
requiredDataColumns.map(_.name).toArray, filters, Array(dir), confBroadcast)
requiredDataColumns.map(_.name).toArray, filters, buckets, Array(dir), confBroadcast)

// Merges data values with partition values.
mergeWithPartitionValues(
Expand Down Expand Up @@ -251,6 +262,69 @@ private[sql] object DataSourceStrategy extends Strategy with Logging {
}
}

// Get the bucket ID based on the bucketing values.
// Restriction: Bucket pruning works iff the bucketing column has one and only one column.
def getBucketId(bucketColumn: Attribute, numBuckets: Int, value: Any): Int = {
val mutableRow = new SpecificMutableRow(Seq(bucketColumn.dataType))
mutableRow(0) = Cast(Literal(value), bucketColumn.dataType).eval(null)
val bucketIdGeneration = UnsafeProjection.create(
HashPartitioning(bucketColumn :: Nil, numBuckets).partitionIdExpression :: Nil,
bucketColumn :: Nil)

bucketIdGeneration(mutableRow).getInt(0)
}

// Get the bucket BitSet by reading the filters that only contains bucketing keys.
// Note: When the returned BitSet is None, no pruning is possible.
// Restriction: Bucket pruning works iff the bucketing column has one and only one column.
private def getBuckets(
filters: Seq[Expression],
bucketSpec: Option[BucketSpec]): Option[BitSet] = {

if (bucketSpec.isEmpty ||
bucketSpec.get.numBuckets == 1 ||
bucketSpec.get.bucketColumnNames.length != 1) {
// None means all the buckets need to be scanned
return None
}

// Just get the first because bucketing pruning only works when the column has one column
val bucketColumnName = bucketSpec.get.bucketColumnNames.head
val numBuckets = bucketSpec.get.numBuckets
val matchedBuckets = new BitSet(numBuckets)
matchedBuckets.clear()

filters.foreach {
case expressions.EqualTo(a: Attribute, Literal(v, _)) if a.name == bucketColumnName =>
matchedBuckets.set(getBucketId(a, numBuckets, v))
case expressions.EqualTo(Literal(v, _), a: Attribute) if a.name == bucketColumnName =>
matchedBuckets.set(getBucketId(a, numBuckets, v))
case expressions.EqualNullSafe(a: Attribute, Literal(v, _)) if a.name == bucketColumnName =>
matchedBuckets.set(getBucketId(a, numBuckets, v))
case expressions.EqualNullSafe(Literal(v, _), a: Attribute) if a.name == bucketColumnName =>
matchedBuckets.set(getBucketId(a, numBuckets, v))
// Because we only convert In to InSet in Optimizer when there are more than certain
// items. So it is possible we still get an In expression here that needs to be pushed
// down.
case expressions.In(a: Attribute, list)
if list.forall(_.isInstanceOf[Literal]) && a.name == bucketColumnName =>
val hSet = list.map(e => e.eval(EmptyRow))
hSet.foreach(e => matchedBuckets.set(getBucketId(a, numBuckets, e)))
case expressions.IsNull(a: Attribute) if a.name == bucketColumnName =>
matchedBuckets.set(getBucketId(a, numBuckets, null))
case _ =>
}

logInfo {
val selected = matchedBuckets.cardinality()
val percentPruned = (1 - selected.toDouble / numBuckets.toDouble) * 100
s"Selected $selected buckets out of $numBuckets, pruned $percentPruned% partitions."
}

// None means all the buckets need to be scanned
if (matchedBuckets.cardinality() == 0) None else Some(matchedBuckets)
}

protected def prunePartitions(
predicates: Seq[Expression],
partitionSpec: PartitionSpec): Seq[Partition] = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ import org.apache.spark.sql.execution.datasources._
import org.apache.spark.sql.execution.streaming.{Sink, Source}
import org.apache.spark.sql.types.{StringType, StructType}
import org.apache.spark.util.SerializableConfiguration
import org.apache.spark.util.collection.BitSet

/**
* ::DeveloperApi::
Expand Down Expand Up @@ -722,6 +723,7 @@ abstract class HadoopFsRelation private[sql](
final private[sql] def buildInternalScan(
requiredColumns: Array[String],
filters: Array[Filter],
bucketSet: Option[BitSet],
inputPaths: Array[String],
broadcastedConf: Broadcast[SerializableConfiguration]): RDD[InternalRow] = {
val inputStatuses = inputPaths.flatMap { input =>
Expand All @@ -743,9 +745,16 @@ abstract class HadoopFsRelation private[sql](
// id from file name. Then read these files into a RDD(use one-partition empty RDD for empty
// bucket), and coalesce it to one partition. Finally union all bucket RDDs to one result.
val perBucketRows = (0 until maybeBucketSpec.get.numBuckets).map { bucketId =>
groupedBucketFiles.get(bucketId).map { inputStatuses =>
buildInternalScan(requiredColumns, filters, inputStatuses, broadcastedConf).coalesce(1)
}.getOrElse(sqlContext.emptyResult)
// If the current bucketId is not set in the bucket bitSet, skip scanning it.
if (bucketSet.nonEmpty && !bucketSet.get.get(bucketId)){
sqlContext.emptyResult
} else {
// When all the buckets need a scan (i.e., bucketSet is equal to None)
// or when the current bucket need a scan (i.e., the bit of bucketId is set to true)
groupedBucketFiles.get(bucketId).map { inputStatuses =>
buildInternalScan(requiredColumns, filters, inputStatuses, broadcastedConf).coalesce(1)
}.getOrElse(sqlContext.emptyResult)
}
}

new UnionRDD(sqlContext.sparkContext, perBucketRows)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,22 +19,28 @@ package org.apache.spark.sql.sources

import java.io.File

import org.apache.spark.sql.{Column, DataFrame, DataFrameWriter, QueryTest, SQLConf}
import org.apache.spark.sql.catalyst.expressions.UnsafeProjection
import org.apache.spark.sql._
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.physical.HashPartitioning
import org.apache.spark.sql.execution.Exchange
import org.apache.spark.sql.execution.datasources.BucketSpec
import org.apache.spark.sql.execution.{Exchange, PhysicalRDD}
import org.apache.spark.sql.execution.datasources.{BucketSpec, DataSourceStrategy}
import org.apache.spark.sql.execution.joins.SortMergeJoin
import org.apache.spark.sql.functions._
import org.apache.spark.sql.hive.test.TestHiveSingleton
import org.apache.spark.sql.test.SQLTestUtils
import org.apache.spark.util.Utils
import org.apache.spark.util.collection.BitSet

class BucketedReadSuite extends QueryTest with SQLTestUtils with TestHiveSingleton {
import testImplicits._

private val df = (0 until 50).map(i => (i % 5, i % 13, i.toString)).toDF("i", "j", "k")
private val nullDF = (for {
i <- 0 to 50
s <- Seq(null, "a", "b", "c", "d", "e", "f", null, "g")
} yield (i % 5, s, i % 13)).toDF("i", "j", "k")

test("read bucketed data") {
val df = (0 until 50).map(i => (i % 5, i % 13, i.toString)).toDF("i", "j", "k")
withTable("bucketed_table") {
df.write
.format("parquet")
Expand All @@ -59,6 +65,152 @@ class BucketedReadSuite extends QueryTest with SQLTestUtils with TestHiveSinglet
}
}

// To verify if the bucket pruning works, this function checks two conditions:
// 1) Check if the pruned buckets (before filtering) are empty.
// 2) Verify the final result is the same as the expected one
private def checkPrunedAnswers(
bucketSpec: BucketSpec,
bucketValues: Seq[Integer],
filterCondition: Column,
originalDataFrame: DataFrame): Unit = {

val bucketedDataFrame = hiveContext.table("bucketed_table").select("i", "j", "k")
val BucketSpec(numBuckets, bucketColumnNames, _) = bucketSpec
// Limit: bucket pruning only works when the bucket column has one and only one column
assert(bucketColumnNames.length == 1)
val bucketColumnIndex = bucketedDataFrame.schema.fieldIndex(bucketColumnNames.head)
val bucketColumn = bucketedDataFrame.schema.toAttributes(bucketColumnIndex)
val matchedBuckets = new BitSet(numBuckets)
bucketValues.foreach { value =>
matchedBuckets.set(DataSourceStrategy.getBucketId(bucketColumn, numBuckets, value))
}

// Filter could hide the bug in bucket pruning. Thus, skipping all the filters
val rdd = bucketedDataFrame.filter(filterCondition).queryExecution.executedPlan
.find(_.isInstanceOf[PhysicalRDD])
assert(rdd.isDefined)

val checkedResult = rdd.get.execute().mapPartitionsWithIndex { case (index, iter) =>
if (matchedBuckets.get(index % numBuckets)) Iterator(true) else Iterator(iter.isEmpty)
}
// checking if all the pruned buckets are empty
assert(checkedResult.collect().forall(_ == true))

checkAnswer(
bucketedDataFrame.filter(filterCondition).orderBy("i", "j", "k"),
originalDataFrame.filter(filterCondition).orderBy("i", "j", "k"))
}

test("read partitioning bucketed tables with bucket pruning filters") {
withTable("bucketed_table") {
val numBuckets = 8
val bucketSpec = BucketSpec(numBuckets, Seq("j"), Nil)
// json does not support predicate push-down, and thus json is used here
df.write
.format("json")
.partitionBy("i")
.bucketBy(numBuckets, "j")
.saveAsTable("bucketed_table")

for (j <- 0 until 13) {
// Case 1: EqualTo
checkPrunedAnswers(
bucketSpec,
bucketValues = j :: Nil,
filterCondition = $"j" === j,
df)

// Case 2: EqualNullSafe
checkPrunedAnswers(
bucketSpec,
bucketValues = j :: Nil,
filterCondition = $"j" <=> j,
df)

// Case 3: In
checkPrunedAnswers(
bucketSpec,
bucketValues = Seq(j, j + 1, j + 2, j + 3),
filterCondition = $"j".isin(j, j + 1, j + 2, j + 3),
df)
}
}
}

test("read non-partitioning bucketed tables with bucket pruning filters") {
withTable("bucketed_table") {
val numBuckets = 8
val bucketSpec = BucketSpec(numBuckets, Seq("j"), Nil)
// json does not support predicate push-down, and thus json is used here
df.write
.format("json")
.bucketBy(numBuckets, "j")
.saveAsTable("bucketed_table")

for (j <- 0 until 13) {
checkPrunedAnswers(
bucketSpec,
bucketValues = j :: Nil,
filterCondition = $"j" === j,
df)
}
}
}

test("read partitioning bucketed tables having null in bucketing key") {
withTable("bucketed_table") {
val numBuckets = 8
val bucketSpec = BucketSpec(numBuckets, Seq("j"), Nil)
// json does not support predicate push-down, and thus json is used here
nullDF.write
.format("json")
.partitionBy("i")
.bucketBy(numBuckets, "j")
.saveAsTable("bucketed_table")

// Case 1: isNull
checkPrunedAnswers(
bucketSpec,
bucketValues = null :: Nil,
filterCondition = $"j".isNull,
nullDF)

// Case 2: <=> null
checkPrunedAnswers(
bucketSpec,
bucketValues = null :: Nil,
filterCondition = $"j" <=> null,
nullDF)
}
}

test("read partitioning bucketed tables having composite filters") {
withTable("bucketed_table") {
val numBuckets = 8
val bucketSpec = BucketSpec(numBuckets, Seq("j"), Nil)
// json does not support predicate push-down, and thus json is used here
df.write
.format("json")
.partitionBy("i")
.bucketBy(numBuckets, "j")
.saveAsTable("bucketed_table")

for (j <- 0 until 13) {
checkPrunedAnswers(
bucketSpec,
bucketValues = j :: Nil,
filterCondition = $"j" === j && $"k" > $"j",
df)

checkPrunedAnswers(
bucketSpec,
bucketValues = j :: Nil,
filterCondition = $"j" === j && $"i" > j % 5,
df)
}
}
}

private val df1 = (0 until 50).map(i => (i % 5, i % 13, i.toString)).toDF("i", "j", "k").as("df1")
private val df2 = (0 until 50).map(i => (i % 7, i % 11, i.toString)).toDF("i", "j", "k").as("df2")

Expand Down

0 comments on commit e3c75c6

Please sign in to comment.