Skip to content

Commit

Permalink
[SPARK-49827][CORE] Adding batches with retry mechanism for fetching …
Browse files Browse the repository at this point in the history
…all partitions from metastore
  • Loading branch information
Madhukar525722 committed Sep 29, 2024
1 parent a49d6f4 commit dc26820
Show file tree
Hide file tree
Showing 3 changed files with 115 additions and 2 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -4423,6 +4423,29 @@ object SQLConf {
.booleanConf
.createWithDefault(false)

val HMS_BATCH_SIZE = buildConf("spark.sql.hive.metastore.batchSize")
.internal()
.doc("This setting defines the batch size for fetching metadata partitions from the" +
"Hive Metastore. A value of -1 disables batching by default. To enable batching," +
"specify a positive integer, which will determine the batch size for partition fetching."
)
.version("3.5.3")
.intConf
.createWithDefault(-1)

val METASTORE_PARTITION_BATCH_RETRY_COUNT = buildConf(
"spark.sql.metastore.partition.batch.retry.count")
.internal()
.doc(
"This setting specifies the number of retries for fetching partitions from the metastore" +
"in case of failure to fetch batch metadata. This retry mechanism is applicable only" +
"when HMS_BATCH_SIZE is enabled. It defines the count for the number of " +
"retries to be done."
)
.version("3.5.3")
.intConf
.createWithDefault(3)

/**
* Holds information about keys that have been deprecated.
*
Expand Down Expand Up @@ -5278,6 +5301,10 @@ class SQLConf extends Serializable with Logging with SqlApiConf {
getConf(SQLConf.LEGACY_NEGATIVE_INDEX_IN_ARRAY_INSERT)
}

def getHiveMetaStoreBatchSize: Int = getConf(HMS_BATCH_SIZE)

def metastorePartitionBatchRetryCount: Int = getConf(METASTORE_PARTITION_BATCH_RETRY_COUNT)

/** ********************** SQLConf functionality methods ************ */

/** Set Spark SQL configuration properties. */
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1140,6 +1140,44 @@ private[client] class Shim_v0_13 extends Shim_v0_12 {
partitions.asScala.toSeq
}

private def getPartitionNamesWithCount(hive: Hive, table: Table): (Int, Seq[String]) = {
val partitionNames = hive.getPartitionNames(
table.getDbName, table.getTableName, -1).asScala.toSeq
(partitionNames.length, partitionNames)
}

private def getPartitionsInBatches(
hive: Hive,
table: Table,
size: Int,
partNames: Seq[String]): java.util.Collection[Partition] = {
logInfo(s"Starting the batch processing of partitions.")

val retryCount = SQLConf.get.metastorePartitionBatchRetryCount
val batches = partNames.grouped(size).toSeq
val processedPartitions = batches.flatMap { batch =>
var partitions: java.util.Collection[Partition] = null
var retry = 0
while (partitions == null && retry < retryCount) {
try {
partitions = hive.getPartitionsByNames(table, batch.asJava)
} catch {
case ex: Exception =>
logWarning(
s"Caught exception while fetching partition metadata for batch '$batch'.",
ex
)
retry += 1
if (retry > retryCount) {
logError(s"Failed to fetch all the partition metadata. Retries count exceeded.")
}
}
}
partitions.asScala
}
processedPartitions.asJava
}

private def prunePartitionsFastFallback(
hive: Hive,
table: Table,
Expand All @@ -1157,11 +1195,19 @@ private[client] class Shim_v0_13 extends Shim_v0_12 {
}
}

val batchSize = SQLConf.get.getHiveMetaStoreBatchSize

if (!SQLConf.get.metastorePartitionPruningFastFallback ||
predicates.isEmpty ||
predicates.exists(hasTimeZoneAwareExpression)) {
val (count, partNames) = getPartitionNamesWithCount(hive, table)
recordHiveCall()
getAllPartitionsMethod.invoke(hive, table).asInstanceOf[JSet[Partition]]
if(count < batchSize || batchSize == -1) {
getAllPartitionsMethod.invoke(hive, table).asInstanceOf[JSet[Partition]]
}
else {
getPartitionsInBatches(hive, table, batchSize, partNames)
}
} else {
try {
val partitionSchema = CharVarcharUtils.replaceCharVarcharWithStringInSchema(
Expand Down Expand Up @@ -1193,8 +1239,14 @@ private[client] class Shim_v0_13 extends Shim_v0_12 {
case ex: HiveException if ex.getCause.isInstanceOf[MetaException] =>
logWarning("Caught Hive MetaException attempting to get partition metadata by " +
"filter from client side. Falling back to fetching all partition metadata", ex)
val (count, partNames) = getPartitionNamesWithCount(hive, table)
recordHiveCall()
getAllPartitionsMethod.invoke(hive, table).asInstanceOf[JSet[Partition]]
if(count < batchSize || batchSize == -1) {
getAllPartitionsMethod.invoke(hive, table).asInstanceOf[JSet[Partition]]
}
else {
getPartitionsInBatches(hive, table, batchSize, partNames)
}
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -673,6 +673,40 @@ class HivePartitionFilteringSuite(version: String)
}
}

test("getPartitionsByFilter: getPartitionsInBatches") {
var filteredPartitions: Seq[CatalogTablePartition] = Seq()
var filteredPartitionsNoBatch: Seq[CatalogTablePartition] = Seq()
var filteredPartitionsHighBatch: Seq[CatalogTablePartition] = Seq()

withSQLConf(SQLConf.HMS_BATCH_SIZE.key -> "1") {
filteredPartitions = client.getPartitionsByFilter(
client.getRawHiveTable("default", "test"),
Seq(attr("ds") === 20170101)
)
}
withSQLConf(SQLConf.HMS_BATCH_SIZE.key -> "-1") {
filteredPartitionsNoBatch = client.getPartitionsByFilter(
client.getRawHiveTable("default", "test"),
Seq(attr("ds") === 20170101)
)
}
withSQLConf(SQLConf.HMS_BATCH_SIZE.key -> "5000") {
filteredPartitionsHighBatch = client.getPartitionsByFilter(
client.getRawHiveTable("default", "test"),
Seq(attr("ds") === 20170101)
)
}

assert(filteredPartitions.size == filteredPartitionsNoBatch.size)
assert(filteredPartitions.size == filteredPartitionsHighBatch.size)
assert(
filteredPartitions.map(_.spec.toSet).toSet ==
filteredPartitionsNoBatch.map(_.spec.toSet).toSet)
assert(
filteredPartitions.map(_.spec.toSet).toSet ==
filteredPartitionsHighBatch.map(_.spec.toSet).toSet)
}

private def testMetastorePartitionFiltering(
filterExpr: Expression,
expectedDs: Seq[Int],
Expand Down

0 comments on commit dc26820

Please sign in to comment.