From 51e527074e987d5f5913c6dfc4f7fb8d62b40688 Mon Sep 17 00:00:00 2001 From: Zhichao Zhang Date: Wed, 24 Jan 2024 14:45:32 +0800 Subject: [PATCH] [GLUTEN-4668][CH] Merge two phase hash-based aggregate into one aggregate in the spark plan when there is no shuffle Examples: HashAggregate(t1.i, SUM, final) | => HashAggregate(t1.i, SUM, complete) HashAggregate(t1.i, SUM, partial) now this feature only support for CH backend. Close #4668. Co-authored-by: lgbo --- .../backendsapi/clickhouse/CHBackend.scala | 2 + .../CHHashAggregateExecTransformer.scala | 8 +- ...tenClickHouseColumnarShuffleAQESuite.scala | 8 +- ...enClickHouseDSV2ColumnarShuffleSuite.scala | 72 +++++- ...lutenClickHouseNativeWriteTableSuite.scala | 1 - .../GlutenClickHouseTPCDSParquetSuite.scala | 2 - .../GlutenClickHouseTPCHBucketSuite.scala | 228 ++++++++++++++++++ ...utenClickHouseTPCHParquetBucketSuite.scala | 144 +++++++++++ .../GlutenClickhouseFunctionSuite.scala | 1 - .../Operator/GraceMergingAggregatedStep.cpp | 148 ++++++++---- .../Operator/GraceMergingAggregatedStep.h | 23 +- .../Parser/AggregateFunctionParser.cpp | 17 +- .../Parser/AggregateFunctionParser.h | 2 +- .../Parser/AggregateRelParser.cpp | 130 +++++++++- .../local-engine/Parser/AggregateRelParser.h | 3 + cpp-ch/local-engine/Parser/TypeParser.cpp | 6 + cpp-ch/local-engine/Parser/TypeParser.h | 1 + .../local-engine/Parser/WindowRelParser.cpp | 2 +- .../BloomFilterAggParser.cpp | 2 +- .../CollectListParser.h | 2 +- .../expression/AggregateFunctionNode.java | 3 + .../backendsapi/BackendSettingsApi.scala | 3 + .../HashAggregateExecBaseTransformer.scala | 9 +- .../extension/ColumnarOverrides.scala | 1 + .../MergeTwoPhasesHashAggregate.scala | 158 ++++++++++++ .../GlutenReplaceHashWithSortAggSuite.scala | 7 +- .../clickhouse/ClickHouseTestSettings.scala | 1 + 27 files changed, 897 insertions(+), 87 deletions(-) create mode 100644 gluten-core/src/main/scala/io/glutenproject/extension/MergeTwoPhasesHashAggregate.scala diff --git a/backends-clickhouse/src/main/scala/io/glutenproject/backendsapi/clickhouse/CHBackend.scala b/backends-clickhouse/src/main/scala/io/glutenproject/backendsapi/clickhouse/CHBackend.scala index 7ba369b37a7b6..fbcb804a3e59d 100644 --- a/backends-clickhouse/src/main/scala/io/glutenproject/backendsapi/clickhouse/CHBackend.scala +++ b/backends-clickhouse/src/main/scala/io/glutenproject/backendsapi/clickhouse/CHBackend.scala @@ -282,4 +282,6 @@ object CHBackendSettings extends BackendSettingsApi with Logging { override def enableNativeWriteFiles(): Boolean = { GlutenConfig.getConf.enableNativeWriter.getOrElse(false) } + + override def mergeTwoPhasesHashBaseAggregateIfNeed(): Boolean = true } diff --git a/backends-clickhouse/src/main/scala/io/glutenproject/execution/CHHashAggregateExecTransformer.scala b/backends-clickhouse/src/main/scala/io/glutenproject/execution/CHHashAggregateExecTransformer.scala index 3a28615ab71f6..1131a349bdf1b 100644 --- a/backends-clickhouse/src/main/scala/io/glutenproject/execution/CHHashAggregateExecTransformer.scala +++ b/backends-clickhouse/src/main/scala/io/glutenproject/execution/CHHashAggregateExecTransformer.scala @@ -109,8 +109,8 @@ case class CHHashAggregateExecTransformer( val typeList = new util.ArrayList[TypeNode]() val nameList = new util.ArrayList[String]() val (inputAttrs, outputAttrs) = { - if (modes.isEmpty) { - // When there is no aggregate function, it does not need + if (modes.isEmpty || modes.forall(_ == Complete)) { + // When there is no aggregate function or there is complete mode, it does not need // to handle outputs according to the AggregateMode for (attr <- child.output) { typeList.add(ConverterUtils.getTypeNode(attr.dataType, attr.nullable)) @@ -212,7 +212,7 @@ case class CHHashAggregateExecTransformer( val aggregateFunc = aggExpr.aggregateFunction val childrenNodeList = new util.ArrayList[ExpressionNode]() val childrenNodes = aggExpr.mode match { - case Partial => + case Partial | Complete => aggregateFunc.children.toList.map( expr => { ExpressionConverter @@ -446,7 +446,7 @@ case class CHHashAggregateExecPullOutHelper( } resIndex += aggBufferAttr.size resIndex - case Final => + case Final | Complete => aggregateAttr += aggregateAttributeList(resIndex) resIndex += 1 resIndex diff --git a/backends-clickhouse/src/test/scala/io/glutenproject/execution/GlutenClickHouseColumnarShuffleAQESuite.scala b/backends-clickhouse/src/test/scala/io/glutenproject/execution/GlutenClickHouseColumnarShuffleAQESuite.scala index 0a858a35d9fa4..be2a29775117d 100644 --- a/backends-clickhouse/src/test/scala/io/glutenproject/execution/GlutenClickHouseColumnarShuffleAQESuite.scala +++ b/backends-clickhouse/src/test/scala/io/glutenproject/execution/GlutenClickHouseColumnarShuffleAQESuite.scala @@ -139,7 +139,13 @@ class GlutenClickHouseColumnarShuffleAQESuite } test("TPCH Q18") { - runTPCHQuery(18) { df => } + runTPCHQuery(18) { + df => + val hashAggregates = collect(df.queryExecution.executedPlan) { + case hash: HashAggregateExecBaseTransformer => hash + } + assert(hashAggregates.size == 3) + } } test("TPCH Q19") { diff --git a/backends-clickhouse/src/test/scala/io/glutenproject/execution/GlutenClickHouseDSV2ColumnarShuffleSuite.scala b/backends-clickhouse/src/test/scala/io/glutenproject/execution/GlutenClickHouseDSV2ColumnarShuffleSuite.scala index 21fc37660e174..5b03c742310af 100644 --- a/backends-clickhouse/src/test/scala/io/glutenproject/execution/GlutenClickHouseDSV2ColumnarShuffleSuite.scala +++ b/backends-clickhouse/src/test/scala/io/glutenproject/execution/GlutenClickHouseDSV2ColumnarShuffleSuite.scala @@ -42,7 +42,13 @@ class GlutenClickHouseDSV2ColumnarShuffleSuite extends GlutenClickHouseTPCHAbstr } test("TPCH Q3") { - runTPCHQuery(3) { df => } + runTPCHQuery(3) { + df => + val aggs = df.queryExecution.executedPlan.collectWithSubqueries { + case agg: HashAggregateExecBaseTransformer => agg + } + assert(aggs.size == 1) + } } test("TPCH Q4") { @@ -74,7 +80,13 @@ class GlutenClickHouseDSV2ColumnarShuffleSuite extends GlutenClickHouseTPCHAbstr } test("TPCH Q11") { - runTPCHQuery(11) { df => } + runTPCHQuery(11) { + df => + val aggs = df.queryExecution.executedPlan.collectWithSubqueries { + case agg: HashAggregateExecBaseTransformer => agg + } + assert(aggs.size == 3) + } } test("TPCH Q12") { @@ -82,15 +94,33 @@ class GlutenClickHouseDSV2ColumnarShuffleSuite extends GlutenClickHouseTPCHAbstr } test("TPCH Q13") { - runTPCHQuery(13) { df => } + runTPCHQuery(13) { + df => + val aggs = df.queryExecution.executedPlan.collectWithSubqueries { + case agg: HashAggregateExecBaseTransformer => agg + } + assert(aggs.size == 3) + } } test("TPCH Q14") { - runTPCHQuery(14) { df => } + runTPCHQuery(14) { + df => + val aggs = df.queryExecution.executedPlan.collectWithSubqueries { + case agg: HashAggregateExecBaseTransformer => agg + } + assert(aggs.size == 1) + } } test("TPCH Q15") { - runTPCHQuery(15) { df => } + runTPCHQuery(15) { + df => + val aggs = df.queryExecution.executedPlan.collectWithSubqueries { + case agg: HashAggregateExecBaseTransformer => agg + } + assert(aggs.size == 4) + } } test("TPCH Q16") { @@ -98,19 +128,43 @@ class GlutenClickHouseDSV2ColumnarShuffleSuite extends GlutenClickHouseTPCHAbstr } test("TPCH Q17") { - runTPCHQuery(17) { df => } + runTPCHQuery(17) { + df => + val aggs = df.queryExecution.executedPlan.collectWithSubqueries { + case agg: HashAggregateExecBaseTransformer => agg + } + assert(aggs.size == 3) + } } test("TPCH Q18") { - runTPCHQuery(18) { df => } + runTPCHQuery(18) { + df => + val aggs = df.queryExecution.executedPlan.collectWithSubqueries { + case agg: HashAggregateExecBaseTransformer => agg + } + assert(aggs.size == 4) + } } test("TPCH Q19") { - runTPCHQuery(19) { df => } + runTPCHQuery(19) { + df => + val aggs = df.queryExecution.executedPlan.collectWithSubqueries { + case agg: HashAggregateExecBaseTransformer => agg + } + assert(aggs.size == 1) + } } test("TPCH Q20") { - runTPCHQuery(20) { df => } + runTPCHQuery(20) { + df => + val aggs = df.queryExecution.executedPlan.collectWithSubqueries { + case agg: HashAggregateExecBaseTransformer => agg + } + assert(aggs.size == 1) + } } test("TPCH Q21") { diff --git a/backends-clickhouse/src/test/scala/io/glutenproject/execution/GlutenClickHouseNativeWriteTableSuite.scala b/backends-clickhouse/src/test/scala/io/glutenproject/execution/GlutenClickHouseNativeWriteTableSuite.scala index 9f86bbbc915c1..d4ad99c78485c 100644 --- a/backends-clickhouse/src/test/scala/io/glutenproject/execution/GlutenClickHouseNativeWriteTableSuite.scala +++ b/backends-clickhouse/src/test/scala/io/glutenproject/execution/GlutenClickHouseNativeWriteTableSuite.scala @@ -74,7 +74,6 @@ class GlutenClickHouseNativeWriteTableSuite .set("spark.gluten.sql.enable.native.validation", "false") // TODO: support default ANSI policy .set("spark.sql.storeAssignmentPolicy", "legacy") -// .set("spark.gluten.sql.columnar.backend.ch.runtime_config.logger.level", "debug") .set("spark.sql.warehouse.dir", getWarehouseDir) .setMaster("local[1]") } diff --git a/backends-clickhouse/src/test/scala/io/glutenproject/execution/GlutenClickHouseTPCDSParquetSuite.scala b/backends-clickhouse/src/test/scala/io/glutenproject/execution/GlutenClickHouseTPCDSParquetSuite.scala index c3537e62e8cda..ea557655ba942 100644 --- a/backends-clickhouse/src/test/scala/io/glutenproject/execution/GlutenClickHouseTPCDSParquetSuite.scala +++ b/backends-clickhouse/src/test/scala/io/glutenproject/execution/GlutenClickHouseTPCDSParquetSuite.scala @@ -46,8 +46,6 @@ class GlutenClickHouseTPCDSParquetSuite extends GlutenClickHouseTPCDSAbstractSui .set("spark.memory.offHeap.size", "4g") .set("spark.gluten.sql.validation.logLevel", "ERROR") .set("spark.gluten.sql.validation.printStackOnFailure", "true") -// .set("spark.gluten.sql.columnar.backend.ch.runtime_config.logger.level", "debug") -// .setMaster("local[1]") } executeTPCDSTest(false) diff --git a/backends-clickhouse/src/test/scala/io/glutenproject/execution/GlutenClickHouseTPCHBucketSuite.scala b/backends-clickhouse/src/test/scala/io/glutenproject/execution/GlutenClickHouseTPCHBucketSuite.scala index 3b37def852c7e..70fb16f2f9d6e 100644 --- a/backends-clickhouse/src/test/scala/io/glutenproject/execution/GlutenClickHouseTPCHBucketSuite.scala +++ b/backends-clickhouse/src/test/scala/io/glutenproject/execution/GlutenClickHouseTPCHBucketSuite.scala @@ -17,13 +17,17 @@ package io.glutenproject.execution import org.apache.spark.{SPARK_VERSION_SHORT, SparkConf} +import org.apache.spark.sql.{DataFrame, Row} import org.apache.spark.sql.execution.InputIteratorTransformer import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper +import org.apache.spark.sql.execution.aggregate.SortAggregateExec import org.apache.commons.io.FileUtils import java.io.File +import scala.collection.mutable + // Some sqls' line length exceeds 100 // scalastyle:off line.size.limit @@ -491,5 +495,229 @@ class GlutenClickHouseTPCHBucketSuite .isInstanceOf[ProjectExecTransformer]) }) } + + test("GLUTEN-4668: Merge two phase hash-based aggregate into one aggregate") { + def checkHashAggregateCount(df: DataFrame, expectedCount: Int): Unit = { + val plans = collect(df.queryExecution.executedPlan) { + case agg: HashAggregateExecBaseTransformer => agg + } + assert(plans.size == expectedCount) + } + + def checkResult(df: DataFrame, exceptedResult: Array[Row]): Unit = { + // check the result + val result = df.collect() + assert(result.size == exceptedResult.size) + result.equals(exceptedResult) + } + + val SQL = + """ + |select l_orderkey, l_returnflag, collect_list(l_linenumber) as t + |from lineitem group by l_orderkey, l_returnflag + |order by l_orderkey, l_returnflag, t limit 5 + |""".stripMargin + runSql(SQL)( + df => { + checkResult( + df, + Array( + Row(1, "N", mutable.WrappedArray.make(Array(3, 6, 1, 5, 2, 4))), + Row(2, "N", mutable.WrappedArray.make(Array(1))), + Row(3, "A", mutable.WrappedArray.make(Array(6, 4, 3))), + Row(3, "R", mutable.WrappedArray.make(Array(2, 5, 1))), + Row(4, "N", mutable.WrappedArray.make(Array(1))) + ) + ) + checkHashAggregateCount(df, 1) + }) + + val SQL1 = + """ + |select l_orderkey, l_returnflag, + |sum(l_linenumber) as t, + |count(l_linenumber) as t1, + |min(l_linenumber) as t2 + |from lineitem group by l_orderkey, l_returnflag + |order by l_orderkey, l_returnflag, t, t1, t2 limit 5 + |""".stripMargin + runSql(SQL1)( + df => { + checkResult( + df, + Array( + Row(1, "N", 21, 6, 1), + Row(2, "N", 1, 1, 1), + Row(3, "A", 13, 3, 3), + Row(3, "R", 8, 3, 1), + Row(4, "N", 1, 1, 1) + ) + ) + checkHashAggregateCount(df, 1) + }) + + val SQL2 = + """ + |select l_returnflag, l_orderkey, collect_list(l_linenumber) as t + |from lineitem group by l_orderkey, l_returnflag + |order by l_returnflag, l_orderkey, t limit 5 + |""".stripMargin + runSql(SQL2)( + df => { + checkResult( + df, + Array( + Row("A", 3, mutable.WrappedArray.make(Array(6, 4, 3))), + Row("A", 5, mutable.WrappedArray.make(Array(3))), + Row("A", 6, mutable.WrappedArray.make(Array(1))), + Row("A", 33, mutable.WrappedArray.make(Array(1, 2, 3))), + Row("A", 37, mutable.WrappedArray.make(Array(2, 3, 1))) + ) + ) + checkHashAggregateCount(df, 1) + }) + + // will merge four aggregates into two one. + val SQL3 = + """ + |select l_returnflag, l_orderkey, + |count(distinct l_linenumber) as t + |from lineitem group by l_orderkey, l_returnflag + |order by l_returnflag, l_orderkey, t limit 5 + |""".stripMargin + runSql(SQL3)( + df => { + checkResult( + df, + Array( + Row("A", 3, 3), + Row("A", 5, 1), + Row("A", 6, 1), + Row("A", 33, 3), + Row("A", 37, 3) + ) + ) + checkHashAggregateCount(df, 2) + }) + + // not support when there are more than one count distinct + val SQL4 = + """ + |select l_returnflag, l_orderkey, + |count(distinct l_linenumber) as t, + |count(distinct l_discount) as t1 + |from lineitem group by l_orderkey, l_returnflag + |order by l_returnflag, l_orderkey, t, t1 limit 5 + |""".stripMargin + runSql(SQL4)( + df => { + checkResult( + df, + Array( + Row("A", 3, 3, 3), + Row("A", 5, 1, 1), + Row("A", 6, 1, 1), + Row("A", 33, 3, 3), + Row("A", 37, 3, 2) + ) + ) + checkHashAggregateCount(df, 4) + }) + + val SQL5 = + """ + |select l_returnflag, l_orderkey, + |count(distinct l_linenumber) as t, + |sum(l_linenumber) as t1 + |from lineitem group by l_orderkey, l_returnflag + |order by l_returnflag, l_orderkey, t, t1 limit 5 + |""".stripMargin + runSql(SQL5)( + df => { + checkResult( + df, + Array( + Row("A", 3, 3, 13), + Row("A", 5, 1, 3), + Row("A", 6, 1, 1), + Row("A", 33, 3, 6), + Row("A", 37, 3, 6) + ) + ) + checkHashAggregateCount(df, 4) + }) + + val SQL6 = + """ + |select count(1) from lineitem + |""".stripMargin + runSql(SQL6)( + df => { + checkResult(df, Array(Row(600572))) + // there is a shuffle between two phase hash aggregates. + checkHashAggregateCount(df, 2) + }) + + // test sort aggregates + val SQL7 = + """ + |select l_orderkey, l_returnflag, max(l_shipinstruct) as t + |from lineitem + |group by l_orderkey, l_returnflag + |order by l_orderkey, l_returnflag, t + |limit 10 + |""".stripMargin + runSql(SQL7)( + df => { + checkResult( + df, + Array( + Row(1, "N", "TAKE BACK RETURN"), + Row(2, "N", "TAKE BACK RETURN"), + Row(3, "A", "TAKE BACK RETURN"), + Row(3, "R", "TAKE BACK RETURN"), + Row(4, "N", "DELIVER IN PERSON"), + Row(5, "A", "DELIVER IN PERSON"), + Row(5, "R", "NONE"), + Row(6, "A", "TAKE BACK RETURN"), + Row(7, "N", "TAKE BACK RETURN"), + Row(32, "N", "TAKE BACK RETURN") + ) + ) + checkHashAggregateCount(df, 1) + }) + + withSQLConf(("spark.gluten.sql.columnar.force.hashagg", "false")) { + val SQL = + """ + |select l_orderkey, l_returnflag, max(l_shipinstruct) as t + |from lineitem + |group by l_orderkey, l_returnflag + |order by l_orderkey, l_returnflag, t + |limit 10 + |""".stripMargin + runSql(SQL7, false)( + df => { + checkResult( + df, + Array( + Row(1, "N", "TAKE BACK RETURN"), + Row(2, "N", "TAKE BACK RETURN"), + Row(3, "A", "TAKE BACK RETURN"), + Row(3, "R", "TAKE BACK RETURN"), + Row(4, "N", "DELIVER IN PERSON"), + Row(5, "A", "DELIVER IN PERSON"), + Row(5, "R", "NONE"), + Row(6, "A", "TAKE BACK RETURN"), + Row(7, "N", "TAKE BACK RETURN"), + Row(32, "N", "TAKE BACK RETURN") + ) + ) + checkHashAggregateCount(df, 0) + val plans = collect(df.queryExecution.executedPlan) { case agg: SortAggregateExec => agg } + assert(plans.size == 2) + }) + } + } } // scalastyle:off line.size.limit diff --git a/backends-clickhouse/src/test/scala/io/glutenproject/execution/GlutenClickHouseTPCHParquetBucketSuite.scala b/backends-clickhouse/src/test/scala/io/glutenproject/execution/GlutenClickHouseTPCHParquetBucketSuite.scala index e840cde6e99b9..8480f3b6a1bd0 100644 --- a/backends-clickhouse/src/test/scala/io/glutenproject/execution/GlutenClickHouseTPCHParquetBucketSuite.scala +++ b/backends-clickhouse/src/test/scala/io/glutenproject/execution/GlutenClickHouseTPCHParquetBucketSuite.scala @@ -17,8 +17,10 @@ package io.glutenproject.execution import org.apache.spark.{SPARK_VERSION_SHORT, SparkConf} +import org.apache.spark.sql.DataFrame import org.apache.spark.sql.execution.InputIteratorTransformer import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper +import org.apache.spark.sql.execution.aggregate.SortAggregateExec import org.apache.commons.io.FileUtils @@ -566,5 +568,147 @@ class GlutenClickHouseTPCHParquetBucketSuite df => {} ) } + + test("GLUTEN-4668: Merge two phase hash-based aggregate into one aggregate") { + def checkHashAggregateCount(df: DataFrame, expectedCount: Int): Unit = { + val plans = collect(df.queryExecution.executedPlan) { + case agg: HashAggregateExecBaseTransformer => agg + } + assert(plans.size == expectedCount) + } + + val SQL = + """ + |select l_orderkey, l_returnflag, collect_list(l_linenumber) as t + |from lineitem group by l_orderkey, l_returnflag + |order by l_orderkey, l_returnflag, t limit 100 + |""".stripMargin + compareResultsAgainstVanillaSpark( + SQL, + true, + df => { checkHashAggregateCount(df, 1) } + ) + + val SQL1 = + """ + |select l_orderkey, l_returnflag, + |sum(l_linenumber) as t, + |count(l_linenumber) as t1, + |min(l_linenumber) as t2 + |from lineitem group by l_orderkey, l_returnflag + |order by l_orderkey, l_returnflag, t, t1, t2 limit 100 + |""".stripMargin + compareResultsAgainstVanillaSpark( + SQL1, + true, + df => { checkHashAggregateCount(df, 1) } + ) + + val SQL2 = + """ + |select l_returnflag, l_orderkey, collect_list(l_linenumber) as t + |from lineitem group by l_orderkey, l_returnflag + |order by l_returnflag, l_orderkey, t limit 100 + |""".stripMargin + compareResultsAgainstVanillaSpark( + SQL2, + true, + df => { checkHashAggregateCount(df, 1) } + ) + + // will merge four aggregates into two one. + val SQL3 = + """ + |select l_returnflag, l_orderkey, + |count(distinct l_linenumber) as t + |from lineitem group by l_orderkey, l_returnflag + |order by l_returnflag, l_orderkey, t limit 100 + |""".stripMargin + compareResultsAgainstVanillaSpark( + SQL3, + true, + df => { checkHashAggregateCount(df, 2) } + ) + + // not support when there are more than one count distinct + val SQL4 = + """ + |select l_returnflag, l_orderkey, + |count(distinct l_linenumber) as t, + |count(distinct l_discount) as t1 + |from lineitem group by l_orderkey, l_returnflag + |order by l_returnflag, l_orderkey, t, t1 limit 100 + |""".stripMargin + compareResultsAgainstVanillaSpark( + SQL4, + true, + df => { checkHashAggregateCount(df, 4) } + ) + + val SQL5 = + """ + |select l_returnflag, l_orderkey, + |count(distinct l_linenumber) as t, + |sum(l_linenumber) as t1 + |from lineitem group by l_orderkey, l_returnflag + |order by l_returnflag, l_orderkey, t, t1 limit 100 + |""".stripMargin + compareResultsAgainstVanillaSpark( + SQL5, + true, + df => { checkHashAggregateCount(df, 4) } + ) + + val SQL6 = + """ + |select count(1) from lineitem + |""".stripMargin + compareResultsAgainstVanillaSpark( + SQL6, + true, + df => { + // there is a shuffle between two phase hash aggregate. + checkHashAggregateCount(df, 2) + } + ) + + // test sort aggregates + val SQL7 = + """ + |select l_orderkey, l_returnflag, max(l_shipinstruct) as t + |from lineitem + |group by l_orderkey, l_returnflag + |order by l_orderkey, l_returnflag, t + |limit 100 + |""".stripMargin + compareResultsAgainstVanillaSpark( + SQL7, + true, + df => { + checkHashAggregateCount(df, 1) + } + ) + + withSQLConf(("spark.gluten.sql.columnar.force.hashagg", "false")) { + val SQL = + """ + |select l_orderkey, l_returnflag, max(l_shipinstruct) as t + |from lineitem + |group by l_orderkey, l_returnflag + |order by l_orderkey, l_returnflag, t + |limit 100 + |""".stripMargin + compareResultsAgainstVanillaSpark( + SQL, + true, + df => { + checkHashAggregateCount(df, 0) + val plans = collect(df.queryExecution.executedPlan) { case agg: SortAggregateExec => agg } + assert(plans.size == 2) + }, + noFallBack = false + ) + } + } } // scalastyle:on line.size.limit diff --git a/backends-clickhouse/src/test/scala/io/glutenproject/execution/GlutenClickhouseFunctionSuite.scala b/backends-clickhouse/src/test/scala/io/glutenproject/execution/GlutenClickhouseFunctionSuite.scala index 30184389706f1..a2388c4a3d5c1 100644 --- a/backends-clickhouse/src/test/scala/io/glutenproject/execution/GlutenClickhouseFunctionSuite.scala +++ b/backends-clickhouse/src/test/scala/io/glutenproject/execution/GlutenClickhouseFunctionSuite.scala @@ -67,7 +67,6 @@ class GlutenClickhouseFunctionSuite extends GlutenClickHouseTPCHAbstractSuite { .set("spark.gluten.sql.enable.native.validation", "false") // TODO: support default ANSI policy .set("spark.sql.storeAssignmentPolicy", "legacy") - // .set("spark.gluten.sql.columnar.backend.ch.runtime_config.logger.level", "debug") .set("spark.sql.warehouse.dir", warehouse) .setMaster("local[1]") } diff --git a/cpp-ch/local-engine/Operator/GraceMergingAggregatedStep.cpp b/cpp-ch/local-engine/Operator/GraceMergingAggregatedStep.cpp index 2bb163561a720..9294d9719b189 100644 --- a/cpp-ch/local-engine/Operator/GraceMergingAggregatedStep.cpp +++ b/cpp-ch/local-engine/Operator/GraceMergingAggregatedStep.cpp @@ -47,19 +47,21 @@ static DB::ITransformingStep::Traits getTraits() }; } -static DB::Block buildOutputHeader(const DB::Block & input_header_, const DB::Aggregator::Params params_) +static DB::Block buildOutputHeader(const DB::Block & input_header_, const DB::Aggregator::Params params_, bool final) { - return params_.getHeader(input_header_, true); + return params_.getHeader(input_header_, final); } GraceMergingAggregatedStep::GraceMergingAggregatedStep( DB::ContextPtr context_, const DB::DataStream & input_stream_, - DB::Aggregator::Params params_) + DB::Aggregator::Params params_, + bool no_pre_aggregated_) : DB::ITransformingStep( - input_stream_, buildOutputHeader(input_stream_.header, params_), getTraits()) + input_stream_, buildOutputHeader(input_stream_.header, params_, true), getTraits()) , context(context_) , params(std::move(params_)) + , no_pre_aggregated(no_pre_aggregated_) { } @@ -73,7 +75,7 @@ void GraceMergingAggregatedStep::transformPipeline(DB::QueryPipelineBuilder & pi DB::Processors new_processors; for (auto & output : outputs) { - auto op = std::make_shared(pipeline.getHeader(), transform_params, context); + auto op = std::make_shared(pipeline.getHeader(), transform_params, context, no_pre_aggregated); new_processors.push_back(op); DB::connect(*output, op->getInputs().front()); } @@ -95,14 +97,17 @@ void GraceMergingAggregatedStep::describeActions(DB::JSONBuilder::JSONMap & map) void GraceMergingAggregatedStep::updateOutputStream() { - output_stream = createOutputStream(input_streams.front(), buildOutputHeader(input_streams.front().header, params), getDataStreamTraits()); + output_stream = createOutputStream(input_streams.front(), buildOutputHeader(input_streams.front().header, params, true), getDataStreamTraits()); } -GraceMergingAggregatedTransform::GraceMergingAggregatedTransform(const DB::Block &header_, DB::AggregatingTransformParamsPtr params_, DB::ContextPtr context_) +GraceMergingAggregatedTransform::GraceMergingAggregatedTransform(const DB::Block &header_, DB::AggregatingTransformParamsPtr params_, DB::ContextPtr context_, bool no_pre_aggregated_) : IProcessor({header_}, {params_->getHeader()}) , header(header_) , params(params_) , context(context_) + , key_columns(params_->params.keys_size) + , aggregate_columns(params_->params.aggregates_size) + , no_pre_aggregated(no_pre_aggregated_) , tmp_data_disk(std::make_unique(context_->getTempDataOnDisk())) { max_buckets = context->getConfigRef().getUInt64("max_grace_aggregate_merging_buckets", 32); @@ -201,7 +206,7 @@ void GraceMergingAggregatedTransform::work() { assert(!input_finished); auto block = header.cloneWithColumns(input_chunk.detachColumns()); - mergeOneBlock(block); + mergeOneBlock(block, true); has_input = false; } else @@ -294,9 +299,10 @@ void GraceMergingAggregatedTransform::rehashDataVariants() } for (size_t i = current_bucket_index + 1; i < getBucketsNum(); ++i) { - addBlockIntoFileBucket(i, scattered_blocks[i]); + addBlockIntoFileBucket(i, scattered_blocks[i], false); scattered_blocks[i] = {}; } + params->aggregator.mergeOnBlock(scattered_blocks[current_bucket_index], *current_data_variants, no_more_keys); } if (block_rows) @@ -326,7 +332,7 @@ DB::Blocks GraceMergingAggregatedTransform::scatterBlock(const DB::Block & block return blocks; } -void GraceMergingAggregatedTransform::addBlockIntoFileBucket(size_t bucket_index, const DB::Block & block) +void GraceMergingAggregatedTransform::addBlockIntoFileBucket(size_t bucket_index, const DB::Block & block, bool is_original_block) { if (!block.rows()) return; @@ -336,7 +342,10 @@ void GraceMergingAggregatedTransform::addBlockIntoFileBucket(size_t bucket_index } auto & file_stream = buckets[bucket_index]; file_stream.pending_bytes += block.allocatedBytes(); - file_stream.blocks.push_back(block); + if (is_original_block && no_pre_aggregated) + file_stream.original_blocks.push_back(block); + else + file_stream.intermediate_blocks.push_back(block); if (file_stream.pending_bytes > max_pending_flush_blocks_per_bucket) { flushBucket(bucket_index); @@ -350,37 +359,54 @@ void GraceMergingAggregatedTransform::flushBuckets() flushBucket(i); } -size_t GraceMergingAggregatedTransform::flushBucket(size_t bucket_index) +static size_t flushBlocksInfoDisk(DB::TemporaryFileStream * file_stream, std::list & blocks) { - Stopwatch watch; - auto & file_stream = buckets[bucket_index]; - if (file_stream.blocks.empty()) - return 0; - if (!file_stream.file_stream) - file_stream.file_stream = &tmp_data_disk->createStream(header); - DB::Blocks blocks; size_t flush_bytes = 0; - while (!file_stream.blocks.empty()) + DB::Blocks tmp_blocks; + while (!blocks.empty()) { - while (!file_stream.blocks.empty()) + while (!blocks.empty()) { - if (!blocks.empty() && blocks.back().info.bucket_num != file_stream.blocks.front().info.bucket_num) + if (!tmp_blocks.empty() && tmp_blocks.back().info.bucket_num != blocks.front().info.bucket_num) break; - blocks.push_back(file_stream.blocks.front()); - file_stream.blocks.pop_front(); + tmp_blocks.push_back(blocks.front()); + blocks.pop_front(); } - auto bucket = blocks.front().info.bucket_num; - auto merged_block = BlockUtil::concatenateBlocksMemoryEfficiently(std::move(blocks)); + auto bucket = tmp_blocks.front().info.bucket_num; + auto merged_block = BlockUtil::concatenateBlocksMemoryEfficiently(std::move(tmp_blocks)); merged_block.info.bucket_num = bucket; - blocks.clear(); + tmp_blocks.clear(); flush_bytes += merged_block.bytes(); if (merged_block.rows()) { - file_stream.file_stream->write(merged_block); + file_stream->write(merged_block); } } if (flush_bytes) - file_stream.file_stream->flush(); + file_stream->flush(); + return flush_bytes; +} + +size_t GraceMergingAggregatedTransform::flushBucket(size_t bucket_index) +{ + Stopwatch watch; + auto & file_stream = buckets[bucket_index]; + size_t flush_bytes = 0; + if (!file_stream.original_blocks.empty()) + { + if (!file_stream.original_file_stream) + file_stream.original_file_stream = &tmp_data_disk->createStream(header); + flush_bytes += flushBlocksInfoDisk(file_stream.original_file_stream, file_stream.original_blocks); + } + if (!file_stream.intermediate_blocks.empty()) + { + if (!file_stream.intermediate_file_stream) + { + auto intermediate_header = params->aggregator.getHeader(false); + file_stream.intermediate_file_stream = &tmp_data_disk->createStream(intermediate_header); + } + flush_bytes += flushBlocksInfoDisk(file_stream.intermediate_file_stream, file_stream.intermediate_blocks); + } total_spill_disk_bytes += flush_bytes; total_spill_disk_time += watch.elapsedMilliseconds(); return flush_bytes; @@ -389,7 +415,8 @@ size_t GraceMergingAggregatedTransform::flushBucket(size_t bucket_index) std::unique_ptr GraceMergingAggregatedTransform::prepareBucketOutputBlocks(size_t bucket_index) { auto & buffer_file_stream = buckets[bucket_index]; - if (!current_data_variants && !buffer_file_stream.file_stream && buffer_file_stream.blocks.empty()) + if (!current_data_variants && !buffer_file_stream.intermediate_file_stream && buffer_file_stream.intermediate_blocks.empty() + && !buffer_file_stream.original_file_stream && buffer_file_stream.original_blocks.empty()) { return nullptr; } @@ -400,30 +427,56 @@ std::unique_ptr GraceMergingAggregatedTransform::pr checkAndSetupCurrentDataVariants(); - if (buffer_file_stream.file_stream) + if (buffer_file_stream.intermediate_file_stream) + { + buffer_file_stream.intermediate_file_stream->finishWriting(); + while (true) + { + auto block = buffer_file_stream.intermediate_file_stream->read(); + if (!block.rows()) + break; + read_bytes += block.bytes(); + read_rows += block.rows(); + mergeOneBlock(block, false); + block = {}; + } + buffer_file_stream.intermediate_file_stream = nullptr; + total_read_disk_time += watch.elapsedMilliseconds(); + } + if (!buffer_file_stream.intermediate_blocks.empty()) + { + for (auto & block : buffer_file_stream.intermediate_blocks) + { + mergeOneBlock(block, false); + block = {}; + } + } + + if (buffer_file_stream.original_file_stream) { - buffer_file_stream.file_stream->finishWriting(); + buffer_file_stream.original_file_stream->finishWriting(); while (true) { - auto block = buffer_file_stream.file_stream->read(); + auto block = buffer_file_stream.original_file_stream->read(); if (!block.rows()) break; read_bytes += block.bytes(); read_rows += block.rows(); - mergeOneBlock(block); + mergeOneBlock(block, true); block = {}; } - buffer_file_stream.file_stream = nullptr; + buffer_file_stream.original_file_stream = nullptr; total_read_disk_time += watch.elapsedMilliseconds(); } - if (!buffer_file_stream.blocks.empty()) + if (!buffer_file_stream.original_blocks.empty()) { - for (auto & block : buffer_file_stream.blocks) + for (auto & block : buffer_file_stream.original_blocks) { - mergeOneBlock(block); + mergeOneBlock(block, true); block = {}; } } + auto last_data_variants_size = current_data_variants->size(); auto converter = currentDataVariantToBlockConverter(true); LOG_INFO( @@ -458,7 +511,7 @@ void GraceMergingAggregatedTransform::checkAndSetupCurrentDataVariants() } } -void GraceMergingAggregatedTransform::mergeOneBlock(const DB::Block &block) +void GraceMergingAggregatedTransform::mergeOneBlock(const DB::Block &block, bool is_original_block) { if (!block.rows()) return; @@ -488,7 +541,10 @@ void GraceMergingAggregatedTransform::mergeOneBlock(const DB::Block &block) /// so if the buckets number is not changed since it was scattered, we don't need to scatter it again. if (block.info.bucket_num == static_cast(getBucketsNum()) || getBucketsNum() == 1) { - params->aggregator.mergeOnBlock(block, *current_data_variants, no_more_keys); + if (is_original_block && no_pre_aggregated) + params->aggregator.executeOnBlock(block, *current_data_variants, key_columns, aggregate_columns, no_more_keys); + else + params->aggregator.mergeOnBlock(block, *current_data_variants, no_more_keys); } else { @@ -511,9 +567,17 @@ void GraceMergingAggregatedTransform::mergeOneBlock(const DB::Block &block) } for (size_t i = current_bucket_index + 1; i < getBucketsNum(); ++i) { - addBlockIntoFileBucket(i, scattered_blocks[i]); + addBlockIntoFileBucket(i, scattered_blocks[i], is_original_block); + } + + if (is_original_block && no_pre_aggregated) + { + params->aggregator.executeOnBlock(scattered_blocks[current_bucket_index], *current_data_variants, key_columns, aggregate_columns, no_more_keys); + } + else + { + params->aggregator.mergeOnBlock(scattered_blocks[current_bucket_index], *current_data_variants, no_more_keys); } - params->aggregator.mergeOnBlock(scattered_blocks[current_bucket_index], *current_data_variants, no_more_keys); } } diff --git a/cpp-ch/local-engine/Operator/GraceMergingAggregatedStep.h b/cpp-ch/local-engine/Operator/GraceMergingAggregatedStep.h index 761fb7bfec222..788ca47dc0a9e 100644 --- a/cpp-ch/local-engine/Operator/GraceMergingAggregatedStep.h +++ b/cpp-ch/local-engine/Operator/GraceMergingAggregatedStep.h @@ -40,7 +40,8 @@ class GraceMergingAggregatedStep : public DB::ITransformingStep explicit GraceMergingAggregatedStep( DB::ContextPtr context_, const DB::DataStream & input_stream_, - DB::Aggregator::Params params_); + DB::Aggregator::Params params_, + bool no_pre_aggregated_); ~GraceMergingAggregatedStep() override = default; String getName() const override { return "GraceMergingAggregatedStep"; } @@ -52,6 +53,7 @@ class GraceMergingAggregatedStep : public DB::ITransformingStep private: DB::ContextPtr context; DB::Aggregator::Params params; + bool no_pre_aggregated; void updateOutputStream() override; }; @@ -59,14 +61,17 @@ class GraceMergingAggregatedTransform : public DB::IProcessor { public: using Status = DB::IProcessor::Status; - explicit GraceMergingAggregatedTransform(const DB::Block &header_, DB::AggregatingTransformParamsPtr params_, DB::ContextPtr context_); + explicit GraceMergingAggregatedTransform(const DB::Block &header_, DB::AggregatingTransformParamsPtr params_, DB::ContextPtr context_, bool no_pre_aggregated_); ~GraceMergingAggregatedTransform() override; Status prepare() override; void work() override; String getName() const override { return "GraceMergingAggregatedTransform"; } private: + bool no_pre_aggregated; DB::Block header; + DB::ColumnRawPtrs key_columns; + DB::Aggregator::AggregateColumns aggregate_columns; DB::AggregatingTransformParamsPtr params; DB::ContextPtr context; DB::TemporaryDataOnDiskPtr tmp_data_disk; @@ -88,8 +93,14 @@ class GraceMergingAggregatedTransform : public DB::IProcessor struct BufferFileStream { - std::list blocks; - DB::TemporaryFileStream * file_stream = nullptr; + /// store the intermediate result blocks. + std::list intermediate_blocks; + /// Only be used when there is no pre-aggregated step, store the original input blocks. + std::list original_blocks; + /// store the intermediate result blocks. + DB::TemporaryFileStream * intermediate_file_stream = nullptr; + /// Only be used when there is no pre-aggregated step + DB::TemporaryFileStream * original_file_stream = nullptr; size_t pending_bytes = 0; }; std::unordered_map buckets; @@ -99,7 +110,7 @@ class GraceMergingAggregatedTransform : public DB::IProcessor void rehashDataVariants(); DB::Blocks scatterBlock(const DB::Block & block); /// Add a block into a bucket, if the pending bytes reaches limit, flush it into disk. - void addBlockIntoFileBucket(size_t bucket_index, const DB::Block & block); + void addBlockIntoFileBucket(size_t bucket_index, const DB::Block & block, bool is_original_block); void flushBuckets(); size_t flushBucket(size_t bucket_index); /// Load blocks from disk and merge them into a new hash table, make a new AggregateDataBlockConverter @@ -109,7 +120,7 @@ class GraceMergingAggregatedTransform : public DB::IProcessor std::unique_ptr currentDataVariantToBlockConverter(bool final); void checkAndSetupCurrentDataVariants(); /// Merge one block into current_data_variants. - void mergeOneBlock(const DB::Block &block); + void mergeOneBlock(const DB::Block &block, bool is_original_block); bool isMemoryOverflow(); bool input_finished = false; diff --git a/cpp-ch/local-engine/Parser/AggregateFunctionParser.cpp b/cpp-ch/local-engine/Parser/AggregateFunctionParser.cpp index 125147ac1b06f..3d2318cc7fab8 100644 --- a/cpp-ch/local-engine/Parser/AggregateFunctionParser.cpp +++ b/cpp-ch/local-engine/Parser/AggregateFunctionParser.cpp @@ -88,7 +88,8 @@ std::pair AggregateFunctionParser::tryApplyCHCombinator( }; String combinator_function_name = ch_func_name; DB::DataTypes combinator_arg_column_types = arg_column_types; - if (func_info.phase != substrait::AggregationPhase::AGGREGATION_PHASE_INITIAL_TO_INTERMEDIATE) + if (func_info.phase != substrait::AggregationPhase::AGGREGATION_PHASE_INITIAL_TO_INTERMEDIATE && + func_info.phase != substrait::AggregationPhase::AGGREGATION_PHASE_INITIAL_TO_RESULT) { if (arg_column_types.size() != 1) { @@ -146,10 +147,20 @@ std::pair AggregateFunctionParser::tryApplyCHCombinator( const DB::ActionsDAG::Node * AggregateFunctionParser::convertNodeTypeIfNeeded( const CommonFunctionInfo & func_info, const DB::ActionsDAG::Node * func_node, - DB::ActionsDAGPtr & actions_dag) const + DB::ActionsDAGPtr & actions_dag, + bool withNullability) const { const auto & output_type = func_info.output_type; - if (!TypeParser::isTypeMatched(output_type, func_node->result_type)) + bool needToConvertNodeType = false; + if (withNullability) + { + needToConvertNodeType = !TypeParser::isTypeMatchedWithNullability(output_type, func_node->result_type); + } + else + { + needToConvertNodeType = !TypeParser::isTypeMatched(output_type, func_node->result_type); + } + if (needToConvertNodeType) { func_node = ActionsDAGUtil::convertNodeType( actions_dag, func_node, TypeParser::parseType(output_type)->getName(), func_node->result_name); diff --git a/cpp-ch/local-engine/Parser/AggregateFunctionParser.h b/cpp-ch/local-engine/Parser/AggregateFunctionParser.h index bfa932b819f8a..a9840eeef8253 100644 --- a/cpp-ch/local-engine/Parser/AggregateFunctionParser.h +++ b/cpp-ch/local-engine/Parser/AggregateFunctionParser.h @@ -106,7 +106,7 @@ class AggregateFunctionParser /// Make a postprojection for the function result. virtual const DB::ActionsDAG::Node * convertNodeTypeIfNeeded( - const CommonFunctionInfo & func_info, const DB::ActionsDAG::Node * func_node, DB::ActionsDAGPtr & actions_dag) const; + const CommonFunctionInfo & func_info, const DB::ActionsDAG::Node * func_node, DB::ActionsDAGPtr & actions_dag, bool withNullability) const; /// Parameters are only used in aggregate functions at present. e.g. percentiles(0.5)(x). /// 0.5 is the parameter of percentiles function. diff --git a/cpp-ch/local-engine/Parser/AggregateRelParser.cpp b/cpp-ch/local-engine/Parser/AggregateRelParser.cpp index 72eaf34bbfe35..e71cf9f8d2496 100644 --- a/cpp-ch/local-engine/Parser/AggregateRelParser.cpp +++ b/cpp-ch/local-engine/Parser/AggregateRelParser.cpp @@ -63,6 +63,13 @@ DB::QueryPlanPtr AggregateRelParser::parse(DB::QueryPlanPtr query_plan, const su addPostProjection(); LOG_TRACE(logger, "header after post-projection is: {}", plan->getCurrentDataStream().header.dumpStructure()); } + else if (has_complete_stage) + { + addCompleteModeAggregatedStep(); + LOG_TRACE(logger, "header after complete aggregate is: {}", plan->getCurrentDataStream().header.dumpStructure()); + addPostProjection(); + LOG_TRACE(logger, "header after post-projection is: {}", plan->getCurrentDataStream().header.dumpStructure()); + } else { addAggregatingStep(); @@ -71,7 +78,7 @@ DB::QueryPlanPtr AggregateRelParser::parse(DB::QueryPlanPtr query_plan, const su /// If the groupings is empty, we still need to return one row with default values even if the input is empty. if ((rel.aggregate().groupings().empty() || rel.aggregate().groupings()[0].grouping_expressions().empty()) - && (has_final_stage || rel.aggregate().measures().empty())) + && (has_final_stage || has_complete_stage || rel.aggregate().measures().empty())) { LOG_TRACE(&Poco::Logger::get("AggregateRelParser"), "default aggregate result step"); auto default_agg_result = std::make_unique(plan->getCurrentDataStream()); @@ -96,6 +103,7 @@ void AggregateRelParser::setup(DB::QueryPlanPtr query_plan, const substrait::Rel has_first_stage = phase_set.contains(substrait::AggregationPhase::AGGREGATION_PHASE_INITIAL_TO_INTERMEDIATE); has_inter_stage = phase_set.contains(substrait::AggregationPhase::AGGREGATION_PHASE_INTERMEDIATE_TO_INTERMEDIATE); has_final_stage = phase_set.contains(substrait::AggregationPhase::AGGREGATION_PHASE_INTERMEDIATE_TO_RESULT); + has_complete_stage = phase_set.contains(substrait::AggregationPhase::AGGREGATION_PHASE_INITIAL_TO_RESULT); if (aggregate_rel->measures().empty()) { /// According to planAggregateWithoutDistinct in AggUtils.scala, an aggregate without aggregate @@ -110,6 +118,11 @@ void AggregateRelParser::setup(DB::QueryPlanPtr query_plan, const substrait::Rel throw DB::Exception( DB::ErrorCodes::LOGICAL_ERROR, "AggregateRelParser: multiple aggregation phases with final stage are not supported"); } + if (phase_set.size() > 1 && has_complete_stage) + { + throw DB::Exception( + DB::ErrorCodes::LOGICAL_ERROR, "AggregateRelParser: multiple aggregation phases with complete mode are not supported"); + } auto input_header = plan->getCurrentDataStream().header; for (const auto & measure : aggregate_rel->measures()) @@ -182,7 +195,7 @@ void AggregateRelParser::buildAggregateDescriptions(AggregateDescriptions & desc { auto build_result_column_name = [](const String & function_name, const Strings & arg_column_names, substrait::AggregationPhase phase) { - if (phase != substrait::AggregationPhase::AGGREGATION_PHASE_INITIAL_TO_INTERMEDIATE) + if (phase == substrait::AggregationPhase::AGGREGATION_PHASE_INTERMEDIATE_TO_RESULT) { assert(arg_column_names.size() == 1); return arg_column_names[0]; @@ -195,6 +208,7 @@ void AggregateRelParser::buildAggregateDescriptions(AggregateDescriptions & desc AggregateDescription description; const auto & measure = agg_info.measure->measure(); description.column_name = build_result_column_name(agg_info.function_name, agg_info.arg_column_names, measure.phase()); + agg_info.measure_column_name = description.column_name; description.argument_names = agg_info.arg_column_names; DB::AggregateFunctionProperties properties; @@ -257,7 +271,7 @@ void AggregateRelParser::addMergingAggregatedStep() if (enable_streaming_aggregating) { params.group_by_two_level_threshold = settings.group_by_two_level_threshold; - auto merging_step = std::make_unique(getContext(), plan->getCurrentDataStream(), params); + auto merging_step = std::make_unique(getContext(), plan->getCurrentDataStream(), params, false); steps.emplace_back(merging_step.get()); plan->addStep(std::move(merging_step)); } @@ -280,6 +294,84 @@ void AggregateRelParser::addMergingAggregatedStep() } } +void AggregateRelParser::addCompleteModeAggregatedStep() +{ + AggregateDescriptions aggregate_descriptions; + buildAggregateDescriptions(aggregate_descriptions); + auto settings = getContext()->getSettingsRef(); + bool enable_streaming_aggregating = getContext()->getConfigRef().getBool("enable_streaming_aggregating", true); + if (enable_streaming_aggregating) + { + Aggregator::Params params( + grouping_keys, + aggregate_descriptions, + false, + settings.max_rows_to_group_by, + settings.group_by_overflow_mode, + settings.group_by_two_level_threshold, + settings.group_by_two_level_threshold_bytes, + settings.max_bytes_before_external_group_by, + settings.empty_result_for_aggregation_by_empty_set, + getContext()->getTempDataOnDisk(), + settings.max_threads, + settings.min_free_disk_space_for_temporary_data, + true, + 3, + PODArrayUtil::adjustMemoryEfficientSize(settings.max_block_size), + /*enable_prefetch*/ true, + /*only_merge*/ false, + settings.optimize_group_by_constant_keys, + settings.min_hit_rate_to_use_consecutive_keys_optimization, + /*StatsCollectingParams*/{}); + auto merging_step = std::make_unique(getContext(), plan->getCurrentDataStream(), params, true); + steps.emplace_back(merging_step.get()); + plan->addStep(std::move(merging_step)); + } + else + { + Aggregator::Params params( + grouping_keys, + aggregate_descriptions, + false, + settings.max_rows_to_group_by, + settings.group_by_overflow_mode, + settings.group_by_two_level_threshold, + settings.group_by_two_level_threshold_bytes, + settings.max_bytes_before_external_group_by, + settings.empty_result_for_aggregation_by_empty_set, + getContext()->getTempDataOnDisk(), + settings.max_threads, + settings.min_free_disk_space_for_temporary_data, + true, + 3, + PODArrayUtil::adjustMemoryEfficientSize(settings.max_block_size), + /*enable_prefetch*/ true, + /*only_merge*/ false, + settings.optimize_group_by_constant_keys, + settings.min_hit_rate_to_use_consecutive_keys_optimization, + /*StatsCollectingParams*/{}); + + auto aggregating_step = std::make_unique( + plan->getCurrentDataStream(), + params, + GroupingSetsParamsList(), + true, + settings.max_block_size, + settings.aggregation_in_order_max_block_bytes, + 1, + 1, + false, + false, + SortDescription(), + SortDescription(), + false, + false, + false); + steps.emplace_back(aggregating_step.get()); + plan->addStep(std::move(aggregating_step)); + } +} + void AggregateRelParser::addAggregatingStep() { AggregateDescriptions aggregate_descriptions; @@ -371,13 +463,33 @@ void AggregateRelParser::addPostProjection() auto input_header = plan->getCurrentDataStream().header; ActionsDAGPtr project_actions_dag = std::make_shared(input_header.getColumnsWithTypeAndName()); auto dag_footprint = project_actions_dag->dumpDAG(); - for (const auto & agg_info : aggregates) + + if (has_final_stage) { - /// For final stage, the aggregate function's input is only one intermediate result columns. - /// The final result columm's position is the same as the intermediate result column's position. - auto pos = agg_info.measure->measure().arguments(0).value().selection().direct_reference().struct_field().field(); - const auto * agg_result_node = project_actions_dag->getInputs()[pos]; - agg_info.function_parser->convertNodeTypeIfNeeded(agg_info.parser_func_info, agg_result_node, project_actions_dag); + for (const auto & agg_info : aggregates) + { + for (const auto * input_node : project_actions_dag->getInputs()) + { + if (input_node->result_name == agg_info.measure_column_name) + { + agg_info.function_parser->convertNodeTypeIfNeeded(agg_info.parser_func_info, input_node, project_actions_dag, false); + } + } + } + } + else if (has_complete_stage) + { + // on the complete mode, it must consider the nullability when converting node type + for (const auto & agg_info : aggregates) + { + for (const auto * output_node : project_actions_dag->getOutputs()) + { + if (output_node->result_name == agg_info.measure_column_name) + { + agg_info.function_parser->convertNodeTypeIfNeeded(agg_info.parser_func_info, output_node, project_actions_dag, true); + } + } + } } if (project_actions_dag->dumpDAG() != dag_footprint) { diff --git a/cpp-ch/local-engine/Parser/AggregateRelParser.h b/cpp-ch/local-engine/Parser/AggregateRelParser.h index 53b5f7fb70480..8f68f858fc511 100644 --- a/cpp-ch/local-engine/Parser/AggregateRelParser.h +++ b/cpp-ch/local-engine/Parser/AggregateRelParser.h @@ -36,6 +36,7 @@ class AggregateRelParser : public RelParser struct AggregateInfo { const substrait::AggregateRel::Measure * measure = nullptr; + String measure_column_name; Strings arg_column_names; DB::DataTypes arg_column_types; Array params; @@ -53,6 +54,7 @@ class AggregateRelParser : public RelParser bool has_first_stage = false; bool has_inter_stage = false; bool has_final_stage = false; + bool has_complete_stage = false; DB::QueryPlanPtr plan = nullptr; const substrait::AggregateRel * aggregate_rel = nullptr; @@ -62,6 +64,7 @@ class AggregateRelParser : public RelParser void setup(DB::QueryPlanPtr query_plan, const substrait::Rel & rel); void addPreProjection(); void addMergingAggregatedStep(); + void addCompleteModeAggregatedStep(); void addAggregatingStep(); void addPostProjection(); diff --git a/cpp-ch/local-engine/Parser/TypeParser.cpp b/cpp-ch/local-engine/Parser/TypeParser.cpp index 958a5fb4518f1..2edd8c1c83ec7 100644 --- a/cpp-ch/local-engine/Parser/TypeParser.cpp +++ b/cpp-ch/local-engine/Parser/TypeParser.cpp @@ -315,6 +315,12 @@ bool TypeParser::isTypeMatched(const substrait::Type & substrait_type, const Dat return a->equals(*b); } +bool TypeParser::isTypeMatchedWithNullability(const substrait::Type & substrait_type, const DataTypePtr & ch_type) +{ + const auto parsed_ch_type = TypeParser::parseType(substrait_type); + return parsed_ch_type->equals(*ch_type); +} + DB::DataTypePtr TypeParser::tryWrapNullable(substrait::Type_Nullability nullable, DB::DataTypePtr nested_type) { if (nullable == substrait::Type_Nullability::Type_Nullability_NULLABILITY_NULLABLE && !nested_type->isNullable()) diff --git a/cpp-ch/local-engine/Parser/TypeParser.h b/cpp-ch/local-engine/Parser/TypeParser.h index 7793ae198b860..a25b2f50afe83 100644 --- a/cpp-ch/local-engine/Parser/TypeParser.h +++ b/cpp-ch/local-engine/Parser/TypeParser.h @@ -48,6 +48,7 @@ class TypeParser static DB::Block buildBlockFromNamedStructWithoutDFS(const substrait::NamedStruct & struct_); static bool isTypeMatched(const substrait::Type & substrait_type, const DB::DataTypePtr & ch_type); + static bool isTypeMatchedWithNullability(const substrait::Type & substrait_type, const DB::DataTypePtr & ch_type); private: /// Mapping spark type names to CH type names. static std::unordered_map type_names_mapping; diff --git a/cpp-ch/local-engine/Parser/WindowRelParser.cpp b/cpp-ch/local-engine/Parser/WindowRelParser.cpp index 969959c3aec7a..a1787a2c93c5c 100644 --- a/cpp-ch/local-engine/Parser/WindowRelParser.cpp +++ b/cpp-ch/local-engine/Parser/WindowRelParser.cpp @@ -357,7 +357,7 @@ void WindowRelParser::tryAddProjectionAfterWindow() { auto & win_info = win_infos[i]; const auto * win_result_node = &actions_dag->findInOutputs(win_info.result_column_name); - win_info.function_parser->convertNodeTypeIfNeeded(win_info.parser_func_info, win_result_node, actions_dag); + win_info.function_parser->convertNodeTypeIfNeeded(win_info.parser_func_info, win_result_node, actions_dag, false); } if (actions_dag->dumpDAG() != dag_footprint) diff --git a/cpp-ch/local-engine/Parser/aggregate_function_parser/BloomFilterAggParser.cpp b/cpp-ch/local-engine/Parser/aggregate_function_parser/BloomFilterAggParser.cpp index 5f508b1333566..8788abb6dcf79 100644 --- a/cpp-ch/local-engine/Parser/aggregate_function_parser/BloomFilterAggParser.cpp +++ b/cpp-ch/local-engine/Parser/aggregate_function_parser/BloomFilterAggParser.cpp @@ -55,7 +55,7 @@ DB::Array get_parameters(Int64 insert_num, Int64 bits_num) DB::Array AggregateFunctionParserBloomFilterAgg::parseFunctionParameters( const CommonFunctionInfo & func_info, DB::ActionsDAG::NodeRawConstPtrs & arg_nodes) const { - if (func_info.phase == substrait::AGGREGATION_PHASE_INITIAL_TO_INTERMEDIATE) + if (func_info.phase == substrait::AGGREGATION_PHASE_INITIAL_TO_INTERMEDIATE || func_info.phase == substrait::AGGREGATION_PHASE_INITIAL_TO_RESULT) { auto get_parameter_field = [](const DB::ActionsDAG::Node * node, size_t /*paramter_index*/) -> DB::Field { diff --git a/cpp-ch/local-engine/Parser/aggregate_function_parser/CollectListParser.h b/cpp-ch/local-engine/Parser/aggregate_function_parser/CollectListParser.h index a75e9ee2ad3c4..d7a9c1a5c1889 100644 --- a/cpp-ch/local-engine/Parser/aggregate_function_parser/CollectListParser.h +++ b/cpp-ch/local-engine/Parser/aggregate_function_parser/CollectListParser.h @@ -52,7 +52,7 @@ class CollectFunctionParser : public AggregateFunctionParser throw DB::Exception(DB::ErrorCodes::NOT_IMPLEMENTED, "Not implement"); } const DB::ActionsDAG::Node * convertNodeTypeIfNeeded( - const CommonFunctionInfo &, const DB::ActionsDAG::Node * func_node, DB::ActionsDAGPtr & actions_dag) const override + const CommonFunctionInfo &, const DB::ActionsDAG::Node * func_node, DB::ActionsDAGPtr & actions_dag, bool /* withNullability */) const override { const DB::ActionsDAG::Node * ret_node = func_node; if (func_node->result_type->isNullable()) diff --git a/gluten-core/src/main/java/io/glutenproject/substrait/expression/AggregateFunctionNode.java b/gluten-core/src/main/java/io/glutenproject/substrait/expression/AggregateFunctionNode.java index eb952b8427f98..ac1500fbb90d6 100644 --- a/gluten-core/src/main/java/io/glutenproject/substrait/expression/AggregateFunctionNode.java +++ b/gluten-core/src/main/java/io/glutenproject/substrait/expression/AggregateFunctionNode.java @@ -57,6 +57,9 @@ public AggregateFunction toProtobuf() { case "PARTIAL_MERGE": aggBuilder.setPhase(AggregationPhase.AGGREGATION_PHASE_INTERMEDIATE_TO_INTERMEDIATE); break; + case "COMPLETE": + aggBuilder.setPhase(AggregationPhase.AGGREGATION_PHASE_INITIAL_TO_RESULT); + break; case "FINAL": aggBuilder.setPhase(AggregationPhase.AGGREGATION_PHASE_INTERMEDIATE_TO_RESULT); break; diff --git a/gluten-core/src/main/scala/io/glutenproject/backendsapi/BackendSettingsApi.scala b/gluten-core/src/main/scala/io/glutenproject/backendsapi/BackendSettingsApi.scala index a2c07f7b05e8c..db586239e01e9 100644 --- a/gluten-core/src/main/scala/io/glutenproject/backendsapi/BackendSettingsApi.scala +++ b/gluten-core/src/main/scala/io/glutenproject/backendsapi/BackendSettingsApi.scala @@ -125,4 +125,7 @@ trait BackendSettingsApi { def shouldRewriteCount(): Boolean = false def supportCartesianProductExec(): Boolean = false + + /** Merge two phases hash based aggregate if need */ + def mergeTwoPhasesHashBaseAggregateIfNeed(): Boolean = false } diff --git a/gluten-core/src/main/scala/io/glutenproject/execution/HashAggregateExecBaseTransformer.scala b/gluten-core/src/main/scala/io/glutenproject/execution/HashAggregateExecBaseTransformer.scala index 29febc6a64763..eff758b479841 100644 --- a/gluten-core/src/main/scala/io/glutenproject/execution/HashAggregateExecBaseTransformer.scala +++ b/gluten-core/src/main/scala/io/glutenproject/execution/HashAggregateExecBaseTransformer.scala @@ -179,17 +179,17 @@ abstract class HashAggregateExecBaseTransformer( case s: Sum if s.prettyName.equals("try_sum") => false case _: CollectList | _: CollectSet => mode match { - case Partial | Final => true + case Partial | Final | Complete => true case _ => false } case bloom if bloom.getClass.getSimpleName.equals("BloomFilterAggregate") => mode match { - case Partial | Final => true + case Partial | Final | Complete => true case _ => false } case _ => mode match { - case Partial | PartialMerge | Final => true + case Partial | PartialMerge | Final | Complete => true case _ => false } } @@ -199,6 +199,7 @@ abstract class HashAggregateExecBaseTransformer( aggregateMode match { case Partial => "PARTIAL" case PartialMerge => "PARTIAL_MERGE" + case Complete => "COMPLETE" case Final => "FINAL" case other => throw new UnsupportedOperationException(s"not currently supported: $other.") @@ -237,7 +238,7 @@ abstract class HashAggregateExecBaseTransformer( } val aggregateFunc = aggExpr.aggregateFunction val childrenNodes = aggExpr.mode match { - case Partial => + case Partial | Complete => aggregateFunc.children.toList.map( expr => { ExpressionConverter diff --git a/gluten-core/src/main/scala/io/glutenproject/extension/ColumnarOverrides.scala b/gluten-core/src/main/scala/io/glutenproject/extension/ColumnarOverrides.scala index 197ac75368fd7..5ed11d0664d07 100644 --- a/gluten-core/src/main/scala/io/glutenproject/extension/ColumnarOverrides.scala +++ b/gluten-core/src/main/scala/io/glutenproject/extension/ColumnarOverrides.scala @@ -638,6 +638,7 @@ case class ColumnarOverrideRules(session: SparkSession) ) ::: BackendsApiManager.getSparkPlanExecApiInstance.genExtendedColumnarValidationRules() ::: List( + (spark: SparkSession) => MergeTwoPhasesHashBaseAggregate(spark), (_: SparkSession) => rewriteSparkPlanRule(), (_: SparkSession) => AddTransformHintRule(), (_: SparkSession) => FallbackBloomFilterAggIfNeeded(), diff --git a/gluten-core/src/main/scala/io/glutenproject/extension/MergeTwoPhasesHashAggregate.scala b/gluten-core/src/main/scala/io/glutenproject/extension/MergeTwoPhasesHashAggregate.scala new file mode 100644 index 0000000000000..91c4225215099 --- /dev/null +++ b/gluten-core/src/main/scala/io/glutenproject/extension/MergeTwoPhasesHashAggregate.scala @@ -0,0 +1,158 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.glutenproject.extension + +import io.glutenproject.GlutenConfig +import io.glutenproject.backendsapi.BackendsApiManager +import io.glutenproject.extension.columnar.TransformHints +import io.glutenproject.utils.PhysicalPlanSelector + +import org.apache.spark.sql.SparkSession +import org.apache.spark.sql.catalyst.expressions.aggregate.{Complete, Final, Partial} +import org.apache.spark.sql.catalyst.rules.Rule +import org.apache.spark.sql.execution.SparkPlan +import org.apache.spark.sql.execution.aggregate.{BaseAggregateExec, HashAggregateExec, ObjectHashAggregateExec, SortAggregateExec} + +/** + * Merge two phase hash-based aggregate into one aggregate in the spark plan if there is no shuffle: + * + * Merge HashAggregate(t1.i, SUM, final) + HashAggregate(t1.i, SUM, partial) into + * HashAggregate(t1.i, SUM, complete) + * + * Note: this rule must be applied before the `PullOutPreProject` rule, because the + * `PullOutPreProject` rule will modify the attributes in some cases. + */ +case class MergeTwoPhasesHashBaseAggregate(session: SparkSession) extends Rule[SparkPlan] { + + val columnarConf: GlutenConfig = GlutenConfig.getConf + val scanOnly: Boolean = columnarConf.enableScanOnly + val enableColumnarHashAgg: Boolean = !scanOnly && columnarConf.enableColumnarHashAgg + val replaceSortAggWithHashAgg = BackendsApiManager.getSettings.replaceSortAggWithHashAgg + + private def isTransformable(agg: BaseAggregateExec): Boolean = { + if (!TransformHints.isAlreadyTagged(agg)) { + // Check whether the aggregate is transformable. + // Note: do not use the AddTransformHintRule to apply all plan again + val transformer = BackendsApiManager.getSparkPlanExecApiInstance + .genHashAggregateExecTransformer( + agg.requiredChildDistributionExpressions, + agg.groupingExpressions, + agg.aggregateExpressions, + agg.aggregateAttributes, + 0, // there is no `initialInputBufferOffset` field in Spark 3.2 + agg.resultExpressions, + agg.child + ) + transformer.doValidate().isValid + } else { + // The transformable tag is already set before + TransformHints.isTransformable(agg) + } + } + + private def isPartialAgg(partialAgg: BaseAggregateExec, finalAgg: BaseAggregateExec): Boolean = { + // TODO: now it can not support to merge agg which there are the filters in the aggregate exprs. + if ( + partialAgg.aggregateExpressions.forall(x => x.mode == Partial && x.filter.isEmpty) && + finalAgg.aggregateExpressions.forall(x => x.mode == Final && x.filter.isEmpty) + ) { + (finalAgg.logicalLink, partialAgg.logicalLink) match { + case (Some(agg1), Some(agg2)) => agg1.sameResult(agg2) + case _ => false + } + } else { + false + } + } + + override def apply(plan: SparkPlan): SparkPlan = PhysicalPlanSelector.maybe(session, plan) { + if ( + !enableColumnarHashAgg || !BackendsApiManager.getSettings + .mergeTwoPhasesHashBaseAggregateIfNeed() + ) { + plan + } else { + plan.transformDown { + case hashAgg @ HashAggregateExec( + _, + isStreaming, + _, + _, + aggregateExpressions, + aggregateAttributes, + _, + resultExpressions, + child: HashAggregateExec) + if !isStreaming && isTransformable(hashAgg) && isTransformable(child) && isPartialAgg( + child, + hashAgg) => + // convert to complete mode aggregate expressions + val completeAggregateExpressions = aggregateExpressions.map(_.copy(mode = Complete)) + hashAgg.copy( + requiredChildDistributionExpressions = None, + groupingExpressions = child.groupingExpressions, + aggregateExpressions = completeAggregateExpressions, + initialInputBufferOffset = 0, + child = child.child + ) + case objectHashAgg @ ObjectHashAggregateExec( + _, + isStreaming, + _, + _, + aggregateExpressions, + aggregateAttributes, + _, + resultExpressions, + child: ObjectHashAggregateExec) + if !isStreaming && isTransformable(objectHashAgg) && isTransformable( + child) && isPartialAgg(child, objectHashAgg) => + // convert to complete mode aggregate expressions + val completeAggregateExpressions = aggregateExpressions.map(_.copy(mode = Complete)) + objectHashAgg.copy( + requiredChildDistributionExpressions = None, + groupingExpressions = child.groupingExpressions, + aggregateExpressions = completeAggregateExpressions, + initialInputBufferOffset = 0, + child = child.child + ) + case sortAgg @ SortAggregateExec( + _, + isStreaming, + _, + _, + aggregateExpressions, + aggregateAttributes, + _, + resultExpressions, + child: SortAggregateExec) + if replaceSortAggWithHashAgg && !isStreaming && isTransformable( + sortAgg) && isTransformable(child) && isPartialAgg(child, sortAgg) => + // convert to complete mode aggregate expressions + val completeAggregateExpressions = aggregateExpressions.map(_.copy(mode = Complete)) + sortAgg.copy( + requiredChildDistributionExpressions = None, + groupingExpressions = child.groupingExpressions, + aggregateExpressions = completeAggregateExpressions, + initialInputBufferOffset = 0, + child = child.child + ) + case plan: SparkPlan => plan + } + } + } +} diff --git a/gluten-ut/spark33/src/test/scala/org/apache/spark/sql/execution/GlutenReplaceHashWithSortAggSuite.scala b/gluten-ut/spark33/src/test/scala/org/apache/spark/sql/execution/GlutenReplaceHashWithSortAggSuite.scala index c60a1cc4686e1..3a16215c7d13f 100644 --- a/gluten-ut/spark33/src/test/scala/org/apache/spark/sql/execution/GlutenReplaceHashWithSortAggSuite.scala +++ b/gluten-ut/spark33/src/test/scala/org/apache/spark/sql/execution/GlutenReplaceHashWithSortAggSuite.scala @@ -16,6 +16,7 @@ */ package org.apache.spark.sql.execution +import io.glutenproject.backendsapi.BackendsApiManager import io.glutenproject.execution.HashAggregateExecBaseTransformer import org.apache.spark.sql.{DataFrame, GlutenSQLTestsBaseTrait} @@ -100,7 +101,11 @@ class GlutenReplaceHashWithSortAggSuite |) |GROUP BY key """.stripMargin - checkAggs(query, 2, 0, 2, 0) + if (BackendsApiManager.getSettings.mergeTwoPhasesHashBaseAggregateIfNeed()) { + checkAggs(query, 1, 0, 1, 0) + } else { + checkAggs(query, 2, 0, 2, 0) + } } } } diff --git a/gluten-ut/spark34/src/test/scala/io/glutenproject/utils/clickhouse/ClickHouseTestSettings.scala b/gluten-ut/spark34/src/test/scala/io/glutenproject/utils/clickhouse/ClickHouseTestSettings.scala index c8b61ba5695db..f0f57aa87dae2 100644 --- a/gluten-ut/spark34/src/test/scala/io/glutenproject/utils/clickhouse/ClickHouseTestSettings.scala +++ b/gluten-ut/spark34/src/test/scala/io/glutenproject/utils/clickhouse/ClickHouseTestSettings.scala @@ -831,6 +831,7 @@ class ClickHouseTestSettings extends BackendTestSettings { .exclude("do not replace hash aggregate if child does not have sort order") .exclude("do not replace hash aggregate if there is no group-by column") .exclude("Gluten - replace partial hash aggregate with sort aggregate") + .exclude("Gluten - replace partial and final hash aggregate together with sort aggregate") enableSuite[GlutenReuseExchangeAndSubquerySuite] enableSuite[GlutenSQLWindowFunctionSuite] .exclude("window function: partition and order expressions")