Skip to content

Commit

Permalink
[SPARK-22252][SQL][2.2] FileFormatWriter should respect the input que…
Browse files Browse the repository at this point in the history
…ry schema

## What changes were proposed in this pull request?

#18386 fixes SPARK-21165 but breaks SPARK-22252. This PR reverts #18386 and picks the patch from #19483 to fix SPARK-21165.

## How was this patch tested?

new regression test

Author: Wenchen Fan <[email protected]>

Closes #19484 from cloud-fan/bug.
  • Loading branch information
cloud-fan authored and gatorsmile committed Oct 13, 2017
1 parent cfc04e0 commit c9187db
Show file tree
Hide file tree
Showing 8 changed files with 62 additions and 23 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -401,6 +401,16 @@ case class DataSource(
val caseSensitive = sparkSession.sessionState.conf.caseSensitiveAnalysis
PartitioningUtils.validatePartitionColumn(data.schema, partitionColumns, caseSensitive)

// SPARK-17230: Resolve the partition columns so InsertIntoHadoopFsRelationCommand does
// not need to have the query as child, to avoid to analyze an optimized query,
// because InsertIntoHadoopFsRelationCommand will be optimized first.
val partitionAttributes = partitionColumns.map { name =>
val plan = data.logicalPlan
plan.resolve(name :: Nil, data.sparkSession.sessionState.analyzer.resolver).getOrElse {
throw new AnalysisException(
s"Unable to resolve $name given [${plan.output.map(_.name).mkString(", ")}]")
}.asInstanceOf[Attribute]
}
val fileIndex = catalogTable.map(_.identifier).map { tableIdent =>
sparkSession.table(tableIdent).queryExecution.analyzed.collect {
case LogicalRelation(t: HadoopFsRelation, _, _) => t.location
Expand All @@ -414,7 +424,7 @@ case class DataSource(
outputPath = outputPath,
staticPartitions = Map.empty,
ifPartitionNotExists = false,
partitionColumns = partitionColumns,
partitionColumns = partitionAttributes,
bucketSpec = bucketSpec,
fileFormat = format,
options = options,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -188,13 +188,15 @@ case class DataSourceAnalysis(conf: SQLConf) extends Rule[LogicalPlan] with Cast
"Cannot overwrite a path that is also being read from.")
}

val partitionSchema = actualQuery.resolve(
t.partitionSchema, t.sparkSession.sessionState.analyzer.resolver)
val staticPartitions = parts.filter(_._2.nonEmpty).map { case (k, v) => k -> v.get }

InsertIntoHadoopFsRelationCommand(
outputPath,
staticPartitions,
i.ifPartitionNotExists,
partitionColumns = t.partitionSchema.map(_.name),
partitionSchema,
t.bucketSpec,
t.fileFormat,
t.options,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ import org.apache.spark.sql.catalyst.plans.physical.HashPartitioning
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.util.{CaseInsensitiveMap, DateTimeUtils}
import org.apache.spark.sql.execution.{QueryExecution, SortExec, SQLExecution}
import org.apache.spark.sql.types.{StringType, StructType}
import org.apache.spark.sql.types.StringType
import org.apache.spark.util.{SerializableConfiguration, Utils}


Expand Down Expand Up @@ -101,7 +101,7 @@ object FileFormatWriter extends Logging {
committer: FileCommitProtocol,
outputSpec: OutputSpec,
hadoopConf: Configuration,
partitionColumnNames: Seq[String],
partitionColumns: Seq[Attribute],
bucketSpec: Option[BucketSpec],
refreshFunction: (Seq[TablePartitionSpec]) => Unit,
options: Map[String, String]): Unit = {
Expand All @@ -111,16 +111,9 @@ object FileFormatWriter extends Logging {
job.setOutputValueClass(classOf[InternalRow])
FileOutputFormat.setOutputPath(job, new Path(outputSpec.outputPath))

val allColumns = queryExecution.executedPlan.output
// Get the actual partition columns as attributes after matching them by name with
// the given columns names.
val partitionColumns = partitionColumnNames.map { col =>
val nameEquality = sparkSession.sessionState.conf.resolver
allColumns.find(f => nameEquality(f.name, col)).getOrElse {
throw new RuntimeException(
s"Partition column $col not found in schema ${queryExecution.executedPlan.schema}")
}
}
// Pick the attributes from analyzed plan, as optimizer may not preserve the output schema
// names' case.
val allColumns = queryExecution.analyzed.output
val partitionSet = AttributeSet(partitionColumns)
val dataColumns = allColumns.filterNot(partitionSet.contains)

Expand Down Expand Up @@ -179,8 +172,13 @@ object FileFormatWriter extends Logging {
val rdd = if (orderingMatched) {
queryExecution.toRdd
} else {
// SPARK-21165: the `requiredOrdering` is based on the attributes from analyzed plan, and
// the physical plan may have different attribute ids due to optimizer removing some
// aliases. Here we bind the expression ahead to avoid potential attribute ids mismatch.
val orderingExpr = requiredOrdering
.map(SortOrder(_, Ascending)).map(BindReferences.bindReference(_, allColumns))
SortExec(
requiredOrdering.map(SortOrder(_, Ascending)),
orderingExpr,
global = false,
child = queryExecution.executedPlan).execute()
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ case class InsertIntoHadoopFsRelationCommand(
outputPath: Path,
staticPartitions: TablePartitionSpec,
ifPartitionNotExists: Boolean,
partitionColumns: Seq[String],
partitionColumns: Seq[Attribute],
bucketSpec: Option[BucketSpec],
fileFormat: FileFormat,
options: Map[String, String],
Expand Down Expand Up @@ -150,7 +150,7 @@ case class InsertIntoHadoopFsRelationCommand(
outputSpec = FileFormatWriter.OutputSpec(
qualifiedOutputPath.toString, customPartitionLocations),
hadoopConf = hadoopConf,
partitionColumnNames = partitionColumns,
partitionColumns = partitionColumns,
bucketSpec = bucketSpec,
refreshFunction = refreshPartitionsCallback,
options = options)
Expand All @@ -176,10 +176,10 @@ case class InsertIntoHadoopFsRelationCommand(
customPartitionLocations: Map[TablePartitionSpec, String],
committer: FileCommitProtocol): Unit = {
val staticPartitionPrefix = if (staticPartitions.nonEmpty) {
"/" + partitionColumns.flatMap { col =>
staticPartitions.get(col) match {
"/" + partitionColumns.flatMap { p =>
staticPartitions.get(p.name) match {
case Some(value) =>
Some(escapePathName(col) + "=" + escapePathName(value))
Some(escapePathName(p.name) + "=" + escapePathName(value))
case None =>
None
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -111,14 +111,23 @@ class FileStreamSink(
case _ => // Do nothing
}

// Get the actual partition columns as attributes after matching them by name with
// the given columns names.
val partitionColumns: Seq[Attribute] = partitionColumnNames.map { col =>
val nameEquality = data.sparkSession.sessionState.conf.resolver
data.logicalPlan.output.find(f => nameEquality(f.name, col)).getOrElse {
throw new RuntimeException(s"Partition column $col not found in schema ${data.schema}")
}
}

FileFormatWriter.write(
sparkSession = sparkSession,
queryExecution = data.queryExecution,
fileFormat = fileFormat,
committer = committer,
outputSpec = FileFormatWriter.OutputSpec(path, Map.empty),
hadoopConf = hadoopConf,
partitionColumnNames = partitionColumnNames,
partitionColumns = partitionColumns,
bucketSpec = None,
refreshFunction = _ => (),
options = options)
Expand Down
13 changes: 13 additions & 0 deletions sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -1759,4 +1759,17 @@ class DataFrameSuite extends QueryTest with SharedSQLContext {
testData2.select(lit(7), 'a, 'b).orderBy(lit(1), lit(2), lit(3)),
Seq(Row(7, 1, 1), Row(7, 1, 2), Row(7, 2, 1), Row(7, 2, 2), Row(7, 3, 1), Row(7, 3, 2)))
}

test("SPARK-22252: FileFormatWriter should respect the input query schema") {
withTable("t1", "t2", "t3", "t4") {
spark.range(1).select('id as 'col1, 'id as 'col2).write.saveAsTable("t1")
spark.sql("select COL1, COL2 from t1").write.saveAsTable("t2")
checkAnswer(spark.table("t2"), Row(0, 0))

// Test picking part of the columns when writing.
spark.range(1).select('id, 'id as 'col1, 'id as 'col2).write.saveAsTable("t3")
spark.sql("select COL1, COL2 from t3").write.saveAsTable("t4")
checkAnswer(spark.table("t4"), Row(0, 0))
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -314,14 +314,21 @@ case class InsertIntoHiveTable(
outputPath = tmpLocation.toString,
isAppend = false)

val partitionAttributes = partitionColumnNames.takeRight(numDynamicPartitions).map { name =>
query.resolve(name :: Nil, sparkSession.sessionState.analyzer.resolver).getOrElse {
throw new AnalysisException(
s"Unable to resolve $name given [${query.output.map(_.name).mkString(", ")}]")
}.asInstanceOf[Attribute]
}

FileFormatWriter.write(
sparkSession = sparkSession,
queryExecution = Dataset.ofRows(sparkSession, query).queryExecution,
fileFormat = new HiveFileFormat(fileSinkConf),
committer = committer,
outputSpec = FileFormatWriter.OutputSpec(tmpLocation.toString, Map.empty),
hadoopConf = hadoopConf,
partitionColumnNames = partitionColumnNames.takeRight(numDynamicPartitions),
partitionColumns = partitionAttributes,
bucketSpec = None,
refreshFunction = _ => (),
options = Map.empty)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -468,7 +468,7 @@ class InsertIntoHiveTableSuite extends QueryTest with TestHiveSingleton with Bef
}
}

test("SPARK-21165: the query schema of INSERT is changed after optimization") {
test("SPARK-21165: FileFormatWriter should only rely on attributes from analyzed plan") {
withSQLConf(("hive.exec.dynamic.partition.mode", "nonstrict")) {
withTable("tab1", "tab2") {
Seq(("a", "b", 3)).toDF("word", "first", "length").write.saveAsTable("tab1")
Expand Down

0 comments on commit c9187db

Please sign in to comment.