Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-13184][SQL] Add a datasource-specific option minPartitions in HadoopFsRelation#options #13320

Closed
wants to merge 4 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions core/src/main/scala/org/apache/spark/SparkConf.scala
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ package org.apache.spark
import java.util.concurrent.ConcurrentHashMap

import scala.collection.JavaConverters._
import scala.collection.immutable
import scala.collection.mutable.LinkedHashSet

import org.apache.avro.{Schema, SchemaNormalization}
Expand Down Expand Up @@ -384,6 +385,10 @@ class SparkConf(loadDefaults: Boolean) extends Cloneable with Logging with Seria
.map { case (k, v) => (k.substring(prefix.length), v) }
}

/** Get all parameters as a Map */
def getAllAsMap: immutable.Map[String, String] = {
settings.asScala.toMap
}

/** Get a parameter as an integer, falling back to a default if not set */
def getInt(key: String, defaultValue: Int): Int = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -149,7 +149,7 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging {
paths = paths,
userSpecifiedSchema = userSpecifiedSchema,
className = source,
options = extraOptions.toMap).resolveRelation())
options = optionsOverriddenWith(extraOptions.toMap)).resolveRelation())
}

/**
Expand Down Expand Up @@ -551,4 +551,10 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging {

private var extraOptions = new scala.collection.mutable.HashMap[String, String]

// Returns all option set in the `SparkConf`, the `SQLConf`, and a given data source `options`.
// If the same keys exist, they are overridden with ones in the `options`.
private def optionsOverriddenWith(options: Map[String, String]): Map[String, String] = {
sparkSession.sparkContext.conf.getAllAsMap ++ sparkSession.conf.getAll ++ options
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here, we just cover the code path for DataFrameReader APIs.

}

}
Original file line number Diff line number Diff line change
Expand Up @@ -470,7 +470,8 @@ case class FileSourceScanExec(
val defaultMaxSplitBytes =
fsRelation.sparkSession.sessionState.conf.filesMaxPartitionBytes
val openCostInBytes = fsRelation.sparkSession.sessionState.conf.filesOpenCostInBytes
val defaultParallelism = fsRelation.sparkSession.sparkContext.defaultParallelism
val defaultParallelism = fsRelation.options.get("spark.default.parallelism").map(_.toInt)
.getOrElse(fsRelation.sparkSession.sparkContext.defaultParallelism)
val totalBytes = selectedPartitions.flatMap(_.files.map(_.getLen + openCostInBytes)).sum
val bytesPerCore = totalBytes / defaultParallelism

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -232,7 +232,7 @@ class ParquetFileFormat
.orElse(filesByType.data.headOption)
.toSeq
}
ParquetFileFormat.mergeSchemasInParallel(filesToTouch, sparkSession)
ParquetFileFormat.mergeSchemasInParallel(filesToTouch, sparkSession, parameters)
}

case class FileTypes(
Expand Down Expand Up @@ -561,9 +561,10 @@ object ParquetFileFormat extends Logging {
* slow. And basically locality is not available when using S3 (you can't run computation on
* S3 nodes).
*/
def mergeSchemasInParallel(
private def mergeSchemasInParallel(
filesToTouch: Seq[FileStatus],
sparkSession: SparkSession): Option[StructType] = {
sparkSession: SparkSession,
parameters: Map[String, String]): Option[StructType] = {
val assumeBinaryIsString = sparkSession.sessionState.conf.isParquetBinaryAsString
val assumeInt96IsTimestamp = sparkSession.sessionState.conf.isParquetINT96AsTimestamp
val writeLegacyParquetFormat = sparkSession.sessionState.conf.writeLegacyParquetFormat
Expand All @@ -584,8 +585,9 @@ object ParquetFileFormat extends Logging {

// Set the number of partitions to prevent following schema reads from generating many tasks
// in case of a small number of parquet files.
val numParallelism = Math.min(Math.max(partialFileStatusInfo.size, 1),
sparkSession.sparkContext.defaultParallelism)
val defaultParallelism = parameters.get("spark.default.parallelism").map(_.toInt)
.getOrElse(sparkSession.sparkContext.defaultParallelism)
val numParallelism = Math.min(Math.max(partialFileStatusInfo.size, 1), defaultParallelism)

// Issues a Spark job to read Parquet schema in parallel.
val partiallyMergedSchemas =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,8 @@ private[parquet] class ParquetOptions(
* Acceptable values are defined in [[shortParquetCompressionCodecNames]].
*/
val compressionCodecClassName: String = {
val codecName = parameters.getOrElse("compression", sqlConf.parquetCompressionCodec).toLowerCase
val codecName = parameters.getOrElse("compression", parameters.getOrElse(
"spark.sql.parquet.compression.codec", sqlConf.parquetCompressionCodec)).toLowerCase
if (!shortParquetCompressionCodecNames.contains(codecName)) {
val availableCodecs = shortParquetCompressionCodecNames.keys.map(_.toLowerCase)
throw new IllegalArgumentException(s"Codec [$codecName] " +
Expand All @@ -55,8 +56,9 @@ private[parquet] class ParquetOptions(
*/
val mergeSchema: Boolean = parameters
.get(MERGE_SCHEMA)
.map(_.toBoolean)
.getOrElse(sqlConf.isParquetSchemaMergingEnabled)
.map(_.toBoolean)
.getOrElse(parameters.get("spark.sql.parquet.mergeSchema").map(_.toBoolean)
.getOrElse(sqlConf.getConf(SQLConf.PARQUET_SCHEMA_MERGING_ENABLED)))
}


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -274,6 +274,29 @@ class FileSourceStrategySuite extends QueryTest with SharedSQLContext with Predi
}
}

test("datasource-specific minPartitions") {
val table =
createTable(
files = Seq(
"file1" -> 1,
"file2" -> 1,
"file3" -> 1,
"file4" -> 1,
"file5" -> 1,
"file6" -> 1,
"file7" -> 1,
"file8" -> 1,
"file9" -> 1),
options = Map("spark.default.parallelism" -> "3"))

checkScan(table.select('c1)) { partitions =>
assert(partitions.size == 3)
assert(partitions(0).files.size == 3)
assert(partitions(1).files.size == 3)
assert(partitions(2).files.size == 3)
}
}

test("Locality support for FileScanRDD") {
val partition = FilePartition(0, Seq(
PartitionedFile(InternalRow.empty, "fakePath0", 0, 10, Array("host0", "host1")),
Expand Down Expand Up @@ -526,7 +549,8 @@ class FileSourceStrategySuite extends QueryTest with SharedSQLContext with Predi
*/
def createTable(
files: Seq[(String, Int)],
buckets: Int = 0): DataFrame = {
buckets: Int = 0,
options: Map[String, String] = Map.empty): DataFrame = {
val tempDir = Utils.createTempDir()
files.foreach {
case (name, size) =>
Expand All @@ -537,6 +561,7 @@ class FileSourceStrategySuite extends QueryTest with SharedSQLContext with Predi

val df = spark.read
.format(classOf[TestFileFormat].getName)
.options(options)
.load(tempDir.getCanonicalPath)

if (buckets > 0) {
Expand Down