diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index 44842734251da..11e1a5eefe933 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -1440,6 +1440,14 @@ object SQLConf { .stringConf .createWithDefault("") + val USE_V1_SOURCE_WRITER_LIST = buildConf("spark.sql.sources.write.useV1SourceList") + .internal() + .doc("A comma-separated list of data source short names or fully qualified data source" + + " register class names for which data source V2 write paths are disabled. Writes from these" + + " sources will fall back to the V1 sources.") + .stringConf + .createWithDefault("") + val DISABLED_V2_STREAMING_WRITERS = buildConf("spark.sql.streaming.disabledV2Writers") .doc("A comma-separated list of fully qualified data source register class names for which" + " StreamWriteSupport is disabled. Writes to these sources will fall back to the V1 Sinks.") @@ -2026,6 +2034,8 @@ class SQLConf extends Serializable with Logging { def userV1SourceReaderList: String = getConf(USE_V1_SOURCE_READER_LIST) + def userV1SourceWriterList: String = getConf(USE_V1_SOURCE_WRITER_LIST) + def disabledV2StreamingWriters: String = getConf(DISABLED_V2_STREAMING_WRITERS) def disabledV2StreamingMicroBatchReaders: String = diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala index d9404cd929925..2148875ce2817 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala @@ -29,7 +29,7 @@ import org.apache.spark.sql.catalyst.plans.logical.{AppendData, InsertIntoTable, import org.apache.spark.sql.execution.SQLExecution import org.apache.spark.sql.execution.command.DDLUtils import org.apache.spark.sql.execution.datasources.{CreateTable, DataSource, LogicalRelation} -import org.apache.spark.sql.execution.datasources.v2.{DataSourceV2Relation, DataSourceV2Utils, WriteToDataSourceV2} +import org.apache.spark.sql.execution.datasources.v2.{DataSourceV2Relation, DataSourceV2Utils, FileDataSourceV2, WriteToDataSourceV2} import org.apache.spark.sql.sources.BaseRelation import org.apache.spark.sql.sources.v2._ import org.apache.spark.sql.sources.v2.writer.SupportsSaveMode @@ -243,8 +243,19 @@ final class DataFrameWriter[T] private[sql](ds: Dataset[T]) { assertNotBucketed("save") val session = df.sparkSession - val cls = DataSource.lookupDataSource(source, session.sessionState.conf) - if (classOf[TableProvider].isAssignableFrom(cls)) { + val useV1Sources = + session.sessionState.conf.userV1SourceWriterList.toLowerCase(Locale.ROOT).split(",") + val lookupCls = DataSource.lookupDataSource(source, session.sessionState.conf) + val cls = lookupCls.newInstance() match { + case f: FileDataSourceV2 if useV1Sources.contains(f.shortName()) || + useV1Sources.contains(lookupCls.getCanonicalName.toLowerCase(Locale.ROOT)) => + f.fallBackFileFormat + case _ => lookupCls + } + // SPARK-26673: In Data Source V2 project, partitioning is still under development. + // Here we fallback to V1 if the write path if output partitioning is required. + // TODO: use V2 implementations when partitioning feature is supported. + if (classOf[TableProvider].isAssignableFrom(cls) && partitioningColumns.isEmpty) { val provider = cls.getConstructor().newInstance().asInstanceOf[TableProvider] val sessionOptions = DataSourceV2Utils.extractSessionConfigs( provider, session.sessionState.conf) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala index d48261e783dc9..8118f08b99ddd 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala @@ -763,7 +763,7 @@ object DataSource extends Logging { * supplied schema is not empty. * @param schema */ - private def validateSchema(schema: StructType): Unit = { + def validateSchema(schema: StructType): Unit = { def hasEmptySchema(schema: StructType): Boolean = { schema.size == 0 || schema.find { case StructField(_, b: StructType, _, _) => hasEmptySchema(b) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FallbackOrcDataSourceV2.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FallbackOrcDataSourceV2.scala index cefdfd14b3b30..d22b3d5049c86 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FallbackOrcDataSourceV2.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FallbackOrcDataSourceV2.scala @@ -29,7 +29,7 @@ import org.apache.spark.sql.execution.datasources.v2.orc.OrcTable * E.g, with temporary view `t` using [[FileDataSourceV2]], inserting into view `t` fails * since there is no corresponding physical plan. * SPARK-23817: This is a temporary hack for making current data source V2 work. It should be - * removed when write path of file data source v2 is finished. + * removed when Catalog of file data source v2 is finished. */ class FallbackOrcDataSourceV2(sparkSession: SparkSession) extends Rule[LogicalPlan] { override def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatDataWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatDataWriter.scala index 10733810b6416..b6b0d7a6ced97 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatDataWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatDataWriter.scala @@ -27,6 +27,7 @@ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.catalog.CatalogTypes.TablePartitionSpec import org.apache.spark.sql.catalyst.catalog.ExternalCatalogUtils import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.sources.v2.writer.{DataWriter, WriterCommitMessage} import org.apache.spark.sql.types.StringType import org.apache.spark.util.SerializableConfiguration @@ -37,7 +38,7 @@ import org.apache.spark.util.SerializableConfiguration abstract class FileFormatDataWriter( description: WriteJobDescription, taskAttemptContext: TaskAttemptContext, - committer: FileCommitProtocol) { + committer: FileCommitProtocol) extends DataWriter[InternalRow] { /** * Max number of files a single task writes out due to file size. In most cases the number of * files written should be very small. This is just a safe guard to protect some really bad @@ -70,7 +71,7 @@ abstract class FileFormatDataWriter( * to the driver and used to update the catalog. Other information will be sent back to the * driver too and used to e.g. update the metrics in UI. */ - def commit(): WriteTaskResult = { + override def commit(): WriteTaskResult = { releaseResources() val summary = ExecutedWriteSummary( updatedPartitions = updatedPartitions.toSet, @@ -301,6 +302,7 @@ class WriteJobDescription( /** The result of a successful write task. */ case class WriteTaskResult(commitMsg: TaskCommitMessage, summary: ExecutedWriteSummary) + extends WriterCommitMessage /** * Wrapper class for the metrics of writing data out. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcOutputWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcOutputWriter.scala index 7e38fc651a31f..08086bcd91f6e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcOutputWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcOutputWriter.scala @@ -28,7 +28,7 @@ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.execution.datasources.OutputWriter import org.apache.spark.sql.types._ -private[orc] class OrcOutputWriter( +private[sql] class OrcOutputWriter( path: String, dataSchema: StructType, context: TaskAttemptContext) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/FileBatchWrite.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/FileBatchWrite.scala new file mode 100644 index 0000000000000..feb71c8acdfc9 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/FileBatchWrite.scala @@ -0,0 +1,46 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.execution.datasources.v2 + +import org.apache.hadoop.mapreduce.Job + +import org.apache.spark.internal.io.FileCommitProtocol +import org.apache.spark.sql.execution.datasources.{WriteJobDescription, WriteTaskResult} +import org.apache.spark.sql.sources.v2.writer._ +import org.apache.spark.util.SerializableConfiguration + +class FileBatchWrite( + job: Job, + description: WriteJobDescription, + committer: FileCommitProtocol) + extends BatchWrite { + override def commit(messages: Array[WriterCommitMessage]): Unit = { + committer.commitJob(job, messages.map(_.asInstanceOf[WriteTaskResult].commitMsg)) + } + + override def useCommitCoordinator(): Boolean = false + + override def abort(messages: Array[WriterCommitMessage]): Unit = { + committer.abortJob(job) + } + + override def createBatchWriterFactory(): DataWriterFactory = { + val conf = new SerializableConfiguration(job.getConfiguration) + FileWriterFactory(description, committer, conf) + } +} + diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/FileTable.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/FileTable.scala index b1786541a8051..0dbef145f7326 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/FileTable.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/FileTable.scala @@ -20,13 +20,14 @@ import org.apache.hadoop.fs.FileStatus import org.apache.spark.sql.{AnalysisException, SparkSession} import org.apache.spark.sql.execution.datasources._ -import org.apache.spark.sql.sources.v2.{SupportsBatchRead, Table} +import org.apache.spark.sql.sources.v2.{SupportsBatchRead, SupportsBatchWrite, Table} import org.apache.spark.sql.types.StructType abstract class FileTable( sparkSession: SparkSession, fileIndex: PartitioningAwareFileIndex, - userSpecifiedSchema: Option[StructType]) extends Table with SupportsBatchRead { + userSpecifiedSchema: Option[StructType]) + extends Table with SupportsBatchRead with SupportsBatchWrite { def getFileIndex: PartitioningAwareFileIndex = this.fileIndex lazy val dataSchema: StructType = userSpecifiedSchema.orElse { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/FileWriteBuilder.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/FileWriteBuilder.scala new file mode 100644 index 0000000000000..f1e8928c6c0aa --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/FileWriteBuilder.scala @@ -0,0 +1,140 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.execution.datasources.v2 + +import java.util.{Optional, UUID} + +import scala.collection.JavaConverters._ + +import org.apache.hadoop.fs.Path +import org.apache.hadoop.mapreduce.Job +import org.apache.hadoop.mapreduce.lib.output.FileOutputFormat + +import org.apache.spark.internal.io.{FileCommitProtocol, HadoopMapReduceCommitProtocol} +import org.apache.spark.sql.{AnalysisException, SaveMode, SparkSession} +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.util.{CaseInsensitiveMap, DateTimeUtils} +import org.apache.spark.sql.execution.datasources.{BasicWriteJobStatsTracker, DataSource, OutputWriterFactory, WriteJobDescription} +import org.apache.spark.sql.execution.metric.SQLMetric +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.sources.v2.{DataSourceOptions, SupportsBatchWrite} +import org.apache.spark.sql.sources.v2.writer.{BatchWrite, SupportsSaveMode, WriteBuilder} +import org.apache.spark.sql.types.{DataType, StructType} +import org.apache.spark.util.SerializableConfiguration + +abstract class FileWriteBuilder(options: DataSourceOptions) + extends WriteBuilder with SupportsSaveMode { + private var schema: StructType = _ + private var queryId: String = _ + private var mode: SaveMode = _ + + override def withInputDataSchema(schema: StructType): WriteBuilder = { + this.schema = schema + this + } + + override def withQueryId(queryId: String): WriteBuilder = { + this.queryId = queryId + this + } + + override def mode(mode: SaveMode): WriteBuilder = { + this.mode = mode + this + } + + override def buildForBatch(): BatchWrite = { + assert(schema != null, "Missing input data schema") + assert(queryId != null, "Missing query ID") + assert(mode != null, "Missing save mode") + assert(options.paths().length == 1) + DataSource.validateSchema(schema) + val pathName = options.paths().head + val path = new Path(pathName) + val sparkSession = SparkSession.active + val optionsAsScala = options.asMap().asScala.toMap + val hadoopConf = sparkSession.sessionState.newHadoopConfWithOptions(optionsAsScala) + val job = Job.getInstance(hadoopConf) + val committer = FileCommitProtocol.instantiate( + sparkSession.sessionState.conf.fileCommitProtocolClass, + jobId = java.util.UUID.randomUUID().toString, + outputPath = pathName) + + job.setOutputKeyClass(classOf[Void]) + job.setOutputValueClass(classOf[InternalRow]) + FileOutputFormat.setOutputPath(job, path) + + val caseInsensitiveOptions = CaseInsensitiveMap(optionsAsScala) + // Note: prepareWrite has side effect. It sets "job". + val outputWriterFactory = + prepareWrite(sparkSession.sessionState.conf, job, caseInsensitiveOptions, schema) + val allColumns = schema.toAttributes + lazy val metrics: Map[String, SQLMetric] = BasicWriteJobStatsTracker.metrics + val serializableHadoopConf = new SerializableConfiguration(hadoopConf) + val statsTracker = new BasicWriteJobStatsTracker(serializableHadoopConf, metrics) + // TODO: after partitioning is supported in V2: + // 1. filter out partition columns in `dataColumns`. + // 2. Don't use Seq.empty for `partitionColumns`. + val description = new WriteJobDescription( + uuid = UUID.randomUUID().toString, + serializableHadoopConf = new SerializableConfiguration(job.getConfiguration), + outputWriterFactory = outputWriterFactory, + allColumns = allColumns, + dataColumns = allColumns, + partitionColumns = Seq.empty, + bucketIdExpression = None, + path = pathName, + customPartitionLocations = Map.empty, + maxRecordsPerFile = caseInsensitiveOptions.get("maxRecordsPerFile").map(_.toLong) + .getOrElse(sparkSession.sessionState.conf.maxRecordsPerFile), + timeZoneId = caseInsensitiveOptions.get(DateTimeUtils.TIMEZONE_OPTION) + .getOrElse(sparkSession.sessionState.conf.sessionLocalTimeZone), + statsTrackers = Seq(statsTracker) + ) + + val fs = path.getFileSystem(hadoopConf) + mode match { + case SaveMode.ErrorIfExists if (fs.exists(path)) => + val qualifiedOutputPath = path.makeQualified(fs.getUri, fs.getWorkingDirectory) + throw new AnalysisException(s"path $qualifiedOutputPath already exists.") + + case SaveMode.Ignore if (fs.exists(path)) => + null + + case SaveMode.Overwrite => + committer.deleteWithJob(fs, path, true) + committer.setupJob(job) + new FileBatchWrite(job, description, committer) + + case _ => + committer.setupJob(job) + new FileBatchWrite(job, description, committer) + } + } + + /** + * Prepares a write job and returns an [[OutputWriterFactory]]. Client side job preparation can + * be put here. For example, user defined output committer can be configured here + * by setting the output committer class in the conf of spark.sql.sources.outputCommitterClass. + */ + def prepareWrite( + sqlConf: SQLConf, + job: Job, + options: Map[String, String], + dataSchema: StructType): OutputWriterFactory +} + diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/FileWriterFactory.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/FileWriterFactory.scala new file mode 100644 index 0000000000000..9ee086a0a4bc4 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/FileWriterFactory.scala @@ -0,0 +1,58 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.execution.datasources.v2 + +import java.util.Date + +import org.apache.hadoop.mapreduce.{TaskAttemptID, TaskID, TaskType} +import org.apache.hadoop.mapreduce.task.TaskAttemptContextImpl + +import org.apache.spark.internal.io.{FileCommitProtocol, SparkHadoopWriterUtils} +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.execution.datasources.{DynamicPartitionDataWriter, EmptyDirectoryDataWriter, SingleDirectoryDataWriter, WriteJobDescription} +import org.apache.spark.sql.sources.v2.writer.{DataWriter, DataWriterFactory} +import org.apache.spark.util.SerializableConfiguration + +case class FileWriterFactory ( + description: WriteJobDescription, + committer: FileCommitProtocol, + conf: SerializableConfiguration) extends DataWriterFactory { + override def createWriter(partitionId: Int, realTaskId: Long): DataWriter[InternalRow] = { + val taskAttemptContext = createTaskAttemptContext(partitionId) + committer.setupTask(taskAttemptContext) + if (description.partitionColumns.isEmpty) { + new SingleDirectoryDataWriter(description, taskAttemptContext, committer) + } else { + new DynamicPartitionDataWriter(description, taskAttemptContext, committer) + } + } + + private def createTaskAttemptContext(partitionId: Int): TaskAttemptContextImpl = { + val jobId = SparkHadoopWriterUtils.createJobID(new Date, 0) + val taskId = new TaskID(jobId, TaskType.MAP, partitionId) + val taskAttemptId = new TaskAttemptID(taskId, 0) + // Set up the configuration object + val hadoopConf = conf.value + hadoopConf.set("mapreduce.job.id", jobId.toString) + hadoopConf.set("mapreduce.task.id", taskId.toString) + hadoopConf.set("mapreduce.task.attempt.id", taskAttemptId.toString) + hadoopConf.setBoolean("mapreduce.task.ismap", true) + hadoopConf.setInt("mapreduce.task.partition", 0) + + new TaskAttemptContextImpl(hadoopConf, taskAttemptId) + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2Exec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2Exec.scala index 406fb8c3a3834..50c5e4f2ad7df 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2Exec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2Exec.scala @@ -56,7 +56,14 @@ case class WriteToDataSourceV2Exec(batchWrite: BatchWrite, query: SparkPlan) val writerFactory = batchWrite.createBatchWriterFactory() val useCommitCoordinator = batchWrite.useCommitCoordinator val rdd = query.execute() - val messages = new Array[WriterCommitMessage](rdd.partitions.length) + // SPARK-23271 If we are attempting to write a zero partition rdd, create a dummy single + // partition rdd to make sure we at least set up one write task to write the metadata. + val rddWithNonEmptyPartitions = if (rdd.partitions.length == 0) { + sparkContext.parallelize(Array.empty[InternalRow], 1) + } else { + rdd + } + val messages = new Array[WriterCommitMessage](rddWithNonEmptyPartitions.partitions.length) val totalNumRowsAccumulator = new LongAccumulator() logInfo(s"Start processing data source write support: $batchWrite. " + @@ -64,10 +71,10 @@ case class WriteToDataSourceV2Exec(batchWrite: BatchWrite, query: SparkPlan) try { sparkContext.runJob( - rdd, + rddWithNonEmptyPartitions, (context: TaskContext, iter: Iterator[InternalRow]) => DataWritingSparkTask.run(writerFactory, context, iter, useCommitCoordinator), - rdd.partitions.indices, + rddWithNonEmptyPartitions.partitions.indices, (index, result: DataWritingSparkTaskResult) => { val commitMessage = result.writerCommitMessage messages(index) = commitMessage diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/orc/OrcTable.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/orc/OrcTable.scala index 719e757c33cbd..b467e505f1bac 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/orc/OrcTable.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/orc/OrcTable.scala @@ -23,6 +23,7 @@ import org.apache.spark.sql.execution.datasources.PartitioningAwareFileIndex import org.apache.spark.sql.execution.datasources.orc.OrcUtils import org.apache.spark.sql.execution.datasources.v2.FileTable import org.apache.spark.sql.sources.v2.DataSourceOptions +import org.apache.spark.sql.sources.v2.writer.WriteBuilder import org.apache.spark.sql.types.StructType case class OrcTable( @@ -36,4 +37,7 @@ case class OrcTable( override def inferSchema(files: Seq[FileStatus]): Option[StructType] = OrcUtils.readSchema(sparkSession, files) + + override def newWriteBuilder(options: DataSourceOptions): WriteBuilder = + new OrcWriteBuilder(options) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/orc/OrcWriteBuilder.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/orc/OrcWriteBuilder.scala new file mode 100644 index 0000000000000..80429d91d5e4d --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/orc/OrcWriteBuilder.scala @@ -0,0 +1,66 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.execution.datasources.v2.orc + +import org.apache.hadoop.mapred.JobConf +import org.apache.hadoop.mapreduce.{Job, TaskAttemptContext} +import org.apache.orc.OrcConf.{COMPRESS, MAPRED_OUTPUT_SCHEMA} +import org.apache.orc.mapred.OrcStruct + +import org.apache.spark.sql.execution.datasources.{OutputWriter, OutputWriterFactory} +import org.apache.spark.sql.execution.datasources.orc.{OrcFileFormat, OrcOptions, OrcOutputWriter, OrcUtils} +import org.apache.spark.sql.execution.datasources.v2.FileWriteBuilder +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.sources.v2.DataSourceOptions +import org.apache.spark.sql.types._ + +class OrcWriteBuilder(options: DataSourceOptions) extends FileWriteBuilder(options) { + override def prepareWrite( + sqlConf: SQLConf, + job: Job, + options: Map[String, String], + dataSchema: StructType): OutputWriterFactory = { + val orcOptions = new OrcOptions(options, sqlConf) + + val conf = job.getConfiguration + + conf.set(MAPRED_OUTPUT_SCHEMA.getAttribute, OrcFileFormat.getQuotedSchemaString(dataSchema)) + + conf.set(COMPRESS.getAttribute, orcOptions.compressionCodec) + + conf.asInstanceOf[JobConf] + .setOutputFormat(classOf[org.apache.orc.mapred.OrcOutputFormat[OrcStruct]]) + + new OutputWriterFactory { + override def newInstance( + path: String, + dataSchema: StructType, + context: TaskAttemptContext): OutputWriter = { + new OrcOutputWriter(path, dataSchema, context) + } + + override def getFileExtension(context: TaskAttemptContext): String = { + val compressionExtension: String = { + val name = context.getConfiguration.get(COMPRESS.getAttribute) + OrcUtils.extensionsForCompressionCodecNames.getOrElse(name, "") + } + + compressionExtension + ".orc" + } + } + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/FileBasedDataSourceSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/FileBasedDataSourceSuite.scala index 5e6705094e602..fc87b0462bc25 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/FileBasedDataSourceSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/FileBasedDataSourceSuite.scala @@ -329,46 +329,49 @@ class FileBasedDataSourceSuite extends QueryTest with SharedSQLContext with Befo test("SPARK-24204 error handling for unsupported Interval data types - csv, json, parquet, orc") { withTempDir { dir => val tempDir = new File(dir, "files").getCanonicalPath + // TODO(SPARK-26744): support data type validating in V2 data source, and test V2 as well. + withSQLConf(SQLConf.USE_V1_SOURCE_WRITER_LIST.key -> "orc") { + // write path + Seq("csv", "json", "parquet", "orc").foreach { format => + var msg = intercept[AnalysisException] { + sql("select interval 1 days").write.format(format).mode("overwrite").save(tempDir) + }.getMessage + assert(msg.contains("Cannot save interval data type into external storage.")) - // write path - Seq("csv", "json", "parquet", "orc").foreach { format => - var msg = intercept[AnalysisException] { - sql("select interval 1 days").write.format(format).mode("overwrite").save(tempDir) - }.getMessage - assert(msg.contains("Cannot save interval data type into external storage.")) - - msg = intercept[AnalysisException] { - spark.udf.register("testType", () => new IntervalData()) - sql("select testType()").write.format(format).mode("overwrite").save(tempDir) - }.getMessage - assert(msg.toLowerCase(Locale.ROOT) - .contains(s"$format data source does not support calendarinterval data type.")) - } + msg = intercept[AnalysisException] { + spark.udf.register("testType", () => new IntervalData()) + sql("select testType()").write.format(format).mode("overwrite").save(tempDir) + }.getMessage + assert(msg.toLowerCase(Locale.ROOT) + .contains(s"$format data source does not support calendarinterval data type.")) + } - // read path - Seq("parquet", "csv").foreach { format => - var msg = intercept[AnalysisException] { - val schema = StructType(StructField("a", CalendarIntervalType, true) :: Nil) - spark.range(1).write.format(format).mode("overwrite").save(tempDir) - spark.read.schema(schema).format(format).load(tempDir).collect() - }.getMessage - assert(msg.toLowerCase(Locale.ROOT) - .contains(s"$format data source does not support calendarinterval data type.")) - - msg = intercept[AnalysisException] { - val schema = StructType(StructField("a", new IntervalUDT(), true) :: Nil) - spark.range(1).write.format(format).mode("overwrite").save(tempDir) - spark.read.schema(schema).format(format).load(tempDir).collect() - }.getMessage - assert(msg.toLowerCase(Locale.ROOT) - .contains(s"$format data source does not support calendarinterval data type.")) + // read path + Seq("parquet", "csv").foreach { format => + var msg = intercept[AnalysisException] { + val schema = StructType(StructField("a", CalendarIntervalType, true) :: Nil) + spark.range(1).write.format(format).mode("overwrite").save(tempDir) + spark.read.schema(schema).format(format).load(tempDir).collect() + }.getMessage + assert(msg.toLowerCase(Locale.ROOT) + .contains(s"$format data source does not support calendarinterval data type.")) + + msg = intercept[AnalysisException] { + val schema = StructType(StructField("a", new IntervalUDT(), true) :: Nil) + spark.range(1).write.format(format).mode("overwrite").save(tempDir) + spark.read.schema(schema).format(format).load(tempDir).collect() + }.getMessage + assert(msg.toLowerCase(Locale.ROOT) + .contains(s"$format data source does not support calendarinterval data type.")) + } } } } test("SPARK-24204 error handling for unsupported Null data types - csv, parquet, orc") { // TODO(SPARK-26744): support data type validating in V2 data source, and test V2 as well. - withSQLConf(SQLConf.USE_V1_SOURCE_READER_LIST.key -> "orc") { + withSQLConf(SQLConf.USE_V1_SOURCE_READER_LIST.key -> "orc", + SQLConf.USE_V1_SOURCE_WRITER_LIST.key -> "orc") { withTempDir { dir => val tempDir = new File(dir, "files").getCanonicalPath diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/FileDataSourceV2FallBackSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/FileDataSourceV2FallBackSuite.scala index f57c581fd800e..fd19a48497fe6 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/FileDataSourceV2FallBackSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/FileDataSourceV2FallBackSuite.scala @@ -22,6 +22,7 @@ import org.apache.spark.sql.execution.datasources.parquet.{ParquetFileFormat, Pa import org.apache.spark.sql.execution.datasources.v2.FileDataSourceV2 import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.sources.v2.reader.ScanBuilder +import org.apache.spark.sql.sources.v2.writer.WriteBuilder import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.types.StructType @@ -46,9 +47,30 @@ class DummyReadOnlyFileTable extends Table with SupportsBatchRead { } } -class FileDataSourceV2FallBackSuite extends QueryTest with ParquetTest with SharedSQLContext { +class DummyWriteOnlyFileDataSourceV2 extends FileDataSourceV2 { + + override def fallBackFileFormat: Class[_ <: FileFormat] = classOf[ParquetFileFormat] + + override def shortName(): String = "parquet" + + override def getTable(options: DataSourceOptions): Table = { + new DummyWriteOnlyFileTable + } +} + +class DummyWriteOnlyFileTable extends Table with SupportsBatchWrite { + override def name(): String = "dummy" + + override def schema(): StructType = StructType(Nil) + + override def newWriteBuilder(options: DataSourceOptions): WriteBuilder = + throw new AnalysisException("Dummy file writer") +} + +class FileDataSourceV2FallBackSuite extends QueryTest with SharedSQLContext { private val dummyParquetReaderV2 = classOf[DummyReadOnlyFileDataSourceV2].getName + private val dummyParquetWriterV2 = classOf[DummyWriteOnlyFileDataSourceV2].getName test("Fall back to v1 when writing to file with read only FileDataSourceV2") { val df = spark.range(10).toDF() @@ -94,4 +116,47 @@ class FileDataSourceV2FallBackSuite extends QueryTest with ParquetTest with Shar } } } + + test("Fall back to v1 when reading file with write only FileDataSourceV2") { + val df = spark.range(10).toDF() + withTempPath { file => + val path = file.getCanonicalPath + // Dummy File writer should fail as expected. + val exception = intercept[AnalysisException] { + df.write.format(dummyParquetWriterV2).save(path) + } + assert(exception.message.equals("Dummy file writer")) + df.write.parquet(path) + // Fallback reads to V1 + checkAnswer(spark.read.format(dummyParquetWriterV2).load(path), df) + } + } + + test("Fall back write path to v1 with configuration USE_V1_SOURCE_WRITER_LIST") { + val df = spark.range(10).toDF() + Seq( + "foo,parquet,bar", + "ParQuet,bar,foo", + s"foobar,$dummyParquetWriterV2" + ).foreach { fallbackWriters => + withSQLConf(SQLConf.USE_V1_SOURCE_WRITER_LIST.key -> fallbackWriters) { + withTempPath { file => + val path = file.getCanonicalPath + // Writes should fall back to v1 and succeed. + df.write.format(dummyParquetWriterV2).save(path) + checkAnswer(spark.read.parquet(path), df) + } + } + } + withSQLConf(SQLConf.USE_V1_SOURCE_WRITER_LIST.key -> "foo,bar") { + withTempPath { file => + val path = file.getCanonicalPath + // Dummy File reader should fail as USE_V1_SOURCE_READER_LIST doesn't include it. + val exception = intercept[AnalysisException] { + df.write.format(dummyParquetWriterV2).save(path) + } + assert(exception.message.equals("Dummy file writer")) + } + } + } }