diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index e1c1740bccced..146ce3011b1f3 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -49,11 +49,14 @@ object Optimizer extends RuleExecutor[LogicalPlan] { object ColumnPruning extends Rule[LogicalPlan] { def apply(plan: LogicalPlan): LogicalPlan = plan transform { case a @ Aggregate(_, _, child) if (child.outputSet -- a.references).nonEmpty => + // Project away references that are not needed to calculate the required aggregates. a.copy(child = Project(a.references.toSeq, child)) case Project(projectList, Join(left, right, joinType, condition)) => + // Collect the list of off references required either above or to evaluate the condition. val allReferences: Set[Attribute] = projectList.flatMap(_.references).toSet ++ condition.map(_.references).getOrElse(Set.empty) + /** Applies a projection when the child is producing unnecessary attributes */ def prunedChild(c: LogicalPlan) = if ((allReferences.filter(c.outputSet.contains) -- c.outputSet).nonEmpty) { Project(allReferences.filter(c.outputSet.contains).toSeq, c) @@ -64,10 +67,16 @@ object ColumnPruning extends Rule[LogicalPlan] { Project(projectList, Join(prunedChild(left), prunedChild(right), joinType, condition)) case Project(project1, Project(project2, child)) => + // Create a map of Aliases to their values from the child projection. + // e.g., 'SELECT ... FROM (SELECT a + b AS c, d ...)' produces Map(c -> Alias(a + b, c)). val aliasMap = project2.collect { case a @ Alias(e, _) => (a.toAttribute: Expression, a) }.toMap - // TODO: Fix TransformBase. + + // Substitute any attributes that are produced by the child projection, so that we safely + // eliminate it. + // e.g., 'SELECT c + 1 FROM (SELECT a + b AS C ...' produces 'SELECT a + b + 1 ...' + // TODO: Fix TransformBase to avoid the cast below. val substitutedProjection = project1.map(_.transform { case a if aliasMap.contains(a) => aliasMap(a) }).asInstanceOf[Seq[NamedExpression]]