Skip to content

Commit

Permalink
enable decaying batch size
Browse files Browse the repository at this point in the history
  • Loading branch information
Madhukar525722 committed Oct 3, 2024
1 parent dc26820 commit af4adda
Showing 1 changed file with 48 additions and 23 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ import java.util.{ArrayList => JArrayList, List => JList, Locale, Map => JMap, S
import java.util.concurrent.TimeUnit

import scala.collection.JavaConverters._
import scala.collection.mutable
import scala.util.control.NonFatal

import org.apache.hadoop.fs.Path
Expand Down Expand Up @@ -1149,32 +1150,56 @@ private[client] class Shim_v0_13 extends Shim_v0_12 {
private def getPartitionsInBatches(
hive: Hive,
table: Table,
size: Int,
initialBatchSize: Int,
partNames: Seq[String]): java.util.Collection[Partition] = {
logInfo(s"Starting the batch processing of partitions.")

val retryCount = SQLConf.get.metastorePartitionBatchRetryCount
val batches = partNames.grouped(size).toSeq
val processedPartitions = batches.flatMap { batch =>
var partitions: java.util.Collection[Partition] = null
var retry = 0
while (partitions == null && retry < retryCount) {
try {
partitions = hive.getPartitionsByNames(table, batch.asJava)
} catch {
case ex: Exception =>
logWarning(
s"Caught exception while fetching partition metadata for batch '$batch'.",
ex
)
retry += 1
if (retry > retryCount) {
logError(s"Failed to fetch all the partition metadata. Retries count exceeded.")
}
}
val maxRetries = SQLConf.get.metastorePartitionBatchRetryCount
val decayingFactor = 2

if (initialBatchSize <= 0) {
throw new IllegalArgumentException(s"Invalid batch size $initialBatchSize provided " +
s"for fetching partitions.Batch size must be greater than 0")
}

if (maxRetries < 0) {
throw new IllegalArgumentException(s"Invalid number of maximum retries $maxRetries " +
s"provided for fetching partitions.It must be a non-negative integer value")
}

logInfo(s"Breaking your request into small batches of '$initialBatchSize'.")

var batchSize = initialBatchSize
val processedPartitions = mutable.ListBuffer[Partition]()
var retryCount = 0
var index = 0

def getNextBatchSize(): Int = {
val currentBatchSize = batchSize
batchSize = (batchSize / decayingFactor) max 1
currentBatchSize
}

var currentBatchSize = getNextBatchSize()
var partitions: java.util.Collection[Partition] = null

while (index < partNames.size && retryCount <= maxRetries) {
val batch = partNames.slice(index, index + currentBatchSize)

try {
partitions = hive.getPartitionsByNames(table, batch.asJava)
processedPartitions ++= partitions.asScala
index += batch.size
} catch {
case ex: Exception =>
logWarning(s"Caught exception while fetching partitions for batch '$batch'.", ex)
retryCount += 1
currentBatchSize = getNextBatchSize()
logInfo(s"Further reducing batch size to '$currentBatchSize'.")
if (retryCount > maxRetries) {
logError(s"Failed to fetch partitions for the request. Retries count exceeded.")
}
}
partitions.asScala
}

processedPartitions.asJava
}

Expand Down

0 comments on commit af4adda

Please sign in to comment.