Skip to content

Commit

Permalink
[SPARK-12978][SQL] Skip unnecessary final group-by when input data al…
Browse files Browse the repository at this point in the history
…ready clustered with group-by keys

This ticket targets the optimization to skip an unnecessary group-by operation below;

Without opt.:
```
== Physical Plan ==
TungstenAggregate(key=[col0#159], functions=[(sum(col1#160),mode=Final,isDistinct=false),(avg(col2#161),mode=Final,isDistinct=false)], output=[col0#159,sum(col1)#177,avg(col2)#178])
+- TungstenAggregate(key=[col0#159], functions=[(sum(col1#160),mode=Partial,isDistinct=false),(avg(col2#161),mode=Partial,isDistinct=false)], output=[col0#159,sum#200,sum#201,count#202L])
   +- TungstenExchange hashpartitioning(col0#159,200), None
      +- InMemoryColumnarTableScan [col0#159,col1#160,col2#161], InMemoryRelation [col0#159,col1#160,col2#161], true, 10000, StorageLevel(true, true, false, true, 1), ConvertToUnsafe, None
```

With opt.:
```
== Physical Plan ==
TungstenAggregate(key=[col0#159], functions=[(sum(col1#160),mode=Complete,isDistinct=false),(avg(col2#161),mode=Final,isDistinct=false)], output=[col0#159,sum(col1)#177,avg(col2)#178])
+- TungstenExchange hashpartitioning(col0#159,200), None
  +- InMemoryColumnarTableScan [col0#159,col1#160,col2#161], InMemoryRelation [col0#159,col1#160,col2#161], true, 10000, StorageLevel(true, true, false, true, 1), ConvertToUnsafe, None
```

Author: Takeshi YAMAMURO <[email protected]>

Closes #10896 from maropu/SkipGroupbySpike.
  • Loading branch information
maropu authored and hvanhovell committed Aug 25, 2016
1 parent 6b8cb1f commit 2b0cc4e
Show file tree
Hide file tree
Showing 8 changed files with 257 additions and 224 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -259,24 +259,17 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
}

