diff --git a/sql/core/src/main/scala/org/apache/spark/sql/sources/commands.scala b/sql/core/src/main/scala/org/apache/spark/sql/sources/commands.scala index c3674a8c76be8..fbd98ef0380e1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/sources/commands.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/sources/commands.scala @@ -377,13 +377,22 @@ private[sql] class DefaultWriterContainer( override def outputWriterForRow(row: Row): OutputWriter = writer override def commitTask(): Unit = { - writer.close() - super.commitTask() + try { + writer.close() + super.commitTask() + } catch { + case cause: Throwable => + super.abortTask() + throw new RuntimeException("Failed to commit task", cause) + } } override def abortTask(): Unit = { - writer.close() - super.abortTask() + try { + writer.close() + } finally { + super.abortTask() + } } } @@ -422,13 +431,21 @@ private[sql] class DynamicPartitionWriterContainer( } override def commitTask(): Unit = { - outputWriters.values.foreach(_.close()) - super.commitTask() + try { + outputWriters.values.foreach(_.close()) + super.commitTask() + } catch { case cause: Throwable => + super.abortTask() + throw new RuntimeException("Failed to commit task", cause) + } } override def abortTask(): Unit = { - outputWriters.values.foreach(_.close()) - super.abortTask() + try { + outputWriters.values.foreach(_.close()) + } finally { + super.abortTask() + } } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/sources/SimpleTextRelation.scala b/sql/hive/src/test/scala/org/apache/spark/sql/sources/SimpleTextRelation.scala index 2d69b89fd9a9c..de907846b9180 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/sources/SimpleTextRelation.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/sources/SimpleTextRelation.scala @@ -28,7 +28,7 @@ import org.apache.hadoop.mapreduce.{Job, RecordWriter, TaskAttemptContext} import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.expressions.{Cast, Literal} -import org.apache.spark.sql.types.{DataType, StructField, StructType} +import org.apache.spark.sql.types.{DataType, StructType} import org.apache.spark.sql.{Row, SQLContext} /** @@ -67,7 +67,9 @@ class SimpleTextOutputWriter(path: String, context: TaskAttemptContext) extends recordWriter.write(null, new Text(serialized)) } - override def close(): Unit = recordWriter.close(context) + override def close(): Unit = { + recordWriter.close(context) + } } /** @@ -120,3 +122,39 @@ class SimpleTextRelation( } } } + +/** + * A simple example [[HadoopFsRelationProvider]]. + */ +class CommitFailureTestSource extends HadoopFsRelationProvider { + override def createRelation( + sqlContext: SQLContext, + paths: Array[String], + schema: Option[StructType], + partitionColumns: Option[StructType], + parameters: Map[String, String]): HadoopFsRelation = { + new CommitFailureTestRelation(paths, schema, partitionColumns, parameters)(sqlContext) + } +} + +class CommitFailureTestRelation( + override val paths: Array[String], + maybeDataSchema: Option[StructType], + override val userDefinedPartitionColumns: Option[StructType], + parameters: Map[String, String])( + @transient sqlContext: SQLContext) + extends SimpleTextRelation( + paths, maybeDataSchema, userDefinedPartitionColumns, parameters)(sqlContext) { + override def prepareJobForWrite(job: Job): OutputWriterFactory = new OutputWriterFactory { + override def newInstance( + path: String, + dataSchema: StructType, + context: TaskAttemptContext): OutputWriter = { + new SimpleTextOutputWriter(path, context) { + override def close(): Unit = { + sys.error("Intentional task commitment failure for testing purpose.") + } + } + } + } +} diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/sources/hadoopFsRelationSuites.scala b/sql/hive/src/test/scala/org/apache/spark/sql/sources/hadoopFsRelationSuites.scala index 32226905bca9d..70328e1ef810d 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/sources/hadoopFsRelationSuites.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/sources/hadoopFsRelationSuites.scala @@ -18,7 +18,9 @@ package org.apache.spark.sql.sources import org.apache.hadoop.fs.Path +import org.scalatest.FunSuite +import org.apache.spark.SparkException import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.sql._ import org.apache.spark.sql.hive.test.TestHive @@ -477,6 +479,26 @@ class SimpleTextHadoopFsRelationSuite extends HadoopFsRelationTest { } } +class CommitFailureTestRelationSuite extends FunSuite with SQLTestUtils { + import TestHive.implicits._ + + override val sqlContext = TestHive + + val dataSourceName: String = classOf[CommitFailureTestSource].getCanonicalName + + test("SPARK-7684: commitTask() failure should fallback to abortTask()") { + withTempPath { file => + val df = (1 to 3).map(i => i -> s"val_$i").toDF("a", "b") + intercept[SparkException] { + df.write.format(dataSourceName).save(file.getCanonicalPath) + } + + val fs = new Path(file.getCanonicalPath).getFileSystem(SparkHadoopUtil.get.conf) + assert(!fs.exists(new Path(file.getCanonicalPath, "_temporary"))) + } + } +} + class ParquetHadoopFsRelationSuite extends HadoopFsRelationTest { override val dataSourceName: String = classOf[parquet.DefaultSource].getCanonicalName