Skip to content

Commit

Permalink
[SPARK-26673][SQL] File source V2 writes: create framework and migrat…
Browse files Browse the repository at this point in the history
…e ORC

## What changes were proposed in this pull request?

Create a framework for write path of File Source V2.
Also, migrate write path of ORC to V2.

Supported:
* Write to file as Dataframe

Not Supported:
* Partitioning, which is still under development in the data source V2 project.
* Bucketing, which is still under development in the data source V2 project.
* Catalog.

## How was this patch tested?

Unit test

Closes #23601 from gengliangwang/orc_write.

Authored-by: Gengliang Wang <[email protected]>
Signed-off-by: Wenchen Fan <[email protected]>
  • Loading branch information
gengliangwang authored and cloud-fan committed Jan 31, 2019
1 parent b3b62ba commit df4c53e
Show file tree
Hide file tree
Showing 16 changed files with 486 additions and 48 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -1440,6 +1440,14 @@ object SQLConf {
.stringConf
.createWithDefault("")

val USE_V1_SOURCE_WRITER_LIST = buildConf("spark.sql.sources.write.useV1SourceList")
.internal()
.doc("A comma-separated list of data source short names or fully qualified data source" +
" register class names for which data source V2 write paths are disabled. Writes from these" +
" sources will fall back to the V1 sources.")
.stringConf
.createWithDefault("")

val DISABLED_V2_STREAMING_WRITERS = buildConf("spark.sql.streaming.disabledV2Writers")
.doc("A comma-separated list of fully qualified data source register class names for which" +
" StreamWriteSupport is disabled. Writes to these sources will fall back to the V1 Sinks.")
Expand Down Expand Up @@ -2026,6 +2034,8 @@ class SQLConf extends Serializable with Logging {

def userV1SourceReaderList: String = getConf(USE_V1_SOURCE_READER_LIST)

def userV1SourceWriterList: String = getConf(USE_V1_SOURCE_WRITER_LIST)

def disabledV2StreamingWriters: String = getConf(DISABLED_V2_STREAMING_WRITERS)

def disabledV2StreamingMicroBatchReaders: String =
Expand Down
17 changes: 14 additions & 3 deletions sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ import org.apache.spark.sql.catalyst.plans.logical.{AppendData, InsertIntoTable,
import org.apache.spark.sql.execution.SQLExecution
import org.apache.spark.sql.execution.command.DDLUtils
import org.apache.spark.sql.execution.datasources.{CreateTable, DataSource, LogicalRelation}
import org.apache.spark.sql.execution.datasources.v2.{DataSourceV2Relation, DataSourceV2Utils, WriteToDataSourceV2}
import org.apache.spark.sql.execution.datasources.v2.{DataSourceV2Relation, DataSourceV2Utils, FileDataSourceV2, WriteToDataSourceV2}
import org.apache.spark.sql.sources.BaseRelation
import org.apache.spark.sql.sources.v2._
import org.apache.spark.sql.sources.v2.writer.SupportsSaveMode
Expand Down Expand Up @@ -243,8 +243,19 @@ final class DataFrameWriter[T] private[sql](ds: Dataset[T]) {
assertNotBucketed("save")

val session = df.sparkSession
val cls = DataSource.lookupDataSource(source, session.sessionState.conf)
if (classOf[TableProvider].isAssignableFrom(cls)) {
val useV1Sources =
session.sessionState.conf.userV1SourceWriterList.toLowerCase(Locale.ROOT).split(",")
val lookupCls = DataSource.lookupDataSource(source, session.sessionState.conf)
val cls = lookupCls.newInstance() match {
case f: FileDataSourceV2 if useV1Sources.contains(f.shortName()) ||
useV1Sources.contains(lookupCls.getCanonicalName.toLowerCase(Locale.ROOT)) =>
f.fallBackFileFormat
case _ => lookupCls
}
// In Data Source V2 project, partitioning is still under development.
// Here we fallback to V1 if partitioning columns are specified.
// TODO(SPARK-26778): use V2 implementations when partitioning feature is supported.
if (classOf[TableProvider].isAssignableFrom(cls) && partitioningColumns.isEmpty) {
val provider = cls.getConstructor().newInstance().asInstanceOf[TableProvider]
val sessionOptions = DataSourceV2Utils.extractSessionConfigs(
provider, session.sessionState.conf)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -763,7 +763,7 @@ object DataSource extends Logging {
* supplied schema is not empty.
* @param schema
*/
private def validateSchema(schema: StructType): Unit = {
def validateSchema(schema: StructType): Unit = {
def hasEmptySchema(schema: StructType): Boolean = {
schema.size == 0 || schema.find {
case StructField(_, b: StructType, _, _) => hasEmptySchema(b)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,8 @@ import org.apache.spark.sql.execution.datasources.v2.orc.OrcTable
* Replace the ORC V2 data source of table in [[InsertIntoTable]] to V1 [[FileFormat]].
* E.g, with temporary view `t` using [[FileDataSourceV2]], inserting into view `t` fails
* since there is no corresponding physical plan.
* SPARK-23817: This is a temporary hack for making current data source V2 work. It should be
* removed when write path of file data source v2 is finished.
* This is a temporary hack for making current data source V2 work. It should be
* removed when Catalog support of file data source v2 is finished.
*/
class FallbackOrcDataSourceV2(sparkSession: SparkSession) extends Rule[LogicalPlan] {
override def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.catalog.CatalogTypes.TablePartitionSpec
import org.apache.spark.sql.catalyst.catalog.ExternalCatalogUtils
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.sources.v2.writer.{DataWriter, WriterCommitMessage}
import org.apache.spark.sql.types.StringType
import org.apache.spark.util.SerializableConfiguration

Expand All @@ -37,7 +38,7 @@ import org.apache.spark.util.SerializableConfiguration
abstract class FileFormatDataWriter(
description: WriteJobDescription,
taskAttemptContext: TaskAttemptContext,
committer: FileCommitProtocol) {
committer: FileCommitProtocol) extends DataWriter[InternalRow] {
/**
* Max number of files a single task writes out due to file size. In most cases the number of
* files written should be very small. This is just a safe guard to protect some really bad
Expand Down Expand Up @@ -70,7 +71,7 @@ abstract class FileFormatDataWriter(
* to the driver and used to update the catalog. Other information will be sent back to the
* driver too and used to e.g. update the metrics in UI.
*/
def commit(): WriteTaskResult = {
override def commit(): WriteTaskResult = {
releaseResources()
val summary = ExecutedWriteSummary(
updatedPartitions = updatedPartitions.toSet,
Expand Down Expand Up @@ -301,6 +302,7 @@ class WriteJobDescription(

/** The result of a successful write task. */
case class WriteTaskResult(commitMsg: TaskCommitMessage, summary: ExecutedWriteSummary)
extends WriterCommitMessage

/**
* Wrapper class for the metrics of writing data out.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -259,7 +259,7 @@ object FileFormatWriter extends Logging {
* For every registered [[WriteJobStatsTracker]], call `processStats()` on it, passing it
* the corresponding [[WriteTaskStats]] from all executors.
*/
private def processStats(
private[datasources] def processStats(
statsTrackers: Seq[WriteJobStatsTracker],
statsPerTask: Seq[Seq[WriteTaskStats]])
: Unit = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.execution.datasources.OutputWriter
import org.apache.spark.sql.types._

private[orc] class OrcOutputWriter(
private[sql] class OrcOutputWriter(
path: String,
dataSchema: StructType,
context: TaskAttemptContext)
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.sql.execution.datasources.v2

import org.apache.hadoop.mapreduce.Job

import org.apache.spark.internal.Logging
import org.apache.spark.internal.io.FileCommitProtocol
import org.apache.spark.sql.execution.datasources.{WriteJobDescription, WriteTaskResult}
import org.apache.spark.sql.execution.datasources.FileFormatWriter.processStats
import org.apache.spark.sql.sources.v2.writer._
import org.apache.spark.util.SerializableConfiguration

class FileBatchWrite(
job: Job,
description: WriteJobDescription,
committer: FileCommitProtocol)
extends BatchWrite with Logging {
override def commit(messages: Array[WriterCommitMessage]): Unit = {
val results = messages.map(_.asInstanceOf[WriteTaskResult])
committer.commitJob(job, results.map(_.commitMsg))
logInfo(s"Write Job ${description.uuid} committed.")

processStats(description.statsTrackers, results.map(_.summary.stats))
logInfo(s"Finished processing stats for write job ${description.uuid}.")
}

override def useCommitCoordinator(): Boolean = false

override def abort(messages: Array[WriterCommitMessage]): Unit = {
committer.abortJob(job)
}

override def createBatchWriterFactory(): DataWriterFactory = {
val conf = new SerializableConfiguration(job.getConfiguration)
FileWriterFactory(description, committer, conf)
}
}

Original file line number Diff line number Diff line change
Expand Up @@ -20,13 +20,14 @@ import org.apache.hadoop.fs.FileStatus

import org.apache.spark.sql.{AnalysisException, SparkSession}
import org.apache.spark.sql.execution.datasources._
import org.apache.spark.sql.sources.v2.{SupportsBatchRead, Table}
import org.apache.spark.sql.sources.v2.{SupportsBatchRead, SupportsBatchWrite, Table}
import org.apache.spark.sql.types.StructType

abstract class FileTable(
sparkSession: SparkSession,
fileIndex: PartitioningAwareFileIndex,
userSpecifiedSchema: Option[StructType]) extends Table with SupportsBatchRead {
userSpecifiedSchema: Option[StructType])
extends Table with SupportsBatchRead with SupportsBatchWrite {
def getFileIndex: PartitioningAwareFileIndex = this.fileIndex

lazy val dataSchema: StructType = userSpecifiedSchema.orElse {
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,158 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.sql.execution.datasources.v2

import java.util.UUID

import scala.collection.JavaConverters._

import org.apache.hadoop.conf.Configuration
import org.apache.hadoop.fs.Path
import org.apache.hadoop.mapreduce.Job
import org.apache.hadoop.mapreduce.lib.output.FileOutputFormat

import org.apache.spark.internal.io.FileCommitProtocol
import org.apache.spark.sql.{AnalysisException, SaveMode, SparkSession}
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.util.{CaseInsensitiveMap, DateTimeUtils}
import org.apache.spark.sql.execution.datasources.{BasicWriteJobStatsTracker, DataSource, OutputWriterFactory, WriteJobDescription}
import org.apache.spark.sql.execution.metric.SQLMetric
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.sources.v2.DataSourceOptions
import org.apache.spark.sql.sources.v2.writer.{BatchWrite, SupportsSaveMode, WriteBuilder}
import org.apache.spark.sql.types.StructType
import org.apache.spark.util.SerializableConfiguration

abstract class FileWriteBuilder(options: DataSourceOptions)
extends WriteBuilder with SupportsSaveMode {
private var schema: StructType = _
private var queryId: String = _
private var mode: SaveMode = _

override def withInputDataSchema(schema: StructType): WriteBuilder = {
this.schema = schema
this
}

override def withQueryId(queryId: String): WriteBuilder = {
this.queryId = queryId
this
}

override def mode(mode: SaveMode): WriteBuilder = {
this.mode = mode
this
}

override def buildForBatch(): BatchWrite = {
validateInputs()
val pathName = options.paths().head
val path = new Path(pathName)
val sparkSession = SparkSession.active
val optionsAsScala = options.asMap().asScala.toMap
val hadoopConf = sparkSession.sessionState.newHadoopConfWithOptions(optionsAsScala)
val job = getJobInstance(hadoopConf, path)
val committer = FileCommitProtocol.instantiate(
sparkSession.sessionState.conf.fileCommitProtocolClass,
jobId = java.util.UUID.randomUUID().toString,
outputPath = pathName)
lazy val description =
createWriteJobDescription(sparkSession, hadoopConf, job, pathName, optionsAsScala)

val fs = path.getFileSystem(hadoopConf)
mode match {
case SaveMode.ErrorIfExists if fs.exists(path) =>
val qualifiedOutputPath = path.makeQualified(fs.getUri, fs.getWorkingDirectory)
throw new AnalysisException(s"path $qualifiedOutputPath already exists.")

case SaveMode.Ignore if fs.exists(path) =>
null

case SaveMode.Overwrite =>
committer.deleteWithJob(fs, path, true)
committer.setupJob(job)
new FileBatchWrite(job, description, committer)

case _ =>
committer.setupJob(job)
new FileBatchWrite(job, description, committer)
}
}

/**
* Prepares a write job and returns an [[OutputWriterFactory]]. Client side job preparation can
* be put here. For example, user defined output committer can be configured here
* by setting the output committer class in the conf of spark.sql.sources.outputCommitterClass.
*/
def prepareWrite(
sqlConf: SQLConf,
job: Job,
options: Map[String, String],
dataSchema: StructType): OutputWriterFactory

private def validateInputs(): Unit = {
assert(schema != null, "Missing input data schema")
assert(queryId != null, "Missing query ID")
assert(mode != null, "Missing save mode")
assert(options.paths().length == 1)
DataSource.validateSchema(schema)
}

private def getJobInstance(hadoopConf: Configuration, path: Path): Job = {
val job = Job.getInstance(hadoopConf)
job.setOutputKeyClass(classOf[Void])
job.setOutputValueClass(classOf[InternalRow])
FileOutputFormat.setOutputPath(job, path)
job
}

private def createWriteJobDescription(
sparkSession: SparkSession,
hadoopConf: Configuration,
job: Job,
pathName: String,
options: Map[String, String]): WriteJobDescription = {
val caseInsensitiveOptions = CaseInsensitiveMap(options)
// Note: prepareWrite has side effect. It sets "job".
val outputWriterFactory =
prepareWrite(sparkSession.sessionState.conf, job, caseInsensitiveOptions, schema)
val allColumns = schema.toAttributes
val metrics: Map[String, SQLMetric] = BasicWriteJobStatsTracker.metrics
val serializableHadoopConf = new SerializableConfiguration(hadoopConf)
val statsTracker = new BasicWriteJobStatsTracker(serializableHadoopConf, metrics)
// TODO: after partitioning is supported in V2:
// 1. filter out partition columns in `dataColumns`.
// 2. Don't use Seq.empty for `partitionColumns`.
new WriteJobDescription(
uuid = UUID.randomUUID().toString,
serializableHadoopConf = new SerializableConfiguration(job.getConfiguration),
outputWriterFactory = outputWriterFactory,
allColumns = allColumns,
dataColumns = allColumns,
partitionColumns = Seq.empty,
bucketIdExpression = None,
path = pathName,
customPartitionLocations = Map.empty,
maxRecordsPerFile = caseInsensitiveOptions.get("maxRecordsPerFile").map(_.toLong)
.getOrElse(sparkSession.sessionState.conf.maxRecordsPerFile),
timeZoneId = caseInsensitiveOptions.get(DateTimeUtils.TIMEZONE_OPTION)
.getOrElse(sparkSession.sessionState.conf.sessionLocalTimeZone),
statsTrackers = Seq(statsTracker)
)
}
}

Loading

0 comments on commit df4c53e

Please sign in to comment.