Skip to content

Commit

Permalink
[SPARK-49827][SQL] Fetching all partitions from hive metastore in bat…
Browse files Browse the repository at this point in the history
…ches
  • Loading branch information
Madhukar525722 committed Oct 3, 2024
1 parent d97acc1 commit c43af04
Show file tree
Hide file tree
Showing 3 changed files with 138 additions and 2 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -5199,6 +5199,29 @@ object SQLConf {
.booleanConf
.createWithDefault(false)

val HMS_BATCH_SIZE = buildConf("spark.sql.hive.metastore.batchSize")
.internal()
.doc("This setting defines the batch size for fetching metadata partitions from the" +
"Hive Metastore. A value of -1 disables batching by default. To enable batching," +
"specify a positive integer, which will determine the batch size for partition fetching."
)
.version("4.0.0")
.intConf
.createWithDefault(-1)

val METASTORE_PARTITION_BATCH_RETRY_COUNT = buildConf(
"spark.sql.metastore.partition.batch.retry.count")
.internal()
.doc(
"This setting specifies the number of retries for fetching partitions from the metastore" +
"in case of failure to fetch batch metadata. This retry mechanism is applicable only" +
"when HMS_BATCH_SIZE is enabled. It defines the count for the number of " +
"retries to be done."
)
.version("4.0.0")
.intConf
.createWithDefault(3)

/**
* Holds information about keys that have been deprecated.
*
Expand Down Expand Up @@ -6177,6 +6200,10 @@ class SQLConf extends Serializable with Logging with SqlApiConf {

def legacyEvalCurrentTime: Boolean = getConf(SQLConf.LEGACY_EVAL_CURRENT_TIME)

def getHiveMetaStoreBatchSize: Int = getConf(HMS_BATCH_SIZE)

def metastorePartitionBatchRetryCount: Int = getConf(METASTORE_PARTITION_BATCH_RETRY_COUNT)

/** ********************** SQLConf functionality methods ************ */

/** Set Spark SQL configuration properties. */
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ import java.util.{ArrayList => JArrayList, List => JList, Locale, Map => JMap}
import java.util.concurrent.TimeUnit

import scala.jdk.CollectionConverters._
import scala.collection.mutable
import scala.util.control.NonFatal

import org.apache.hadoop.fs.Path
Expand Down Expand Up @@ -389,6 +390,66 @@ private[client] class Shim_v2_0 extends Shim with Logging {
partitions.asScala.toSeq
}

private def getPartitionNamesWithCount(hive: Hive, table: Table): (Int, Seq[String]) = {
val partitionNames = hive.getPartitionNames(
table.getDbName, table.getTableName, -1).asScala.toSeq
(partitionNames.length, partitionNames)
}

private def getPartitionsInBatches(
hive: Hive,
table: Table,
initialBatchSize: Int,
partNames: Seq[String]): java.util.Collection[Partition] = {
val maxRetries = SQLConf.get.metastorePartitionBatchRetryCount
val decayingFactor = 2

if (initialBatchSize <= 0) {
throw new IllegalArgumentException(s"Invalid batch size $initialBatchSize provided " +
s"for fetching partitions.Batch size must be greater than 0")
}

if (maxRetries < 0) {
throw new IllegalArgumentException(s"Invalid number of maximum retries $maxRetries " +
s"provided for fetching partitions.It must be a non-negative integer value")
}

logInfo(s"Breaking your request into small batches of '$initialBatchSize'.")

var batchSize = initialBatchSize
val processedPartitions = mutable.ListBuffer[Partition]()
var retryCount = 0
var index = 0

def getNextBatchSize(): Int = {
val currentBatchSize = batchSize
batchSize = (batchSize / decayingFactor) max 1
currentBatchSize
}

while (index < partNames.size && retryCount <= maxRetries) {
val currentBatchSize = getNextBatchSize()
val batch = partNames.slice(index, index + currentBatchSize)
var partitions: java.util.Collection[Partition] = null

while (partitions == null && retryCount <= maxRetries) {
try {
partitions = hive.getPartitionsByNames(table, batch.asJava)
processedPartitions ++= partitions.asScala
index += batch.size
} catch {
case ex: Exception =>
logWarning(s"Caught exception while fetching partitions for batch '$batch'.", ex)
retryCount += 1
if (retryCount > maxRetries) {
logError(s"Failed to fetch partitions for the request. Retries count exceeded.")
}
}
}
}
processedPartitions.asJava
}

private def prunePartitionsFastFallback(
hive: Hive,
table: Table,
Expand All @@ -406,11 +467,19 @@ private[client] class Shim_v2_0 extends Shim with Logging {
}
}

val batchSize = SQLConf.get.getHiveMetaStoreBatchSize

if (!SQLConf.get.metastorePartitionPruningFastFallback ||
predicates.isEmpty ||
predicates.exists(hasTimeZoneAwareExpression)) {
val (count, partNames) = getPartitionNamesWithCount(hive, table)
recordHiveCall()
hive.getAllPartitionsOf(table)
if(count < batchSize || batchSize == -1) {
hive.getAllPartitionsOf(table)
}
else {
getPartitionsInBatches(hive, table, batchSize, partNames)
}
} else {
try {
val partitionSchema = CharVarcharUtils.replaceCharVarcharWithStringInSchema(
Expand Down Expand Up @@ -442,8 +511,14 @@ private[client] class Shim_v2_0 extends Shim with Logging {
case ex: HiveException if ex.getCause.isInstanceOf[MetaException] =>
logWarning("Caught Hive MetaException attempting to get partition metadata by " +
"filter from client side. Falling back to fetching all partition metadata", ex)
val (count, partNames) = getPartitionNamesWithCount(hive, table)
recordHiveCall()
hive.getAllPartitionsOf(table)
if(count < batchSize || batchSize == -1) {
hive.getAllPartitionsOf(table)
}
else {
getPartitionsInBatches(hive, table, batchSize, partNames)
}
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -726,6 +726,40 @@ class HivePartitionFilteringSuite(version: String)
}
}

test("getPartitionsByFilter: getPartitionsInBatches") {
var filteredPartitions: Seq[CatalogTablePartition] = Seq()
var filteredPartitionsNoBatch: Seq[CatalogTablePartition] = Seq()
var filteredPartitionsHighBatch: Seq[CatalogTablePartition] = Seq()

withSQLConf(SQLConf.HMS_BATCH_SIZE.key -> "1") {
filteredPartitions = client.getPartitionsByFilter(
client.getRawHiveTable("default", "test"),
Seq(attr("ds") === 20170101)
)
}
withSQLConf(SQLConf.HMS_BATCH_SIZE.key -> "-1") {
filteredPartitionsNoBatch = client.getPartitionsByFilter(
client.getRawHiveTable("default", "test"),
Seq(attr("ds") === 20170101)
)
}
withSQLConf(SQLConf.HMS_BATCH_SIZE.key -> "5000") {
filteredPartitionsHighBatch = client.getPartitionsByFilter(
client.getRawHiveTable("default", "test"),
Seq(attr("ds") === 20170101)
)
}

assert(filteredPartitions.size == filteredPartitionsNoBatch.size)
assert(filteredPartitions.size == filteredPartitionsHighBatch.size)
assert(
filteredPartitions.map(_.spec.toSet).toSet ==
filteredPartitionsNoBatch.map(_.spec.toSet).toSet)
assert(
filteredPartitions.map(_.spec.toSet).toSet ==
filteredPartitionsHighBatch.map(_.spec.toSet).toSet)
}

private def testMetastorePartitionFiltering(
filterExpr: Expression,
expectedDs: Seq[Int],
Expand Down

0 comments on commit c43af04

Please sign in to comment.