Skip to content


[SPARK-12850][SQL] Support Bucket Pruning (Predicate Pushdown for Buc…
Browse files Browse the repository at this point in the history
…keted Tables)


This PR is to support bucket pruning when the predicates are `EqualTo`, `EqualNullSafe`, `IsNull`, `In`, and `InSet`.

Like HIVE, in this PR, the bucket pruning works when the bucketing key has one and only one column.

So far, I do not find a way to verify how many buckets are actually scanned. However, I did verify it when doing the debug. Could you provide a suggestion how to do it properly? Thank you! cloud-fan yhuai rxin marmbrus

BTW, we can add more cases to support complex predicate including `Or` and `And`. Please let me know if I should do it in this PR.

Maybe we also need to add test cases to verify if bucket pruning works well for each data type.

Author: gatorsmile <[email protected]>

Closes #10942 from gatorsmile/pruningBuckets.
  • Loading branch information
gatorsmile authored and rxin committed Feb 5, 2016
1 parent 6dbfc40 commit e3c75c6
Show file tree
Hide file tree
Showing 3 changed files with 245 additions and 10 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -29,12 +29,14 @@ import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.planning.PhysicalOperation
import org.apache.spark.sql.catalyst.plans.logical
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.sql.catalyst.plans.physical.HashPartitioning
import org.apache.spark.sql.execution.PhysicalRDD.{INPUT_PATHS, PUSHED_FILTERS}
import org.apache.spark.sql.execution.SparkPlan
import org.apache.spark.sql.sources._
import org.apache.spark.sql.types.{StringType, StructType}
import org.apache.spark.unsafe.types.UTF8String
import org.apache.spark.util.{SerializableConfiguration, Utils}
import org.apache.spark.util.collection.BitSet

