Skip to content

Commit

Permalink
expression: add rollup grouping sets util (#44105)
Browse files Browse the repository at this point in the history
close #44112
  • Loading branch information
AilinKid authored May 24, 2023
1 parent aa7a386 commit 1b0e019
Show file tree
Hide file tree
Showing 2 changed files with 346 additions and 4 deletions.
221 changes: 217 additions & 4 deletions expression/grouping_sets.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,9 @@ import (
"strings"

"github.com/pingcap/tidb/kv"
"github.com/pingcap/tidb/parser/mysql"
fd "github.com/pingcap/tidb/planner/funcdep"
"github.com/pingcap/tidb/sessionctx"
"github.com/pingcap/tidb/sessionctx/stmtctx"
"github.com/pingcap/tidb/util/size"
"github.com/pingcap/tipb/go-tipb"
Expand Down Expand Up @@ -174,7 +176,7 @@ func (gss GroupingSets) TargetOne(normalAggArgs []Expression) int {
allGroupingColIDs := gss.AllSetsColIDs()
for idx, groupingSet := range gss {
// diffCols are those columns being filled with null in the group row data of current grouping set.
diffCols := allGroupingColIDs.Difference(*groupingSet.allSetColIDs())
diffCols := allGroupingColIDs.Difference(*groupingSet.AllColIDs())
if diffCols.Intersects(normalAggArgsIDSet) {
// normal agg's target arg columns are being filled with null value in this grouping set, continue next grouping set check.
continue
Expand All @@ -191,7 +193,7 @@ func (gss GroupingSets) NeedCloneColumn() bool {
// the column c should be copied one more time here, otherwise it will be filled with null values and not visible for the other grouping set again.
setIDs := make([]*fd.FastIntSet, 0, len(gss))
for _, groupingSet := range gss {
setIDs = append(setIDs, groupingSet.allSetColIDs())
setIDs = append(setIDs, groupingSet.AllColIDs())
}
for idx, oneSetIDs := range setIDs {
for j := idx + 1; j < len(setIDs); j++ {
Expand All @@ -217,7 +219,8 @@ func (gs GroupingSet) IsEmpty() bool {
return true
}

func (gs GroupingSet) allSetColIDs() *fd.FastIntSet {
// AllColIDs collect all the grouping col's uniqueID. (here assuming that all the grouping expressions are single col)
func (gs GroupingSet) AllColIDs() *fd.FastIntSet {
res := fd.NewFastIntSet()
for _, groupingExprs := range gs {
for _, one := range groupingExprs {
Expand Down Expand Up @@ -300,7 +303,7 @@ func (gss GroupingSets) IsEmpty() bool {
func (gss GroupingSets) AllSetsColIDs() *fd.FastIntSet {
res := fd.NewFastIntSet()
for _, groupingSet := range gss {
res.UnionWith(*groupingSet.allSetColIDs())
res.UnionWith(*groupingSet.AllColIDs())
}
return &res
}
Expand Down Expand Up @@ -396,3 +399,213 @@ func (g GroupingExprs) MemoryUsage() int64 {
}
return sum
}

/********************************* ROllUP SUPPORT UTILITY *********************************/
//
// case: select count(1) from t group by a,a,a,a,a with rollup.
//
// Eg: in the group by a,a,a,a,a with rollup cases. We don't really need rollup
// the grouping sets as: {a,a,a,a,a},{a,a,a,a},{a,a,a},{a,a},{a},{}, although it
// can work with copying the column 'a' multi times and distinguish them as a1, a2...
//
// since group by {a,a} is same as group by {a,a,a} and {a}...; or since group by
// {b,a,a} is same as group by {b,a} and {b,a,a,a}. In order to avoid unnecessary
// redundant column copy, we introduce a new gen_col named as "grouping_pos" to
// distinguish several grouping set with same basic distinct group items.
//
// the above case can be simplified as: {a, gid, gpos}, and gid is used to distinguish
// deduplicated grouping sets, and gpos is indicated to distinguish duplicated grouping
// sets. the grouping sets are:
// {a,0,0}, {a,0,1}, {a,0,2}, {a,0,3}, {a,0,4},{null,1,5}
//
// when a is existed, the gid is all the same as 0, that's to say, the first 5 grouping
// set are all the same one grouping set logically, but why we still need to store them
// 5 times? because we want to output needed aggregation row as required from rollup syntax;
// so here we use gpos to distinguish them. In sixth grouping set, the 'a' is null, the
// gid is newed as 1, it means a new grouping set different with those previous ones, for
// computation convenience, we increment its gpos number as well.
//
// so should care about the gpos when doing the data aggregation? Yes we did, in the above
// case, shuffle keys should also include the grouping_pos, while the grouping function
// still only care about the gid as before.
//
// mysql> select count(1), grouping(a+1) from t group by a+1, 1+a with rollup;
// original grouping sets are:
// 1: {a+1, 1+a} ---> duplicate! proj a+1 let's say as column#1
// 2: {a+1}
// 3: {}
// distinguish them in projection item:
// Expand schema:[a, column#1, gid, gpos]
// | L1 proj: [a, column#1, 0, 0]; L2 proj: [a, column#1, 0, 1]; L3 proj: [a, null, 1, 2]
// |
// +--- Projection a+1 as column#1
//
// +----------+---------------+
// | count(1) | grouping(a+1) |
// +----------+---------------+ a+1, gid, gpos
// | 1 | 0 | ---+-----> grouping set{column#1, 0, 0}
// | 1 | 0 | ---+
// | 1 | 0 | ---+-----> grouping set{column#1, 0, 1}
// | 1 | 0 | ---+
// | 2 | 1 | ---------> grouping set{column#1, 1, 2}
// +----------+---------------+
// 5 rows in set (0.01 sec)
//
// Yes, as you see it, tidb should be able to
// 1: grouping item semantic equivalence check logic
// 2: matching select list item to group by item (this seems hard)

// RollupGroupingSets roll up the given expressions, and iterate out a slice of grouping set expressions.
func RollupGroupingSets(rollupExprs []Expression) GroupingSets {
// rollupExpr is recursive decrease pattern.
// eg: [a,b,c] => {a,b,c}, {a,b},{a},{}
res := GroupingSets{}
// loop from -1 leading the inner loop start from {}, {a}, ...
for i := -1; i < len(rollupExprs); i++ {
groupingExprs := GroupingExprs{}
for j := 0; j <= i; j++ {
// append all the items from 0 to j.
groupingExprs = append(groupingExprs, rollupExprs[j])
}
res = append(res, newGroupingSet(groupingExprs))
}
return res
}

// AdjustNullabilityFromGroupingSets adjust the nullability of the Expand schema out.
func AdjustNullabilityFromGroupingSets(gss GroupingSets, schema *Schema) {
// If anyone (grouping set) of the grouping sets doesn't include one grouping-set col, meaning that
// this col must be fill the null value in one level projection.
// Exception: let's say a non-rollup case: grouping sets (a,b,c),(a,b),(a), the `a` is in every grouping
// set, so it won't be filled with null value at any time, the nullable change is unnecessary.
groupingIDs := gss.AllSetsColIDs()
// cache the grouping ids set to avoid fetch them multi times below.
groupingIDsSlice := make([]*fd.FastIntSet, 0, len(gss))
for _, oneGroupingSet := range gss {
groupingIDsSlice = append(groupingIDsSlice, oneGroupingSet.AllColIDs())
}
for idx, col := range schema.Columns {
if !groupingIDs.Has(int(col.UniqueID)) {
// not a grouping-set col, maybe un-related column, keep it real.
continue
}
for _, oneSetIDs := range groupingIDsSlice {
if !oneSetIDs.Has(int(col.UniqueID)) {
// this col may be projected as null value, change it as nullable.
schema.Columns[idx].RetType.SetFlag(col.RetType.GetFlag() &^ mysql.NotNullFlag)
}
}
}
}

// DeduplicateGbyExpression is try to detect the semantic equivalence expression, and mark them into same position.
// eg: group by a+b, b+a, b with rollup.
// the 1st and 2nd expression is semantically equivalent, so we only need to keep the distinct expression: [a+b, b]
// down, and output another position slice out, the [0, 0, 1] for the case above.
func DeduplicateGbyExpression(ctx sessionctx.Context, exprs []Expression) ([]Expression, []int) {
distinctExprs := make([]Expression, 0, len(exprs))
sc := ctx.GetSessionVars().StmtCtx
sc.CanonicalHashCode = true
defer func() {
sc.CanonicalHashCode = false
}()
distinctMap := make(map[string]int, len(exprs))
for _, expr := range exprs {
// -1 means pos is not assigned yet.
distinctMap[string(expr.HashCode(sc))] = -1
}
// pos is from 0 to len(distinctMap)-1
pos := 0
posSlice := make([]int, 0, len(exprs))
for _, one := range exprs {
key := string(one.HashCode(sc))
if val, ok := distinctMap[key]; ok {
if val == -1 {
// means a new distinct expr.
distinctExprs = append(distinctExprs, one)
posSlice = append(posSlice, pos)
// assign this pos to val.
distinctMap[key] = pos
pos++
} else {
// means this expr is SemanticEqual with a previous one, ref the pos
posSlice = append(posSlice, val)
}
}
}
return distinctExprs, posSlice
}

// RestoreGbyExpression restore the new gby expression according to recorded idxes reference.
func RestoreGbyExpression(exprs []*Column, idxes []int) []Expression {
res := make([]Expression, 0, len(idxes))
for _, pos := range idxes {
res = append(res, exprs[pos].Clone())
}
return res
}

// DistinctSize judge whether there exist duplicate grouping set and output the distinct grouping set size.
// for example: grouping set{a, a, b}, and grouping set{a, b}, they actually group on the same thing --- (a,b)
// we can tell that from its unique id set equation. Note: group item here must be a column at instant.
func (gss GroupingSets) DistinctSize() (int, []uint64, map[int]map[uint64]struct{}) {
// default 64
return gss.DistinctSizeWithThreshold(64)
}

// DistinctSizeWithThreshold is exported for test with smaller injected N rather constructing a large grouping sets.
func (gss GroupingSets) DistinctSizeWithThreshold(N int) (int, []uint64, map[int]map[uint64]struct{}) {
// all the group by item are col, deduplicate from id-set.
distinctGroupingIDsPos := make([]int, 0, len(gss))
originGroupingIDsSlice := make([]*fd.FastIntSet, 0, len(gss))

for _, oneGroupingSet := range gss {
curIDs := oneGroupingSet.AllColIDs()
duplicate := false
for _, offset := range distinctGroupingIDsPos {
if originGroupingIDsSlice[offset].Equals(*curIDs) {
duplicate = true
break
}
}
if !duplicate {
// if distinct, record its offset in originGroupingIDsSlice.
distinctGroupingIDsPos = append(distinctGroupingIDsPos, len(originGroupingIDsSlice))
}
// record current grouping set ids.
originGroupingIDsSlice = append(originGroupingIDsSlice, curIDs)
}
// once distinct size greater than 64, just allocating gid for every grouping set here (utilizing cached slice here).
// otherwise, the later gid allocation logic will compute the logic above again.
if len(distinctGroupingIDsPos) > N {
gids := make([]uint64, len(originGroupingIDsSlice))
// for every distinct grouping set, the gid allocation is from 0 ~ (distinctSize-1).
for gid, offset := range distinctGroupingIDsPos {
oneDistinctSetIDs := originGroupingIDsSlice[offset]
for idx, oneOriginSetIDs := range originGroupingIDsSlice {
// for every equivalent grouping set, store the same gid.
if oneDistinctSetIDs.Equals(*oneOriginSetIDs) {
gids[idx] = uint64(gid)
}
}
}

allIDs := gss.AllSetsColIDs()
id2GIDs := make(map[int]map[uint64]struct{}, allIDs.Len())
allIDs.ForEach(func(i int) {
collectionMap := make(map[uint64]struct{})
// for every original column unique, traverse the all grouping set.
for idx, oneOriginSetIDs := range originGroupingIDsSlice {
if oneOriginSetIDs.Has(i) {
// this column is needed in this grouping set.
continue
}
// this column is not needed in this grouping set.(this column is grouped)
collectionMap[gids[idx]] = struct{}{}
}
id2GIDs[i] = collectionMap
})
return len(distinctGroupingIDsPos), gids, id2GIDs
}
return len(distinctGroupingIDsPos), nil, nil
}
Loading

0 comments on commit 1b0e019

Please sign in to comment.