diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index 843ae3816f061..a83d06ad0821e 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -466,6 +466,14 @@ def nanvl(col1, col2): return Column(sc._jvm.functions.nanvl(_to_java_column(col1), _to_java_column(col2))) +@since(2.2) +def no_collapse(df): + """Marks a DataFrame as non-collapsible.""" + + sc = SparkContext._active_spark_context + return DataFrame(sc._jvm.functions.no_collapse(df._jdf), df.sql_ctx) + + @since(1.4) def rand(seed=None): """Generates a random column with independent and identically distributed (i.i.d.) samples diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala index 3ad757ebba851..656b0bba1865a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala @@ -386,6 +386,13 @@ case class BroadcastHint(child: LogicalPlan) extends UnaryNode { child.stats(conf).copy(isBroadcastable = true) } +/** + * A hint for the optimizer that we should not merge two projections. + */ +case class NoCollapseHint(child: LogicalPlan) extends UnaryNode { + override def output: Seq[Attribute] = child.output +} + /** * A general hint for the child. This node will be eliminated post analysis. * A pair of (name, parameters). diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/CollapseProjectSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/CollapseProjectSuite.scala index 587437e9aa81d..d639f34a6eae8 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/CollapseProjectSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/CollapseProjectSuite.scala @@ -22,7 +22,7 @@ import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.dsl.plans._ import org.apache.spark.sql.catalyst.expressions.Rand import org.apache.spark.sql.catalyst.plans.PlanTest -import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan} +import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan, NoCollapseHint} import org.apache.spark.sql.catalyst.rules.RuleExecutor class CollapseProjectSuite extends PlanTest { @@ -119,4 +119,14 @@ class CollapseProjectSuite extends PlanTest { comparePlans(optimized, correctAnswer) } + + test("do not collapse projects with onceOnly expressions") { + val query = NoCollapseHint(testRelation.select(('a * 10).as('a_times_10))) + .select(('a_times_10 + 1).as('a_times_10_plus_1), ('a_times_10 + 2).as('a_times_10_plus_2)) + + val optimized = Optimize.execute(query.analyze) + val correctAnswer = query.analyze + + comparePlans(optimized, correctAnswer) + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index ca2f6dd7a84b2..8a7b4fc88ade4 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -433,6 +433,7 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] { case r: LogicalRDD => RDDScanExec(r.output, r.rdd, "ExistingRDD", r.outputPartitioning, r.outputOrdering) :: Nil case BroadcastHint(child) => planLater(child) :: Nil + case NoCollapseHint(child) => planLater(child) :: Nil case _ => Nil } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index f07e04368389f..589be14a500f2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -29,7 +29,7 @@ import org.apache.spark.sql.catalyst.analysis.{Star, UnresolvedFunction} import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate._ -import org.apache.spark.sql.catalyst.plans.logical.BroadcastHint +import org.apache.spark.sql.catalyst.plans.logical.{BroadcastHint, NoCollapseHint} import org.apache.spark.sql.execution.SparkSqlParser import org.apache.spark.sql.expressions.UserDefinedFunction import org.apache.spark.sql.internal.SQLConf @@ -1022,6 +1022,22 @@ object functions { Dataset[T](df.sparkSession, BroadcastHint(df.logicalPlan))(df.exprEnc) } + /** + * Marks a DataFrame as non-collapsible. + * + * For example: + * {{{ + * df1 = no_collapse(df.select((df.col("qty") * lit(10).alias("c1"))) + * df2 = df1.select(col("c1") + lit(1)), col("c1") + lit(2))) + * }}} + * + * @group normal_funcs + * @since 2.2.0 + */ + def no_collapse[T](df: Dataset[T]): Dataset[T] = { + Dataset[T](df.sparkSession, NoCollapseHint(df.logicalPlan))(df.exprEnc) + } + /** * Returns the first column that is not null, or null if all inputs are null. *