Skip to content

Commit

Permalink
[GLUTEN-4668][CH] Merge two phase hash-based aggregate into one aggre…
Browse files Browse the repository at this point in the history
…gate in the spark plan when there is no shuffle

Examples:

 HashAggregate(t1.i, SUM, final)
                |                  =>    HashAggregate(t1.i, SUM, complete)
 HashAggregate(t1.i, SUM, partial)

now this feature only support for CH backend.

Close apache#4668.

Co-authored-by: lgbo <[email protected]>
  • Loading branch information
zzcclp and lgbo-ustc committed Feb 8, 2024
1 parent 97bd0f6 commit 35c2555
Show file tree
Hide file tree
Showing 21 changed files with 579 additions and 89 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -282,4 +282,6 @@ object CHBackendSettings extends BackendSettingsApi with Logging {
override def enableNativeWriteFiles(): Boolean = {
GlutenConfig.getConf.enableNativeWriter.getOrElse(false)
}

override def mergeTwoPhasesHashBaseAggregateIfNeed(): Boolean = true
}
Original file line number Diff line number Diff line change
Expand Up @@ -109,16 +109,25 @@ case class CHHashAggregateExecTransformer(
val typeList = new util.ArrayList[TypeNode]()
val nameList = new util.ArrayList[String]()
val (inputAttrs, outputAttrs) = {
if (modes.isEmpty) {
// When there is no aggregate function, it does not need
if (modes.isEmpty || modes.forall(_ == Complete)) {
// When there is no aggregate function or complete mode, it does not need
// to handle outputs according to the AggregateMode
for (attr <- child.output) {
typeList.add(ConverterUtils.getTypeNode(attr.dataType, attr.nullable))
nameList.add(ConverterUtils.genColumnNameWithExprId(attr))
nameList.addAll(ConverterUtils.collectStructFieldNames(attr.dataType))
}
(child.output, output)
} else if (!modes.contains(Partial)) {
} else if (modes.forall(_ == Partial)) {
// partial mode
for (attr <- child.output) {
typeList.add(ConverterUtils.getTypeNode(attr.dataType, attr.nullable))
nameList.add(ConverterUtils.genColumnNameWithExprId(attr))
nameList.addAll(ConverterUtils.collectStructFieldNames(attr.dataType))
}

(child.output, aggregateResultAttributes)
} else {
// non-partial mode
var resultAttrIndex = 0
for (attr <- aggregateResultAttributes) {
Expand All @@ -135,15 +144,6 @@ case class CHHashAggregateExecTransformer(
resultAttrIndex += 1
}
(aggregateResultAttributes, output)
} else {
// partial mode
for (attr <- child.output) {
typeList.add(ConverterUtils.getTypeNode(attr.dataType, attr.nullable))
nameList.add(ConverterUtils.genColumnNameWithExprId(attr))
nameList.addAll(ConverterUtils.collectStructFieldNames(attr.dataType))
}

(child.output, aggregateResultAttributes)
}
}

Expand Down Expand Up @@ -212,7 +212,7 @@ case class CHHashAggregateExecTransformer(
val aggregateFunc = aggExpr.aggregateFunction
val childrenNodeList = new util.ArrayList[ExpressionNode]()
val childrenNodes = aggExpr.mode match {
case Partial =>
case Partial | Complete =>
aggregateFunc.children.toList.map(
expr => {
ExpressionConverter
Expand Down Expand Up @@ -446,7 +446,7 @@ case class CHHashAggregateExecPullOutHelper(
}
resIndex += aggBufferAttr.size
resIndex
case Final =>
case Final | Complete =>
aggregateAttr += aggregateAttributeList(resIndex)
resIndex += 1
resIndex
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,13 @@ class GlutenClickHouseColumnarShuffleAQESuite
}

test("TPCH Q18") {
runTPCHQuery(18) { df => }
runTPCHQuery(18) {
df =>
val hashAggregates = collect(df.queryExecution.executedPlan) {
case hash: HashAggregateExecBaseTransformer => hash
}
assert(hashAggregates.size == 3)
}
}

test("TPCH Q19") {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,13 @@ class GlutenClickHouseDSV2ColumnarShuffleSuite extends GlutenClickHouseTPCHAbstr
}

test("TPCH Q3") {
runTPCHQuery(3) { df => }
runTPCHQuery(3) {
df =>
val aggs = df.queryExecution.executedPlan.collectWithSubqueries {
case agg: HashAggregateExecBaseTransformer => agg
}
assert(aggs.size == 1)
}
}

test("TPCH Q4") {
Expand Down Expand Up @@ -74,43 +80,91 @@ class GlutenClickHouseDSV2ColumnarShuffleSuite extends GlutenClickHouseTPCHAbstr
}

test("TPCH Q11") {
runTPCHQuery(11) { df => }
runTPCHQuery(11) {
df =>
val aggs = df.queryExecution.executedPlan.collectWithSubqueries {
case agg: HashAggregateExecBaseTransformer => agg
}
assert(aggs.size == 3)
}
}

test("TPCH Q12") {
runTPCHQuery(12) { df => }
}

test("TPCH Q13") {
runTPCHQuery(13) { df => }
runTPCHQuery(13) {
df =>
val aggs = df.queryExecution.executedPlan.collectWithSubqueries {
case agg: HashAggregateExecBaseTransformer => agg
}
assert(aggs.size == 3)
}
}

test("TPCH Q14") {
runTPCHQuery(14) { df => }
runTPCHQuery(14) {
df =>
val aggs = df.queryExecution.executedPlan.collectWithSubqueries {
case agg: HashAggregateExecBaseTransformer => agg
}
assert(aggs.size == 1)
}
}

test("TPCH Q15") {
runTPCHQuery(15) { df => }
runTPCHQuery(15) {
df =>
val aggs = df.queryExecution.executedPlan.collectWithSubqueries {
case agg: HashAggregateExecBaseTransformer => agg
}
assert(aggs.size == 4)
}
}

test("TPCH Q16") {
runTPCHQuery(16, noFallBack = false) { df => }
}

test("TPCH Q17") {
runTPCHQuery(17) { df => }
runTPCHQuery(17) {
df =>
val aggs = df.queryExecution.executedPlan.collectWithSubqueries {
case agg: HashAggregateExecBaseTransformer => agg
}
assert(aggs.size == 3)
}
}

test("TPCH Q18") {
runTPCHQuery(18) { df => }
runTPCHQuery(18) {
df =>
val aggs = df.queryExecution.executedPlan.collectWithSubqueries {
case agg: HashAggregateExecBaseTransformer => agg
}
assert(aggs.size == 4)
}
}

test("TPCH Q19") {
runTPCHQuery(19) { df => }
runTPCHQuery(19) {
df =>
val aggs = df.queryExecution.executedPlan.collectWithSubqueries {
case agg: HashAggregateExecBaseTransformer => agg
}
assert(aggs.size == 1)
}
}

test("TPCH Q20") {
runTPCHQuery(20) { df => }
runTPCHQuery(20) {
df =>
val aggs = df.queryExecution.executedPlan.collectWithSubqueries {
case agg: HashAggregateExecBaseTransformer => agg
}
assert(aggs.size == 1)
}
}

test("TPCH Q21") {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,6 @@ class GlutenClickHouseNativeWriteTableSuite
.set("spark.gluten.sql.enable.native.validation", "false")
// TODO: support default ANSI policy
.set("spark.sql.storeAssignmentPolicy", "legacy")
// .set("spark.gluten.sql.columnar.backend.ch.runtime_config.logger.level", "debug")
.set("spark.sql.warehouse.dir", getWarehouseDir)
.setMaster("local[1]")
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -46,8 +46,6 @@ class GlutenClickHouseTPCDSParquetSuite extends GlutenClickHouseTPCDSAbstractSui
.set("spark.memory.offHeap.size", "4g")
.set("spark.gluten.sql.validation.logLevel", "ERROR")
.set("spark.gluten.sql.validation.printStackOnFailure", "true")
// .set("spark.gluten.sql.columnar.backend.ch.runtime_config.logger.level", "debug")
// .setMaster("local[1]")
}

executeTPCDSTest(false)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
package io.glutenproject.execution

import org.apache.spark.{SPARK_VERSION_SHORT, SparkConf}
import org.apache.spark.sql.DataFrame
import org.apache.spark.sql.execution.InputIteratorTransformer
import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper

Expand Down Expand Up @@ -566,5 +567,108 @@ class GlutenClickHouseTPCHParquetBucketSuite
df => {}
)
}

test("GLUTEN-4668: Merge two phase hash-based aggregate into one aggregate") {
def checkHashAggregateCount(df: DataFrame, expectedCount: Int): Unit = {
val plans = collect(df.queryExecution.executedPlan) {
case agg: HashAggregateExecBaseTransformer => agg
}
assert(plans.size == expectedCount)
}

val SQL =
"""
|select l_orderkey, l_returnflag, collect_list(l_linenumber) as t
|from lineitem group by l_orderkey, l_returnflag
|order by l_orderkey, l_returnflag, t limit 100
|""".stripMargin
compareResultsAgainstVanillaSpark(
SQL,
true,
df => { checkHashAggregateCount(df, 1) }
)

val SQL1 =
"""
|select l_orderkey, l_returnflag,
|sum(l_linenumber) as t,
|count(l_linenumber) as t1,
|min(l_linenumber) as t2
|from lineitem group by l_orderkey, l_returnflag
|order by l_orderkey, l_returnflag, t, t1, t2 limit 100
|""".stripMargin
compareResultsAgainstVanillaSpark(
SQL1,
true,
df => { checkHashAggregateCount(df, 1) }
)

val SQL2 =
"""
|select l_returnflag, l_orderkey, collect_list(l_linenumber) as t
|from lineitem group by l_orderkey, l_returnflag
|order by l_returnflag, l_orderkey, t limit 100
|""".stripMargin
compareResultsAgainstVanillaSpark(
SQL2,
true,
df => { checkHashAggregateCount(df, 1) }
)

// will merge four aggregates into two one.
val SQL3 =
"""
|select l_returnflag, l_orderkey,
|count(distinct l_linenumber) as t
|from lineitem group by l_orderkey, l_returnflag
|order by l_returnflag, l_orderkey, t limit 100
|""".stripMargin
compareResultsAgainstVanillaSpark(
SQL3,
true,
df => { checkHashAggregateCount(df, 2) }
)

// not support when there are more than one count distinct
val SQL4 =
"""
|select l_returnflag, l_orderkey,
|count(distinct l_linenumber) as t,
|count(distinct l_discount) as t1
|from lineitem group by l_orderkey, l_returnflag
|order by l_returnflag, l_orderkey, t, t1 limit 100
|""".stripMargin
compareResultsAgainstVanillaSpark(
SQL4,
true,
df => { checkHashAggregateCount(df, 4) }
)

val SQL5 =
"""
|select l_returnflag, l_orderkey,
|count(distinct l_linenumber) as t,
|sum(l_linenumber) as t1
|from lineitem group by l_orderkey, l_returnflag
|order by l_returnflag, l_orderkey, t, t1 limit 100
|""".stripMargin
compareResultsAgainstVanillaSpark(
SQL5,
true,
df => { checkHashAggregateCount(df, 4) }
)

val SQL6 =
"""
|select count(1) from lineitem
|""".stripMargin
compareResultsAgainstVanillaSpark(
SQL6,
true,
df => {
checkHashAggregateCount(df, 2)
}
)
}
}
// scalastyle:on line.size.limit
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,6 @@ class GlutenClickhouseFunctionSuite extends GlutenClickHouseTPCHAbstractSuite {
.set("spark.gluten.sql.enable.native.validation", "false")
// TODO: support default ANSI policy
.set("spark.sql.storeAssignmentPolicy", "legacy")
// .set("spark.gluten.sql.columnar.backend.ch.runtime_config.logger.level", "debug")
.set("spark.sql.warehouse.dir", warehouse)
.setMaster("local[1]")
}
Expand Down
Loading

0 comments on commit 35c2555

Please sign in to comment.