From 30e743364313d4b81c99de8f9a7170f5bca2771c Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Thu, 12 Nov 2015 08:14:08 -0800 Subject: [PATCH] [SPARK-11673][SQL] Remove the normal Project physical operator (and keep TungstenProject) Also make full outer join being able to produce UnsafeRows. Author: Reynold Xin Closes #9643 from rxin/SPARK-11673. --- .../execution/UnsafeExternalRowSorter.java | 7 --- .../scala/org/apache/spark/sql/Encoder.scala | 8 +-- .../spark/sql/catalyst/ScalaReflection.scala | 6 +-- .../expressions/EquivalentExpressions.scala | 2 +- .../sql/catalyst/expressions/Projection.scala | 10 ---- .../apache/spark/sql/execution/Exchange.scala | 5 +- .../spark/sql/execution/SparkPlanner.scala | 6 +-- .../spark/sql/execution/SparkStrategies.scala | 14 +---- .../apache/spark/sql/execution/Window.scala | 6 +-- .../aggregate/SortBasedAggregate.scala | 6 +-- .../aggregate/TungstenAggregate.scala | 3 +- .../spark/sql/execution/basicOperators.scala | 26 --------- .../datasources/DataSourceStrategy.scala | 3 +- .../spark/sql/execution/joins/HashJoin.scala | 31 +++-------- .../sql/execution/joins/HashOuterJoin.scala | 54 +++++++------------ .../sql/execution/joins/HashSemiJoin.scala | 25 ++------- .../sql/execution/joins/SortMergeJoin.scala | 41 +++----------- .../execution/joins/SortMergeOuterJoin.scala | 38 +++---------- .../execution/local/BinaryHashJoinNode.scala | 6 +-- .../sql/execution/local/HashJoinNode.scala | 21 ++------ .../sql/execution/rowFormatConverters.scala | 18 ++----- .../org/apache/spark/sql/execution/sort.scala | 9 ---- .../spark/sql/ColumnExpressionSuite.scala | 3 +- .../sql/execution/TungstenSortSuite.scala | 1 - .../execution/local/HashJoinNodeSuite.scala | 10 +--- .../execution/HiveTypeCoercionSuite.scala | 6 ++- .../ParquetHadoopFsRelationSuite.scala | 2 +- 27 files changed, 80 insertions(+), 287 deletions(-) diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/execution/UnsafeExternalRowSorter.java b/sql/catalyst/src/main/java/org/apache/spark/sql/execution/UnsafeExternalRowSorter.java index f7063d1e5c829..3986d6e18f770 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/execution/UnsafeExternalRowSorter.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/execution/UnsafeExternalRowSorter.java @@ -170,13 +170,6 @@ public Iterator sort(Iterator inputIterator) throws IOExce return sort(); } - /** - * Return true if UnsafeExternalRowSorter can sort rows with the given schema, false otherwise. - */ - public static boolean supportsSchema(StructType schema) { - return UnsafeProjection.canSupport(schema); - } - private static final class RowComparator extends RecordComparator { private final Ordering ordering; private final int numFields; diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/Encoder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/Encoder.scala index 1ff7340557e6c..6134f9e036638 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/Encoder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/Encoder.scala @@ -17,13 +17,13 @@ package org.apache.spark.sql +import scala.reflect.ClassTag + import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.types.{ObjectType, StructField, StructType} import org.apache.spark.util.Utils -import scala.reflect.ClassTag - /** * Used to convert a JVM object of type `T` to and from the internal Spark SQL representation. * @@ -123,9 +123,9 @@ object Encoders { new ExpressionEncoder[Any]( schema, - false, + flat = false, extractExpressions, constructExpression, - ClassTag.apply(cls)) + ClassTag(cls)) } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala index 0b8a8abd02d67..6d822261b050a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala @@ -153,18 +153,18 @@ trait ScalaReflection { */ def constructorFor[T : TypeTag]: Expression = constructorFor(typeOf[T], None) - protected def constructorFor( + private def constructorFor( tpe: `Type`, path: Option[Expression]): Expression = ScalaReflectionLock.synchronized { /** Returns the current path with a sub-field extracted. */ - def addToPath(part: String) = + def addToPath(part: String): Expression = path .map(p => UnresolvedExtractValue(p, expressions.Literal(part))) .getOrElse(UnresolvedAttribute(part)) /** Returns the current path with a field at ordinal extracted. */ - def addToPathOrdinal(ordinal: Int, dataType: DataType) = + def addToPathOrdinal(ordinal: Int, dataType: DataType): Expression = path .map(p => GetStructField(p, StructField(s"_$ordinal", dataType), ordinal)) .getOrElse(BoundReference(ordinal, dataType, false)) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/EquivalentExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/EquivalentExpressions.scala index f83df494ba8a6..f7162e420d19a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/EquivalentExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/EquivalentExpressions.scala @@ -77,7 +77,7 @@ class EquivalentExpressions { * an empty collection if there are none. */ def getEquivalentExprs(e: Expression): Seq[Expression] = { - equivalenceMap.get(Expr(e)).getOrElse(mutable.MutableList()) + equivalenceMap.getOrElse(Expr(e), mutable.MutableList()) } /** diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala index 9f0b7821ae74a..053e612f3ecb5 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala @@ -102,16 +102,6 @@ abstract class UnsafeProjection extends Projection { object UnsafeProjection { - /* - * Returns whether UnsafeProjection can support given StructType, Array[DataType] or - * Seq[Expression]. - */ - def canSupport(schema: StructType): Boolean = canSupport(schema.fields.map(_.dataType)) - def canSupport(exprs: Seq[Expression]): Boolean = canSupport(exprs.map(_.dataType).toArray) - private def canSupport(types: Array[DataType]): Boolean = { - types.forall(GenerateUnsafeProjection.canSupport) - } - /** * Returns an UnsafeProjection for given StructType. */ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala index d0e4e068092f9..bc252d98e7144 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala @@ -57,10 +57,7 @@ case class Exchange( /** * Returns true iff we can support the data type, and we are not doing range partitioning. */ - private lazy val tungstenMode: Boolean = { - GenerateUnsafeProjection.canSupport(child.schema) && - !newPartitioning.isInstanceOf[RangePartitioning] - } + private lazy val tungstenMode: Boolean = !newPartitioning.isInstanceOf[RangePartitioning] override def outputPartitioning: Partitioning = newPartitioning diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlanner.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlanner.scala index cf482ae4a05ee..b7c5476346b2a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlanner.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlanner.scala @@ -18,12 +18,10 @@ package org.apache.spark.sql.execution import org.apache.spark.SparkContext -import org.apache.spark.annotation.Experimental import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.execution.datasources.DataSourceStrategy -@Experimental class SparkPlanner(val sqlContext: SQLContext) extends SparkStrategies { val sparkContext: SparkContext = sqlContext.sparkContext @@ -64,7 +62,7 @@ class SparkPlanner(val sqlContext: SQLContext) extends SparkStrategies { val projectSet = AttributeSet(projectList.flatMap(_.references)) val filterSet = AttributeSet(filterPredicates.flatMap(_.references)) - val filterCondition = + val filterCondition: Option[Expression] = prunePushedDownFilters(filterPredicates).reduceLeftOption(catalyst.expressions.And) // Right now we still use a projection even if the only evaluation is applying an alias @@ -82,7 +80,7 @@ class SparkPlanner(val sqlContext: SQLContext) extends SparkStrategies { filterCondition.map(Filter(_, scan)).getOrElse(scan) } else { val scan = scanBuilder((projectSet ++ filterSet).toSeq) - Project(projectList, filterCondition.map(Filter(_, scan)).getOrElse(scan)) + TungstenProject(projectList, filterCondition.map(Filter(_, scan)).getOrElse(scan)) } } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index 90989f2cee9a6..a99ae4674bb12 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -309,11 +309,7 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { * if necessary. */ def getSortOperator(sortExprs: Seq[SortOrder], global: Boolean, child: SparkPlan): SparkPlan = { - if (TungstenSort.supportsSchema(child.schema)) { - execution.TungstenSort(sortExprs, global, child) - } else { - execution.Sort(sortExprs, global, child) - } + execution.TungstenSort(sortExprs, global, child) } def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match { @@ -347,13 +343,7 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { case logical.Sort(sortExprs, global, child) => getSortOperator(sortExprs, global, planLater(child)):: Nil case logical.Project(projectList, child) => - // If unsafe mode is enabled and we support these data types in Unsafe, use the - // Tungsten project. Otherwise, use the normal project. - if (UnsafeProjection.canSupport(projectList) && UnsafeProjection.canSupport(child.schema)) { - execution.TungstenProject(projectList, planLater(child)) :: Nil - } else { - execution.Project(projectList, planLater(child)) :: Nil - } + execution.TungstenProject(projectList, planLater(child)) :: Nil case logical.Filter(condition, child) => execution.Filter(condition, planLater(child)) :: Nil case e @ logical.Expand(_, _, child) => diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/Window.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/Window.scala index 53c5ccf8fa37e..b1280c32a6a43 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/Window.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/Window.scala @@ -247,11 +247,7 @@ case class Window( // Get all relevant projections. val result = createResultProjection(unboundExpressions) - val grouping = if (child.outputsUnsafeRows) { - UnsafeProjection.create(partitionSpec, child.output) - } else { - newProjection(partitionSpec, child.output) - } + val grouping = UnsafeProjection.create(partitionSpec, child.output) // Manage the stream and the grouping. var nextRow: InternalRow = EmptyRow diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortBasedAggregate.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortBasedAggregate.scala index fb7f30c2aec99..c8ccbb933df61 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortBasedAggregate.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortBasedAggregate.scala @@ -78,11 +78,9 @@ case class SortBasedAggregate( // so return an empty iterator. Iterator[InternalRow]() } else { - val groupingKeyProjection = if (UnsafeProjection.canSupport(groupingExpressions)) { + val groupingKeyProjection = UnsafeProjection.create(groupingExpressions, child.output) - } else { - newMutableProjection(groupingExpressions, child.output)() - } + val outputIter = new SortBasedAggregationIterator( groupingKeyProjection, groupingExpressions.map(_.toAttribute), diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala index 1edde1e5a16d9..920de615e1d86 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala @@ -139,7 +139,6 @@ object TungstenAggregate { groupingExpressions: Seq[Expression], aggregateBufferAttributes: Seq[Attribute]): Boolean = { val aggregationBufferSchema = StructType.fromAttributes(aggregateBufferAttributes) - UnsafeFixedWidthAggregationMap.supportsAggregationBufferSchema(aggregationBufferSchema) && - UnsafeProjection.canSupport(groupingExpressions) + UnsafeFixedWidthAggregationMap.supportsAggregationBufferSchema(aggregationBufferSchema) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala index 303d636164adb..ae08fb71bf4cb 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala @@ -30,32 +30,6 @@ import org.apache.spark.util.random.PoissonSampler import org.apache.spark.{HashPartitioner, SparkEnv} -case class Project(projectList: Seq[NamedExpression], child: SparkPlan) extends UnaryNode { - override def output: Seq[Attribute] = projectList.map(_.toAttribute) - - override private[sql] lazy val metrics = Map( - "numRows" -> SQLMetrics.createLongMetric(sparkContext, "number of rows")) - - @transient lazy val buildProjection = newMutableProjection(projectList, child.output) - - protected override def doExecute(): RDD[InternalRow] = { - val numRows = longMetric("numRows") - child.execute().mapPartitions { iter => - val reusableProjection = buildProjection() - iter.map { row => - numRows += 1 - reusableProjection(row) - } - } - } - - override def outputOrdering: Seq[SortOrder] = child.outputOrdering -} - - -/** - * A variant of [[Project]] that returns [[UnsafeRow]]s. - */ case class TungstenProject(projectList: Seq[NamedExpression], child: SparkPlan) extends UnaryNode { override private[sql] lazy val metrics = Map( diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala index d7c01b6e6f07e..824c89a90eb8a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala @@ -343,7 +343,8 @@ private[sql] object DataSourceStrategy extends Strategy with Logging { requestedColumns, scanBuilder(requestedColumns, candidatePredicates, pushedFilters), relation.relation) - execution.Project(projects, filterCondition.map(execution.Filter(_, scan)).getOrElse(scan)) + execution.TungstenProject( + projects, filterCondition.map(execution.Filter(_, scan)).getOrElse(scan)) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala index 997f7f494f4a3..fb961d97c3c3c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala @@ -44,27 +44,15 @@ trait HashJoin { override def output: Seq[Attribute] = left.output ++ right.output - protected[this] def isUnsafeMode: Boolean = { - UnsafeProjection.canSupport(buildKeys) && UnsafeProjection.canSupport(self.schema) - } - - override def outputsUnsafeRows: Boolean = isUnsafeMode - override def canProcessUnsafeRows: Boolean = isUnsafeMode - override def canProcessSafeRows: Boolean = !isUnsafeMode + override def outputsUnsafeRows: Boolean = true + override def canProcessUnsafeRows: Boolean = true + override def canProcessSafeRows: Boolean = false protected def buildSideKeyGenerator: Projection = - if (isUnsafeMode) { - UnsafeProjection.create(buildKeys, buildPlan.output) - } else { - newMutableProjection(buildKeys, buildPlan.output)() - } + UnsafeProjection.create(buildKeys, buildPlan.output) protected def streamSideKeyGenerator: Projection = - if (isUnsafeMode) { - UnsafeProjection.create(streamedKeys, streamedPlan.output) - } else { - newMutableProjection(streamedKeys, streamedPlan.output)() - } + UnsafeProjection.create(streamedKeys, streamedPlan.output) protected def hashJoin( streamIter: Iterator[InternalRow], @@ -79,13 +67,8 @@ trait HashJoin { // Mutable per row objects. private[this] val joinRow = new JoinedRow - private[this] val resultProjection: (InternalRow) => InternalRow = { - if (isUnsafeMode) { - UnsafeProjection.create(self.schema) - } else { - identity[InternalRow] - } - } + private[this] val resultProjection: (InternalRow) => InternalRow = + UnsafeProjection.create(self.schema) private[this] val joinKeys = streamSideKeyGenerator diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashOuterJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashOuterJoin.scala index 3633f356b014b..ed626fef56af7 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashOuterJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashOuterJoin.scala @@ -64,38 +64,18 @@ trait HashOuterJoin { s"HashOuterJoin should not take $x as the JoinType") } - protected[this] def isUnsafeMode: Boolean = { - joinType != FullOuter && - UnsafeProjection.canSupport(buildKeys) && - UnsafeProjection.canSupport(self.schema) - } - - override def outputsUnsafeRows: Boolean = isUnsafeMode - override def canProcessUnsafeRows: Boolean = isUnsafeMode - override def canProcessSafeRows: Boolean = !isUnsafeMode + override def outputsUnsafeRows: Boolean = true + override def canProcessUnsafeRows: Boolean = true + override def canProcessSafeRows: Boolean = false protected def buildKeyGenerator: Projection = - if (isUnsafeMode) { - UnsafeProjection.create(buildKeys, buildPlan.output) - } else { - newMutableProjection(buildKeys, buildPlan.output)() - } + UnsafeProjection.create(buildKeys, buildPlan.output) - protected[this] def streamedKeyGenerator: Projection = { - if (isUnsafeMode) { - UnsafeProjection.create(streamedKeys, streamedPlan.output) - } else { - newProjection(streamedKeys, streamedPlan.output) - } - } + protected[this] def streamedKeyGenerator: Projection = + UnsafeProjection.create(streamedKeys, streamedPlan.output) - protected[this] def resultProjection: InternalRow => InternalRow = { - if (isUnsafeMode) { - UnsafeProjection.create(self.schema) - } else { - identity[InternalRow] - } - } + protected[this] def resultProjection: InternalRow => InternalRow = + UnsafeProjection.create(self.schema) @transient private[this] lazy val DUMMY_LIST = CompactBuffer[InternalRow](null) @transient protected[this] lazy val EMPTY_LIST = CompactBuffer[InternalRow]() @@ -173,8 +153,12 @@ trait HashOuterJoin { } protected[this] def fullOuterIterator( - key: InternalRow, leftIter: Iterable[InternalRow], rightIter: Iterable[InternalRow], - joinedRow: JoinedRow, numOutputRows: LongSQLMetric): Iterator[InternalRow] = { + key: InternalRow, + leftIter: Iterable[InternalRow], + rightIter: Iterable[InternalRow], + joinedRow: JoinedRow, + resultProjection: InternalRow => InternalRow, + numOutputRows: LongSQLMetric): Iterator[InternalRow] = { if (!key.anyNull) { // Store the positions of records in right, if one of its associated row satisfy // the join condition. @@ -191,7 +175,7 @@ trait HashOuterJoin { matched = true // if the row satisfy the join condition, add its index into the matched set rightMatchedSet.add(idx) - joinedRow.copy() + resultProjection(joinedRow) } ++ DUMMY_LIST.filter(_ => !matched).map( _ => { // 2. For those unmatched records in left, append additional records with empty right. @@ -201,7 +185,7 @@ trait HashOuterJoin { // of the records in right side. // If we didn't get any proper row, then append a single row with empty right. numOutputRows += 1 - joinedRow.withRight(rightNullRow).copy() + resultProjection(joinedRow.withRight(rightNullRow)) }) } ++ rightIter.zipWithIndex.collect { // 3. For those unmatched records in right, append additional records with empty left. @@ -210,15 +194,15 @@ trait HashOuterJoin { // in the matched set. case (r, idx) if !rightMatchedSet.contains(idx) => numOutputRows += 1 - joinedRow(leftNullRow, r).copy() + resultProjection(joinedRow(leftNullRow, r)) } } else { leftIter.iterator.map[InternalRow] { l => numOutputRows += 1 - joinedRow(l, rightNullRow).copy() + resultProjection(joinedRow(l, rightNullRow)) } ++ rightIter.iterator.map[InternalRow] { r => numOutputRows += 1 - joinedRow(leftNullRow, r).copy() + resultProjection(joinedRow(leftNullRow, r)) } } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashSemiJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashSemiJoin.scala index c7d13e0a72a87..f23a1830e91c1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashSemiJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashSemiJoin.scala @@ -33,30 +33,15 @@ trait HashSemiJoin { override def output: Seq[Attribute] = left.output - protected[this] def supportUnsafe: Boolean = { - UnsafeProjection.canSupport(leftKeys) && - UnsafeProjection.canSupport(rightKeys) && - UnsafeProjection.canSupport(left.schema) && - UnsafeProjection.canSupport(right.schema) - } - - override def outputsUnsafeRows: Boolean = supportUnsafe - override def canProcessUnsafeRows: Boolean = supportUnsafe - override def canProcessSafeRows: Boolean = !supportUnsafe + override def outputsUnsafeRows: Boolean = true + override def canProcessUnsafeRows: Boolean = true + override def canProcessSafeRows: Boolean = false protected def leftKeyGenerator: Projection = - if (supportUnsafe) { - UnsafeProjection.create(leftKeys, left.output) - } else { - newMutableProjection(leftKeys, left.output)() - } + UnsafeProjection.create(leftKeys, left.output) protected def rightKeyGenerator: Projection = - if (supportUnsafe) { - UnsafeProjection.create(rightKeys, right.output) - } else { - newMutableProjection(rightKeys, right.output)() - } + UnsafeProjection.create(rightKeys, right.output) @transient private lazy val boundCondition = newPredicate(condition.getOrElse(Literal(true)), left.output ++ right.output) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoin.scala index 7aee8e3dd3fce..4bf7b521c77d3 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoin.scala @@ -53,15 +53,9 @@ case class SortMergeJoin( override def requiredChildOrdering: Seq[Seq[SortOrder]] = requiredOrders(leftKeys) :: requiredOrders(rightKeys) :: Nil - protected[this] def isUnsafeMode: Boolean = { - UnsafeProjection.canSupport(leftKeys) && - UnsafeProjection.canSupport(rightKeys) && - UnsafeProjection.canSupport(schema) - } - - override def outputsUnsafeRows: Boolean = isUnsafeMode - override def canProcessUnsafeRows: Boolean = isUnsafeMode - override def canProcessSafeRows: Boolean = !isUnsafeMode + override def outputsUnsafeRows: Boolean = true + override def canProcessUnsafeRows: Boolean = true + override def canProcessSafeRows: Boolean = false private def requiredOrders(keys: Seq[Expression]): Seq[SortOrder] = { // This must be ascending in order to agree with the `keyOrdering` defined in `doExecute()`. @@ -76,26 +70,10 @@ case class SortMergeJoin( left.execute().zipPartitions(right.execute()) { (leftIter, rightIter) => new RowIterator { // The projection used to extract keys from input rows of the left child. - private[this] val leftKeyGenerator = { - if (isUnsafeMode) { - // It is very important to use UnsafeProjection if input rows are UnsafeRows. - // Otherwise, GenerateProjection will cause wrong results. - UnsafeProjection.create(leftKeys, left.output) - } else { - newProjection(leftKeys, left.output) - } - } + private[this] val leftKeyGenerator = UnsafeProjection.create(leftKeys, left.output) // The projection used to extract keys from input rows of the right child. - private[this] val rightKeyGenerator = { - if (isUnsafeMode) { - // It is very important to use UnsafeProjection if input rows are UnsafeRows. - // Otherwise, GenerateProjection will cause wrong results. - UnsafeProjection.create(rightKeys, right.output) - } else { - newProjection(rightKeys, right.output) - } - } + private[this] val rightKeyGenerator = UnsafeProjection.create(rightKeys, right.output) // An ordering that can be used to compare keys from both sides. private[this] val keyOrdering = newNaturalAscendingOrdering(leftKeys.map(_.dataType)) @@ -112,13 +90,8 @@ case class SortMergeJoin( numRightRows ) private[this] val joinRow = new JoinedRow - private[this] val resultProjection: (InternalRow) => InternalRow = { - if (isUnsafeMode) { - UnsafeProjection.create(schema) - } else { - identity[InternalRow] - } - } + private[this] val resultProjection: (InternalRow) => InternalRow = + UnsafeProjection.create(schema) override def advanceNext(): Boolean = { if (currentMatchIdx == -1 || currentMatchIdx == currentRightMatches.length) { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeOuterJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeOuterJoin.scala index 5f1590c463836..efaa69c1d3227 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeOuterJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeOuterJoin.scala @@ -89,31 +89,15 @@ case class SortMergeOuterJoin( keys.map(SortOrder(_, Ascending)) } - private def isUnsafeMode: Boolean = { - UnsafeProjection.canSupport(leftKeys) && - UnsafeProjection.canSupport(rightKeys) && - UnsafeProjection.canSupport(schema) - } - - override def outputsUnsafeRows: Boolean = isUnsafeMode - override def canProcessUnsafeRows: Boolean = isUnsafeMode - override def canProcessSafeRows: Boolean = !isUnsafeMode + override def outputsUnsafeRows: Boolean = true + override def canProcessUnsafeRows: Boolean = true + override def canProcessSafeRows: Boolean = false - private def createLeftKeyGenerator(): Projection = { - if (isUnsafeMode) { - UnsafeProjection.create(leftKeys, left.output) - } else { - newProjection(leftKeys, left.output) - } - } + private def createLeftKeyGenerator(): Projection = + UnsafeProjection.create(leftKeys, left.output) - private def createRightKeyGenerator(): Projection = { - if (isUnsafeMode) { - UnsafeProjection.create(rightKeys, right.output) - } else { - newProjection(rightKeys, right.output) - } - } + private def createRightKeyGenerator(): Projection = + UnsafeProjection.create(rightKeys, right.output) override def doExecute(): RDD[InternalRow] = { val numLeftRows = longMetric("numLeftRows") @@ -130,13 +114,7 @@ case class SortMergeOuterJoin( (r: InternalRow) => true } } - val resultProj: InternalRow => InternalRow = { - if (isUnsafeMode) { - UnsafeProjection.create(schema) - } else { - identity[InternalRow] - } - } + val resultProj: InternalRow => InternalRow = UnsafeProjection.create(schema) joinType match { case LeftOuter => diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/local/BinaryHashJoinNode.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/local/BinaryHashJoinNode.scala index 52dcb9e43c4e8..3dcef94095647 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/local/BinaryHashJoinNode.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/local/BinaryHashJoinNode.scala @@ -50,11 +50,7 @@ case class BinaryHashJoinNode( private def buildSideKeyGenerator: Projection = { // We are expecting the data types of buildKeys and streamedKeys are the same. assert(buildKeys.map(_.dataType) == streamedKeys.map(_.dataType)) - if (isUnsafeMode) { - UnsafeProjection.create(buildKeys, buildNode.output) - } else { - newMutableProjection(buildKeys, buildNode.output)() - } + UnsafeProjection.create(buildKeys, buildNode.output) } protected override def doOpen(): Unit = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/local/HashJoinNode.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/local/HashJoinNode.scala index aef655727fbbb..fd7948ffa9a9b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/local/HashJoinNode.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/local/HashJoinNode.scala @@ -45,17 +45,8 @@ trait HashJoinNode { private[this] var hashed: HashedRelation = _ private[this] var joinKeys: Projection = _ - protected def isUnsafeMode: Boolean = { - UnsafeProjection.canSupport(schema) && UnsafeProjection.canSupport(streamedKeys) - } - - private def streamSideKeyGenerator: Projection = { - if (isUnsafeMode) { - UnsafeProjection.create(streamedKeys, streamedNode.output) - } else { - newMutableProjection(streamedKeys, streamedNode.output)() - } - } + private def streamSideKeyGenerator: Projection = + UnsafeProjection.create(streamedKeys, streamedNode.output) /** * Sets the HashedRelation used by this node. This method needs to be called after @@ -73,13 +64,7 @@ trait HashJoinNode { override def open(): Unit = { doOpen() joinRow = new JoinedRow - resultProjection = { - if (isUnsafeMode) { - UnsafeProjection.create(schema) - } else { - identity[InternalRow] - } - } + resultProjection = UnsafeProjection.create(schema) joinKeys = streamSideKeyGenerator } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/rowFormatConverters.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/rowFormatConverters.scala index 0e601cd2cab5d..5f8fc2de8b46d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/rowFormatConverters.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/rowFormatConverters.scala @@ -28,8 +28,6 @@ import org.apache.spark.sql.catalyst.rules.Rule */ case class ConvertToUnsafe(child: SparkPlan) extends UnaryNode { - require(UnsafeProjection.canSupport(child.schema), s"Cannot convert ${child.schema} to Unsafe") - override def output: Seq[Attribute] = child.output override def outputPartitioning: Partitioning = child.outputPartitioning override def outputOrdering: Seq[SortOrder] = child.outputOrdering @@ -97,18 +95,10 @@ private[sql] object EnsureRowFormats extends Rule[SparkPlan] { case operator: SparkPlan if handlesBothSafeAndUnsafeRows(operator) => if (operator.children.map(_.outputsUnsafeRows).toSet.size != 1) { // If this operator's children produce both unsafe and safe rows, - // convert everything unsafe rows if all the schema of them are support by UnsafeRow - if (operator.children.forall(c => UnsafeProjection.canSupport(c.schema))) { - operator.withNewChildren { - operator.children.map { - c => if (!c.outputsUnsafeRows) ConvertToUnsafe(c) else c - } - } - } else { - operator.withNewChildren { - operator.children.map { - c => if (c.outputsUnsafeRows) ConvertToSafe(c) else c - } + // convert everything unsafe rows. + operator.withNewChildren { + operator.children.map { + c => if (!c.outputsUnsafeRows) ConvertToUnsafe(c) else c } } } else { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/sort.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/sort.scala index 1a3832a698b61..47fe70ab154ec 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/sort.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/sort.scala @@ -145,12 +145,3 @@ case class TungstenSort( } } - -object TungstenSort { - /** - * Return true if UnsafeExternalSort can sort rows with the given schema, false otherwise. - */ - def supportsSchema(schema: StructType): Boolean = { - UnsafeExternalRowSorter.supportsSchema(schema) - } -} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala index fa559c9c64005..010df2a341589 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala @@ -20,7 +20,7 @@ package org.apache.spark.sql import org.apache.spark.sql.catalyst.expressions.NamedExpression import org.scalatest.Matchers._ -import org.apache.spark.sql.execution.{Project, TungstenProject} +import org.apache.spark.sql.execution.TungstenProject import org.apache.spark.sql.functions._ import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.types._ @@ -615,7 +615,6 @@ class ColumnExpressionSuite extends QueryTest with SharedSQLContext { def checkNumProjects(df: DataFrame, expectedNumProjects: Int): Unit = { val projects = df.queryExecution.executedPlan.collect { - case project: Project => project case tungstenProject: TungstenProject => tungstenProject } assert(projects.size === expectedNumProjects) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/TungstenSortSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/TungstenSortSuite.scala index 85486c08894c9..7c860d1d58d5a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/TungstenSortSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/TungstenSortSuite.scala @@ -74,7 +74,6 @@ class TungstenSortSuite extends SparkPlanTest with SharedSQLContext { sparkContext.parallelize(Random.shuffle(inputData).map(v => Row(v))), StructType(StructField("a", dataType, nullable = true) :: Nil) ) - assert(TungstenSort.supportsSchema(inputDf.schema)) checkThatPlansAgree( inputDf, plan => ConvertToSafe( diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/local/HashJoinNodeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/local/HashJoinNodeSuite.scala index 44b0d9d4102a1..c30327185e169 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/local/HashJoinNodeSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/local/HashJoinNodeSuite.scala @@ -42,15 +42,7 @@ class HashJoinNodeSuite extends LocalNodeTest { buildKeys: Seq[Expression], buildNode: LocalNode): HashedRelation = { - val isUnsafeMode = UnsafeProjection.canSupport(buildKeys) - - val buildSideKeyGenerator = - if (isUnsafeMode) { - UnsafeProjection.create(buildKeys, buildNode.output) - } else { - new InterpretedMutableProjection(buildKeys, buildNode.output) - } - + val buildSideKeyGenerator = UnsafeProjection.create(buildKeys, buildNode.output) buildNode.prepare() buildNode.open() val hashedRelation = HashedRelation(buildNode, buildSideKeyGenerator) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveTypeCoercionSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveTypeCoercionSuite.scala index 197e9bfb02c4e..4cf4e13890294 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveTypeCoercionSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveTypeCoercionSuite.scala @@ -18,7 +18,7 @@ package org.apache.spark.sql.hive.execution import org.apache.spark.sql.catalyst.expressions.{Cast, EqualTo} -import org.apache.spark.sql.execution.Project +import org.apache.spark.sql.execution.TungstenProject import org.apache.spark.sql.hive.test.TestHive /** @@ -43,7 +43,9 @@ class HiveTypeCoercionSuite extends HiveComparisonTest { test("[SPARK-2210] boolean cast on boolean value should be removed") { val q = "select cast(cast(key=0 as boolean) as boolean) from src" - val project = TestHive.sql(q).queryExecution.executedPlan.collect { case e: Project => e }.head + val project = TestHive.sql(q).queryExecution.executedPlan.collect { + case e: TungstenProject => e + }.head // No cast expression introduced project.transformAllExpressions { case c: Cast => diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/sources/ParquetHadoopFsRelationSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/sources/ParquetHadoopFsRelationSuite.scala index e866493ee6c96..b6db6225805a1 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/sources/ParquetHadoopFsRelationSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/sources/ParquetHadoopFsRelationSuite.scala @@ -151,7 +151,7 @@ class ParquetHadoopFsRelationSuite extends HadoopFsRelationTest { val df = sqlContext.read.parquet(path).filter('a === 0).select('b) val physicalPlan = df.queryExecution.executedPlan - assert(physicalPlan.collect { case p: execution.Project => p }.length === 1) + assert(physicalPlan.collect { case p: execution.TungstenProject => p }.length === 1) assert(physicalPlan.collect { case p: execution.Filter => p }.length === 1) } }