diff --git a/pkg/planner/core/generator/hash64_equals/hash64_equals_generator.go b/pkg/planner/core/generator/hash64_equals/hash64_equals_generator.go index 29843f055f262..714491d91f3cd 100644 --- a/pkg/planner/core/generator/hash64_equals/hash64_equals_generator.go +++ b/pkg/planner/core/generator/hash64_equals/hash64_equals_generator.go @@ -21,6 +21,7 @@ import ( "log" "os" "reflect" + "strings" "github.com/pingcap/tidb/pkg/planner/cascades/base" "github.com/pingcap/tidb/pkg/planner/core/operator/logicalop" @@ -34,7 +35,7 @@ import ( // If a field is tagged with `hash64-equals`, then it will be computed in hash64 and equals func. // If a field is not tagged, then it will be skipped. func GenHash64Equals4LogicalOps() ([]byte, error) { - var structures = []any{logicalop.LogicalJoin{}, logicalop.LogicalAggregation{}, logicalop.LogicalApply{}} + var structures = []any{logicalop.LogicalJoin{}, logicalop.LogicalAggregation{}, logicalop.LogicalApply{}, logicalop.LogicalExpand{}} c := new(cc) c.write(codeGenHash64EqualsPrefix) for _, s := range structures { @@ -101,6 +102,8 @@ func logicalOpName2PlanCodecString(name string) string { return "plancodec.TypeAgg" case "LogicalApply": return "plancodec.TypeApply" + case "LogicalExpand": + return "plancodec.TypeExpand" default: return "" } @@ -115,7 +118,11 @@ func (c *cc) EqualsElement(fType reflect.Type, lhs, rhs string, i string) { switch fType.Kind() { case reflect.Slice: c.write("if len(%v) != len(%v) { return false }", lhs, rhs) - c.write("for %v, one := range %v {", i, lhs) + itemName := "one" + if strings.HasPrefix(lhs, "one") { + itemName = lhs + "e" + } + c.write("for %v, %v := range %v {", i, itemName, lhs) // one more round rhs = rhs + "[" + i + "]" // for ?, one := range [][][][]... @@ -126,7 +133,7 @@ func (c *cc) EqualsElement(fType reflect.Type, lhs, rhs string, i string) { // for iii, one := range [] // and so on... newi := i + "i" - c.EqualsElement(fType.Elem(), "one", rhs, newi) + c.EqualsElement(fType.Elem(), itemName, rhs, newi) c.write("}") case reflect.String, reflect.Bool, reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Float32, reflect.Float64: @@ -147,9 +154,13 @@ func (c *cc) Hash64Element(fType reflect.Type, callName string) { switch fType.Kind() { case reflect.Slice: c.write("h.HashInt(len(%v))", callName) - c.write("for _, one := range %v {", callName) + itemName := "one" + if strings.HasPrefix(callName, "one") { + itemName = callName + "e" + } + c.write("for _, %v := range %v {", itemName, callName) // one more round - c.Hash64Element(fType.Elem(), "one") + c.Hash64Element(fType.Elem(), itemName) c.write("}") case reflect.String: c.write("h.HashString(%v)", callName) diff --git a/pkg/planner/core/operator/logicalop/hash64_equals_generated.go b/pkg/planner/core/operator/logicalop/hash64_equals_generated.go index e756c932ccce7..65a551921e762 100644 --- a/pkg/planner/core/operator/logicalop/hash64_equals_generated.go +++ b/pkg/planner/core/operator/logicalop/hash64_equals_generated.go @@ -155,8 +155,8 @@ func (op *LogicalAggregation) Hash64(h base.Hasher) { h.HashInt(len(op.PossibleProperties)) for _, one := range op.PossibleProperties { h.HashInt(len(one)) - for _, one := range one { - one.Hash64(h) + for _, onee := range one { + onee.Hash64(h) } } } @@ -194,8 +194,8 @@ func (op *LogicalAggregation) Equals(other any) bool { if len(one) != len(op2.PossibleProperties[i]) { return false } - for ii, one := range one { - if !one.Equals(op2.PossibleProperties[i][ii]) { + for ii, onee := range one { + if !onee.Equals(op2.PossibleProperties[i][ii]) { return false } } @@ -244,3 +244,134 @@ func (op *LogicalApply) Equals(other any) bool { } return true } + +// Hash64 implements the Hash64Equals interface. +func (op *LogicalExpand) Hash64(h base.Hasher) { + h.HashString(plancodec.TypeExpand) + if op.DistinctGroupByCol == nil { + h.HashByte(base.NilFlag) + } else { + h.HashByte(base.NotNilFlag) + h.HashInt(len(op.DistinctGroupByCol)) + for _, one := range op.DistinctGroupByCol { + one.Hash64(h) + } + } + if op.DistinctGbyExprs == nil { + h.HashByte(base.NilFlag) + } else { + h.HashByte(base.NotNilFlag) + h.HashInt(len(op.DistinctGbyExprs)) + for _, one := range op.DistinctGbyExprs { + one.Hash64(h) + } + } + h.HashInt64(int64(op.DistinctSize)) + if op.RollupGroupingSets == nil { + h.HashByte(base.NilFlag) + } else { + h.HashByte(base.NotNilFlag) + h.HashInt(len(op.RollupGroupingSets)) + for _, one := range op.RollupGroupingSets { + h.HashInt(len(one)) + for _, onee := range one { + h.HashInt(len(onee)) + for _, oneee := range onee { + oneee.Hash64(h) + } + } + } + } + if op.LevelExprs == nil { + h.HashByte(base.NilFlag) + } else { + h.HashByte(base.NotNilFlag) + h.HashInt(len(op.LevelExprs)) + for _, one := range op.LevelExprs { + h.HashInt(len(one)) + for _, onee := range one { + onee.Hash64(h) + } + } + } + if op.GID == nil { + h.HashByte(base.NilFlag) + } else { + h.HashByte(base.NotNilFlag) + op.GID.Hash64(h) + } + if op.GPos == nil { + h.HashByte(base.NilFlag) + } else { + h.HashByte(base.NotNilFlag) + op.GPos.Hash64(h) + } +} + +// Equals implements the Hash64Equals interface, only receive *LogicalExpand pointer. +func (op *LogicalExpand) Equals(other any) bool { + if other == nil { + return false + } + op2, ok := other.(*LogicalExpand) + if !ok { + return false + } + if len(op.DistinctGroupByCol) != len(op2.DistinctGroupByCol) { + return false + } + for i, one := range op.DistinctGroupByCol { + if !one.Equals(op2.DistinctGroupByCol[i]) { + return false + } + } + if len(op.DistinctGbyExprs) != len(op2.DistinctGbyExprs) { + return false + } + for i, one := range op.DistinctGbyExprs { + if !one.Equals(op2.DistinctGbyExprs[i]) { + return false + } + } + if op.DistinctSize != op2.DistinctSize { + return false + } + if len(op.RollupGroupingSets) != len(op2.RollupGroupingSets) { + return false + } + for i, one := range op.RollupGroupingSets { + if len(one) != len(op2.RollupGroupingSets[i]) { + return false + } + for ii, onee := range one { + if len(onee) != len(op2.RollupGroupingSets[i][ii]) { + return false + } + for iii, oneee := range onee { + if !oneee.Equals(op2.RollupGroupingSets[i][ii][iii]) { + return false + } + } + } + } + if len(op.LevelExprs) != len(op2.LevelExprs) { + return false + } + for i, one := range op.LevelExprs { + if len(one) != len(op2.LevelExprs[i]) { + return false + } + for ii, onee := range one { + if !onee.Equals(op2.LevelExprs[i][ii]) { + return false + } + } + } + if !op.GID.Equals(op2.GID) { + return false + } + if !op.GPos.Equals(op2.GPos) { + return false + } + return true +} diff --git a/pkg/planner/core/operator/logicalop/logical_expand.go b/pkg/planner/core/operator/logicalop/logical_expand.go index bd5cc3ba2a054..5760637b8a269 100644 --- a/pkg/planner/core/operator/logicalop/logical_expand.go +++ b/pkg/planner/core/operator/logicalop/logical_expand.go @@ -36,20 +36,20 @@ type LogicalExpand struct { LogicalSchemaProducer // distinct group by columns. (maybe projected below if it's a non-col) - DistinctGroupByCol []*expression.Column + DistinctGroupByCol []*expression.Column `hash64-equals:"true"` DistinctGbyColNames []*types.FieldName // keep the old gbyExprs for resolve cases like grouping(a+b), the args: // a+b should be resolved to new projected gby col according to ref pos. - DistinctGbyExprs []expression.Expression + DistinctGbyExprs []expression.Expression `hash64-equals:"true"` // rollup grouping sets. - DistinctSize int - RollupGroupingSets expression.GroupingSets + DistinctSize int `hash64-equals:"true"` + RollupGroupingSets expression.GroupingSets `hash64-equals:"true"` RollupID2GIDS map[int]map[uint64]struct{} RollupGroupingIDs []uint64 // The level projections is generated from grouping sets,make execution more clearly. - LevelExprs [][]expression.Expression + LevelExprs [][]expression.Expression `hash64-equals:"true"` // The generated column names. Eg: "grouping_id" and so on. ExtraGroupingColNames []string @@ -58,9 +58,9 @@ type LogicalExpand struct { GroupingMode tipb.GroupingMode // The GID and GPos column generated by logical expand if any. - GID *expression.Column + GID *expression.Column `hash64-equals:"true"` GIDName *types.FieldName - GPos *expression.Column + GPos *expression.Column `hash64-equals:"true"` GPosName *types.FieldName } diff --git a/pkg/planner/core/operator/logicalop/logicalop_test/BUILD.bazel b/pkg/planner/core/operator/logicalop/logicalop_test/BUILD.bazel index 7695efef4251c..1a823ad3b63ba 100644 --- a/pkg/planner/core/operator/logicalop/logicalop_test/BUILD.bazel +++ b/pkg/planner/core/operator/logicalop/logicalop_test/BUILD.bazel @@ -8,7 +8,7 @@ go_test( "logical_mem_table_predicate_extractor_test.go", ], flaky = True, - shard_count = 16, + shard_count = 17, deps = [ "//pkg/domain", "//pkg/expression", diff --git a/pkg/planner/core/operator/logicalop/logicalop_test/hash64_equals_test.go b/pkg/planner/core/operator/logicalop/logicalop_test/hash64_equals_test.go index 87cb7b9cdf5a3..ed0e5b882bd5a 100644 --- a/pkg/planner/core/operator/logicalop/logicalop_test/hash64_equals_test.go +++ b/pkg/planner/core/operator/logicalop/logicalop_test/hash64_equals_test.go @@ -28,6 +28,88 @@ import ( "github.com/stretchr/testify/require" ) +func TestLogicalExpandHash64Equals(t *testing.T) { + col1 := &expression.Column{ + ID: 1, + Index: 0, + RetType: types.NewFieldType(mysql.TypeLonglong), + } + col2 := &expression.Column{ + ID: 2, + Index: 0, + RetType: types.NewFieldType(mysql.TypeLonglong), + } + expand1 := &logicalop.LogicalExpand{ + DistinctGroupByCol: []*expression.Column{col1}, + DistinctGbyExprs: []expression.Expression{col1}, + DistinctSize: 1, + RollupGroupingSets: nil, + LevelExprs: nil, + GID: col1, + GPos: col1, + } + expand2 := &logicalop.LogicalExpand{ + DistinctGroupByCol: []*expression.Column{col1}, + DistinctGbyExprs: []expression.Expression{col1}, + DistinctSize: 1, + RollupGroupingSets: nil, + LevelExprs: nil, + GID: col1, + GPos: col1, + } + hasher1 := base.NewHashEqualer() + hasher2 := base.NewHashEqualer() + expand1.Hash64(hasher1) + expand2.Hash64(hasher2) + require.Equal(t, hasher1.Sum64(), hasher2.Sum64()) + + expand2.DistinctGroupByCol = []*expression.Column{col2} + hasher2.Reset() + expand2.Hash64(hasher2) + require.NotEqual(t, hasher1.Sum64(), hasher2.Sum64()) + + expand2.DistinctGroupByCol = []*expression.Column{col1} + expand2.DistinctGbyExprs = []expression.Expression{col2} + hasher2.Reset() + expand2.Hash64(hasher2) + require.NotEqual(t, hasher1.Sum64(), hasher2.Sum64()) + + expand2.DistinctGbyExprs = []expression.Expression{col1} + expand2.DistinctSize = 2 + hasher2.Reset() + expand2.Hash64(hasher2) + require.NotEqual(t, hasher1.Sum64(), hasher2.Sum64()) + + expand2.DistinctSize = 1 + expand2.RollupGroupingSets = expression.GroupingSets{expression.GroupingSet{expression.GroupingExprs{col1}}} + hasher2.Reset() + expand2.Hash64(hasher2) + require.NotEqual(t, hasher1.Sum64(), hasher2.Sum64()) + + expand2.RollupGroupingSets = nil + expand2.LevelExprs = [][]expression.Expression{{col1}} + hasher2.Reset() + expand2.Hash64(hasher2) + require.NotEqual(t, hasher1.Sum64(), hasher2.Sum64()) + + expand2.LevelExprs = nil + expand2.GID = col2 + hasher2.Reset() + expand2.Hash64(hasher2) + require.NotEqual(t, hasher1.Sum64(), hasher2.Sum64()) + + expand2.GID = col1 + expand2.GPos = col2 + hasher2.Reset() + expand2.Hash64(hasher2) + require.NotEqual(t, hasher1.Sum64(), hasher2.Sum64()) + + expand2.GPos = col1 + hasher2.Reset() + expand2.Hash64(hasher2) + require.Equal(t, hasher1.Sum64(), hasher2.Sum64()) +} + func TestLogicalApplyHash64Equals(t *testing.T) { col1 := &expression.Column{ ID: 1,