diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/AnalyzePartitionCommand.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/AnalyzePartitionCommand.scala index 1ee657dd6f168..d829576f2fe91 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/AnalyzePartitionCommand.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/AnalyzePartitionCommand.scala @@ -103,15 +103,8 @@ case class AnalyzePartitionCommand( // Update the metastore if newly computed statistics are different from those // recorded in the metastore. - - val sizes = CommandUtils.calculateMultipleLocationSizes(sparkSession, tableMeta.identifier, - partitions.map(_.storage.locationUri)) - val newPartitions = partitions.zipWithIndex.flatMap { case (p, idx) => - val newRowCount = rowCounts.get(p.spec) - val newStats = CommandUtils.compareAndGetNewStats(p.stats, sizes(idx), newRowCount) - newStats.map(_ => p.copy(stats = newStats)) - } - + val (_, newPartitions) = CommandUtils.calculatePartitionStats( + sparkSession, tableMeta, partitions, Some(rowCounts)) if (newPartitions.nonEmpty) { sessionState.catalog.alterPartitions(tableMeta.identifier, newPartitions) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/CommandUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/CommandUtils.scala index 8d201ff281aeb..7a18fbdd03d88 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/CommandUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/CommandUtils.scala @@ -91,13 +91,8 @@ object CommandUtils extends Logging { // Calculate table size as a sum of the visible partitions. See SPARK-21079 val partitions = sessionState.catalog.listPartitions(catalogTable.identifier) logInfo(s"Starting to calculate sizes for ${partitions.length} partitions.") - val paths = partitions.map(_.storage.locationUri) - val sizes = calculateMultipleLocationSizes(spark, catalogTable.identifier, paths) - val newPartitions = partitions.zipWithIndex.flatMap { case (p, idx) => - val newRowCount = partitionRowCount.flatMap(_.get(p.spec)) - val newStats = CommandUtils.compareAndGetNewStats(p.stats, sizes(idx), newRowCount) - newStats.map(_ => p.copy(stats = newStats)) - } + val (sizes, newPartitions) = calculatePartitionStats(spark, catalogTable, partitions, + partitionRowCount) (sizes.sum, newPartitions) } logInfo(s"It took ${(System.nanoTime() - startTime) / (1000 * 1000)} ms to calculate" + @@ -105,6 +100,22 @@ object CommandUtils extends Logging { (totalSize, newPartitions) } + def calculatePartitionStats( + spark: SparkSession, + catalogTable: CatalogTable, + partitions: Seq[CatalogTablePartition], + partitionRowCount: Option[Map[TablePartitionSpec, BigInt]] = None): + (Seq[Long], Seq[CatalogTablePartition]) = { + val paths = partitions.map(_.storage.locationUri) + val sizes = calculateMultipleLocationSizes(spark, catalogTable.identifier, paths) + val newPartitions = partitions.zipWithIndex.flatMap { case (p, idx) => + val newRowCount = partitionRowCount.flatMap(_.get(p.spec)) + val newStats = CommandUtils.compareAndGetNewStats(p.stats, sizes(idx), newRowCount) + newStats.map(_ => p.copy(stats = newStats)) + } + (sizes, newPartitions) + } + def calculateSingleLocationSize( sessionState: SessionState, identifier: TableIdentifier,