val aggregateOperator =
if (aggregateExpressions.map(_.aggregateFunction).exists(!_.supportsPartial)) {
if (functionsWithDistinct.nonEmpty) {
sys.error("Distinct columns cannot exist in Aggregate operator containing " +
"aggregate functions which don't support partial aggregation.")
} else {
aggregate.AggUtils.planAggregateWithoutPartial(
groupingExpressions,
aggregateExpressions,
resultExpressions,
planLater(child))
}
} else if (functionsWithDistinct.isEmpty) {
if (functionsWithDistinct.isEmpty) {
aggregate.AggUtils.planAggregateWithoutDistinct(
groupingExpressions,
aggregateExpressions,
resultExpressions,
planLater(child))
} else {
if (aggregateExpressions.map(_.aggregateFunction).exists(!_.supportsPartial)) {
sys.error("Distinct columns cannot exist in Aggregate operator containing " +
"aggregate functions which don't support partial aggregation.")
}
aggregate.AggUtils.planAggregateWithOneDistinct(
groupingExpressions,
functionsWithDistinct,
Expand Down

Large diffs are not rendered by default.

Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.sql.execution.aggregate

import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression
import org.apache.spark.sql.catalyst.plans.physical._
import org.apache.spark.sql.execution.SparkPlan
import org.apache.spark.sql.execution.UnaryExecNode

/**
* A base class for aggregate implementation.
*/
abstract class AggregateExec extends UnaryExecNode {

def requiredChildDistributionExpressions: Option[Seq[Expression]]
def groupingExpressions: Seq[NamedExpression]
def aggregateExpressions: Seq[AggregateExpression]
def aggregateAttributes: Seq[Attribute]
def initialInputBufferOffset: Int
def resultExpressions: Seq[NamedExpression]

protected[this] val aggregateBufferAttributes = {
aggregateExpressions.flatMap(_.aggregateFunction.aggBufferAttributes)
}

override def producedAttributes: AttributeSet =
AttributeSet(aggregateAttributes) ++
AttributeSet(resultExpressions.diff(groupingExpressions).map(_.toAttribute)) ++
AttributeSet(aggregateBufferAttributes)

override def output: Seq[Attribute] = resultExpressions.map(_.toAttribute)

override def requiredChildDistribution: List[Distribution] = {
requiredChildDistributionExpressions match {
case Some(exprs) if exprs.isEmpty => AllTuples :: Nil
case Some(exprs) if exprs.nonEmpty => ClusteredDistribution(exprs) :: Nil
case None => UnspecifiedDistribution :: Nil
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@ import org.apache.spark.sql.catalyst.errors._
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.aggregate._
import org.apache.spark.sql.catalyst.expressions.codegen._
import org.apache.spark.sql.catalyst.plans.physical._
import org.apache.spark.sql.execution._
import org.apache.spark.sql.execution.metric.{SQLMetric, SQLMetrics}
import org.apache.spark.sql.types.{DecimalType, StringType, StructType}
Expand All @@ -42,11 +41,7 @@ case class HashAggregateExec(
initialInputBufferOffset: Int,
resultExpressions: Seq[NamedExpression],
child: SparkPlan)
extends UnaryExecNode with CodegenSupport {

private[this] val aggregateBufferAttributes = {
aggregateExpressions.flatMap(_.aggregateFunction.aggBufferAttributes)
}
extends AggregateExec with CodegenSupport {

require(HashAggregateExec.supportsAggregate(aggregateBufferAttributes))

Expand All @@ -60,21 +55,6 @@ case class HashAggregateExec(
"spillSize" -> SQLMetrics.createSizeMetric(sparkContext, "spill size"),
"aggTime" -> SQLMetrics.createTimingMetric(sparkContext, "aggregate time"))

override def output: Seq[Attribute] = resultExpressions.map(_.toAttribute)

override def producedAttributes: AttributeSet =
AttributeSet(aggregateAttributes) ++
AttributeSet(resultExpressions.diff(groupingExpressions).map(_.toAttribute)) ++
AttributeSet(aggregateBufferAttributes)

override def requiredChildDistribution: List[Distribution] = {
requiredChildDistributionExpressions match {
case Some(exprs) if exprs.isEmpty => AllTuples :: Nil
case Some(exprs) if exprs.nonEmpty => ClusteredDistribution(exprs) :: Nil
case None => UnspecifiedDistribution :: Nil
}
}

// This is for testing. We force TungstenAggregationIterator to fall back to the unsafe row hash
// map and/or the sort-based aggregation once it has processed a given number of input rows.
private val testFallbackStartsAt: Option[(Int, Int)] = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,7 @@ import org.apache.spark.sql.catalyst.errors._
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.aggregate._
import org.apache.spark.sql.catalyst.plans.physical.{AllTuples, ClusteredDistribution, Distribution, UnspecifiedDistribution}
import org.apache.spark.sql.execution.{SparkPlan, UnaryExecNode}
import org.apache.spark.sql.execution.SparkPlan
import org.apache.spark.sql.execution.metric.SQLMetrics
import org.apache.spark.util.Utils

Expand All @@ -38,30 +37,11 @@ case class SortAggregateExec(
initialInputBufferOffset: Int,
resultExpressions: Seq[NamedExpression],
child: SparkPlan)
extends UnaryExecNode {

private[this] val aggregateBufferAttributes = {
aggregateExpressions.flatMap(_.aggregateFunction.aggBufferAttributes)
}

override def producedAttributes: AttributeSet =
AttributeSet(aggregateAttributes) ++
AttributeSet(resultExpressions.diff(groupingExpressions).map(_.toAttribute)) ++
AttributeSet(aggregateBufferAttributes)
extends AggregateExec {

override lazy val metrics = Map(
"numOutputRows" -> SQLMetrics.createMetric(sparkContext, "number of output rows"))

override def output: Seq[Attribute] = resultExpressions.map(_.toAttribute)

override def requiredChildDistribution: List[Distribution] = {
requiredChildDistributionExpressions match {
case Some(exprs) if exprs.isEmpty => AllTuples :: Nil
case Some(exprs) if exprs.nonEmpty => ClusteredDistribution(exprs) :: Nil
case None => UnspecifiedDistribution :: Nil
}
}

override def requiredChildOrdering: Seq[Seq[SortOrder]] = {
groupingExpressions.map(SortOrder(_, Ascending)) :: Nil
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@ import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.physical._
import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.execution._
import org.apache.spark.sql.execution.aggregate.AggUtils
import org.apache.spark.sql.execution.aggregate.PartialAggregate
import org.apache.spark.sql.internal.SQLConf

/**
Expand Down Expand Up @@ -151,18 +153,30 @@ case class EnsureRequirements(conf: SQLConf) extends Rule[SparkPlan] {
private def ensureDistributionAndOrdering(operator: SparkPlan): SparkPlan = {
val requiredChildDistributions: Seq[Distribution] = operator.requiredChildDistribution
val requiredChildOrderings: Seq[Seq[SortOrder]] = operator.requiredChildOrdering
var children: Seq[SparkPlan] = operator.children
assert(requiredChildDistributions.length == children.length)
assert(requiredChildOrderings.length == children.length)
assert(requiredChildDistributions.length == operator.children.length)
assert(requiredChildOrderings.length == operator.children.length)

// Ensure that the operator's children satisfy their output distribution requirements:
children = children.zip(requiredChildDistributions).map {
case (child, distribution) if child.outputPartitioning.satisfies(distribution) =>
child
case (child, BroadcastDistribution(mode)) =>
BroadcastExchangeExec(mode, child)
case (child, distribution) =>
ShuffleExchange(createPartitioning(distribution, defaultNumPreShufflePartitions), child)
def createShuffleExchange(dist: Distribution, child: SparkPlan) =
ShuffleExchange(createPartitioning(dist, defaultNumPreShufflePartitions), child)

var (parent, children) = operator match {
case PartialAggregate(childDist) if !operator.outputPartitioning.satisfies(childDist) =>
// If an aggregation needs a shuffle and support partial aggregations, a map-side partial
// aggregation and a shuffle are added as children.
val (mergeAgg, mapSideAgg) = AggUtils.createMapMergeAggregatePair(operator)
(mergeAgg, createShuffleExchange(requiredChildDistributions.head, mapSideAgg) :: Nil)
case _ =>
// Ensure that the operator's children satisfy their output distribution requirements:
val childrenWithDist = operator.children.zip(requiredChildDistributions)
val newChildren = childrenWithDist.map {
case (child, distribution) if child.outputPartitioning.satisfies(distribution) =>
child
case (child, BroadcastDistribution(mode)) =>
BroadcastExchangeExec(mode, child)
case (child, distribution) =>
createShuffleExchange(distribution, child)
}
(operator, newChildren)
}

// If the operator has multiple children and specifies child output distributions (e.g. join),
Expand Down Expand Up @@ -246,7 +260,7 @@ case class EnsureRequirements(conf: SQLConf) extends Rule[SparkPlan] {
}
}

operator.withNewChildren(children)
parent.withNewChildren(children)
}

def apply(plan: SparkPlan): SparkPlan = plan.transformUp {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1248,17 +1248,17 @@ class DataFrameSuite extends QueryTest with SharedSQLContext {
}

/**
* Verifies that there is no Exchange between the Aggregations for `df`
* Verifies that there is a single Aggregation for `df`
*/
private def verifyNonExchangingAgg(df: DataFrame) = {
private def verifyNonExchangingSingleAgg(df: DataFrame) = {
var atFirstAgg: Boolean = false
df.queryExecution.executedPlan.foreach {
case agg: HashAggregateExec =>
atFirstAgg = !atFirstAgg
case _ =>
if (atFirstAgg) {
fail("Should not have operators between the two aggregations")
fail("Should not have back to back Aggregates")
}
atFirstAgg = true
case _ =>
}
}

Expand Down Expand Up @@ -1292,9 +1292,10 @@ class DataFrameSuite extends QueryTest with SharedSQLContext {
// Group by the column we are distributed by. This should generate a plan with no exchange
// between the aggregates
val df3 = testData.repartition($"key").groupBy("key").count()
verifyNonExchangingAgg(df3)
verifyNonExchangingAgg(testData.repartition($"key", $"value")
verifyNonExchangingSingleAgg(df3)
verifyNonExchangingSingleAgg(testData.repartition($"key", $"value")
.groupBy("key", "value").count())
verifyNonExchangingSingleAgg(testData.repartition($"key").groupBy("key", "value").count())

// Grouping by just the first distributeBy expr, need to exchange.
verifyExchangingAgg(testData.repartition($"key", $"value")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
package org.apache.spark.sql.execution

import org.apache.spark.rdd.RDD
import org.apache.spark.sql.{execution, Row}
import org.apache.spark.sql.{execution, DataFrame, Row}
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.{Ascending, Attribute, Literal, SortOrder}
import org.apache.spark.sql.catalyst.plans.Inner
Expand All @@ -37,36 +37,65 @@ class PlannerSuite extends SharedSQLContext {

setupTestData()

private def testPartialAggregationPlan(query: LogicalPlan): Unit = {
private def testPartialAggregationPlan(query: LogicalPlan): Seq[SparkPlan] = {
val planner = spark.sessionState.planner
import planner._
val plannedOption = Aggregation(query).headOption
val planned =
plannedOption.getOrElse(
fail(s"Could query play aggregation query $query. Is it an aggregation query?"))
val aggregations = planned.collect { case n if n.nodeName contains "Aggregate" => n }

// For the new aggregation code path, there will be four aggregate operator for
// distinct aggregations.
assert(
aggregations.size == 2 || aggregations.size == 4,
s"The plan of query $query does not have partial aggregations.")
val ensureRequirements = EnsureRequirements(spark.sessionState.conf)
val planned = Aggregation(query).headOption.map(ensureRequirements(_))
.getOrElse(fail(s"Could query play aggregation query $query. Is it an aggregation query?"))
planned.collect { case n if n.nodeName contains "Aggregate" => n }
}

test("count is partially aggregated") {
val query = testData.groupBy('value).agg(count('key)).queryExecution.analyzed
testPartialAggregationPlan(query)
assert(testPartialAggregationPlan(query).size == 2,
s"The plan of query $query does not have partial aggregations.")
}

test("count distinct is partially aggregated") {
val query = testData.groupBy('value).agg(countDistinct('key)).queryExecution.analyzed
testPartialAggregationPlan(query)
// For the new aggregation code path, there will be four aggregate operator for distinct
// aggregations.
assert(testPartialAggregationPlan(query).size == 4,
s"The plan of query $query does not have partial aggregations.")
}

test("mixed aggregates are partially aggregated") {
val query =
testData.groupBy('value).agg(count('value), countDistinct('key)).queryExecution.analyzed
testPartialAggregationPlan(query)
// For the new aggregation code path, there will be four aggregate operator for distinct
// aggregations.
assert(testPartialAggregationPlan(query).size == 4,
s"The plan of query $query does not have partial aggregations.")
}

test("non-partial aggregation for aggregates") {
withTempView("testNonPartialAggregation") {
val schema = StructType(StructField(s"value", IntegerType, true) :: Nil)
val row = Row.fromSeq(Seq.fill(1)(null))
val rowRDD = sparkContext.parallelize(row :: Nil)
spark.createDataFrame(rowRDD, schema).repartition($"value")
.createOrReplaceTempView("testNonPartialAggregation")

val planned1 = sql("SELECT SUM(value) FROM testNonPartialAggregation GROUP BY value")
.queryExecution.executedPlan

// If input data are already partitioned and the same columns are used in grouping keys and
// aggregation values, no partial aggregation exist in query plans.
val aggOps1 = planned1.collect { case n if n.nodeName contains "Aggregate" => n }
assert(aggOps1.size == 1, s"The plan $planned1 has partial aggregations.")

val planned2 = sql(
"""
|SELECT t.value, SUM(DISTINCT t.value)
|FROM (SELECT * FROM testNonPartialAggregation ORDER BY value) t
|GROUP BY t.value
""".stripMargin).queryExecution.executedPlan

val aggOps2 = planned1.collect { case n if n.nodeName contains "Aggregate" => n }
assert(aggOps2.size == 1, s"The plan $planned2 has partial aggregations.")
}
}

test("sizeInBytes estimation of limit operator for broadcast hash join optimization") {
Expand Down

0 comments on commit 2b0cc4e

Please sign in to comment.