From 750fe74765377a340f3be5a07d3ff4e419b962d3 Mon Sep 17 00:00:00 2001 From: Dongjoon Hyun Date: Thu, 21 Apr 2016 17:08:32 -0700 Subject: [PATCH 1/3] [SPARK-14830][SQL] Add RemoveRepetitionFromGroupExpressions optimizer. --- .../spark/sql/catalyst/optimizer/Optimizer.scala | 15 ++++++++++++++- .../optimizer/AggregateOptimizeSuite.scala | 14 +++++++++++++- 2 files changed, 27 insertions(+), 2 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index 0b70edec8e37a..4e568668d538c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -68,7 +68,8 @@ abstract class Optimizer(sessionCatalog: SessionCatalog, conf: CatalystConf) ReplaceExceptWithAntiJoin, ReplaceDistinctWithAggregate) :: Batch("Aggregate", fixedPoint, - RemoveLiteralFromGroupExpressions) :: + RemoveLiteralFromGroupExpressions, + RemoveRepetitionFromGroupExpressions) :: Batch("Operator Optimizations", fixedPoint, // Operator push down SetOperationPushDown, @@ -1439,6 +1440,18 @@ object RemoveLiteralFromGroupExpressions extends Rule[LogicalPlan] { } } +/** + * Removes repetition from group expressions in [[Aggregate]], as they have no effect to the result + * but only makes the grouping key bigger. + */ +object RemoveRepetitionFromGroupExpressions extends Rule[LogicalPlan] { + def apply(plan: LogicalPlan): LogicalPlan = plan transform { + case a @ Aggregate(grouping, _, _) => + val newGrouping = grouping.distinct + a.copy(groupingExpressions = newGrouping) + } +} + /** * Computes the current date and time to make sure we return the same result in a single query. */ diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/AggregateOptimizeSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/AggregateOptimizeSuite.scala index e458eb8a1d362..20d0d2d554fc8 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/AggregateOptimizeSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/AggregateOptimizeSuite.scala @@ -28,7 +28,8 @@ class AggregateOptimizeSuite extends PlanTest { object Optimize extends RuleExecutor[LogicalPlan] { val batches = Batch("Aggregate", FixedPoint(100), - RemoveLiteralFromGroupExpressions) :: Nil + RemoveLiteralFromGroupExpressions, + RemoveRepetitionFromGroupExpressions) :: Nil } test("remove literals in grouping expression") { @@ -42,4 +43,15 @@ class AggregateOptimizeSuite extends PlanTest { comparePlans(optimized, correctAnswer) } + + test("remove repetition in grouping expression") { + val input = LocalRelation('a.int, 'b.int) + + val query = input.groupBy('a, 'a)(sum('b)) + val optimized = Optimize.execute(query) + + val correctAnswer = input.groupBy('a)(sum('b)) + + comparePlans(optimized, correctAnswer) + } } From f72fd67d297afb9e5d45c9c48884989f3eaa0e52 Mon Sep 17 00:00:00 2001 From: Dongjoon Hyun Date: Thu, 28 Apr 2016 13:13:30 -0700 Subject: [PATCH 2/3] update testcases. --- .../sql/catalyst/optimizer/AggregateOptimizeSuite.scala | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/AggregateOptimizeSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/AggregateOptimizeSuite.scala index 20d0d2d554fc8..80587f234a4dc 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/AggregateOptimizeSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/AggregateOptimizeSuite.scala @@ -45,12 +45,12 @@ class AggregateOptimizeSuite extends PlanTest { } test("remove repetition in grouping expression") { - val input = LocalRelation('a.int, 'b.int) + val input = LocalRelation('a.int, 'b.int, 'c.int) - val query = input.groupBy('a, 'a)(sum('b)) + val query = input.groupBy('a, 'b, 'b, 'a)(sum('c)) val optimized = Optimize.execute(query) - val correctAnswer = input.groupBy('a)(sum('b)) + val correctAnswer = input.groupBy('a, 'b)(sum('c)) comparePlans(optimized, correctAnswer) } From 2198f0f73a863a801b0e6da9fdcc15908550ba4d Mon Sep 17 00:00:00 2001 From: Dongjoon Hyun Date: Mon, 2 May 2016 11:04:37 -0700 Subject: [PATCH 3/3] Use ExpressionSet. --- .../spark/sql/catalyst/optimizer/Optimizer.scala | 2 +- .../catalyst/optimizer/AggregateOptimizeSuite.scala | 12 +++++++++--- 2 files changed, 10 insertions(+), 4 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index 4e568668d538c..a147fff274139 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -1447,7 +1447,7 @@ object RemoveLiteralFromGroupExpressions extends Rule[LogicalPlan] { object RemoveRepetitionFromGroupExpressions extends Rule[LogicalPlan] { def apply(plan: LogicalPlan): LogicalPlan = plan transform { case a @ Aggregate(grouping, _, _) => - val newGrouping = grouping.distinct + val newGrouping = ExpressionSet(grouping).toSeq a.copy(groupingExpressions = newGrouping) } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/AggregateOptimizeSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/AggregateOptimizeSuite.scala index 80587f234a4dc..c94dcb33546f8 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/AggregateOptimizeSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/AggregateOptimizeSuite.scala @@ -17,6 +17,9 @@ package org.apache.spark.sql.catalyst.optimizer +import org.apache.spark.sql.catalyst.SimpleCatalystConf +import org.apache.spark.sql.catalyst.analysis.{Analyzer, EmptyFunctionRegistry} +import org.apache.spark.sql.catalyst.catalog.{InMemoryCatalog, SessionCatalog} import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.dsl.plans._ import org.apache.spark.sql.catalyst.expressions.Literal @@ -25,6 +28,9 @@ import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan} import org.apache.spark.sql.catalyst.rules.RuleExecutor class AggregateOptimizeSuite extends PlanTest { + val conf = new SimpleCatalystConf(caseSensitiveAnalysis = false) + val catalog = new SessionCatalog(new InMemoryCatalog, EmptyFunctionRegistry, conf) + val analyzer = new Analyzer(catalog, conf) object Optimize extends RuleExecutor[LogicalPlan] { val batches = Batch("Aggregate", FixedPoint(100), @@ -47,10 +53,10 @@ class AggregateOptimizeSuite extends PlanTest { test("remove repetition in grouping expression") { val input = LocalRelation('a.int, 'b.int, 'c.int) - val query = input.groupBy('a, 'b, 'b, 'a)(sum('c)) - val optimized = Optimize.execute(query) + val query = input.groupBy('a + 1, 'b + 2, Literal(1) + 'A, Literal(2) + 'B)(sum('c)) + val optimized = Optimize.execute(analyzer.execute(query)) - val correctAnswer = input.groupBy('a, 'b)(sum('c)) + val correctAnswer = analyzer.execute(input.groupBy('a + 1, 'b + 2)(sum('c))) comparePlans(optimized, correctAnswer) }