* A Strategy for planning scans over data sources defined using the sources API.
Expand Down Expand Up @@ -97,10 +99,15 @@ private[sql] object DataSourceStrategy extends Strategy with Logging {
(partitionAndNormalColumnAttrs ++ projects).toSeq

// Prune the buckets based on the pushed filters that do not contain partitioning key
// since the bucketing key is not allowed to use the columns in partitioning key
val bucketSet = getBuckets(pushedFilters, t.getBucketSpec)

val scan = buildPartitionedTableScan(

Expand All @@ -124,11 +131,14 @@ private[sql] object DataSourceStrategy extends Strategy with Logging {
val sharedHadoopConf = SparkHadoopUtil.get.conf
val confBroadcast =
t.sqlContext.sparkContext.broadcast(new SerializableConfiguration(sharedHadoopConf))
// Prune the buckets based on the filters
val bucketSet = getBuckets(filters, t.getBucketSpec)
(a, f) => t.buildInternalScan(, f, t.paths, confBroadcast)) :: Nil
(a, f) =>
t.buildInternalScan(, f, bucketSet, t.paths, confBroadcast)) :: Nil

case l @ LogicalRelation(baseRelation: TableScan, _, _) =>
Expand All @@ -150,6 +160,7 @@ private[sql] object DataSourceStrategy extends Strategy with Logging {
logicalRelation: LogicalRelation,
projections: Seq[NamedExpression],
filters: Seq[Expression],
buckets: Option[BitSet],
partitionColumns: StructType,
partitions: Array[Partition]): SparkPlan = {
val relation = logicalRelation.relation.asInstanceOf[HadoopFsRelation]
Expand All @@ -174,7 +185,7 @@ private[sql] object DataSourceStrategy extends Strategy with Logging {
// assuming partition columns data stored in data files are always consistent with those
// partition values encoded in partition directory paths.
val dataRows = relation.buildInternalScan(, filters, Array(dir), confBroadcast), filters, buckets, Array(dir), confBroadcast)

// Merges data values with partition values.
Expand Down Expand Up @@ -251,6 +262,69 @@ private[sql] object DataSourceStrategy extends Strategy with Logging {

// Get the bucket ID based on the bucketing values.
// Restriction: Bucket pruning works iff the bucketing column has one and only one column.
def getBucketId(bucketColumn: Attribute, numBuckets: Int, value: Any): Int = {
val mutableRow = new SpecificMutableRow(Seq(bucketColumn.dataType))
mutableRow(0) = Cast(Literal(value), bucketColumn.dataType).eval(null)
val bucketIdGeneration = UnsafeProjection.create(
HashPartitioning(bucketColumn :: Nil, numBuckets).partitionIdExpression :: Nil,
bucketColumn :: Nil)


// Get the bucket BitSet by reading the filters that only contains bucketing keys.
// Note: When the returned BitSet is None, no pruning is possible.
// Restriction: Bucket pruning works iff the bucketing column has one and only one column.
private def getBuckets(
filters: Seq[Expression],
bucketSpec: Option[BucketSpec]): Option[BitSet] = {

if (bucketSpec.isEmpty ||
bucketSpec.get.numBuckets == 1 ||
bucketSpec.get.bucketColumnNames.length != 1) {
// None means all the buckets need to be scanned
return None

// Just get the first because bucketing pruning only works when the column has one column
val bucketColumnName = bucketSpec.get.bucketColumnNames.head
val numBuckets = bucketSpec.get.numBuckets
val matchedBuckets = new BitSet(numBuckets)

filters.foreach {
case expressions.EqualTo(a: Attribute, Literal(v, _)) if == bucketColumnName =>
matchedBuckets.set(getBucketId(a, numBuckets, v))
case expressions.EqualTo(Literal(v, _), a: Attribute) if == bucketColumnName =>
matchedBuckets.set(getBucketId(a, numBuckets, v))
case expressions.EqualNullSafe(a: Attribute, Literal(v, _)) if == bucketColumnName =>
matchedBuckets.set(getBucketId(a, numBuckets, v))
case expressions.EqualNullSafe(Literal(v, _), a: Attribute) if == bucketColumnName =>
matchedBuckets.set(getBucketId(a, numBuckets, v))
// Because we only convert In to InSet in Optimizer when there are more than certain
// items. So it is possible we still get an In expression here that needs to be pushed
// down.
case expressions.In(a: Attribute, list)
if list.forall(_.isInstanceOf[Literal]) && == bucketColumnName =>
val hSet = => e.eval(EmptyRow))
hSet.foreach(e => matchedBuckets.set(getBucketId(a, numBuckets, e)))
case expressions.IsNull(a: Attribute) if == bucketColumnName =>
matchedBuckets.set(getBucketId(a, numBuckets, null))
case _ =>

logInfo {
val selected = matchedBuckets.cardinality()
val percentPruned = (1 - selected.toDouble / numBuckets.toDouble) * 100
s"Selected $selected buckets out of $numBuckets, pruned $percentPruned% partitions."

// None means all the buckets need to be scanned
if (matchedBuckets.cardinality() == 0) None else Some(matchedBuckets)

protected def prunePartitions(
predicates: Seq[Expression],
partitionSpec: PartitionSpec): Seq[Partition] = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ import org.apache.spark.sql.execution.datasources._
import org.apache.spark.sql.execution.streaming.{Sink, Source}
import org.apache.spark.sql.types.{StringType, StructType}
import org.apache.spark.util.SerializableConfiguration
import org.apache.spark.util.collection.BitSet

* ::DeveloperApi::
Expand Down Expand Up @@ -722,6 +723,7 @@ abstract class HadoopFsRelation private[sql](
final private[sql] def buildInternalScan(
requiredColumns: Array[String],
filters: Array[Filter],
bucketSet: Option[BitSet],
inputPaths: Array[String],
broadcastedConf: Broadcast[SerializableConfiguration]): RDD[InternalRow] = {
val inputStatuses = inputPaths.flatMap { input =>
Expand All @@ -743,9 +745,16 @@ abstract class HadoopFsRelation private[sql](
// id from file name. Then read these files into a RDD(use one-partition empty RDD for empty
// bucket), and coalesce it to one partition. Finally union all bucket RDDs to one result.
val perBucketRows = (0 until maybeBucketSpec.get.numBuckets).map { bucketId =>
groupedBucketFiles.get(bucketId).map { inputStatuses =>
buildInternalScan(requiredColumns, filters, inputStatuses, broadcastedConf).coalesce(1)
// If the current bucketId is not set in the bucket bitSet, skip scanning it.
if (bucketSet.nonEmpty && !bucketSet.get.get(bucketId)){
} else {
// When all the buckets need a scan (i.e., bucketSet is equal to None)
// or when the current bucket need a scan (i.e., the bit of bucketId is set to true)
groupedBucketFiles.get(bucketId).map { inputStatuses =>
buildInternalScan(requiredColumns, filters, inputStatuses, broadcastedConf).coalesce(1)

new UnionRDD(sqlContext.sparkContext, perBucketRows)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,22 +19,28 @@ package org.apache.spark.sql.sources


import org.apache.spark.sql.{Column, DataFrame, DataFrameWriter, QueryTest, SQLConf}
import org.apache.spark.sql.catalyst.expressions.UnsafeProjection
import org.apache.spark.sql._
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.physical.HashPartitioning
import org.apache.spark.sql.execution.Exchange
import org.apache.spark.sql.execution.datasources.BucketSpec
import org.apache.spark.sql.execution.{Exchange, PhysicalRDD}
import org.apache.spark.sql.execution.datasources.{BucketSpec, DataSourceStrategy}
import org.apache.spark.sql.execution.joins.SortMergeJoin
import org.apache.spark.sql.functions._
import org.apache.spark.sql.hive.test.TestHiveSingleton
import org.apache.spark.sql.test.SQLTestUtils
import org.apache.spark.util.Utils
import org.apache.spark.util.collection.BitSet

class BucketedReadSuite extends QueryTest with SQLTestUtils with TestHiveSingleton {
import testImplicits._

private val df = (0 until 50).map(i => (i % 5, i % 13, i.toString)).toDF("i", "j", "k")
private val nullDF = (for {
i <- 0 to 50
s <- Seq(null, "a", "b", "c", "d", "e", "f", null, "g")
} yield (i % 5, s, i % 13)).toDF("i", "j", "k")

test("read bucketed data") {
val df = (0 until 50).map(i => (i % 5, i % 13, i.toString)).toDF("i", "j", "k")
withTable("bucketed_table") {
Expand All @@ -59,6 +65,152 @@ class BucketedReadSuite extends QueryTest with SQLTestUtils with TestHiveSinglet

// To verify if the bucket pruning works, this function checks two conditions:
// 1) Check if the pruned buckets (before filtering) are empty.
// 2) Verify the final result is the same as the expected one
private def checkPrunedAnswers(
bucketSpec: BucketSpec,
bucketValues: Seq[Integer],
filterCondition: Column,
originalDataFrame: DataFrame): Unit = {

val bucketedDataFrame = hiveContext.table("bucketed_table").select("i", "j", "k")
val BucketSpec(numBuckets, bucketColumnNames, _) = bucketSpec
// Limit: bucket pruning only works when the bucket column has one and only one column
assert(bucketColumnNames.length == 1)
val bucketColumnIndex = bucketedDataFrame.schema.fieldIndex(bucketColumnNames.head)
val bucketColumn = bucketedDataFrame.schema.toAttributes(bucketColumnIndex)
val matchedBuckets = new BitSet(numBuckets)
bucketValues.foreach { value =>
matchedBuckets.set(DataSourceStrategy.getBucketId(bucketColumn, numBuckets, value))

// Filter could hide the bug in bucket pruning. Thus, skipping all the filters
val rdd = bucketedDataFrame.filter(filterCondition).queryExecution.executedPlan

val checkedResult = rdd.get.execute().mapPartitionsWithIndex { case (index, iter) =>
if (matchedBuckets.get(index % numBuckets)) Iterator(true) else Iterator(iter.isEmpty)
// checking if all the pruned buckets are empty
assert(checkedResult.collect().forall(_ == true))

bucketedDataFrame.filter(filterCondition).orderBy("i", "j", "k"),
originalDataFrame.filter(filterCondition).orderBy("i", "j", "k"))

test("read partitioning bucketed tables with bucket pruning filters") {
withTable("bucketed_table") {
val numBuckets = 8
val bucketSpec = BucketSpec(numBuckets, Seq("j"), Nil)
// json does not support predicate push-down, and thus json is used here
.bucketBy(numBuckets, "j")

for (j <- 0 until 13) {
// Case 1: EqualTo
bucketValues = j :: Nil,
filterCondition = $"j" === j,

// Case 2: EqualNullSafe
bucketValues = j :: Nil,
filterCondition = $"j" <=> j,

// Case 3: In
bucketValues = Seq(j, j + 1, j + 2, j + 3),
filterCondition = $"j".isin(j, j + 1, j + 2, j + 3),

test("read non-partitioning bucketed tables with bucket pruning filters") {
withTable("bucketed_table") {
val numBuckets = 8
val bucketSpec = BucketSpec(numBuckets, Seq("j"), Nil)
// json does not support predicate push-down, and thus json is used here
.bucketBy(numBuckets, "j")

for (j <- 0 until 13) {
bucketValues = j :: Nil,
filterCondition = $"j" === j,

test("read partitioning bucketed tables having null in bucketing key") {
withTable("bucketed_table") {
val numBuckets = 8
val bucketSpec = BucketSpec(numBuckets, Seq("j"), Nil)
// json does not support predicate push-down, and thus json is used here
.bucketBy(numBuckets, "j")

// Case 1: isNull
bucketValues = null :: Nil,
filterCondition = $"j".isNull,

// Case 2: <=> null
bucketValues = null :: Nil,
filterCondition = $"j" <=> null,

test("read partitioning bucketed tables having composite filters") {
withTable("bucketed_table") {
val numBuckets = 8
val bucketSpec = BucketSpec(numBuckets, Seq("j"), Nil)
// json does not support predicate push-down, and thus json is used here
.bucketBy(numBuckets, "j")

for (j <- 0 until 13) {
bucketValues = j :: Nil,
filterCondition = $"j" === j && $"k" > $"j",

bucketValues = j :: Nil,
filterCondition = $"j" === j && $"i" > j % 5,

private val df1 = (0 until 50).map(i => (i % 5, i % 13, i.toString)).toDF("i", "j", "k").as("df1")
private val df2 = (0 until 50).map(i => (i % 7, i % 11, i.toString)).toDF("i", "j", "k").as("df2")

Expand Down

0 comments on commit e3c75c6

Please sign in to comment.