diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala index 46bd60daa1f78..2dda3ad1211fa 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala @@ -921,12 +921,15 @@ class SQLContext(@transient val sparkContext: SparkContext) protected[sql] lazy val emptyResult = sparkContext.parallelize(Seq.empty[InternalRow], 1) /** - * Prepares a planned SparkPlan for execution by inserting shuffle operations as needed. + * Prepares a planned SparkPlan for execution by inserting shuffle operations and internal + * row format conversions as needed. */ @transient protected[sql] val prepareForExecution = new RuleExecutor[SparkPlan] { - val batches = - Batch("Add exchange", Once, EnsureRequirements(self)) :: Nil + val batches = Seq( + Batch("Add exchange", Once, EnsureRequirements(self)), + Batch("Add row converters", Once, EnsureRowFormats) + ) } protected[sql] def openSession(): SQLSession = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala index ba12056ee7a1b..f363e9947d5f6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala @@ -79,12 +79,36 @@ abstract class SparkPlan extends QueryPlan[SparkPlan] with Logging with Product /** Specifies sort order for each partition requirements on the input data for this operator. */ def requiredChildOrdering: Seq[Seq[SortOrder]] = Seq.fill(children.size)(Nil) + /** Specifies whether this operator outputs UnsafeRows */ + def outputsUnsafeRows: Boolean = false + + /** Specifies whether this operator is capable of processing UnsafeRows */ + def canProcessUnsafeRows: Boolean = false + + /** + * Specifies whether this operator is capable of processing Java-object-based Rows (i.e. rows + * that are not UnsafeRows). + */ + def canProcessSafeRows: Boolean = true + /** * Returns the result of this query as an RDD[InternalRow] by delegating to doExecute * after adding query plan information to created RDDs for visualization. * Concrete implementations of SparkPlan should override doExecute instead. */ final def execute(): RDD[InternalRow] = { + if (children.nonEmpty) { + val hasUnsafeInputs = children.exists(_.outputsUnsafeRows) + val hasSafeInputs = children.exists(!_.outputsUnsafeRows) + assert(!(hasSafeInputs && hasUnsafeInputs), + "Child operators should output rows in the same format") + assert(canProcessSafeRows || canProcessUnsafeRows, + "Operator must be able to process at least one row format") + assert(!hasSafeInputs || canProcessSafeRows, + "Operator will receive safe rows as input but cannot process safe rows") + assert(!hasUnsafeInputs || canProcessUnsafeRows, + "Operator will receive unsafe rows as input but cannot process unsafe rows") + } RDDOperationScope.withScope(sparkContext, nodeName, false, true) { doExecute() } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala index 4c063c299ba53..82bef269b069f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala @@ -64,6 +64,12 @@ case class Filter(condition: Expression, child: SparkPlan) extends UnaryNode { } override def outputOrdering: Seq[SortOrder] = child.outputOrdering + + override def outputsUnsafeRows: Boolean = child.outputsUnsafeRows + + override def canProcessUnsafeRows: Boolean = true + + override def canProcessSafeRows: Boolean = true } /** @@ -104,6 +110,9 @@ case class Sample( case class Union(children: Seq[SparkPlan]) extends SparkPlan { // TODO: attributes output by union should be distinct for nullability purposes override def output: Seq[Attribute] = children.head.output + override def outputsUnsafeRows: Boolean = children.forall(_.outputsUnsafeRows) + override def canProcessUnsafeRows: Boolean = true + override def canProcessSafeRows: Boolean = true protected override def doExecute(): RDD[InternalRow] = sparkContext.union(children.map(_.execute())) } @@ -306,6 +315,8 @@ case class UnsafeExternalSort( override def output: Seq[Attribute] = child.output override def outputOrdering: Seq[SortOrder] = sortOrder + + override def outputsUnsafeRows: Boolean = true } @DeveloperApi diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/rowFormatConverters.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/rowFormatConverters.scala new file mode 100644 index 0000000000000..421d510e6782d --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/rowFormatConverters.scala @@ -0,0 +1,107 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution + +import org.apache.spark.annotation.DeveloperApi +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.rules.Rule + +/** + * :: DeveloperApi :: + * Converts Java-object-based rows into [[UnsafeRow]]s. + */ +@DeveloperApi +case class ConvertToUnsafe(child: SparkPlan) extends UnaryNode { + override def output: Seq[Attribute] = child.output + override def outputsUnsafeRows: Boolean = true + override def canProcessUnsafeRows: Boolean = false + override def canProcessSafeRows: Boolean = true + override protected def doExecute(): RDD[InternalRow] = { + child.execute().mapPartitions { iter => + val convertToUnsafe = UnsafeProjection.create(child.schema) + iter.map(convertToUnsafe) + } + } +} + +/** + * :: DeveloperApi :: + * Converts [[UnsafeRow]]s back into Java-object-based rows. + */ +@DeveloperApi +case class ConvertToSafe(child: SparkPlan) extends UnaryNode { + override def output: Seq[Attribute] = child.output + override def outputsUnsafeRows: Boolean = false + override def canProcessUnsafeRows: Boolean = true + override def canProcessSafeRows: Boolean = false + override protected def doExecute(): RDD[InternalRow] = { + child.execute().mapPartitions { iter => + val convertToSafe = FromUnsafeProjection(child.output.map(_.dataType)) + iter.map(convertToSafe) + } + } +} + +private[sql] object EnsureRowFormats extends Rule[SparkPlan] { + + private def onlyHandlesSafeRows(operator: SparkPlan): Boolean = + operator.canProcessSafeRows && !operator.canProcessUnsafeRows + + private def onlyHandlesUnsafeRows(operator: SparkPlan): Boolean = + operator.canProcessUnsafeRows && !operator.canProcessSafeRows + + private def handlesBothSafeAndUnsafeRows(operator: SparkPlan): Boolean = + operator.canProcessSafeRows && operator.canProcessUnsafeRows + + override def apply(operator: SparkPlan): SparkPlan = operator.transformUp { + case operator: SparkPlan if onlyHandlesSafeRows(operator) => + if (operator.children.exists(_.outputsUnsafeRows)) { + operator.withNewChildren { + operator.children.map { + c => if (c.outputsUnsafeRows) ConvertToSafe(c) else c + } + } + } else { + operator + } + case operator: SparkPlan if onlyHandlesUnsafeRows(operator) => + if (operator.children.exists(!_.outputsUnsafeRows)) { + operator.withNewChildren { + operator.children.map { + c => if (!c.outputsUnsafeRows) ConvertToUnsafe(c) else c + } + } + } else { + operator + } + case operator: SparkPlan if handlesBothSafeAndUnsafeRows(operator) => + if (operator.children.map(_.outputsUnsafeRows).toSet.size != 1) { + // If this operator's children produce both unsafe and safe rows, then convert everything + // to unsafe rows + operator.withNewChildren { + operator.children.map { + c => if (!c.outputsUnsafeRows) ConvertToUnsafe(c) else c + } + } + } else { + operator + } + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/RowFormatConvertersSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/RowFormatConvertersSuite.scala new file mode 100644 index 0000000000000..7b75f755918c1 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/RowFormatConvertersSuite.scala @@ -0,0 +1,91 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution + +import org.apache.spark.sql.Row +import org.apache.spark.sql.catalyst.CatalystTypeConverters +import org.apache.spark.sql.catalyst.expressions.IsNull +import org.apache.spark.sql.test.TestSQLContext + +class RowFormatConvertersSuite extends SparkPlanTest { + + private def getConverters(plan: SparkPlan): Seq[SparkPlan] = plan.collect { + case c: ConvertToUnsafe => c + case c: ConvertToSafe => c + } + + private val outputsSafe = ExternalSort(Nil, false, PhysicalRDD(Seq.empty, null)) + assert(!outputsSafe.outputsUnsafeRows) + private val outputsUnsafe = UnsafeExternalSort(Nil, false, PhysicalRDD(Seq.empty, null)) + assert(outputsUnsafe.outputsUnsafeRows) + + test("planner should insert unsafe->safe conversions when required") { + val plan = Limit(10, outputsUnsafe) + val preparedPlan = TestSQLContext.prepareForExecution.execute(plan) + assert(preparedPlan.children.head.isInstanceOf[ConvertToSafe]) + } + + test("filter can process unsafe rows") { + val plan = Filter(IsNull(null), outputsUnsafe) + val preparedPlan = TestSQLContext.prepareForExecution.execute(plan) + assert(getConverters(preparedPlan).isEmpty) + assert(preparedPlan.outputsUnsafeRows) + } + + test("filter can process safe rows") { + val plan = Filter(IsNull(null), outputsSafe) + val preparedPlan = TestSQLContext.prepareForExecution.execute(plan) + assert(getConverters(preparedPlan).isEmpty) + assert(!preparedPlan.outputsUnsafeRows) + } + + test("execute() fails an assertion if inputs rows are of different formats") { + val e = intercept[AssertionError] { + Union(Seq(outputsSafe, outputsUnsafe)).execute() + } + assert(e.getMessage.contains("format")) + } + + test("union requires all of its input rows' formats to agree") { + val plan = Union(Seq(outputsSafe, outputsUnsafe)) + assert(plan.canProcessSafeRows && plan.canProcessUnsafeRows) + val preparedPlan = TestSQLContext.prepareForExecution.execute(plan) + assert(preparedPlan.outputsUnsafeRows) + } + + test("union can process safe rows") { + val plan = Union(Seq(outputsSafe, outputsSafe)) + val preparedPlan = TestSQLContext.prepareForExecution.execute(plan) + assert(!preparedPlan.outputsUnsafeRows) + } + + test("union can process unsafe rows") { + val plan = Union(Seq(outputsUnsafe, outputsUnsafe)) + val preparedPlan = TestSQLContext.prepareForExecution.execute(plan) + assert(preparedPlan.outputsUnsafeRows) + } + + test("round trip with ConvertToUnsafe and ConvertToSafe") { + val input = Seq(("hello", 1), ("world", 2)) + checkAnswer( + TestSQLContext.createDataFrame(input), + plan => ConvertToSafe(ConvertToUnsafe(plan)), + input.map(Row.fromTuple) + ) + } +}