diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index ea187c0316c17..be1d34037c858 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -5199,6 +5199,29 @@ object SQLConf { .booleanConf .createWithDefault(false) + val HMS_BATCH_SIZE = buildConf("spark.sql.hive.metastore.batchSize") + .internal() + .doc("This setting defines the batch size for fetching metadata partitions from the" + + "Hive Metastore. A value of -1 disables batching by default. To enable batching," + + "specify a positive integer, which will determine the batch size for partition fetching." + ) + .version("4.0.0") + .intConf + .createWithDefault(-1) + + val METASTORE_PARTITION_BATCH_RETRY_COUNT = buildConf( + "spark.sql.metastore.partition.batch.retry.count") + .internal() + .doc( + "This setting specifies the number of retries for fetching partitions from the metastore" + + "in case of failure to fetch batch metadata. This retry mechanism is applicable only" + + "when HMS_BATCH_SIZE is enabled. It defines the count for the number of " + + "retries to be done." + ) + .version("4.0.0") + .intConf + .createWithDefault(3) + /** * Holds information about keys that have been deprecated. * @@ -6177,6 +6200,10 @@ class SQLConf extends Serializable with Logging with SqlApiConf { def legacyEvalCurrentTime: Boolean = getConf(SQLConf.LEGACY_EVAL_CURRENT_TIME) + def getHiveMetaStoreBatchSize: Int = getConf(HMS_BATCH_SIZE) + + def metastorePartitionBatchRetryCount: Int = getConf(METASTORE_PARTITION_BATCH_RETRY_COUNT) + /** ********************** SQLConf functionality methods ************ */ /** Set Spark SQL configuration properties. */ diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveShim.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveShim.scala index c03fed4cc3184..5624aab5bcb67 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveShim.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveShim.scala @@ -23,6 +23,7 @@ import java.util.{ArrayList => JArrayList, List => JList, Locale, Map => JMap} import java.util.concurrent.TimeUnit import scala.jdk.CollectionConverters._ +import scala.collection.mutable import scala.util.control.NonFatal import org.apache.hadoop.fs.Path @@ -389,6 +390,66 @@ private[client] class Shim_v2_0 extends Shim with Logging { partitions.asScala.toSeq } + private def getPartitionNamesWithCount(hive: Hive, table: Table): (Int, Seq[String]) = { + val partitionNames = hive.getPartitionNames( + table.getDbName, table.getTableName, -1).asScala.toSeq + (partitionNames.length, partitionNames) + } + + private def getPartitionsInBatches( + hive: Hive, + table: Table, + initialBatchSize: Int, + partNames: Seq[String]): java.util.Collection[Partition] = { + val maxRetries = SQLConf.get.metastorePartitionBatchRetryCount + val decayingFactor = 2 + + if (initialBatchSize <= 0) { + throw new IllegalArgumentException(s"Invalid batch size $initialBatchSize provided " + + s"for fetching partitions.Batch size must be greater than 0") + } + + if (maxRetries < 0) { + throw new IllegalArgumentException(s"Invalid number of maximum retries $maxRetries " + + s"provided for fetching partitions.It must be a non-negative integer value") + } + + logInfo(s"Breaking your request into small batches of '$initialBatchSize'.") + + var batchSize = initialBatchSize + val processedPartitions = mutable.ListBuffer[Partition]() + var retryCount = 0 + var index = 0 + + def getNextBatchSize(): Int = { + val currentBatchSize = batchSize + batchSize = (batchSize / decayingFactor) max 1 + currentBatchSize + } + + while (index < partNames.size && retryCount <= maxRetries) { + val currentBatchSize = getNextBatchSize() + val batch = partNames.slice(index, index + currentBatchSize) + var partitions: java.util.Collection[Partition] = null + + while (partitions == null && retryCount <= maxRetries) { + try { + partitions = hive.getPartitionsByNames(table, batch.asJava) + processedPartitions ++= partitions.asScala + index += batch.size + } catch { + case ex: Exception => + logWarning(s"Caught exception while fetching partitions for batch '$batch'.", ex) + retryCount += 1 + if (retryCount > maxRetries) { + logError(s"Failed to fetch partitions for the request. Retries count exceeded.") + } + } + } + } + processedPartitions.asJava + } + private def prunePartitionsFastFallback( hive: Hive, table: Table, @@ -406,11 +467,19 @@ private[client] class Shim_v2_0 extends Shim with Logging { } } + val batchSize = SQLConf.get.getHiveMetaStoreBatchSize + if (!SQLConf.get.metastorePartitionPruningFastFallback || predicates.isEmpty || predicates.exists(hasTimeZoneAwareExpression)) { + val (count, partNames) = getPartitionNamesWithCount(hive, table) recordHiveCall() - hive.getAllPartitionsOf(table) + if(count < batchSize || batchSize == -1) { + hive.getAllPartitionsOf(table) + } + else { + getPartitionsInBatches(hive, table, batchSize, partNames) + } } else { try { val partitionSchema = CharVarcharUtils.replaceCharVarcharWithStringInSchema( @@ -442,8 +511,14 @@ private[client] class Shim_v2_0 extends Shim with Logging { case ex: HiveException if ex.getCause.isInstanceOf[MetaException] => logWarning("Caught Hive MetaException attempting to get partition metadata by " + "filter from client side. Falling back to fetching all partition metadata", ex) + val (count, partNames) = getPartitionNamesWithCount(hive, table) recordHiveCall() - hive.getAllPartitionsOf(table) + if(count < batchSize || batchSize == -1) { + hive.getAllPartitionsOf(table) + } + else { + getPartitionsInBatches(hive, table, batchSize, partNames) + } } } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/HivePartitionFilteringSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/HivePartitionFilteringSuite.scala index 1a4eb75547894..ce162a325f44c 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/HivePartitionFilteringSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/HivePartitionFilteringSuite.scala @@ -726,6 +726,40 @@ class HivePartitionFilteringSuite(version: String) } } + test("getPartitionsByFilter: getPartitionsInBatches") { + var filteredPartitions: Seq[CatalogTablePartition] = Seq() + var filteredPartitionsNoBatch: Seq[CatalogTablePartition] = Seq() + var filteredPartitionsHighBatch: Seq[CatalogTablePartition] = Seq() + + withSQLConf(SQLConf.HMS_BATCH_SIZE.key -> "1") { + filteredPartitions = client.getPartitionsByFilter( + client.getRawHiveTable("default", "test"), + Seq(attr("ds") === 20170101) + ) + } + withSQLConf(SQLConf.HMS_BATCH_SIZE.key -> "-1") { + filteredPartitionsNoBatch = client.getPartitionsByFilter( + client.getRawHiveTable("default", "test"), + Seq(attr("ds") === 20170101) + ) + } + withSQLConf(SQLConf.HMS_BATCH_SIZE.key -> "5000") { + filteredPartitionsHighBatch = client.getPartitionsByFilter( + client.getRawHiveTable("default", "test"), + Seq(attr("ds") === 20170101) + ) + } + + assert(filteredPartitions.size == filteredPartitionsNoBatch.size) + assert(filteredPartitions.size == filteredPartitionsHighBatch.size) + assert( + filteredPartitions.map(_.spec.toSet).toSet == + filteredPartitionsNoBatch.map(_.spec.toSet).toSet) + assert( + filteredPartitions.map(_.spec.toSet).toSet == + filteredPartitionsHighBatch.map(_.spec.toSet).toSet) + } + private def testMetastorePartitionFiltering( filterExpr: Expression, expectedDs: Seq[Int],