From 2f4e7b9438cc8b3184b585175cc6e635f75fd1a8 Mon Sep 17 00:00:00 2001 From: Michael Armbrust Date: Thu, 10 Apr 2014 03:07:26 -0700 Subject: [PATCH] Improve column pruning in the optimizer. --- .../sql/catalyst/optimizer/Optimizer.scala | 42 ++++++++++++++++++- .../plans/logical/basicOperators.scala | 2 +- 2 files changed, 42 insertions(+), 2 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index 37b23ba58289c..e1c1740bccced 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -33,7 +33,47 @@ object Optimizer extends RuleExecutor[LogicalPlan] { Batch("Filter Pushdown", Once, CombineFilters, PushPredicateThroughProject, - PushPredicateThroughInnerJoin) :: Nil + PushPredicateThroughInnerJoin, + ColumnPruning) :: Nil +} + +/** + * Attempts to eliminate the reading of unneeded columns from the query plan using the following + * transformations: + * + * - Inserting Projections beneath the following operators: + * - Aggregate + * - Project <- Join + * - Collapse adjacent projections, performing alias substitution. + */ +object ColumnPruning extends Rule[LogicalPlan] { + def apply(plan: LogicalPlan): LogicalPlan = plan transform { + case a @ Aggregate(_, _, child) if (child.outputSet -- a.references).nonEmpty => + a.copy(child = Project(a.references.toSeq, child)) + + case Project(projectList, Join(left, right, joinType, condition)) => + val allReferences: Set[Attribute] = + projectList.flatMap(_.references).toSet ++ condition.map(_.references).getOrElse(Set.empty) + def prunedChild(c: LogicalPlan) = + if ((allReferences.filter(c.outputSet.contains) -- c.outputSet).nonEmpty) { + Project(allReferences.filter(c.outputSet.contains).toSeq, c) + } else { + c + } + + Project(projectList, Join(prunedChild(left), prunedChild(right), joinType, condition)) + + case Project(project1, Project(project2, child)) => + val aliasMap = project2.collect { + case a @ Alias(e, _) => (a.toAttribute: Expression, a) + }.toMap + // TODO: Fix TransformBase. + val substitutedProjection = project1.map(_.transform { + case a if aliasMap.contains(a) => aliasMap(a) + }).asInstanceOf[Seq[NamedExpression]] + + Project(substitutedProjection, child) + } } /** diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala index cfc0b0c3a8d98..397473e178867 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala @@ -127,7 +127,7 @@ case class Aggregate( extends UnaryNode { def output = aggregateExpressions.map(_.toAttribute) - def references = child.references + def references = (groupingExpressions ++ aggregateExpressions).flatMap(_.references).toSet } case class Limit(limit: Expression, child: LogicalPlan) extends UnaryNode {