Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-20194] Add support for partition pruning to in-memory catalog #17510

Closed
wants to merge 3 commits into from
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ import org.apache.hadoop.util.Shell
import org.apache.spark.sql.AnalysisException
import org.apache.spark.sql.catalyst.analysis.Resolver
import org.apache.spark.sql.catalyst.catalog.CatalogTypes.TablePartitionSpec
import org.apache.spark.sql.catalyst.expressions.{And, AttributeReference, BoundReference, Expression, InterpretedPredicate}

object ExternalCatalogUtils {
// This duplicates default value of Hive `ConfVars.DEFAULTPARTITIONNAME`, since catalyst doesn't
Expand Down Expand Up @@ -125,6 +126,37 @@ object ExternalCatalogUtils {
}
escapePathName(col) + "=" + partitionString
}

def prunePartitionsByFilter(
catalogTable: CatalogTable,
inputPartitions: Seq[CatalogTablePartition],
predicates: Seq[Expression],
defaultTimeZoneId: String): Seq[CatalogTablePartition] = {
if (predicates.isEmpty) {
inputPartitions
} else {
val partitionSchema = catalogTable.partitionSchema
val partitionColumnNames = catalogTable.partitionColumnNames.toSet

val nonPartitionPruningPredicates = predicates.filterNot {
_.references.map(_.name).toSet.subsetOf(partitionColumnNames)
}
if (nonPartitionPruningPredicates.nonEmpty) {
sys.error("Expected only partition pruning predicates: " + nonPartitionPruningPredicates)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: Throwing an AnalysisException is preferred.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you add the negative test cases in your newly added test cases?

}

val boundPredicate =
InterpretedPredicate.create(predicates.reduce(And).transform {
case att: AttributeReference =>
val index = partitionSchema.indexWhere(_.name == att.name)
BoundReference(index, partitionSchema(index).dataType, nullable = true)
})

inputPartitions.filter { p =>
boundPredicate(p.toRow(partitionSchema, defaultTimeZoneId))
}
}
}
}

object CatalogUtils {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ import org.apache.spark.{SparkConf, SparkException}
import org.apache.spark.sql.AnalysisException
import org.apache.spark.sql.catalyst.{FunctionIdentifier, TableIdentifier}
import org.apache.spark.sql.catalyst.analysis._
import org.apache.spark.sql.catalyst.catalog.ExternalCatalogUtils.escapePathName
import org.apache.spark.sql.catalyst.catalog.ExternalCatalogUtils._
import org.apache.spark.sql.catalyst.expressions.Expression
import org.apache.spark.sql.catalyst.util.StringUtils
import org.apache.spark.sql.types.StructType
Expand Down Expand Up @@ -556,9 +556,9 @@ class InMemoryCatalog(
table: String,
predicates: Seq[Expression],
defaultTimeZoneId: String): Seq[CatalogTablePartition] = {
// TODO: Provide an implementation
throw new UnsupportedOperationException(
"listPartitionsByFilter is not implemented for InMemoryCatalog")
val catalogTable = getTable(db, table)
val allPartitions = listPartitions(db, table)
prunePartitionsByFilter(catalogTable, allPartitions, predicates, defaultTimeZoneId)
}

// --------------------------------------------------------------------------
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
package org.apache.spark.sql.catalyst.catalog

import java.net.URI
import java.util.TimeZone

import org.apache.hadoop.conf.Configuration
import org.apache.hadoop.fs.Path
Expand All @@ -28,6 +29,7 @@ import org.apache.spark.sql.AnalysisException
import org.apache.spark.sql.catalyst.{FunctionIdentifier, TableIdentifier}
import org.apache.spark.sql.catalyst.analysis.{FunctionAlreadyExistsException, NoSuchDatabaseException, NoSuchFunctionException}
import org.apache.spark.sql.catalyst.analysis.TableAlreadyExistsException
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.types._
import org.apache.spark.util.Utils

Expand Down Expand Up @@ -436,6 +438,37 @@ abstract class ExternalCatalogSuite extends SparkFunSuite with BeforeAndAfterEac
assert(catalog.listPartitions("db2", "tbl2", Some(Map("a" -> "unknown"))).isEmpty)
}

