diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala index d518e07bfb62c..44b1670ff038b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -196,7 +196,7 @@ class Dataset[T] private[sql]( } // Wraps analyzed logical plans with an analysis barrier so we won't traverse/resolve it again. - @transient private val planWithBarrier = AnalysisBarrier(logicalPlan) + @transient private[sql] val planWithBarrier = AnalysisBarrier(logicalPlan) /** * Currently [[ExpressionEncoder]] is the only implementation of [[Encoder]], here we turn the diff --git a/sql/core/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala index 6bab21dca0cbd..36f6038aa9485 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala @@ -49,7 +49,7 @@ class KeyValueGroupedDataset[K, V] private[sql]( private implicit val kExprEnc = encoderFor(kEncoder) private implicit val vExprEnc = encoderFor(vEncoder) - private def logicalPlan = queryExecution.analyzed + private def logicalPlan = AnalysisBarrier(queryExecution.analyzed) private def sparkSession = queryExecution.sparkSession /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala index 6c2be3610ae30..c6449cd5a16b0 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala @@ -63,17 +63,17 @@ class RelationalGroupedDataset protected[sql]( groupType match { case RelationalGroupedDataset.GroupByType => Dataset.ofRows( - df.sparkSession, Aggregate(groupingExprs, aliasedAgg, df.logicalPlan)) + df.sparkSession, Aggregate(groupingExprs, aliasedAgg, df.planWithBarrier)) case RelationalGroupedDataset.RollupType => Dataset.ofRows( - df.sparkSession, Aggregate(Seq(Rollup(groupingExprs)), aliasedAgg, df.logicalPlan)) + df.sparkSession, Aggregate(Seq(Rollup(groupingExprs)), aliasedAgg, df.planWithBarrier)) case RelationalGroupedDataset.CubeType => Dataset.ofRows( - df.sparkSession, Aggregate(Seq(Cube(groupingExprs)), aliasedAgg, df.logicalPlan)) + df.sparkSession, Aggregate(Seq(Cube(groupingExprs)), aliasedAgg, df.planWithBarrier)) case RelationalGroupedDataset.PivotType(pivotCol, values) => val aliasedGrps = groupingExprs.map(alias) Dataset.ofRows( - df.sparkSession, Pivot(Some(aliasedGrps), pivotCol, values, aggExprs, df.logicalPlan)) + df.sparkSession, Pivot(Some(aliasedGrps), pivotCol, values, aggExprs, df.planWithBarrier)) } } @@ -433,7 +433,7 @@ class RelationalGroupedDataset protected[sql]( df.exprEnc.schema, groupingAttributes, df.logicalPlan.output, - df.logicalPlan)) + df.planWithBarrier)) } /** @@ -459,7 +459,7 @@ class RelationalGroupedDataset protected[sql]( case other => Alias(other, other.toString)() } val groupingAttributes = groupingNamedExpressions.map(_.toAttribute) - val child = df.logicalPlan + val child = df.planWithBarrier val project = Project(groupingNamedExpressions ++ child.output, child) val output = expr.dataType.asInstanceOf[StructType].toAttributes val plan = FlatMapGroupsInPandas(groupingAttributes, expr, output, project) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/GroupedDatasetSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/GroupedDatasetSuite.scala new file mode 100644 index 0000000000000..147c0b61f5017 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/GroupedDatasetSuite.scala @@ -0,0 +1,96 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql + +import org.apache.spark.api.python.PythonEvalType +import org.apache.spark.sql.catalyst.expressions.PythonUDF +import org.apache.spark.sql.catalyst.plans.logical.AnalysisBarrier +import org.apache.spark.sql.functions.udf +import org.apache.spark.sql.test.SharedSQLContext +import org.apache.spark.sql.types.{LongType, StructField, StructType} + +class GroupedDatasetSuite extends QueryTest with SharedSQLContext { + import testImplicits._ + + private val scalaUDF = udf((x: Long) => { x + 1 }) + private lazy val datasetWithUDF = spark.range(1).toDF("s").select($"s", scalaUDF($"s")) + + private def assertContainsAnalysisBarrier(ds: Dataset[_], atLevel: Int = 1): Unit = { + assert(atLevel >= 0) + var children = Seq(ds.queryExecution.logical) + (1 to atLevel).foreach { _ => + children = children.flatMap(_.children) + } + val barriers = children.collect { + case ab: AnalysisBarrier => ab + } + assert(barriers.nonEmpty, s"Plan does not contain AnalysisBarrier at level $atLevel:\n" + + ds.queryExecution.logical) + } + + test("SPARK-24373: avoid running Analyzer rules twice on RelationalGroupedDataset") { + val groupByDataset = datasetWithUDF.groupBy() + val rollupDataset = datasetWithUDF.rollup("s") + val cubeDataset = datasetWithUDF.cube("s") + val pivotDataset = datasetWithUDF.groupBy().pivot("s", Seq(1, 2)) + datasetWithUDF.cache() + Seq(groupByDataset, rollupDataset, cubeDataset, pivotDataset).foreach { rgDS => + val df = rgDS.count() + assertContainsAnalysisBarrier(df) + assertCached(df) + } + + val flatMapGroupsInRDF = datasetWithUDF.groupBy().flatMapGroupsInR( + Array.emptyByteArray, + Array.emptyByteArray, + Array.empty, + StructType(Seq(StructField("s", LongType)))) + val flatMapGroupsInPandasDF = datasetWithUDF.groupBy().flatMapGroupsInPandas(PythonUDF( + "pyUDF", + null, + StructType(Seq(StructField("s", LongType))), + Seq.empty, + PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF, + true)) + Seq(flatMapGroupsInRDF, flatMapGroupsInPandasDF).foreach { df => + assertContainsAnalysisBarrier(df, 2) + assertCached(df) + } + datasetWithUDF.unpersist(true) + } + + test("SPARK-24373: avoid running Analyzer rules twice on KeyValueGroupedDataset") { + val kvDasaset = datasetWithUDF.groupByKey(_.getLong(0)) + datasetWithUDF.cache() + val mapValuesKVDataset = kvDasaset.mapValues(_.getLong(0)).reduceGroups(_ + _) + val keysKVDataset = kvDasaset.keys + val flatMapGroupsKVDataset = kvDasaset.flatMapGroups((k, _) => Seq(k)) + val aggKVDataset = kvDasaset.count() + val otherKVDataset = spark.range(1).groupByKey(_ + 1) + val cogroupKVDataset = kvDasaset.cogroup(otherKVDataset)((k, _, _) => Seq(k)) + Seq((mapValuesKVDataset, 1), + (keysKVDataset, 2), + (flatMapGroupsKVDataset, 2), + (aggKVDataset, 1), + (cogroupKVDataset, 2)).foreach { case (df, analysisBarrierDepth) => + assertContainsAnalysisBarrier(df, analysisBarrierDepth) + assertCached(df) + } + datasetWithUDF.unpersist(true) + } +}