Skip to content

Commit

Permalink
[SPARK-22252][SQL] FileFormatWriter should respect the input query sc…
Browse files Browse the repository at this point in the history
…hema

## What changes were proposed in this pull request?

In apache#18064, we allowed `RunnableCommand` to have children in order to fix some UI issues. Then we made `InsertIntoXXX` commands take the input `query` as a child, when we do the actual writing, we just pass the physical plan to the writer(`FileFormatWriter.write`).

However this is problematic. In Spark SQL, optimizer and planner are allowed to change the schema names a little bit. e.g. `ColumnPruning` rule will remove no-op `Project`s, like `Project("A", Scan("a"))`, and thus change the output schema from "<A: int>" to `<a: int>`. When it comes to writing, especially for self-description data format like parquet, we may write the wrong schema to the file and cause null values at the read path.

Fortunately, in apache#18450 , we decided to allow nested execution and one query can map to multiple executions in the UI. This releases the major restriction in apache#18604 , and now we don't have to take the input `query` as child of `InsertIntoXXX` commands.

So the fix is simple, this PR partially revert apache#18064 and make `InsertIntoXXX` commands leaf nodes again.

## How was this patch tested?

new regression test

Author: Wenchen Fan <[email protected]>

Closes apache#19474 from cloud-fan/bug.
  • Loading branch information
cloud-fan committed Oct 12, 2017
1 parent ccdf21f commit 274f0ef
Show file tree
Hide file tree
Showing 22 changed files with 85 additions and 76 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -178,7 +178,7 @@ abstract class QueryPlan[PlanType <: QueryPlan[PlanType]] extends TreeNode[PlanT
})
}

override def innerChildren: Seq[QueryPlan[_]] = subqueries
override protected def innerChildren: Seq[QueryPlan[_]] = subqueries

