diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala index 14c40605ea31c..4ffe2151ae638 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala @@ -408,16 +408,6 @@ case class DataSource( val caseSensitive = sparkSession.sessionState.conf.caseSensitiveAnalysis PartitioningUtils.validatePartitionColumn(data.schema, partitionColumns, caseSensitive) - // SPARK-17230: Resolve the partition columns so InsertIntoHadoopFsRelationCommand does - // not need to have the query as child, to avoid to analyze an optimized query, - // because InsertIntoHadoopFsRelationCommand will be optimized first. - val partitionAttributes = partitionColumns.map { name => - val plan = data.logicalPlan - plan.resolve(name :: Nil, data.sparkSession.sessionState.analyzer.resolver).getOrElse { - throw new AnalysisException( - s"Unable to resolve $name given [${plan.output.map(_.name).mkString(", ")}]") - }.asInstanceOf[Attribute] - } val fileIndex = catalogTable.map(_.identifier).map { tableIdent => sparkSession.table(tableIdent).queryExecution.analyzed.collect { case LogicalRelation(t: HadoopFsRelation, _, _) => t.location @@ -431,7 +421,7 @@ case class DataSource( outputPath = outputPath, staticPartitions = Map.empty, ifPartitionNotExists = false, - partitionColumns = partitionAttributes, + partitionColumns = partitionColumns, bucketSpec = bucketSpec, fileFormat = format, options = options, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala index e05a8d5f02bd8..ded9303de55fe 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala @@ -188,15 +188,13 @@ case class DataSourceAnalysis(conf: SQLConf) extends Rule[LogicalPlan] with Cast "Cannot overwrite a path that is also being read from.") } - val partitionSchema = actualQuery.resolve( - t.partitionSchema, t.sparkSession.sessionState.analyzer.resolver) val staticPartitions = parts.filter(_._2.nonEmpty).map { case (k, v) => k -> v.get } InsertIntoHadoopFsRelationCommand( outputPath, staticPartitions, i.ifPartitionNotExists, - partitionSchema, + partitionColumns = t.partitionSchema.map(_.name), t.bucketSpec, t.fileFormat, t.options, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatWriter.scala index 4ec09bff429c5..2c31d2a84c258 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatWriter.scala @@ -101,7 +101,7 @@ object FileFormatWriter extends Logging { committer: FileCommitProtocol, outputSpec: OutputSpec, hadoopConf: Configuration, - partitionColumns: Seq[Attribute], + partitionColumnNames: Seq[String], bucketSpec: Option[BucketSpec], refreshFunction: (Seq[TablePartitionSpec]) => Unit, options: Map[String, String]): Unit = { @@ -111,9 +111,18 @@ object FileFormatWriter extends Logging { job.setOutputValueClass(classOf[InternalRow]) FileOutputFormat.setOutputPath(job, new Path(outputSpec.outputPath)) - val allColumns = queryExecution.logical.output + val allColumns = queryExecution.executedPlan.output + // Get the actual partition columns as attributes after matching them by name with + // the given columns names. + val partitionColumns = partitionColumnNames.map { col => + val nameEquality = sparkSession.sessionState.conf.resolver + allColumns.find(f => nameEquality(f.name, col)).getOrElse { + throw new RuntimeException( + s"Partition column $col not found in schema ${queryExecution.executedPlan.schema}") + } + } val partitionSet = AttributeSet(partitionColumns) - val dataColumns = queryExecution.logical.output.filterNot(partitionSet.contains) + val dataColumns = allColumns.filterNot(partitionSet.contains) val bucketIdExpression = bucketSpec.map { spec => val bucketColumns = spec.bucketColumnNames.map(c => dataColumns.find(_.name == c).get) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InsertIntoHadoopFsRelationCommand.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InsertIntoHadoopFsRelationCommand.scala index c9d31449d3629..ab35fdcbc1f25 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InsertIntoHadoopFsRelationCommand.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InsertIntoHadoopFsRelationCommand.scala @@ -44,7 +44,7 @@ case class InsertIntoHadoopFsRelationCommand( outputPath: Path, staticPartitions: TablePartitionSpec, ifPartitionNotExists: Boolean, - partitionColumns: Seq[Attribute], + partitionColumns: Seq[String], bucketSpec: Option[BucketSpec], fileFormat: FileFormat, options: Map[String, String], @@ -150,7 +150,7 @@ case class InsertIntoHadoopFsRelationCommand( outputSpec = FileFormatWriter.OutputSpec( qualifiedOutputPath.toString, customPartitionLocations), hadoopConf = hadoopConf, - partitionColumns = partitionColumns, + partitionColumnNames = partitionColumns, bucketSpec = bucketSpec, refreshFunction = refreshPartitionsCallback, options = options) @@ -176,10 +176,10 @@ case class InsertIntoHadoopFsRelationCommand( customPartitionLocations: Map[TablePartitionSpec, String], committer: FileCommitProtocol): Unit = { val staticPartitionPrefix = if (staticPartitions.nonEmpty) { - "/" + partitionColumns.flatMap { p => - staticPartitions.get(p.name) match { + "/" + partitionColumns.flatMap { col => + staticPartitions.get(col) match { case Some(value) => - Some(escapePathName(p.name) + "=" + escapePathName(value)) + Some(escapePathName(col) + "=" + escapePathName(value)) case None => None } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/rules.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/rules.scala index 3f4a78580f1eb..45f2a41f24937 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/rules.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/rules.scala @@ -127,11 +127,11 @@ case class PreprocessTableCreation(sparkSession: SparkSession) extends Rule[Logi val resolver = sparkSession.sessionState.conf.resolver val tableCols = existingTable.schema.map(_.name) - // As we are inserting into an existing table, we should respect the existing schema and - // adjust the column order of the given dataframe according to it, or throw exception - // if the column names do not match. + // As we are inserting into an existing table, we should respect the existing schema, preserve + // the case and adjust the column order of the given DataFrame according to it, or throw + // an exception if the column names do not match. val adjustedColumns = tableCols.map { col => - query.resolve(Seq(col), resolver).getOrElse { + query.resolve(Seq(col), resolver).map(Alias(_, col)()).getOrElse { val inputColumns = query.schema.map(_.name).mkString(", ") throw new AnalysisException( s"cannot resolve '$col' given input columns: [$inputColumns]") @@ -168,15 +168,9 @@ case class PreprocessTableCreation(sparkSession: SparkSession) extends Rule[Logi """.stripMargin) } - val newQuery = if (adjustedColumns != query.output) { - Project(adjustedColumns, query) - } else { - query - } - c.copy( tableDesc = existingTable, - query = Some(newQuery)) + query = Some(Project(adjustedColumns, query))) // Here we normalize partition, bucket and sort column names, w.r.t. the case sensitivity // config, and do various checks: diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamSink.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamSink.scala index 6885d0bf67ccb..2a652920c10c9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamSink.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamSink.scala @@ -111,15 +111,6 @@ class FileStreamSink( case _ => // Do nothing } - // Get the actual partition columns as attributes after matching them by name with - // the given columns names. - val partitionColumns: Seq[Attribute] = partitionColumnNames.map { col => - val nameEquality = data.sparkSession.sessionState.conf.resolver - data.logicalPlan.output.find(f => nameEquality(f.name, col)).getOrElse { - throw new RuntimeException(s"Partition column $col not found in schema ${data.schema}") - } - } - FileFormatWriter.write( sparkSession = sparkSession, queryExecution = data.queryExecution, @@ -127,7 +118,7 @@ class FileStreamSink( committer = committer, outputSpec = FileFormatWriter.OutputSpec(path, Map.empty), hadoopConf = hadoopConf, - partitionColumns = partitionColumns, + partitionColumnNames = partitionColumnNames, bucketSpec = None, refreshFunction = _ => (), options = options) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveTable.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveTable.scala index 35f65e972fe27..797481c879e7a 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveTable.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveTable.scala @@ -314,13 +314,6 @@ case class InsertIntoHiveTable( outputPath = tmpLocation.toString, isAppend = false) - val partitionAttributes = partitionColumnNames.takeRight(numDynamicPartitions).map { name => - query.resolve(name :: Nil, sparkSession.sessionState.analyzer.resolver).getOrElse { - throw new AnalysisException( - s"Unable to resolve $name given [${query.output.map(_.name).mkString(", ")}]") - }.asInstanceOf[Attribute] - } - FileFormatWriter.write( sparkSession = sparkSession, queryExecution = Dataset.ofRows(sparkSession, query).queryExecution, @@ -328,7 +321,7 @@ case class InsertIntoHiveTable( committer = committer, outputSpec = FileFormatWriter.OutputSpec(tmpLocation.toString, Map.empty), hadoopConf = hadoopConf, - partitionColumns = partitionAttributes, + partitionColumnNames = partitionColumnNames.takeRight(numDynamicPartitions), bucketSpec = None, refreshFunction = _ => (), options = Map.empty) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/InsertIntoHiveTableSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/InsertIntoHiveTableSuite.scala index 58ab0c252bfd7..618e5b68ff8c0 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/InsertIntoHiveTableSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/InsertIntoHiveTableSuite.scala @@ -468,6 +468,28 @@ class InsertIntoHiveTableSuite extends QueryTest with TestHiveSingleton with Bef } } + test("SPARK-21165: the query schema of INSERT is changed after optimization") { + withSQLConf(("hive.exec.dynamic.partition.mode", "nonstrict")) { + withTable("tab1", "tab2") { + Seq(("a", "b", 3)).toDF("word", "first", "length").write.saveAsTable("tab1") + + spark.sql( + """ + |CREATE TABLE tab2 (word string, length int) + |PARTITIONED BY (first string) + """.stripMargin) + + spark.sql( + """ + |INSERT INTO TABLE tab2 PARTITION(first) + |SELECT word, length, cast(first as string) as first FROM tab1 + """.stripMargin) + + checkAnswer(spark.table("tab2"), Row("a", 3, "b")) + } + } + } + testPartitionedTable("insertInto() should reject extra columns") { tableName => sql("CREATE TABLE t (a INT, b INT, c INT, d INT, e INT)")