Skip to content

Commit

Permalink
enable decaying batch size
Browse files Browse the repository at this point in the history
  • Loading branch information
Madhukar525722 committed Oct 2, 2024
1 parent dc26820 commit 06b87d5
Showing 1 changed file with 35 additions and 15 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ import java.util.{ArrayList => JArrayList, List => JList, Locale, Map => JMap, S
import java.util.concurrent.TimeUnit

import scala.collection.JavaConverters._
import scala.collection.mutable
import scala.util.control.NonFatal

import org.apache.hadoop.fs.Path
Expand Down Expand Up @@ -1149,31 +1150,50 @@ private[client] class Shim_v0_13 extends Shim_v0_12 {
private def getPartitionsInBatches(
hive: Hive,
table: Table,
size: Int,
initialBatchSize: Int,
partNames: Seq[String]): java.util.Collection[Partition] = {
logInfo(s"Starting the batch processing of partitions.")
val maxRetries = SQLConf.get.metastorePartitionBatchRetryCount
val decayingFactor = 2

require(initialBatchSize > 0,
s"Invalid batch size $initialBatchSize provided for fetching partitions. " +
s"Batch size must be greater than 0")
require(maxRetries >= 0,
s"Invalid number of maximum retries $maxRetries provided for fetching partitions." +
s" It must be a non-negative integer value")

logInfo(s"Breaking your request into small batches of '$initialBatchSize'.")

var batchSize = initialBatchSize
val processedPartitions = mutable.ListBuffer[Partition]()
var retryCount = 0
var index = 0

def getNextBatchSize(): Int = {
val currentBatchSize = batchSize
batchSize = (batchSize / decayingFactor) max 1
currentBatchSize
}

val retryCount = SQLConf.get.metastorePartitionBatchRetryCount
val batches = partNames.grouped(size).toSeq
val processedPartitions = batches.flatMap { batch =>
while (index < partNames.size && retryCount <= maxRetries) {
val currentBatchSize = getNextBatchSize()
val batch = partNames.slice(index, index + currentBatchSize)
var partitions: java.util.Collection[Partition] = null
var retry = 0
while (partitions == null && retry < retryCount) {

while (partitions == null && retryCount <= maxRetries) {
try {
partitions = hive.getPartitionsByNames(table, batch.asJava)
processedPartitions ++= partitions.asScala
index += batch.size
} catch {
case ex: Exception =>
logWarning(
s"Caught exception while fetching partition metadata for batch '$batch'.",
ex
)
retry += 1
if (retry > retryCount) {
logError(s"Failed to fetch all the partition metadata. Retries count exceeded.")
logWarning(s"Caught exception while fetching partitions for batch '$batch'.", ex)
retryCount += 1
if (retryCount > maxRetries) {
logError(s"Failed to fetch partitions for the request. Retries count exceeded.")
}
}
}
partitions.asScala
}
processedPartitions.asJava
}
Expand Down

0 comments on commit 06b87d5

Please sign in to comment.