Skip to content
This repository has been archived by the owner on Sep 18, 2023. It is now read-only.

[NSE-1075] Dynamically adjust input partition size #1076

Open
wants to merge 8 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.sql.execution.datasources.v2.arrow

import org.apache.spark.sql.catalyst.expressions.AttributeReference
import org.apache.spark.sql.execution.datasources.PartitioningAwareFileIndex


object ScanUtils {

def toAttributes(fileIndex: PartitioningAwareFileIndex): Seq[AttributeReference] = {
fileIndex.partitionSchema.toAttributes
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -16,15 +16,21 @@
*/
package com.intel.oap.spark.sql.execution.datasources.v2.arrow

import scala.collection.JavaConverters._
import com.intel.oap.sql.shims.SparkShimLoader

import org.apache.hadoop.fs.Path
import java.util.Locale

import org.apache.spark.sql.SparkSession
import scala.collection.JavaConverters._
import org.apache.hadoop.fs.Path
import org.apache.spark.sql.{AnalysisException, SparkSession}
import org.apache.spark.sql.catalyst.expressions.Expression
import org.apache.spark.sql.catalyst.expressions.codegen.GenerateUnsafeProjection
import org.apache.spark.sql.connector.read.PartitionReaderFactory
import org.apache.spark.sql.execution.datasources.PartitioningAwareFileIndex
import org.apache.spark.sql.execution.PartitionedFileUtil
import org.apache.spark.sql.execution.datasources.{FilePartition, PartitionDirectory, PartitioningAwareFileIndex}
import org.apache.spark.sql.execution.datasources.v2.FileScan
import org.apache.spark.sql.execution.datasources.v2.arrow.ScanUtils
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.sources.Filter
import org.apache.spark.sql.types.StructType
import org.apache.spark.sql.util.CaseInsensitiveStringMap
Expand All @@ -41,6 +47,9 @@ case class ArrowScan(
dataFilters: Seq[Expression] = Seq.empty)
extends FileScan {

// Use the default value for org.apache.spark.internal.config.IO_WARNING_LARGEFILETHRESHOLD.
val IO_WARNING_LARGEFILETHRESHOLD: Long = 1024 * 1024 * 1024

override def isSplitable(path: Path): Boolean = {
ArrowUtils.isOriginalFormatSplitable(
new ArrowOptions(new CaseInsensitiveStringMap(options).asScala.toMap))
Expand All @@ -63,4 +72,90 @@ case class ArrowScan(
override def withFilters(partitionFilters: Seq[Expression],
dataFilters: Seq[Expression]): FileScan =
this.copy(partitionFilters = partitionFilters, dataFilters = dataFilters)

// compute maxSplitBytes
def maxSplitBytes(sparkSession: SparkSession,
selectedPartitions: Seq[PartitionDirectory]): Long = {
val defaultMaxSplitBytes = sparkSession.sessionState.conf.filesMaxPartitionBytes
val openCostInBytes = sparkSession.sessionState.conf.filesOpenCostInBytes
// val minPartitionNum = sparkSession.sessionState.conf.filesMinPartitionNum
// .getOrElse(sparkSession.leafNodeDefaultParallelism)
val minPartitionNum = sparkSession.sessionState.conf.filesMinPartitionNum
.getOrElse(SparkShimLoader.getSparkShims.leafNodeDefaultParallelism(sparkSession))
val PREFERRED_PARTITION_SIZE_LOWER_BOUND: Long = 128 * 1024 * 1024
val PREFERRED_PARTITION_SIZE_UPPER_BOUND: Long = 512 * 1024 * 1024
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe add a new config for these tow value?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for your advice. The PREFERRED_PARTITION_SIZE_UPPER_BOUND may do the same limitation with spark's max partition size configuration. They can be unified. Maybe, we can make PREFERRED_PARTITION_SIZE_LOWER_BOUND configurable.

val totalBytes = selectedPartitions.flatMap(_.files.map(_.getLen + openCostInBytes)).sum
var maxBytesPerCore = totalBytes / minPartitionNum
var bytesPerCoreFinal = maxBytesPerCore
var bytesPerCore = maxBytesPerCore
var i = 2
while (bytesPerCore > PREFERRED_PARTITION_SIZE_UPPER_BOUND) {
bytesPerCore = maxBytesPerCore / i
if (bytesPerCore > PREFERRED_PARTITION_SIZE_LOWER_BOUND) {
bytesPerCoreFinal = bytesPerCore
}
i = i + 1
}
Math.min(PREFERRED_PARTITION_SIZE_UPPER_BOUND, bytesPerCoreFinal)
// Math.min(defaultMaxSplitBytes, Math.max(openCostInBytes, bytesPerCore))
}

override def partitions: Seq[FilePartition] = {
val selectedPartitions = fileIndex.listFiles(partitionFilters, dataFilters)
// val maxSplitBytes = FilePartition.maxSplitBytes(sparkSession, selectedPartitions)
val maxSplitBytes = this.maxSplitBytes(sparkSession, selectedPartitions)
// val partitionAttributes = fileIndex.partitionSchema.toAttributes
val partitionAttributes = ScanUtils.toAttributes(fileIndex)
val attributeMap = partitionAttributes.map(a => normalizeName(a.name) -> a).toMap
val readPartitionAttributes = readPartitionSchema.map { readField =>
attributeMap.get(normalizeName(readField.name)).getOrElse {
// throw QueryCompilationErrors.cannotFindPartitionColumnInPartitionSchemaError(
// readField, fileIndex.partitionSchema)
throw new RuntimeException(s"Can't find required partition column ${readField.name} " +
s"in partition schema ${fileIndex.partitionSchema}")
}
}
lazy val partitionValueProject =
GenerateUnsafeProjection.generate(readPartitionAttributes, partitionAttributes)
val splitFiles = selectedPartitions.flatMap { partition =>
// Prune partition values if part of the partition columns are not required.
val partitionValues = if (readPartitionAttributes != partitionAttributes) {
partitionValueProject(partition.values).copy()
} else {
partition.values
}
partition.files.flatMap { file =>
val filePath = file.getPath
PartitionedFileUtil.splitFiles(
sparkSession = sparkSession,
file = file,
filePath = filePath,
isSplitable = isSplitable(filePath),
maxSplitBytes = maxSplitBytes,
partitionValues = partitionValues
)
}.toArray.sortBy(_.length)(implicitly[Ordering[Long]].reverse)
}

if (splitFiles.length == 1) {
val path = new Path(splitFiles(0).filePath)
if (!isSplitable(path) && splitFiles(0).length >
IO_WARNING_LARGEFILETHRESHOLD) {
logWarning(s"Loading one large unsplittable file ${path.toString} with only one " +
s"partition, the reason is: ${getFileUnSplittableReason(path)}")
}
}

FilePartition.getFilePartitions(sparkSession, splitFiles, maxSplitBytes)
}

private val isCaseSensitive = sparkSession.sessionState.conf.caseSensitiveAnalysis

private def normalizeName(name: String): String = {
if (isCaseSensitive) {
name
} else {
name.toLowerCase(Locale.ROOT)
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,7 @@ import org.apache.spark.shuffle.MigratableResolver
import org.apache.spark.shuffle.ShuffleHandle
import org.apache.spark.shuffle.api.ShuffleExecutorComponents
import org.apache.spark.shuffle.sort.SortShuffleWriter
import org.apache.spark.sql.SQLContext
import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.{SQLContext, SparkSession}
import org.apache.spark.sql.catalyst.expressions.Expression
import org.apache.spark.sql.catalyst.plans.physical.{BroadcastMode, Partitioning}
import org.apache.spark.sql.execution.{ShufflePartitionSpec, SparkPlan}
Expand Down Expand Up @@ -121,4 +120,7 @@ trait SparkShims {
def getEndMapIndexOfCoalescedMapperPartitionSpec(spec: ShufflePartitionSpec): Int

def getNumReducersOfCoalescedMapperPartitionSpec(spec: ShufflePartitionSpec): Int

def leafNodeDefaultParallelism(sparkSession: SparkSession): Int

}
Original file line number Diff line number Diff line change
Expand Up @@ -205,4 +205,8 @@ class Spark311Shims extends SparkShims {
throw new RuntimeException("This method should not be invoked in spark 3.1.")
}

override def leafNodeDefaultParallelism(sparkSession: SparkSession): Int = {
sparkSession.sparkContext.defaultParallelism
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -235,4 +235,8 @@ class Spark321Shims extends SparkShims {
}
}

override def leafNodeDefaultParallelism(sparkSession: SparkSession): Int = {
org.apache.spark.sql.util.ShimUtils.leafNodeDefaultParallelism(sparkSession)
}

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
/*
* Copyright 2020 Intel Corporation
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.sql.util

import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.internal.SQLConf

object ShimUtils {

def leafNodeDefaultParallelism(sparkSession: SparkSession): Int = {
sparkSession.conf.get(SQLConf.LEAF_NODE_DEFAULT_PARALLELISM).getOrElse(
sparkSession.sparkContext.defaultParallelism)
}
}