diff --git a/expression/grouping_sets.go b/expression/grouping_sets.go index b252cacdf059a..14f57b0567364 100644 --- a/expression/grouping_sets.go +++ b/expression/grouping_sets.go @@ -18,7 +18,9 @@ import ( "strings" "github.com/pingcap/tidb/kv" + "github.com/pingcap/tidb/parser/mysql" fd "github.com/pingcap/tidb/planner/funcdep" + "github.com/pingcap/tidb/sessionctx" "github.com/pingcap/tidb/sessionctx/stmtctx" "github.com/pingcap/tidb/util/size" "github.com/pingcap/tipb/go-tipb" @@ -174,7 +176,7 @@ func (gss GroupingSets) TargetOne(normalAggArgs []Expression) int { allGroupingColIDs := gss.AllSetsColIDs() for idx, groupingSet := range gss { // diffCols are those columns being filled with null in the group row data of current grouping set. - diffCols := allGroupingColIDs.Difference(*groupingSet.allSetColIDs()) + diffCols := allGroupingColIDs.Difference(*groupingSet.AllColIDs()) if diffCols.Intersects(normalAggArgsIDSet) { // normal agg's target arg columns are being filled with null value in this grouping set, continue next grouping set check. continue @@ -191,7 +193,7 @@ func (gss GroupingSets) NeedCloneColumn() bool { // the column c should be copied one more time here, otherwise it will be filled with null values and not visible for the other grouping set again. setIDs := make([]*fd.FastIntSet, 0, len(gss)) for _, groupingSet := range gss { - setIDs = append(setIDs, groupingSet.allSetColIDs()) + setIDs = append(setIDs, groupingSet.AllColIDs()) } for idx, oneSetIDs := range setIDs { for j := idx + 1; j < len(setIDs); j++ { @@ -217,7 +219,8 @@ func (gs GroupingSet) IsEmpty() bool { return true } -func (gs GroupingSet) allSetColIDs() *fd.FastIntSet { +// AllColIDs collect all the grouping col's uniqueID. (here assuming that all the grouping expressions are single col) +func (gs GroupingSet) AllColIDs() *fd.FastIntSet { res := fd.NewFastIntSet() for _, groupingExprs := range gs { for _, one := range groupingExprs { @@ -300,7 +303,7 @@ func (gss GroupingSets) IsEmpty() bool { func (gss GroupingSets) AllSetsColIDs() *fd.FastIntSet { res := fd.NewFastIntSet() for _, groupingSet := range gss { - res.UnionWith(*groupingSet.allSetColIDs()) + res.UnionWith(*groupingSet.AllColIDs()) } return &res } @@ -396,3 +399,213 @@ func (g GroupingExprs) MemoryUsage() int64 { } return sum } + +/********************************* ROllUP SUPPORT UTILITY *********************************/ +// +// case: select count(1) from t group by a,a,a,a,a with rollup. +// +// Eg: in the group by a,a,a,a,a with rollup cases. We don't really need rollup +// the grouping sets as: {a,a,a,a,a},{a,a,a,a},{a,a,a},{a,a},{a},{}, although it +// can work with copying the column 'a' multi times and distinguish them as a1, a2... +// +// since group by {a,a} is same as group by {a,a,a} and {a}...; or since group by +// {b,a,a} is same as group by {b,a} and {b,a,a,a}. In order to avoid unnecessary +// redundant column copy, we introduce a new gen_col named as "grouping_pos" to +// distinguish several grouping set with same basic distinct group items. +// +// the above case can be simplified as: {a, gid, gpos}, and gid is used to distinguish +// deduplicated grouping sets, and gpos is indicated to distinguish duplicated grouping +// sets. the grouping sets are: +// {a,0,0}, {a,0,1}, {a,0,2}, {a,0,3}, {a,0,4},{null,1,5} +// +// when a is existed, the gid is all the same as 0, that's to say, the first 5 grouping +// set are all the same one grouping set logically, but why we still need to store them +// 5 times? because we want to output needed aggregation row as required from rollup syntax; +// so here we use gpos to distinguish them. In sixth grouping set, the 'a' is null, the +// gid is newed as 1, it means a new grouping set different with those previous ones, for +// computation convenience, we increment its gpos number as well. +// +// so should care about the gpos when doing the data aggregation? Yes we did, in the above +// case, shuffle keys should also include the grouping_pos, while the grouping function +// still only care about the gid as before. +// +// mysql> select count(1), grouping(a+1) from t group by a+1, 1+a with rollup; +// original grouping sets are: +// 1: {a+1, 1+a} ---> duplicate! proj a+1 let's say as column#1 +// 2: {a+1} +// 3: {} +// distinguish them in projection item: +// Expand schema:[a, column#1, gid, gpos] +// | L1 proj: [a, column#1, 0, 0]; L2 proj: [a, column#1, 0, 1]; L3 proj: [a, null, 1, 2] +// | +// +--- Projection a+1 as column#1 +// +// +----------+---------------+ +// | count(1) | grouping(a+1) | +// +----------+---------------+ a+1, gid, gpos +// | 1 | 0 | ---+-----> grouping set{column#1, 0, 0} +// | 1 | 0 | ---+ +// | 1 | 0 | ---+-----> grouping set{column#1, 0, 1} +// | 1 | 0 | ---+ +// | 2 | 1 | ---------> grouping set{column#1, 1, 2} +// +----------+---------------+ +// 5 rows in set (0.01 sec) +// +// Yes, as you see it, tidb should be able to +// 1: grouping item semantic equivalence check logic +// 2: matching select list item to group by item (this seems hard) + +// RollupGroupingSets roll up the given expressions, and iterate out a slice of grouping set expressions. +func RollupGroupingSets(rollupExprs []Expression) GroupingSets { + // rollupExpr is recursive decrease pattern. + // eg: [a,b,c] => {a,b,c}, {a,b},{a},{} + res := GroupingSets{} + // loop from -1 leading the inner loop start from {}, {a}, ... + for i := -1; i < len(rollupExprs); i++ { + groupingExprs := GroupingExprs{} + for j := 0; j <= i; j++ { + // append all the items from 0 to j. + groupingExprs = append(groupingExprs, rollupExprs[j]) + } + res = append(res, newGroupingSet(groupingExprs)) + } + return res +} + +// AdjustNullabilityFromGroupingSets adjust the nullability of the Expand schema out. +func AdjustNullabilityFromGroupingSets(gss GroupingSets, schema *Schema) { + // If anyone (grouping set) of the grouping sets doesn't include one grouping-set col, meaning that + // this col must be fill the null value in one level projection. + // Exception: let's say a non-rollup case: grouping sets (a,b,c),(a,b),(a), the `a` is in every grouping + // set, so it won't be filled with null value at any time, the nullable change is unnecessary. + groupingIDs := gss.AllSetsColIDs() + // cache the grouping ids set to avoid fetch them multi times below. + groupingIDsSlice := make([]*fd.FastIntSet, 0, len(gss)) + for _, oneGroupingSet := range gss { + groupingIDsSlice = append(groupingIDsSlice, oneGroupingSet.AllColIDs()) + } + for idx, col := range schema.Columns { + if !groupingIDs.Has(int(col.UniqueID)) { + // not a grouping-set col, maybe un-related column, keep it real. + continue + } + for _, oneSetIDs := range groupingIDsSlice { + if !oneSetIDs.Has(int(col.UniqueID)) { + // this col may be projected as null value, change it as nullable. + schema.Columns[idx].RetType.SetFlag(col.RetType.GetFlag() &^ mysql.NotNullFlag) + } + } + } +} + +// DeduplicateGbyExpression is try to detect the semantic equivalence expression, and mark them into same position. +// eg: group by a+b, b+a, b with rollup. +// the 1st and 2nd expression is semantically equivalent, so we only need to keep the distinct expression: [a+b, b] +// down, and output another position slice out, the [0, 0, 1] for the case above. +func DeduplicateGbyExpression(ctx sessionctx.Context, exprs []Expression) ([]Expression, []int) { + distinctExprs := make([]Expression, 0, len(exprs)) + sc := ctx.GetSessionVars().StmtCtx + sc.CanonicalHashCode = true + defer func() { + sc.CanonicalHashCode = false + }() + distinctMap := make(map[string]int, len(exprs)) + for _, expr := range exprs { + // -1 means pos is not assigned yet. + distinctMap[string(expr.HashCode(sc))] = -1 + } + // pos is from 0 to len(distinctMap)-1 + pos := 0 + posSlice := make([]int, 0, len(exprs)) + for _, one := range exprs { + key := string(one.HashCode(sc)) + if val, ok := distinctMap[key]; ok { + if val == -1 { + // means a new distinct expr. + distinctExprs = append(distinctExprs, one) + posSlice = append(posSlice, pos) + // assign this pos to val. + distinctMap[key] = pos + pos++ + } else { + // means this expr is SemanticEqual with a previous one, ref the pos + posSlice = append(posSlice, val) + } + } + } + return distinctExprs, posSlice +} + +// RestoreGbyExpression restore the new gby expression according to recorded idxes reference. +func RestoreGbyExpression(exprs []*Column, idxes []int) []Expression { + res := make([]Expression, 0, len(idxes)) + for _, pos := range idxes { + res = append(res, exprs[pos].Clone()) + } + return res +} + +// DistinctSize judge whether there exist duplicate grouping set and output the distinct grouping set size. +// for example: grouping set{a, a, b}, and grouping set{a, b}, they actually group on the same thing --- (a,b) +// we can tell that from its unique id set equation. Note: group item here must be a column at instant. +func (gss GroupingSets) DistinctSize() (int, []uint64, map[int]map[uint64]struct{}) { + // default 64 + return gss.DistinctSizeWithThreshold(64) +} + +// DistinctSizeWithThreshold is exported for test with smaller injected N rather constructing a large grouping sets. +func (gss GroupingSets) DistinctSizeWithThreshold(N int) (int, []uint64, map[int]map[uint64]struct{}) { + // all the group by item are col, deduplicate from id-set. + distinctGroupingIDsPos := make([]int, 0, len(gss)) + originGroupingIDsSlice := make([]*fd.FastIntSet, 0, len(gss)) + + for _, oneGroupingSet := range gss { + curIDs := oneGroupingSet.AllColIDs() + duplicate := false + for _, offset := range distinctGroupingIDsPos { + if originGroupingIDsSlice[offset].Equals(*curIDs) { + duplicate = true + break + } + } + if !duplicate { + // if distinct, record its offset in originGroupingIDsSlice. + distinctGroupingIDsPos = append(distinctGroupingIDsPos, len(originGroupingIDsSlice)) + } + // record current grouping set ids. + originGroupingIDsSlice = append(originGroupingIDsSlice, curIDs) + } + // once distinct size greater than 64, just allocating gid for every grouping set here (utilizing cached slice here). + // otherwise, the later gid allocation logic will compute the logic above again. + if len(distinctGroupingIDsPos) > N { + gids := make([]uint64, len(originGroupingIDsSlice)) + // for every distinct grouping set, the gid allocation is from 0 ~ (distinctSize-1). + for gid, offset := range distinctGroupingIDsPos { + oneDistinctSetIDs := originGroupingIDsSlice[offset] + for idx, oneOriginSetIDs := range originGroupingIDsSlice { + // for every equivalent grouping set, store the same gid. + if oneDistinctSetIDs.Equals(*oneOriginSetIDs) { + gids[idx] = uint64(gid) + } + } + } + + allIDs := gss.AllSetsColIDs() + id2GIDs := make(map[int]map[uint64]struct{}, allIDs.Len()) + allIDs.ForEach(func(i int) { + collectionMap := make(map[uint64]struct{}) + // for every original column unique, traverse the all grouping set. + for idx, oneOriginSetIDs := range originGroupingIDsSlice { + if oneOriginSetIDs.Has(i) { + // this column is needed in this grouping set. + continue + } + // this column is not needed in this grouping set.(this column is grouped) + collectionMap[gids[idx]] = struct{}{} + } + id2GIDs[i] = collectionMap + }) + return len(distinctGroupingIDsPos), gids, id2GIDs + } + return len(distinctGroupingIDsPos), nil, nil +} diff --git a/expression/grouping_sets_test.go b/expression/grouping_sets_test.go index e4f636bb9845d..3081ea426f1c2 100644 --- a/expression/grouping_sets_test.go +++ b/expression/grouping_sets_test.go @@ -20,6 +20,7 @@ import ( "github.com/pingcap/tidb/parser/ast" "github.com/pingcap/tidb/parser/mysql" "github.com/pingcap/tidb/types" + "github.com/pingcap/tidb/util/mock" "github.com/stretchr/testify/require" "go.opencensus.io/stats/view" ) @@ -246,3 +247,131 @@ func TestGroupingSetsMergeUnitTest(t *testing.T) { // [d] require.Equal(t, len(newGroupingSets[1][0]), 1) } + +func TestDistinctGroupingSets(t *testing.T) { + defer view.Stop() + // premise: every grouping item in grouping sets should be a col. + a := &Column{ + UniqueID: 1, + } + b := &Column{ + UniqueID: 2, + } + c := &Column{ + UniqueID: 3, + } + d := &Column{ + UniqueID: 4, + } + // case1: non-duplicated case. + // raw rollup expressions: [a,b,c,d] + rawRollupExprs := []Expression{a, b, c, d} + mockCtx := mock.NewContext() + deduplicateExprs, pos := DeduplicateGbyExpression(mockCtx, rawRollupExprs) + + // nothing to deduplicate. + require.Equal(t, len(rawRollupExprs), len(deduplicateExprs)) + require.Equal(t, len(pos), 4) + require.Equal(t, pos[0], 0) + require.Equal(t, pos[1], 1) + require.Equal(t, pos[2], 2) + require.Equal(t, pos[3], 3) + + // rollup grouping sets. + gss := RollupGroupingSets(deduplicateExprs) + require.Equal(t, len(gss), 5) // len = N + 1 (N = len(raw rollup expression)) + require.Equal(t, gss[0].AllColIDs().String(), "()") + require.Equal(t, gss[1].AllColIDs().String(), "(1)") + require.Equal(t, gss[2].AllColIDs().String(), "(1,2)") + require.Equal(t, gss[3].AllColIDs().String(), "(1-3)") + require.Equal(t, gss[4].AllColIDs().String(), "(1-4)") + + size, _, _ := gss.DistinctSize() + require.Equal(t, size, 5) + + // case2: duplicated case. + rawRollupExprs = []Expression{a, b, b, c} + deduplicateExprs, pos = DeduplicateGbyExpression(mockCtx, rawRollupExprs) + require.Equal(t, len(deduplicateExprs), 3) + require.Equal(t, deduplicateExprs[0].String(), "Column#1") + require.Equal(t, deduplicateExprs[1].String(), "Column#2") + require.Equal(t, deduplicateExprs[2].String(), "Column#3") + deduplicateColumns := make([]*Column, 0, len(deduplicateExprs)) + for _, one := range deduplicateExprs { + deduplicateColumns = append(deduplicateColumns, one.(*Column)) + } + require.Equal(t, len(pos), 4) + require.Equal(t, pos[0], 0) + require.Equal(t, pos[1], 1) // ref to column#2 + require.Equal(t, pos[2], 1) // ref to column#2 + require.Equal(t, pos[3], 2) + + // restoreGbyExpression (some complicated expression will be projected as column before enter Expand) + // so cases like below is possible: + // GBY: a+b, b+a, c ==> Expand: rollup with (column#1, column#1, c) + // +----- proj: (a+b) as column#1 + // so that why restore gby expression according to their pos is necessary. + restoreGbyExpressions := RestoreGbyExpression(deduplicateColumns, pos) + require.Equal(t, len(restoreGbyExpressions), 4) + require.Equal(t, restoreGbyExpressions[0].String(), "Column#1") + require.Equal(t, restoreGbyExpressions[1].String(), "Column#2") + require.Equal(t, restoreGbyExpressions[2].String(), "Column#2") + require.Equal(t, restoreGbyExpressions[3].String(), "Column#3") + + // rollup grouping sets (build grouping sets on the restored gby expression, because all the + // complicated expressions have been projected as simple columns at this time). + gss = RollupGroupingSets(restoreGbyExpressions) + require.Equal(t, len(gss), 5) // len = N + 1 (N = len(raw rollup expression)) + require.Equal(t, gss[0].AllColIDs().String(), "()") + require.Equal(t, gss[1].AllColIDs().String(), "(1)") + require.Equal(t, gss[2].AllColIDs().String(), "(1,2)") + require.Equal(t, gss[3].AllColIDs().String(), "(1,2)") + require.Equal(t, gss[4].AllColIDs().String(), "(1-3)") + + // after gss is built on all the simple columns, get the distinct size if possible. + size, _, _ = gss.DistinctSize() + require.Equal(t, size, 4) + + // case3: when grouping sets number bigger than 64, we will allocate gid incrementally rather bit map. + // here we inject the N as 3, so we can test the computation logic easily. + size, gids, id2Gids := gss.DistinctSizeWithThreshold(3) + require.Equal(t, size, 4) + require.Equal(t, len(gids), len(gss)) + // assigned gids should be equal to the origin gss size. + // for this case: + // {} () --> 0 + // {a} (1) --> 1 + // {a,b} (1,2) --> 2 --+---> has the same gid + // {a,b,b} (1,2) --> 2 __/ + // {a,b,b,c} (1,2,3) --> 3 + require.Equal(t, gids[0], uint64(0)) + require.Equal(t, gids[1], uint64(1)) + require.Equal(t, gids[2], uint64(2)) + require.Equal(t, gids[3], uint64(2)) + require.Equal(t, gids[4], uint64(3)) + + // for every col id, mapping them to a slice of gid. + require.Equal(t, len(id2Gids), 3) + // 0 --> when grouping(column#1), the corresponding affected gids should be {0} + // +--- explanation: when grouping(a), a is only grouped when grouping id = 0. + require.Equal(t, len(id2Gids[1]), 1) + _, ok := id2Gids[1][0] + require.Equal(t, ok, true) + // 1 --> when grouping(column#2), the corresponding affected gids should be {0,1} + // +--- explanation: when grouping(b), b is only grouped when grouping id = 0 or 1. + require.Equal(t, len(id2Gids[2]), 2) + _, ok = id2Gids[2][0] + require.Equal(t, ok, true) + _, ok = id2Gids[2][1] + require.Equal(t, ok, true) + // 2 --> when grouping(column#3), the corresponding affected gids should be {0,1,2} + // +--- explanation: when grouping(c), c is only grouped when grouping id = 0 or 1 or 2. + require.Equal(t, len(id2Gids[3]), 3) + _, ok = id2Gids[3][0] + require.Equal(t, ok, true) + _, ok = id2Gids[3][1] + require.Equal(t, ok, true) + _, ok = id2Gids[3][2] + require.Equal(t, ok, true) + // column d is not in the grouping set columns, so it won't be here. +}