/**
* Returns a plan where a best effort attempt has been made to transform `this` in a way
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@ import org.apache.spark.sql.catalyst.expressions.Attribute
* commands can be used by parsers to represent DDL operations. Commands, unlike queries, are
* eagerly executed.
*/
trait Command extends LogicalPlan {
trait Command extends LeafNode {
override def output: Seq[Attribute] = Seq.empty
override def children: Seq[LogicalPlan] = Seq.empty
}
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,7 @@ class QueryExecution(val sparkSession: SparkSession, val logical: LogicalPlan) {
* `SparkSQLDriver` for CLI applications.
*/
def hiveResultString(): Seq[String] = executedPlan match {
case ExecutedCommandExec(desc: DescribeTableCommand, _) =>
case ExecutedCommandExec(desc: DescribeTableCommand) =>
// If it is a describe command for a Hive table, we want to have the output format
// be similar with Hive.
desc.run(sparkSession).map {
Expand All @@ -130,7 +130,7 @@ class QueryExecution(val sparkSession: SparkSession, val logical: LogicalPlan) {
.mkString("\t")
}
// SHOW TABLES in Hive only output table names, while ours output database, table name, isTemp.
case command @ ExecutedCommandExec(s: ShowTablesCommand, _) if !s.isExtended =>
case command @ ExecutedCommandExec(s: ShowTablesCommand) if !s.isExtended =>
command.executeCollect().map(_.getString(1))
case other =>
val result: Seq[Seq[Any]] = other.executeCollectPublic().map(_.toSeq).toSeq
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -364,7 +364,7 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
// Can we automate these 'pass through' operations?
object BasicOperators extends Strategy {
def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match {
case r: RunnableCommand => ExecutedCommandExec(r, r.children.map(planLater)) :: Nil
case r: RunnableCommand => ExecutedCommandExec(r) :: Nil

case MemoryPlan(sink, output) =>
val encoder = RowEncoder(sink.schema)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,8 @@ case class InMemoryRelation(
@transient var _cachedColumnBuffers: RDD[CachedBatch] = null,
val batchStats: LongAccumulator = child.sqlContext.sparkContext.longAccumulator)
extends logical.LeafNode with MultiInstanceRelation {
override def innerChildren: Seq[SparkPlan] = Seq(child)

override protected def innerChildren: Seq[SparkPlan] = Seq(child)

override def producedAttributes: AttributeSet = outputSet

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ case class InMemoryTableScanExec(
@transient relation: InMemoryRelation)
extends LeafExecNode {

override def innerChildren: Seq[QueryPlan[_]] = Seq(relation) ++ super.innerChildren
override protected def innerChildren: Seq[QueryPlan[_]] = Seq(relation) ++ super.innerChildren

override lazy val metrics = Map(
"numOutputRows" -> SQLMetrics.createMetric(sparkContext, "number of output rows"))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ package org.apache.spark.sql.execution.command
import org.apache.hadoop.conf.Configuration

import org.apache.spark.SparkContext
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.sql.execution.datasources.BasicWriteJobStatsTracker
import org.apache.spark.sql.execution.metric.{SQLMetric, SQLMetrics}
import org.apache.spark.util.SerializableConfiguration
Expand All @@ -30,6 +31,18 @@ import org.apache.spark.util.SerializableConfiguration
*/
trait DataWritingCommand extends RunnableCommand {

/**
* The input query plan that produces the data to be written.
*/
def query: LogicalPlan

// We make the input `query` an inner child instead of a child in order to hide it from the
// optimizer. This is because optimizer may not preserve the output schema names' case, and we
// have to keep the original analyzed plan here so that we can pass the corrected schema to the
// writer. The schema of analyzed plan is what user expects(or specifies), so we should respect
// it when writing.
override protected def innerChildren: Seq[LogicalPlan] = query :: Nil

override lazy val metrics: Map[String, SQLMetric] = {
val sparkContext = SparkContext.getActive.get
Map(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@ import org.apache.spark.SparkException
import org.apache.spark.sql._
import org.apache.spark.sql.catalyst.catalog._
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.sql.execution.SparkPlan
import org.apache.spark.sql.execution.datasources._

/**
Expand All @@ -45,10 +44,9 @@ case class InsertIntoDataSourceDirCommand(
query: LogicalPlan,
overwrite: Boolean) extends RunnableCommand {

override def children: Seq[LogicalPlan] = Seq(query)
override protected def innerChildren: Seq[LogicalPlan] = query :: Nil

override def run(sparkSession: SparkSession, children: Seq[SparkPlan]): Seq[Row] = {
assert(children.length == 1)
override def run(sparkSession: SparkSession): Seq[Row] = {
assert(storage.locationUri.nonEmpty, "Directory path is required")
assert(provider.nonEmpty, "Data source is required")

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ case class CacheTableCommand(
require(plan.isEmpty || tableIdent.database.isEmpty,
"Database name is not allowed in CACHE TABLE AS SELECT")

override def innerChildren: Seq[QueryPlan[_]] = plan.toSeq
override protected def innerChildren: Seq[QueryPlan[_]] = plan.toSeq

override def run(sparkSession: SparkSession): Seq[Row] = {
plan.foreach { logicalPlan =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,9 @@ import org.apache.spark.sql.{Row, SparkSession}
import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow}
import org.apache.spark.sql.catalyst.errors.TreeNodeException
import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference}
import org.apache.spark.sql.catalyst.plans.{logical, QueryPlan}
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.sql.execution.SparkPlan
import org.apache.spark.sql.catalyst.plans.QueryPlan
import org.apache.spark.sql.catalyst.plans.logical.{Command, LogicalPlan}
import org.apache.spark.sql.execution.LeafExecNode
import org.apache.spark.sql.execution.debug._
import org.apache.spark.sql.execution.metric.SQLMetric
import org.apache.spark.sql.execution.streaming.{IncrementalExecution, OffsetSeqMetadata}
Expand All @@ -37,29 +37,22 @@ import org.apache.spark.sql.types._
* A logical command that is executed for its side-effects. `RunnableCommand`s are
* wrapped in `ExecutedCommand` during execution.
*/
trait RunnableCommand extends logical.Command {
trait RunnableCommand extends Command {

// The map used to record the metrics of running the command. This will be passed to
// `ExecutedCommand` during query planning.
lazy val metrics: Map[String, SQLMetric] = Map.empty

def run(sparkSession: SparkSession, children: Seq[SparkPlan]): Seq[Row] = {
throw new NotImplementedError
}

def run(sparkSession: SparkSession): Seq[Row] = {
throw new NotImplementedError
}
def run(sparkSession: SparkSession): Seq[Row]
}

/**
* A physical operator that executes the run method of a `RunnableCommand` and
* saves the result to prevent multiple executions.
*
* @param cmd the `RunnableCommand` this operator will run.
* @param children the children physical plans ran by the `RunnableCommand`.
*/
case class ExecutedCommandExec(cmd: RunnableCommand, children: Seq[SparkPlan]) extends SparkPlan {
case class ExecutedCommandExec(cmd: RunnableCommand) extends LeafExecNode {

override lazy val metrics: Map[String, SQLMetric] = cmd.metrics

Expand All @@ -74,19 +67,14 @@ case class ExecutedCommandExec(cmd: RunnableCommand, children: Seq[SparkPlan]) e
*/
protected[sql] lazy val sideEffectResult: Seq[InternalRow] = {
val converter = CatalystTypeConverters.createToCatalystConverter(schema)
val rows = if (children.isEmpty) {
cmd.run(sqlContext.sparkSession)
} else {
cmd.run(sqlContext.sparkSession, children)
}
rows.map(converter(_).asInstanceOf[InternalRow])
cmd.run(sqlContext.sparkSession).map(converter(_).asInstanceOf[InternalRow])
}

override def innerChildren: Seq[QueryPlan[_]] = cmd.innerChildren
override protected def innerChildren: Seq[QueryPlan[_]] = cmd :: Nil

override def output: Seq[Attribute] = cmd.output

override def nodeName: String = cmd.nodeName
override def nodeName: String = "Execute " + cmd.nodeName

override def executeCollect(): Array[InternalRow] = sideEffectResult.toArray

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,7 @@ case class CreateDataSourceTableAsSelectCommand(
query: LogicalPlan)
extends RunnableCommand {

override def innerChildren: Seq[LogicalPlan] = Seq(query)
override protected def innerChildren: Seq[LogicalPlan] = Seq(query)

override def run(sparkSession: SparkSession): Seq[Row] = {
assert(table.tableType != CatalogTableType.VIEW)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ case class CreateViewCommand(

import ViewHelper._

override def innerChildren: Seq[QueryPlan[_]] = Seq(child)
override protected def innerChildren: Seq[QueryPlan[_]] = Seq(child)

if (viewType == PersistedView) {
require(originalText.isDefined, "'originalText' must be provided to create permanent view")
Expand Down Expand Up @@ -267,7 +267,7 @@ case class AlterViewAsCommand(

import ViewHelper._

override def innerChildren: Seq[QueryPlan[_]] = Seq(query)
override protected def innerChildren: Seq[QueryPlan[_]] = Seq(query)

override def run(session: SparkSession): Seq[Row] = {
// If the plan cannot be analyzed, throw an exception and don't proceed.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -453,6 +453,17 @@ case class DataSource(
val caseSensitive = sparkSession.sessionState.conf.caseSensitiveAnalysis
PartitioningUtils.validatePartitionColumn(data.schema, partitionColumns, caseSensitive)


// SPARK-17230: Resolve the partition columns so InsertIntoHadoopFsRelationCommand does
// not need to have the query as child, to avoid to analyze an optimized query,
// because InsertIntoHadoopFsRelationCommand will be optimized first.
val partitionAttributes = partitionColumns.map { name =>
data.output.find(a => equality(a.name, name)).getOrElse {
throw new AnalysisException(
s"Unable to resolve $name given [${data.output.map(_.name).mkString(", ")}]")
}
}

val fileIndex = catalogTable.map(_.identifier).map { tableIdent =>
sparkSession.table(tableIdent).queryExecution.analyzed.collect {
case LogicalRelation(t: HadoopFsRelation, _, _, _) => t.location
Expand All @@ -465,7 +476,7 @@ case class DataSource(
outputPath = outputPath,
staticPartitions = Map.empty,
ifPartitionNotExists = false,
partitionColumns = partitionColumns.map(UnresolvedAttribute.quoted),
partitionColumns = partitionAttributes,
bucketSpec = bucketSpec,
fileFormat = format,
options = options,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ import org.apache.spark.sql.catalyst.catalog.CatalogTypes.TablePartitionSpec
import org.apache.spark.sql.catalyst.expressions.{UnsafeProjection, _}
import org.apache.spark.sql.catalyst.plans.physical.HashPartitioning
import org.apache.spark.sql.catalyst.util.{CaseInsensitiveMap, DateTimeUtils}
import org.apache.spark.sql.execution.{SortExec, SparkPlan, SQLExecution}
import org.apache.spark.sql.execution.{QueryExecution, SortExec, SQLExecution}
import org.apache.spark.sql.types.StringType
import org.apache.spark.util.{SerializableConfiguration, Utils}

Expand Down Expand Up @@ -101,7 +101,7 @@ object FileFormatWriter extends Logging {
*/
def write(
sparkSession: SparkSession,
plan: SparkPlan,
queryExecution: QueryExecution,
fileFormat: FileFormat,
committer: FileCommitProtocol,
outputSpec: OutputSpec,
Expand All @@ -117,7 +117,9 @@ object FileFormatWriter extends Logging {
job.setOutputValueClass(classOf[InternalRow])
FileOutputFormat.setOutputPath(job, new Path(outputSpec.outputPath))

val allColumns = plan.output
// Pick the attributes from analyzed plan, as optimizer may not preserve the output schema
// names' case.
val allColumns = queryExecution.analyzed.output
val partitionSet = AttributeSet(partitionColumns)
val dataColumns = allColumns.filterNot(partitionSet.contains)

Expand Down Expand Up @@ -158,7 +160,7 @@ object FileFormatWriter extends Logging {
// We should first sort by partition columns, then bucket id, and finally sorting columns.
val requiredOrdering = partitionColumns ++ bucketIdExpression ++ sortColumns
// the sort order doesn't matter
val actualOrdering = plan.outputOrdering.map(_.child)
val actualOrdering = queryExecution.executedPlan.outputOrdering.map(_.child)
val orderingMatched = if (requiredOrdering.length > actualOrdering.length) {
false
} else {
Expand All @@ -176,12 +178,12 @@ object FileFormatWriter extends Logging {

try {
val rdd = if (orderingMatched) {
plan.execute()
queryExecution.toRdd
} else {
SortExec(
requiredOrdering.map(SortOrder(_, Ascending)),
global = false,
child = plan).execute()
child = queryExecution.executedPlan).execute()
}
val ret = new Array[WriteTaskResult](rdd.partitions.length)
sparkSession.sparkContext.runJob(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ case class InsertIntoDataSourceCommand(
overwrite: Boolean)
extends RunnableCommand {

override def innerChildren: Seq[QueryPlan[_]] = Seq(query)
override protected def innerChildren: Seq[QueryPlan[_]] = Seq(query)

override def run(sparkSession: SparkSession): Seq[Row] = {
val relation = logicalRelation.relation.asInstanceOf[InsertableRelation]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@ import org.apache.spark.sql.catalyst.catalog.{BucketSpec, CatalogTable, CatalogT
import org.apache.spark.sql.catalyst.catalog.CatalogTypes.TablePartitionSpec
import org.apache.spark.sql.catalyst.expressions.Attribute
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.sql.execution.SparkPlan
import org.apache.spark.sql.execution.command._
import org.apache.spark.sql.util.SchemaUtils

Expand Down Expand Up @@ -57,11 +56,7 @@ case class InsertIntoHadoopFsRelationCommand(
extends DataWritingCommand {
import org.apache.spark.sql.catalyst.catalog.ExternalCatalogUtils.escapePathName

override def children: Seq[LogicalPlan] = query :: Nil

override def run(sparkSession: SparkSession, children: Seq[SparkPlan]): Seq[Row] = {
assert(children.length == 1)

override def run(sparkSession: SparkSession): Seq[Row] = {
// Most formats don't do well with duplicate columns, so lets not allow that
SchemaUtils.checkSchemaColumnNameDuplication(
query.schema,
Expand Down Expand Up @@ -144,7 +139,7 @@ case class InsertIntoHadoopFsRelationCommand(
val updatedPartitionPaths =
FileFormatWriter.write(
sparkSession = sparkSession,
plan = children.head,
queryExecution = Dataset.ofRows(sparkSession, query).queryExecution,
fileFormat = fileFormat,
committer = committer,
outputSpec = FileFormatWriter.OutputSpec(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ case class SaveIntoDataSourceCommand(
options: Map[String, String],
mode: SaveMode) extends RunnableCommand {

override def innerChildren: Seq[QueryPlan[_]] = Seq(query)
override protected def innerChildren: Seq[QueryPlan[_]] = Seq(query)

override def run(sparkSession: SparkSession): Seq[Row] = {
dataSource.createRelation(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,7 @@ class FileStreamSink(

FileFormatWriter.write(
sparkSession = sparkSession,
plan = data.queryExecution.executedPlan,
queryExecution = data.queryExecution,
fileFormat = fileFormat,
committer = committer,
outputSpec = FileFormatWriter.OutputSpec(path, Map.empty),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,11 @@

package org.apache.spark.sql.execution.datasources

import org.apache.spark.sql.QueryTest
import org.apache.spark.sql.{QueryTest, Row}
import org.apache.spark.sql.test.SharedSQLContext

class FileFormatWriterSuite extends QueryTest with SharedSQLContext {
import testImplicits._

test("empty file should be skipped while write to file") {
withTempPath { path =>
Expand All @@ -30,4 +31,17 @@ class FileFormatWriterSuite extends QueryTest with SharedSQLContext {
assert(partFiles.length === 2)
}
}

test("FileFormatWriter should respect the input query schema") {
withTable("t1", "t2", "t3", "t4") {
spark.range(1).select('id as 'col1, 'id as 'col2).write.saveAsTable("t1")
spark.sql("select COL1, COL2 from t1").write.saveAsTable("t2")
checkAnswer(spark.table("t2"), Row(0, 0))

// Test picking part of the columns when writing.
spark.range(1).select('id, 'id as 'col1, 'id as 'col2).write.saveAsTable("t3")
spark.sql("select COL1, COL2 from t3").write.saveAsTable("t4")
checkAnswer(spark.table("t4"), Row(0, 0))
}
}
}
Loading

0 comments on commit 274f0ef

Please sign in to comment.