From c5e9cc77e646e31c01029623660485501db8b1da Mon Sep 17 00:00:00 2001 From: Jinlong Liu <50897894+King-Dylan@users.noreply.github.com> Date: Tue, 19 Nov 2024 22:17:38 -0800 Subject: [PATCH] planner: Fix the issue where any(nil) is not considered equal to nil (#57428) close pingcap/tidb#57326 --- pkg/expression/aggregation/base_func.go | 9 +++-- pkg/expression/aggregation/descriptor.go | 9 +++-- pkg/expression/column.go | 30 ++++++++--------- pkg/expression/column_test.go | 19 +++++++++++ pkg/expression/constant.go | 16 ++++----- pkg/expression/scalar_function.go | 16 ++++----- pkg/planner/cascades/memo/group.go | 14 ++++---- .../cascades/memo/group_and_expr_test.go | 9 ++--- pkg/planner/cascades/memo/group_expr.go | 14 ++++---- pkg/planner/core/scalar_subq_expression.go | 14 ++++---- pkg/planner/property/physical_property.go | 14 ++++---- pkg/util/codec/codec_test.go | 33 ++++++++++--------- 12 files changed, 106 insertions(+), 91 deletions(-) diff --git a/pkg/expression/aggregation/base_func.go b/pkg/expression/aggregation/base_func.go index 158251754ee54..3ae05057a6fe2 100644 --- a/pkg/expression/aggregation/base_func.go +++ b/pkg/expression/aggregation/base_func.go @@ -64,13 +64,16 @@ func (a *baseFuncDesc) Hash64(h base.Hasher) { // Equals implements the base.Equals interface. func (a *baseFuncDesc) Equals(other any) bool { - if other == nil { - return false - } a2, ok := other.(*baseFuncDesc) if !ok { return false } + if a == nil { + return a2 == nil + } + if a2 == nil { + return false + } ok = a.Name == a2.Name && len(a.Args) == len(a2.Args) && ((a.RetTp == nil && a2.RetTp == nil) || (a.RetTp != nil && a2.RetTp != nil && a.RetTp.Equals(a2.RetTp))) if !ok { return false diff --git a/pkg/expression/aggregation/descriptor.go b/pkg/expression/aggregation/descriptor.go index 1795d6224dacd..b762c5127be6e 100644 --- a/pkg/expression/aggregation/descriptor.go +++ b/pkg/expression/aggregation/descriptor.go @@ -74,13 +74,16 @@ func (a *AggFuncDesc) Hash64(h base.Hasher) { // Equals checks whether two aggregation function signatures are equal. func (a *AggFuncDesc) Equals(other any) bool { - if other == nil { - return false - } otherAgg, ok := other.(*AggFuncDesc) if !ok { return false } + if a == nil { + return otherAgg == nil + } + if otherAgg == nil { + return false + } if a.Mode != otherAgg.Mode || a.HasDistinct != otherAgg.HasDistinct || len(a.OrderByItems) != len(otherAgg.OrderByItems) { return false } diff --git a/pkg/expression/column.go b/pkg/expression/column.go index a1779f60cb98d..635f500717a55 100644 --- a/pkg/expression/column.go +++ b/pkg/expression/column.go @@ -263,16 +263,14 @@ func (col *CorrelatedColumn) Hash64(h base.Hasher) { // Equals implements HashEquals.<1st> interface. func (col *CorrelatedColumn) Equals(other any) bool { - if other == nil { + col2, ok := other.(*CorrelatedColumn) + if !ok { return false } - var col2 *CorrelatedColumn - switch x := other.(type) { - case CorrelatedColumn: - col2 = &x - case *CorrelatedColumn: - col2 = x - default: + if col == nil { + return col2 == nil + } + if col2 == nil { return false } return col.Column.Equals(&col2.Column) @@ -511,21 +509,19 @@ func (col *Column) Hash64(h base.Hasher) { // Equals implements HashEquals.<1st> interface. func (col *Column) Equals(other any) bool { - if other == nil { + col2, ok := other.(*Column) + if !ok { return false } - var col2 *Column - switch x := other.(type) { - case Column: - col2 = &x - case *Column: - col2 = x - default: + if col == nil { + return col2 == nil + } + if col2 == nil { return false } // when step into here, we could ensure that col1.RetType and col2.RetType are same type. // and we should ensure col1.RetType and col2.RetType is not nil ourselves. - ok := col.RetType == nil && col2.RetType == nil || col.RetType != nil && col2.RetType != nil && col.RetType.Equal(col2.RetType) + ok = col.RetType == nil && col2.RetType == nil || col.RetType != nil && col2.RetType != nil && col.RetType.Equal(col2.RetType) ok = ok && (col.VirtualExpr == nil && col2.VirtualExpr == nil || col.VirtualExpr != nil && col2.VirtualExpr != nil && col.VirtualExpr.Equals(col2.VirtualExpr)) return ok && col.ID == col2.ID && diff --git a/pkg/expression/column_test.go b/pkg/expression/column_test.go index 153cd6f929b43..a1babed68ab9a 100644 --- a/pkg/expression/column_test.go +++ b/pkg/expression/column_test.go @@ -493,3 +493,22 @@ func TestColumnHashEuqals4VirtualExpr(t *testing.T) { require.Equal(t, hasher1.Sum64(), hasher2.Sum64()) require.True(t, col1.Equals(col2)) } + +func TestColumnEqualsWithNilWrappedInAny(t *testing.T) { + col1 := &Column{UniqueID: 1} + // Test: other is nil + require.False(t, col1.Equals(nil)) + // Test: other is *Column(nil) wrapped in any + var col2 *Column = nil + var col2AsAny any = col2 + require.False(t, col1.Equals(col2AsAny)) + // Test: both Columns are nil + var col3 *Column = nil + require.True(t, col3.Equals(col2AsAny)) + // Test: two Columns with the same values + col4 := &Column{UniqueID: 1} + require.True(t, col1.Equals(col4)) + // Test: two Columns with different values + col5 := &Column{UniqueID: 2} + require.False(t, col1.Equals(col5)) +} diff --git a/pkg/expression/constant.go b/pkg/expression/constant.go index 9739437ec4721..d2e24734ee28d 100644 --- a/pkg/expression/constant.go +++ b/pkg/expression/constant.go @@ -560,19 +560,17 @@ func (c *Constant) Hash64(h base.Hasher) { // Equals implements HashEquals.<1st> interface. func (c *Constant) Equals(other any) bool { - if other == nil { + c2, ok := other.(*Constant) + if !ok { return false } - var c2 *Constant - switch x := other.(type) { - case *Constant: - c2 = x - case Constant: - c2 = &x - default: + if c == nil { + return c2 == nil + } + if c2 == nil { return false } - ok := c.RetType == nil && c2.RetType == nil || c.RetType != nil && c2.RetType != nil && c.RetType.Equals(c2.RetType) + ok = c.RetType == nil && c2.RetType == nil || c.RetType != nil && c2.RetType != nil && c.RetType.Equals(c2.RetType) ok = ok && c.collationInfo.Equals(c2.collationInfo) ok = ok && (c.DeferredExpr == nil && c2.DeferredExpr == nil || c.DeferredExpr != nil && c2.DeferredExpr != nil && c.DeferredExpr.Equals(c2.DeferredExpr)) ok = ok && (c.ParamMarker == nil && c2.ParamMarker == nil || c.ParamMarker != nil && c2.ParamMarker != nil && c.ParamMarker.order == c2.ParamMarker.order) diff --git a/pkg/expression/scalar_function.go b/pkg/expression/scalar_function.go index 6aa7613accc18..6a6dadc4f1b5d 100644 --- a/pkg/expression/scalar_function.go +++ b/pkg/expression/scalar_function.go @@ -700,19 +700,17 @@ func (sf *ScalarFunction) Hash64(h base.Hasher) { // Equals implements HashEquals.<1th> interface. func (sf *ScalarFunction) Equals(other any) bool { - if other == nil { + sf2, ok := other.(*ScalarFunction) + if !ok { return false } - var sf2 *ScalarFunction - switch x := other.(type) { - case *ScalarFunction: - sf2 = x - case ScalarFunction: - sf2 = &x - default: + if sf == nil { + return sf2 == nil + } + if sf2 == nil { return false } - ok := sf.FuncName.L == sf2.FuncName.L + ok = sf.FuncName.L == sf2.FuncName.L ok = ok && (sf.RetType == nil && sf2.RetType == nil || sf.RetType != nil && sf2.RetType != nil && sf.RetType.Equals(sf2.RetType)) if len(sf.GetArgs()) != len(sf2.GetArgs()) { return false diff --git a/pkg/planner/cascades/memo/group.go b/pkg/planner/cascades/memo/group.go index 5da51ea3be8bb..1be6681f0f681 100644 --- a/pkg/planner/cascades/memo/group.go +++ b/pkg/planner/cascades/memo/group.go @@ -59,17 +59,17 @@ func (g *Group) Hash64(h base.Hasher) { // Equals implements the HashEquals.<1st> interface. func (g *Group) Equals(other any) bool { - if other == nil { + g2, ok := other.(*Group) + if !ok { return false } - switch x := other.(type) { - case *Group: - return g.groupID == x.groupID - case Group: - return g.groupID == x.groupID - default: + if g == nil { + return g2 == nil + } + if g2 == nil { return false } + return g.groupID == g2.groupID } // ******************************************* end of HashEqual methods ******************************************* diff --git a/pkg/planner/cascades/memo/group_and_expr_test.go b/pkg/planner/cascades/memo/group_and_expr_test.go index bc3f0b8d32608..5b6babb0544a0 100644 --- a/pkg/planner/cascades/memo/group_and_expr_test.go +++ b/pkg/planner/cascades/memo/group_and_expr_test.go @@ -31,16 +31,17 @@ func TestGroupHashEquals(t *testing.T) { a.Hash64(hasher1) b.Hash64(hasher2) require.Equal(t, hasher1.Sum64(), hasher2.Sum64()) - require.True(t, a.Equals(b)) require.True(t, a.Equals(&b)) - + require.True(t, (&a).Equals(&b)) + require.False(t, a.Equals(b)) + require.False(t, (&a).Equals(b)) // change the id. b.groupID = 2 hasher2.Reset() b.Hash64(hasher2) require.NotEqual(t, hasher1.Sum64(), hasher2.Sum64()) - require.False(t, a.Equals(b)) require.False(t, a.Equals(&b)) + require.False(t, (&a).Equals(&b)) } func TestGroupExpressionHashEquals(t *testing.T) { @@ -62,7 +63,7 @@ func TestGroupExpressionHashEquals(t *testing.T) { a.Hash64(hasher1) b.Hash64(hasher2) require.Equal(t, hasher1.Sum64(), hasher2.Sum64()) - require.True(t, a.Equals(b)) + require.False(t, a.Equals(b)) require.True(t, a.Equals(&b)) // change the children order, like join commutative. diff --git a/pkg/planner/cascades/memo/group_expr.go b/pkg/planner/cascades/memo/group_expr.go index 306ed221e23db..373aa59531916 100644 --- a/pkg/planner/cascades/memo/group_expr.go +++ b/pkg/planner/cascades/memo/group_expr.go @@ -82,16 +82,14 @@ func (e *GroupExpression) Hash64(h base2.Hasher) { // Equals implements the Equals interface. func (e *GroupExpression) Equals(other any) bool { - if other == nil { + e2, ok := other.(*GroupExpression) + if !ok { return false } - var e2 *GroupExpression - switch x := other.(type) { - case *GroupExpression: - e2 = x - case GroupExpression: - e2 = &x - default: + if e == nil { + return e2 == nil + } + if e2 == nil { return false } if len(e.Inputs) != len(e2.Inputs) { diff --git a/pkg/planner/core/scalar_subq_expression.go b/pkg/planner/core/scalar_subq_expression.go index 2a7f2bf16ecee..50a07bc6de0a4 100644 --- a/pkg/planner/core/scalar_subq_expression.go +++ b/pkg/planner/core/scalar_subq_expression.go @@ -233,16 +233,14 @@ func (s *ScalarSubQueryExpr) Hash64(h base2.Hasher) { // Equals implements the HashEquals.<1st> interface. func (s *ScalarSubQueryExpr) Equals(other any) bool { - if other == nil { + s2, ok := other.(*ScalarSubQueryExpr) + if !ok { return false } - var s2 *ScalarSubQueryExpr - switch x := other.(type) { - case *ScalarSubQueryExpr: - s2 = x - case ScalarSubQueryExpr: - s2 = &x - default: + if s == nil { + return s2 == nil + } + if s2 == nil { return false } return s.scalarSubqueryColID == s2.scalarSubqueryColID diff --git a/pkg/planner/property/physical_property.go b/pkg/planner/property/physical_property.go index a147be0737fca..8eec99aef5c12 100644 --- a/pkg/planner/property/physical_property.go +++ b/pkg/planner/property/physical_property.go @@ -51,16 +51,14 @@ func (s *SortItem) Hash64(h base.Hasher) { // Equals implements the HashEquals interface. func (s *SortItem) Equals(other any) bool { - if other == nil { + s2, ok := other.(*SortItem) + if !ok { return false } - var s2 *SortItem - switch x := other.(type) { - case *SortItem: - s2 = x - case SortItem: - s2 = &x - default: + if s == nil { + return s2 == nil + } + if s2 == nil { return false } return s.Col.Equals(s2.Col) && s.Desc == s2.Desc diff --git a/pkg/util/codec/codec_test.go b/pkg/util/codec/codec_test.go index 6abc1c382c16f..e23ba29f151b6 100644 --- a/pkg/util/codec/codec_test.go +++ b/pkg/util/codec/codec_test.go @@ -1297,23 +1297,26 @@ func TestHashChunkColumns(t *testing.T) { func TestDatumHashEquals(t *testing.T) { now := time.Now() + newDatumPtr := func(d types.Datum) *types.Datum { + return &d + } tests := []struct { - d1 types.Datum - d2 types.Datum + d1 *types.Datum + d2 *types.Datum }{ - {types.NewIntDatum(1), types.NewIntDatum(1)}, - {types.NewUintDatum(1), types.NewUintDatum(1)}, - {types.NewFloat64Datum(1.1), types.NewFloat64Datum(1.1)}, - {types.NewStringDatum("abc"), types.NewStringDatum("abc")}, - {types.NewBytesDatum([]byte("abc")), types.NewBytesDatum([]byte("abc"))}, - {types.NewMysqlEnumDatum(types.Enum{Name: "a", Value: 1}), types.NewMysqlEnumDatum(types.Enum{Name: "a", Value: 1})}, - {types.NewMysqlSetDatum(types.Set{Name: "a", Value: 1}, "a"), types.NewMysqlSetDatum(types.Set{Name: "a", Value: 1}, "a")}, - {types.NewBinaryLiteralDatum([]byte{0x01}), types.NewBinaryLiteralDatum([]byte{0x01})}, - {types.NewMysqlBitDatum(types.NewBinaryLiteralFromUint(1, -1)), types.NewMysqlBitDatum(types.NewBinaryLiteralFromUint(1, -1))}, - {types.NewTimeDatum(types.NewTime(types.FromGoTime(now), mysql.TypeDatetime, 6)), types.NewTimeDatum(types.NewTime(types.FromGoTime(now), mysql.TypeDatetime, 6))}, - {types.NewDurationDatum(types.Duration{Duration: time.Second}), types.NewDurationDatum(types.Duration{Duration: time.Second})}, - {types.NewJSONDatum(types.CreateBinaryJSON("a")), types.NewJSONDatum(types.CreateBinaryJSON("a"))}, - {types.NewTimeDatum(types.NewTime(types.FromGoTime(now), mysql.TypeDatetime, 6)), types.NewTimeDatum(types.NewTime(types.FromGoTime(time.Now()), mysql.TypeDatetime, 6))}, + {newDatumPtr(types.NewIntDatum(1)), newDatumPtr(types.NewIntDatum(1))}, + {newDatumPtr(types.NewUintDatum(1)), newDatumPtr(types.NewUintDatum(1))}, + {newDatumPtr(types.NewFloat64Datum(1.1)), newDatumPtr(types.NewFloat64Datum(1.1))}, + {newDatumPtr(types.NewStringDatum("abc")), newDatumPtr(types.NewStringDatum("abc"))}, + {newDatumPtr(types.NewBytesDatum([]byte("abc"))), newDatumPtr(types.NewBytesDatum([]byte("abc")))}, + {newDatumPtr(types.NewMysqlEnumDatum(types.Enum{Name: "a", Value: 1})), newDatumPtr(types.NewMysqlEnumDatum(types.Enum{Name: "a", Value: 1}))}, + {newDatumPtr(types.NewMysqlSetDatum(types.Set{Name: "a", Value: 1}, "a")), newDatumPtr(types.NewMysqlSetDatum(types.Set{Name: "a", Value: 1}, "a"))}, + {newDatumPtr(types.NewBinaryLiteralDatum([]byte{0x01})), newDatumPtr(types.NewBinaryLiteralDatum([]byte{0x01}))}, + {newDatumPtr(types.NewMysqlBitDatum(types.NewBinaryLiteralFromUint(1, -1))), newDatumPtr(types.NewMysqlBitDatum(types.NewBinaryLiteralFromUint(1, -1)))}, + {newDatumPtr(types.NewTimeDatum(types.NewTime(types.FromGoTime(now), mysql.TypeDatetime, 6))), newDatumPtr(types.NewTimeDatum(types.NewTime(types.FromGoTime(now), mysql.TypeDatetime, 6)))}, + {newDatumPtr(types.NewDurationDatum(types.Duration{Duration: time.Second})), newDatumPtr(types.NewDurationDatum(types.Duration{Duration: time.Second}))}, + {newDatumPtr(types.NewJSONDatum(types.CreateBinaryJSON("a"))), newDatumPtr(types.NewJSONDatum(types.CreateBinaryJSON("a")))}, + {newDatumPtr(types.NewTimeDatum(types.NewTime(types.FromGoTime(now), mysql.TypeDatetime, 6))), newDatumPtr(types.NewTimeDatum(types.NewTime(types.FromGoTime(time.Now()), mysql.TypeDatetime, 6)))}, } hasher1 := base.NewHashEqualer() hasher2 := base.NewHashEqualer()