diff --git a/pkg/expression/aggregation/BUILD.bazel b/pkg/expression/aggregation/BUILD.bazel index 502d43d527cbc..d4d35efd5700a 100644 --- a/pkg/expression/aggregation/BUILD.bazel +++ b/pkg/expression/aggregation/BUILD.bazel @@ -50,6 +50,7 @@ go_library( "//pkg/util/mvmap", "//pkg/util/size", "@com_github_pingcap_errors//:errors", + "@com_github_pingcap_failpoint//:failpoint", "@com_github_pingcap_tipb//go-tipb", ], ) diff --git a/pkg/expression/aggregation/aggregation.go b/pkg/expression/aggregation/aggregation.go index 71567ed560534..4229d40dc9e76 100644 --- a/pkg/expression/aggregation/aggregation.go +++ b/pkg/expression/aggregation/aggregation.go @@ -134,6 +134,23 @@ const ( DedupMode ) +// ToString show the agg mode. +func (a AggFunctionMode) ToString() string { + switch a { + case CompleteMode: + return "complete" + case FinalMode: + return "final" + case Partial1Mode: + return "partial1" + case Partial2Mode: + return "partial2" + case DedupMode: + return "deduplicate" + } + return "" +} + type aggFunction struct { *AggFuncDesc } diff --git a/pkg/expression/aggregation/explain.go b/pkg/expression/aggregation/explain.go index 29f88499e1bd1..d89fc08d88dc6 100644 --- a/pkg/expression/aggregation/explain.go +++ b/pkg/expression/aggregation/explain.go @@ -18,6 +18,7 @@ import ( "bytes" "fmt" + "github.com/pingcap/failpoint" "github.com/pingcap/tidb/pkg/expression" "github.com/pingcap/tidb/pkg/parser/ast" ) @@ -25,7 +26,18 @@ import ( // ExplainAggFunc generates explain information for a aggregation function. func ExplainAggFunc(ctx expression.EvalContext, agg *AggFuncDesc, normalized bool) string { var buffer bytes.Buffer - fmt.Fprintf(&buffer, "%s(", agg.Name) + showMode := false + failpoint.Inject("show-agg-mode", func(v failpoint.Value) { + if v.(bool) { + showMode = true + } + }) + if showMode { + fmt.Fprintf(&buffer, "%s(%s,", agg.Name, agg.Mode.ToString()) + } else { + fmt.Fprintf(&buffer, "%s(", agg.Name) + } + if agg.HasDistinct { buffer.WriteString("distinct ") } diff --git a/pkg/planner/core/enforce_mpp_test.go b/pkg/planner/core/enforce_mpp_test.go index f161f7b7bd7b0..6e8deb41b8197 100644 --- a/pkg/planner/core/enforce_mpp_test.go +++ b/pkg/planner/core/enforce_mpp_test.go @@ -19,12 +19,62 @@ import ( "strconv" "testing" + "github.com/pingcap/failpoint" "github.com/pingcap/tidb/pkg/domain" "github.com/pingcap/tidb/pkg/parser/model" "github.com/pingcap/tidb/pkg/testkit" "github.com/stretchr/testify/require" ) +func TestMppAggShouldAlignFinalMode(t *testing.T) { + store := testkit.CreateMockStore(t) + tk := testkit.NewTestKit(t, store) + tk.MustExec("use test") + tk.MustExec("create table t (" + + " d date," + + " v int," + + " primary key(d, v)" + + ") partition by range columns (d) (" + + " partition p1 values less than ('2023-07-02')," + + " partition p2 values less than ('2023-07-03')" + + ");") + // Create virtual tiflash replica info. + dom := domain.GetDomain(tk.Session()) + is := dom.InfoSchema() + db, exists := is.SchemaByName(model.NewCIStr("test")) + require.True(t, exists) + for _, tblInfo := range db.Tables { + if tblInfo.Name.L == "t" { + tblInfo.TiFlashReplica = &model.TiFlashReplicaInfo{ + Count: 1, + Available: true, + } + } + } + tk.MustExec(`set tidb_partition_prune_mode='static';`) + err := failpoint.Enable("github.com/pingcap/tidb/pkg/expression/aggregation/show-agg-mode", "return(true)") + require.Nil(t, err) + + tk.MustQuery("explain format='brief' select 1 from (" + + " select /*+ read_from_storage(tiflash[t]) */ /*+ set_var(mpp_version=\"0\") */ sum(1)" + + " from t where d BETWEEN '2023-07-01' and '2023-07-03' group by d" + + ") total;").Check(testkit.Rows("Projection 400.00 root 1->Column#4", + "└─HashAgg 400.00 root group by:test.t.d, funcs:count(complete,1)->Column#8", + " └─PartitionUnion 400.00 root ", + " ├─Projection 200.00 root test.t.d", + " │ └─HashAgg 200.00 root group by:test.t.d, funcs:firstrow(partial2,test.t.d)->test.t.d, funcs:count(final,Column#12)->Column#9", + " │ └─TableReader 200.00 root data:HashAgg", + " │ └─HashAgg 200.00 cop[tikv] group by:test.t.d, funcs:count(partial1,1)->Column#12", + " │ └─TableRangeScan 250.00 cop[tikv] table:t, partition:p1 range:[2023-07-01,2023-07-03], keep order:false, stats:pseudo", + " └─Projection 200.00 root test.t.d", + " └─HashAgg 200.00 root group by:test.t.d, funcs:firstrow(partial2,test.t.d)->test.t.d, funcs:count(final,Column#16)->Column#10", + " └─TableReader 200.00 root data:HashAgg", + " └─HashAgg 200.00 cop[tikv] group by:test.t.d, funcs:count(partial1,1)->Column#16", + " └─TableRangeScan 250.00 cop[tikv] table:t, partition:p2 range:[2023-07-01,2023-07-03], keep order:false, stats:pseudo")) + + err = failpoint.Disable("github.com/pingcap/tidb/pkg/expression/aggregation/show-agg-mode") + require.Nil(t, err) +} func TestRowSizeInMPP(t *testing.T) { store := testkit.CreateMockStore(t) tk := testkit.NewTestKit(t, store) diff --git a/pkg/planner/core/exhaust_physical_plans.go b/pkg/planner/core/exhaust_physical_plans.go index db85041869b00..43e4dea250222 100644 --- a/pkg/planner/core/exhaust_physical_plans.go +++ b/pkg/planner/core/exhaust_physical_plans.go @@ -3382,6 +3382,22 @@ func (la *LogicalAggregation) tryToGetMppHashAggs(prop *property.PhysicalPropert } } } + // ref: https://github.com/pingcap/tiflash/blob/3ebb102fba17dce3d990d824a9df93d93f1ab + // 766/dbms/src/Flash/Coprocessor/AggregationInterpreterHelper.cpp#L26 + validMppAgg := func(mppAgg *PhysicalHashAgg) bool { + isFinalAgg := true + if mppAgg.AggFuncs[0].Mode != aggregation.FinalMode && mppAgg.AggFuncs[0].Mode != aggregation.CompleteMode { + isFinalAgg = false + } + for _, one := range mppAgg.AggFuncs[1:] { + otherIsFinalAgg := one.Mode == aggregation.FinalMode || one.Mode == aggregation.CompleteMode + if isFinalAgg != otherIsFinalAgg { + // different agg mode detected in mpp side. + return false + } + } + return true + } if len(la.GroupByItems) > 0 { partitionCols := la.GetPotentialPartitionKeys() @@ -3415,7 +3431,9 @@ func (la *LogicalAggregation) tryToGetMppHashAggs(prop *property.PhysicalPropert agg.SetSchema(la.schema.Clone()) agg.MppRunMode = Mpp1Phase finalAggAdjust(agg.AggFuncs) - hashAggs = append(hashAggs, agg) + if validMppAgg(agg) { + hashAggs = append(hashAggs, agg) + } } // Final agg can't be split into multi-stage aggregate, so exit early @@ -3430,7 +3448,9 @@ func (la *LogicalAggregation) tryToGetMppHashAggs(prop *property.PhysicalPropert agg.SetSchema(la.schema.Clone()) agg.MppRunMode = Mpp2Phase agg.MppPartitionCols = partitionCols - hashAggs = append(hashAggs, agg) + if validMppAgg(agg) { + hashAggs = append(hashAggs, agg) + } // agg runs on TiDB with a partial agg on TiFlash if possible if prop.TaskTp == property.RootTaskType {