test("list partitions by filter") {
val tz = TimeZone.getDefault().getID()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: val tz = TimeZone.getDefault.getID

val catalog = newBasicCatalog()

def checkAnswer(table: CatalogTable, filters: Seq[Expression],
expected: Set[CatalogTablePartition]): Unit = {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: code style

def checkAnswer(
    param1: XX, param2: XX, ...

assertResult(expected.map(_.spec)) {
catalog.listPartitionsByFilter(table.database, table.identifier.identifier, filters, tz)
.map(_.spec).toSet
}
}

def pcol(table: CatalogTable, name: String): Expression = {
val col = table.partitionSchema(name)
AttributeReference(col.name, col.dataType, col.nullable)()
}
val tbl2 = catalog.getTable("db2", "tbl2")

checkAnswer(tbl2, Seq.empty, Set(part1, part2))
checkAnswer(tbl2, Seq(EqualTo(pcol(tbl2, "a"), Literal(1))), Set(part1))
checkAnswer(tbl2, Seq(EqualTo(pcol(tbl2, "a"), Literal(2))), Set.empty)
checkAnswer(tbl2, Seq(In(pcol(tbl2, "a"), Seq(Literal(3)))), Set(part2))
checkAnswer(tbl2, Seq(Not(In(pcol(tbl2, "a"), Seq(Literal(4))))), Set(part1, part2))
checkAnswer(tbl2, Seq(
EqualTo(pcol(tbl2, "a"), Literal(1)),
EqualTo(pcol(tbl2, "b"), Literal("2"))), Set(part1))
checkAnswer(tbl2, Seq(
EqualTo(pcol(tbl2, "a"), Literal(1)),
EqualTo(pcol(tbl2, "b"), Literal("x"))), Set.empty)
}

test("drop partitions") {
val catalog = newBasicCatalog()
assert(catalogPartitionsEqual(catalog, "db2", "tbl2", Seq(part1, part2)))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -217,8 +217,6 @@ class FindDataSourceTable(sparkSession: SparkSession) extends Rule[LogicalPlan]
val table = r.tableMeta
val qualifiedTableName = QualifiedTableName(table.database, table.identifier.table)
val cache = sparkSession.sessionState.catalog.tableRelationCache
val withHiveSupport =
sparkSession.sparkContext.conf.get(StaticSQLConf.CATALOG_IMPLEMENTATION) == "hive"

val plan = cache.get(qualifiedTableName, new Callable[LogicalPlan]() {
override def call(): LogicalPlan = {
Expand All @@ -233,8 +231,7 @@ class FindDataSourceTable(sparkSession: SparkSession) extends Rule[LogicalPlan]
bucketSpec = table.bucketSpec,
className = table.provider.get,
options = table.storage.properties ++ pathOption,
// TODO: improve `InMemoryCatalog` and remove this limitation.
catalogTable = if (withHiveSupport) Some(table) else None)
catalogTable = Some(table))

LogicalRelation(
dataSource.resolveRelation(checkFilesExist = false),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ import org.apache.spark.sql.AnalysisException
import org.apache.spark.sql.catalyst.TableIdentifier
import org.apache.spark.sql.catalyst.analysis.TableAlreadyExistsException
import org.apache.spark.sql.catalyst.catalog._
import org.apache.spark.sql.catalyst.catalog.ExternalCatalogUtils.escapePathName
import org.apache.spark.sql.catalyst.catalog.ExternalCatalogUtils._
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.logical.ColumnStat
import org.apache.spark.sql.catalyst.util.CaseInsensitiveMap
Expand Down Expand Up @@ -1039,37 +1039,14 @@ private[spark] class HiveExternalCatalog(conf: SparkConf, hadoopConf: Configurat
defaultTimeZoneId: String): Seq[CatalogTablePartition] = withClient {
val rawTable = getRawTable(db, table)
val catalogTable = restoreTableMetadata(rawTable)
val partitionColumnNames = catalogTable.partitionColumnNames.toSet
val nonPartitionPruningPredicates = predicates.filterNot {
_.references.map(_.name).toSet.subsetOf(partitionColumnNames)
}

if (nonPartitionPruningPredicates.nonEmpty) {
sys.error("Expected only partition pruning predicates: " +
predicates.reduceLeft(And))
}
val partColNameMap = buildLowerCasePartColNameMap(catalogTable)

val partitionSchema = catalogTable.partitionSchema
val partColNameMap = buildLowerCasePartColNameMap(getTable(db, table))

if (predicates.nonEmpty) {
val clientPrunedPartitions = client.getPartitionsByFilter(rawTable, predicates).map { part =>
val clientPrunedPartitions =
client.getPartitionsByFilter(rawTable, predicates).map { part =>
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if predicates.isEmpty, the previous code will run client.getPartitions. Can you double check there is no performance regression?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A similar optimization is done in the function itself: client.getPartitionsByFilter(), while Hive Shim_v0_12 delegates to getAllPartitions anyway.

I've now made the only piece of non-trivial code along that path lazy, so I think we're good.

part.copy(spec = restorePartitionSpec(part.spec, partColNameMap))
}
val boundPredicate =
InterpretedPredicate.create(predicates.reduce(And).transform {
case att: AttributeReference =>
val index = partitionSchema.indexWhere(_.name == att.name)
BoundReference(index, partitionSchema(index).dataType, nullable = true)
})
clientPrunedPartitions.filter { p =>
boundPredicate(p.toRow(partitionSchema, defaultTimeZoneId))
}
} else {
client.getPartitions(catalogTable).map { part =>
part.copy(spec = restorePartitionSpec(part.spec, partColNameMap))
}
}
prunePartitionsByFilter(catalogTable, clientPrunedPartitions, predicates, defaultTimeZoneId)
}

// --------------------------------------------------------------------------
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -50,13 +50,6 @@ class HiveExternalCatalogSuite extends ExternalCatalogSuite {

import utils._

test("list partitions by filter") {
val catalog = newBasicCatalog()
val selectedPartitions = catalog.listPartitionsByFilter("db2", "tbl2", Seq('a.int === 1), "GMT")
assert(selectedPartitions.length == 1)
assert(selectedPartitions.head.spec == part1.spec)
}

test("SPARK-18647: do not put provider in table properties for Hive serde table") {
val catalog = newBasicCatalog()
val hiveTable = CatalogTable(
Expand Down