Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-22252][SQL] FileFormatWriter should respect the input query schema #19474

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -178,7 +178,7 @@ abstract class QueryPlan[PlanType <: QueryPlan[PlanType]] extends TreeNode[PlanT
})
}

override def innerChildren: Seq[QueryPlan[_]] = subqueries
override protected def innerChildren: Seq[QueryPlan[_]] = subqueries

/**
* Returns a plan where a best effort attempt has been made to transform `this` in a way
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@ import org.apache.spark.sql.catalyst.expressions.Attribute
* commands can be used by parsers to represent DDL operations. Commands, unlike queries, are
* eagerly executed.
*/
trait Command extends LogicalPlan {
trait Command extends LeafNode {
override def output: Seq[Attribute] = Seq.empty
override def children: Seq[LogicalPlan] = Seq.empty
}
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,7 @@ class QueryExecution(val sparkSession: SparkSession, val logical: LogicalPlan) {
* `SparkSQLDriver` for CLI applications.
*/
def hiveResultString(): Seq[String] = executedPlan match {
case ExecutedCommandExec(desc: DescribeTableCommand, _) =>
case ExecutedCommandExec(desc: DescribeTableCommand) =>
// If it is a describe command for a Hive table, we want to have the output format
// be similar with Hive.
desc.run(sparkSession).map {
Expand All @@ -130,7 +130,7 @@ class QueryExecution(val sparkSession: SparkSession, val logical: LogicalPlan) {
.mkString("\t")
}
// SHOW TABLES in Hive only output table names, while ours output database, table name, isTemp.
case command @ ExecutedCommandExec(s: ShowTablesCommand, _) if !s.isExtended =>
case command @ ExecutedCommandExec(s: ShowTablesCommand) if !s.isExtended =>
command.executeCollect().map(_.getString(1))
case other =>
val result: Seq[Seq[Any]] = other.executeCollectPublic().map(_.toSeq).toSeq
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -364,7 +364,7 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
// Can we automate these 'pass through' operations?
object BasicOperators extends Strategy {
def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match {
case r: RunnableCommand => ExecutedCommandExec(r, r.children.map(planLater)) :: Nil
case r: RunnableCommand => ExecutedCommandExec(r) :: Nil

case MemoryPlan(sink, output) =>
val encoder = RowEncoder(sink.schema)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,8 @@ case class InMemoryRelation(
@transient var _cachedColumnBuffers: RDD[CachedBatch] = null,
val batchStats: LongAccumulator = child.sqlContext.sparkContext.longAccumulator)
extends logical.LeafNode with MultiInstanceRelation {
override def innerChildren: Seq[SparkPlan] = Seq(child)

override protected def innerChildren: Seq[SparkPlan] = Seq(child)

override def producedAttributes: AttributeSet = outputSet

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ case class InMemoryTableScanExec(
@transient relation: InMemoryRelation)
extends LeafExecNode {

override def innerChildren: Seq[QueryPlan[_]] = Seq(relation) ++ super.innerChildren
override protected def innerChildren: Seq[QueryPlan[_]] = Seq(relation) ++ super.innerChildren

override lazy val metrics = Map(
"numOutputRows" -> SQLMetrics.createMetric(sparkContext, "number of output rows"))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ package org.apache.spark.sql.execution.command
import org.apache.hadoop.conf.Configuration

import org.apache.spark.SparkContext
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.sql.execution.datasources.BasicWriteJobStatsTracker
import org.apache.spark.sql.execution.metric.{SQLMetric, SQLMetrics}
import org.apache.spark.util.SerializableConfiguration
Expand All @@ -30,6 +31,15 @@ import org.apache.spark.util.SerializableConfiguration
*/
trait DataWritingCommand extends RunnableCommand {

def query: LogicalPlan
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add one line description for query?


// We make the input `query` an inner child instead of a child in order to hide it from the
// optimizer. This is because optimizer may change the output schema names, and we have to keep
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You will scare others. :)

-> may not preserve the output schema names' case

// the original analyzed plan here so that we can pass the corrected schema to the writer. The
// schema of analyzed plan is what user expects(or specifies), so we should respect it when
// writing.
override protected def innerChildren: Seq[LogicalPlan] = query :: Nil

override lazy val metrics: Map[String, SQLMetric] = {
val sparkContext = SparkContext.getActive.get
Map(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@ import org.apache.spark.SparkException
import org.apache.spark.sql._
import org.apache.spark.sql.catalyst.catalog._
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.sql.execution.SparkPlan
import org.apache.spark.sql.execution.datasources._

/**
Expand All @@ -45,10 +44,9 @@ case class InsertIntoDataSourceDirCommand(
query: LogicalPlan,
overwrite: Boolean) extends RunnableCommand {

override def children: Seq[LogicalPlan] = Seq(query)
override protected def innerChildren: Seq[LogicalPlan] = query :: Nil

override def run(sparkSession: SparkSession, children: Seq[SparkPlan]): Seq[Row] = {
assert(children.length == 1)
override def run(sparkSession: SparkSession): Seq[Row] = {
assert(storage.locationUri.nonEmpty, "Directory path is required")
assert(provider.nonEmpty, "Data source is required")

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ case class CacheTableCommand(
require(plan.isEmpty || tableIdent.database.isEmpty,
"Database name is not allowed in CACHE TABLE AS SELECT")

override def innerChildren: Seq[QueryPlan[_]] = plan.toSeq
override protected def innerChildren: Seq[QueryPlan[_]] = plan.toSeq

override def run(sparkSession: SparkSession): Seq[Row] = {
plan.foreach { logicalPlan =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,9 @@ import org.apache.spark.sql.{Row, SparkSession}
import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow}
import org.apache.spark.sql.catalyst.errors.TreeNodeException
import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference}
import org.apache.spark.sql.catalyst.plans.{logical, QueryPlan}
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.sql.execution.SparkPlan
import org.apache.spark.sql.catalyst.plans.QueryPlan
import org.apache.spark.sql.catalyst.plans.logical.{Command, LogicalPlan}
import org.apache.spark.sql.execution.LeafExecNode
import org.apache.spark.sql.execution.debug._
import org.apache.spark.sql.execution.metric.SQLMetric
import org.apache.spark.sql.execution.streaming.{IncrementalExecution, OffsetSeqMetadata}
Expand All @@ -37,29 +37,22 @@ import org.apache.spark.sql.types._
* A logical command that is executed for its side-effects. `RunnableCommand`s are
* wrapped in `ExecutedCommand` during execution.
*/
trait RunnableCommand extends logical.Command {
trait RunnableCommand extends Command {

// The map used to record the metrics of running the command. This will be passed to
// `ExecutedCommand` during query planning.
lazy val metrics: Map[String, SQLMetric] = Map.empty

def run(sparkSession: SparkSession, children: Seq[SparkPlan]): Seq[Row] = {
throw new NotImplementedError
}

def run(sparkSession: SparkSession): Seq[Row] = {
throw new NotImplementedError
}
def run(sparkSession: SparkSession): Seq[Row]
}

/**
* A physical operator that executes the run method of a `RunnableCommand` and
* saves the result to prevent multiple executions.
*
* @param cmd the `RunnableCommand` this operator will run.
* @param children the children physical plans ran by the `RunnableCommand`.
*/
case class ExecutedCommandExec(cmd: RunnableCommand, children: Seq[SparkPlan]) extends SparkPlan {
case class ExecutedCommandExec(cmd: RunnableCommand) extends LeafExecNode {

override lazy val metrics: Map[String, SQLMetric] = cmd.metrics

Expand All @@ -74,19 +67,14 @@ case class ExecutedCommandExec(cmd: RunnableCommand, children: Seq[SparkPlan]) e
*/
protected[sql] lazy val sideEffectResult: Seq[InternalRow] = {
val converter = CatalystTypeConverters.createToCatalystConverter(schema)
val rows = if (children.isEmpty) {
cmd.run(sqlContext.sparkSession)
} else {
cmd.run(sqlContext.sparkSession, children)
}
rows.map(converter(_).asInstanceOf[InternalRow])
cmd.run(sqlContext.sparkSession).map(converter(_).asInstanceOf[InternalRow])
}

override def innerChildren: Seq[QueryPlan[_]] = cmd.innerChildren
override protected def innerChildren: Seq[QueryPlan[_]] = cmd :: Nil

override def output: Seq[Attribute] = cmd.output

override def nodeName: String = cmd.nodeName
override def nodeName: String = "Execute " + cmd.nodeName

override def executeCollect(): Array[InternalRow] = sideEffectResult.toArray

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,7 @@ case class CreateDataSourceTableAsSelectCommand(
query: LogicalPlan)
extends RunnableCommand {

override def innerChildren: Seq[LogicalPlan] = Seq(query)
override protected def innerChildren: Seq[LogicalPlan] = Seq(query)

override def run(sparkSession: SparkSession): Seq[Row] = {
assert(table.tableType != CatalogTableType.VIEW)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ case class CreateViewCommand(

import ViewHelper._

override def innerChildren: Seq[QueryPlan[_]] = Seq(child)
override protected def innerChildren: Seq[QueryPlan[_]] = Seq(child)

if (viewType == PersistedView) {
require(originalText.isDefined, "'originalText' must be provided to create permanent view")
Expand Down Expand Up @@ -267,7 +267,7 @@ case class AlterViewAsCommand(

import ViewHelper._

override def innerChildren: Seq[QueryPlan[_]] = Seq(query)
override protected def innerChildren: Seq[QueryPlan[_]] = Seq(query)

override def run(session: SparkSession): Seq[Row] = {
// If the plan cannot be analyzed, throw an exception and don't proceed.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -453,6 +453,17 @@ case class DataSource(
val caseSensitive = sparkSession.sessionState.conf.caseSensitiveAnalysis
PartitioningUtils.validatePartitionColumn(data.schema, partitionColumns, caseSensitive)


// SPARK-17230: Resolve the partition columns so InsertIntoHadoopFsRelationCommand does
// not need to have the query as child, to avoid to analyze an optimized query,
// because InsertIntoHadoopFsRelationCommand will be optimized first.
val partitionAttributes = partitionColumns.map { name =>
data.output.find(a => equality(a.name, name)).getOrElse {
throw new AnalysisException(
s"Unable to resolve $name given [${data.output.map(_.name).mkString(", ")}]")
}
}

val fileIndex = catalogTable.map(_.identifier).map { tableIdent =>
sparkSession.table(tableIdent).queryExecution.analyzed.collect {
case LogicalRelation(t: HadoopFsRelation, _, _, _) => t.location
Expand All @@ -465,7 +476,7 @@ case class DataSource(
outputPath = outputPath,
staticPartitions = Map.empty,
ifPartitionNotExists = false,
partitionColumns = partitionColumns.map(UnresolvedAttribute.quoted),
partitionColumns = partitionAttributes,
bucketSpec = bucketSpec,
fileFormat = format,
options = options,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ import org.apache.spark.sql.catalyst.catalog.CatalogTypes.TablePartitionSpec
import org.apache.spark.sql.catalyst.expressions.{UnsafeProjection, _}
import org.apache.spark.sql.catalyst.plans.physical.HashPartitioning
import org.apache.spark.sql.catalyst.util.{CaseInsensitiveMap, DateTimeUtils}
import org.apache.spark.sql.execution.{SortExec, SparkPlan, SQLExecution}
import org.apache.spark.sql.execution.{QueryExecution, SortExec, SQLExecution}
import org.apache.spark.sql.types.StringType
import org.apache.spark.util.{SerializableConfiguration, Utils}

Expand Down Expand Up @@ -101,7 +101,7 @@ object FileFormatWriter extends Logging {
*/
def write(
sparkSession: SparkSession,
plan: SparkPlan,
queryExecution: QueryExecution,
fileFormat: FileFormat,
committer: FileCommitProtocol,
outputSpec: OutputSpec,
Expand All @@ -117,7 +117,7 @@ object FileFormatWriter extends Logging {
job.setOutputValueClass(classOf[InternalRow])
FileOutputFormat.setOutputPath(job, new Path(outputSpec.outputPath))

val allColumns = plan.output
val allColumns = queryExecution.logical.output
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it'd be good to leave a comment that we should not use optimized output here in case it will be changed in the future.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Btw, shall we use queryExecution.analyzed.output?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Explicitly using analyzed's schema is better here.

val partitionSet = AttributeSet(partitionColumns)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You might need to double check the partitionColumns in all the other files are also from analyzed plans.

val dataColumns = allColumns.filterNot(partitionSet.contains)

Expand Down Expand Up @@ -158,7 +158,7 @@ object FileFormatWriter extends Logging {
// We should first sort by partition columns, then bucket id, and finally sorting columns.
val requiredOrdering = partitionColumns ++ bucketIdExpression ++ sortColumns
// the sort order doesn't matter
val actualOrdering = plan.outputOrdering.map(_.child)
val actualOrdering = queryExecution.executedPlan.outputOrdering.map(_.child)
val orderingMatched = if (requiredOrdering.length > actualOrdering.length) {
false
} else {
Expand All @@ -176,12 +176,12 @@ object FileFormatWriter extends Logging {

try {
val rdd = if (orderingMatched) {
plan.execute()
queryExecution.toRdd
} else {
SortExec(
requiredOrdering.map(SortOrder(_, Ascending)),
global = false,
child = plan).execute()
child = queryExecution.executedPlan).execute()
}
val ret = new Array[WriteTaskResult](rdd.partitions.length)
sparkSession.sparkContext.runJob(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ case class InsertIntoDataSourceCommand(
overwrite: Boolean)
extends RunnableCommand {

override def innerChildren: Seq[QueryPlan[_]] = Seq(query)
override protected def innerChildren: Seq[QueryPlan[_]] = Seq(query)

override def run(sparkSession: SparkSession): Seq[Row] = {
val relation = logicalRelation.relation.asInstanceOf[InsertableRelation]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@ import org.apache.spark.sql.catalyst.catalog.{BucketSpec, CatalogTable, CatalogT
import org.apache.spark.sql.catalyst.catalog.CatalogTypes.TablePartitionSpec
import org.apache.spark.sql.catalyst.expressions.Attribute
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.sql.execution.SparkPlan
import org.apache.spark.sql.execution.command._
import org.apache.spark.sql.util.SchemaUtils

Expand Down Expand Up @@ -57,11 +56,7 @@ case class InsertIntoHadoopFsRelationCommand(
extends DataWritingCommand {
import org.apache.spark.sql.catalyst.catalog.ExternalCatalogUtils.escapePathName

override def children: Seq[LogicalPlan] = query :: Nil

override def run(sparkSession: SparkSession, children: Seq[SparkPlan]): Seq[Row] = {
assert(children.length == 1)

override def run(sparkSession: SparkSession): Seq[Row] = {
// Most formats don't do well with duplicate columns, so lets not allow that
SchemaUtils.checkSchemaColumnNameDuplication(
query.schema,
Expand Down Expand Up @@ -144,7 +139,7 @@ case class InsertIntoHadoopFsRelationCommand(
val updatedPartitionPaths =
FileFormatWriter.write(
sparkSession = sparkSession,
plan = children.head,
queryExecution = Dataset.ofRows(sparkSession, query).queryExecution,
fileFormat = fileFormat,
committer = committer,
outputSpec = FileFormatWriter.OutputSpec(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ case class SaveIntoDataSourceCommand(
options: Map[String, String],
mode: SaveMode) extends RunnableCommand {

override def innerChildren: Seq[QueryPlan[_]] = Seq(query)
override protected def innerChildren: Seq[QueryPlan[_]] = Seq(query)

override def run(sparkSession: SparkSession): Seq[Row] = {
dataSource.createRelation(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,7 @@ class FileStreamSink(

FileFormatWriter.write(
sparkSession = sparkSession,
plan = data.queryExecution.executedPlan,
queryExecution = data.queryExecution,
fileFormat = fileFormat,
committer = committer,
outputSpec = FileFormatWriter.OutputSpec(path, Map.empty),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,11 @@

package org.apache.spark.sql.execution.datasources

import org.apache.spark.sql.QueryTest
import org.apache.spark.sql.{QueryTest, Row}
import org.apache.spark.sql.test.SharedSQLContext

class FileFormatWriterSuite extends QueryTest with SharedSQLContext {
import testImplicits._

test("empty file should be skipped while write to file") {
withTempPath { path =>
Expand All @@ -30,4 +31,12 @@ class FileFormatWriterSuite extends QueryTest with SharedSQLContext {
assert(partFiles.length === 2)
}
}

test("FileFormatWriter should respect the input query schema") {
withTable("t1", "t2") {
spark.range(1).select('id as 'col1, 'id as 'col2).write.saveAsTable("t1")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also add another case here?

spark.range(1).select('id, 'id as 'col1, 'id as 'col2).write.saveAsTable("t3")

spark.sql("select COL1, COL2 from t1").write.saveAsTable("t2")
checkAnswer(spark.table("t2"), Row(0, 0))
}
}
}
Loading