diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index 6f2f0088fccd1..e86f6fe086653 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -4423,6 +4423,29 @@ object SQLConf { .booleanConf .createWithDefault(false) + val HMS_BATCH_SIZE = buildConf("spark.sql.hive.metastore.batchSize") + .internal() + .doc("This setting defines the batch size for fetching metadata partitions from the" + + "Hive Metastore. A value of -1 disables batching by default. To enable batching," + + "specify a positive integer, which will determine the batch size for partition fetching." + ) + .version("3.5.3") + .intConf + .createWithDefault(-1) + + val METASTORE_PARTITION_BATCH_RETRY_COUNT = buildConf( + "spark.sql.metastore.partition.batch.retry.count") + .internal() + .doc( + "This setting specifies the number of retries for fetching partitions from the metastore" + + "in case of failure to fetch batch metadata. This retry mechanism is applicable only" + + "when HMS_BATCH_SIZE is enabled. It defines the count for the number of " + + "retries to be done." + ) + .version("3.5.3") + .intConf + .createWithDefault(3) + /** * Holds information about keys that have been deprecated. * @@ -5278,6 +5301,10 @@ class SQLConf extends Serializable with Logging with SqlApiConf { getConf(SQLConf.LEGACY_NEGATIVE_INDEX_IN_ARRAY_INSERT) } + def getHiveMetaStoreBatchSize: Int = getConf(HMS_BATCH_SIZE) + + def metastorePartitionBatchRetryCount: Int = getConf(METASTORE_PARTITION_BATCH_RETRY_COUNT) + /** ********************** SQLConf functionality methods ************ */ /** Set Spark SQL configuration properties. */ diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveShim.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveShim.scala index 60ff9ec42f29d..ad8e57b500595 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveShim.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveShim.scala @@ -1140,6 +1140,44 @@ private[client] class Shim_v0_13 extends Shim_v0_12 { partitions.asScala.toSeq } + private def getPartitionNamesWithCount(hive: Hive, table: Table): (Int, Seq[String]) = { + val partitionNames = hive.getPartitionNames( + table.getDbName, table.getTableName, -1).asScala.toSeq + (partitionNames.length, partitionNames) + } + + private def getPartitionsInBatches( + hive: Hive, + table: Table, + size: Int, + partNames: Seq[String]): java.util.Collection[Partition] = { + logInfo(s"Starting the batch processing of partitions.") + + val retryCount = SQLConf.get.metastorePartitionBatchRetryCount + val batches = partNames.grouped(size).toSeq + val processedPartitions = batches.flatMap { batch => + var partitions: java.util.Collection[Partition] = null + var retry = 0 + while (partitions == null && retry < retryCount) { + try { + partitions = hive.getPartitionsByNames(table, batch.asJava) + } catch { + case ex: Exception => + logWarning( + s"Caught exception while fetching partition metadata for batch '$batch'.", + ex + ) + retry += 1 + if (retry > retryCount) { + logError(s"Failed to fetch all the partition metadata. Retries count exceeded.") + } + } + } + partitions.asScala + } + processedPartitions.asJava + } + private def prunePartitionsFastFallback( hive: Hive, table: Table, @@ -1157,11 +1195,19 @@ private[client] class Shim_v0_13 extends Shim_v0_12 { } } + val batchSize = SQLConf.get.getHiveMetaStoreBatchSize + if (!SQLConf.get.metastorePartitionPruningFastFallback || predicates.isEmpty || predicates.exists(hasTimeZoneAwareExpression)) { + val (count, partNames) = getPartitionNamesWithCount(hive, table) recordHiveCall() - getAllPartitionsMethod.invoke(hive, table).asInstanceOf[JSet[Partition]] + if(count < batchSize || batchSize == -1) { + getAllPartitionsMethod.invoke(hive, table).asInstanceOf[JSet[Partition]] + } + else { + getPartitionsInBatches(hive, table, batchSize, partNames) + } } else { try { val partitionSchema = CharVarcharUtils.replaceCharVarcharWithStringInSchema( @@ -1193,8 +1239,14 @@ private[client] class Shim_v0_13 extends Shim_v0_12 { case ex: HiveException if ex.getCause.isInstanceOf[MetaException] => logWarning("Caught Hive MetaException attempting to get partition metadata by " + "filter from client side. Falling back to fetching all partition metadata", ex) + val (count, partNames) = getPartitionNamesWithCount(hive, table) recordHiveCall() - getAllPartitionsMethod.invoke(hive, table).asInstanceOf[JSet[Partition]] + if(count < batchSize || batchSize == -1) { + getAllPartitionsMethod.invoke(hive, table).asInstanceOf[JSet[Partition]] + } + else { + getPartitionsInBatches(hive, table, batchSize, partNames) + } } } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/HivePartitionFilteringSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/HivePartitionFilteringSuite.scala index b96d28d22cc7f..daac2eb11cbd7 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/HivePartitionFilteringSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/HivePartitionFilteringSuite.scala @@ -673,6 +673,40 @@ class HivePartitionFilteringSuite(version: String) } } + test("getPartitionsByFilter: getPartitionsInBatches") { + var filteredPartitions: Seq[CatalogTablePartition] = Seq() + var filteredPartitionsNoBatch: Seq[CatalogTablePartition] = Seq() + var filteredPartitionsHighBatch: Seq[CatalogTablePartition] = Seq() + + withSQLConf(SQLConf.HMS_BATCH_SIZE.key -> "1") { + filteredPartitions = client.getPartitionsByFilter( + client.getRawHiveTable("default", "test"), + Seq(attr("ds") === 20170101) + ) + } + withSQLConf(SQLConf.HMS_BATCH_SIZE.key -> "-1") { + filteredPartitionsNoBatch = client.getPartitionsByFilter( + client.getRawHiveTable("default", "test"), + Seq(attr("ds") === 20170101) + ) + } + withSQLConf(SQLConf.HMS_BATCH_SIZE.key -> "5000") { + filteredPartitionsHighBatch = client.getPartitionsByFilter( + client.getRawHiveTable("default", "test"), + Seq(attr("ds") === 20170101) + ) + } + + assert(filteredPartitions.size == filteredPartitionsNoBatch.size) + assert(filteredPartitions.size == filteredPartitionsHighBatch.size) + assert( + filteredPartitions.map(_.spec.toSet).toSet == + filteredPartitionsNoBatch.map(_.spec.toSet).toSet) + assert( + filteredPartitions.map(_.spec.toSet).toSet == + filteredPartitionsHighBatch.map(_.spec.toSet).toSet) + } + private def testMetastorePartitionFiltering( filterExpr: Expression, expectedDs: Seq[Int],