Skip to content

Commit

Permalink
planner: Fix the issue where any(nil) is not considered equal to nil (#…
Browse files Browse the repository at this point in the history
  • Loading branch information
King-Dylan authored Nov 20, 2024
1 parent 14e99ea commit c5e9cc7
Show file tree
Hide file tree
Showing 12 changed files with 106 additions and 91 deletions.
9 changes: 6 additions & 3 deletions pkg/expression/aggregation/base_func.go
Original file line number Diff line number Diff line change
Expand Up @@ -64,13 +64,16 @@ func (a *baseFuncDesc) Hash64(h base.Hasher) {

// Equals implements the base.Equals interface.
func (a *baseFuncDesc) Equals(other any) bool {
if other == nil {
return false
}
a2, ok := other.(*baseFuncDesc)
if !ok {
return false
}
if a == nil {
return a2 == nil
}
if a2 == nil {
return false
}
ok = a.Name == a2.Name && len(a.Args) == len(a2.Args) && ((a.RetTp == nil && a2.RetTp == nil) || (a.RetTp != nil && a2.RetTp != nil && a.RetTp.Equals(a2.RetTp)))
if !ok {
return false
Expand Down
9 changes: 6 additions & 3 deletions pkg/expression/aggregation/descriptor.go
Original file line number Diff line number Diff line change
Expand Up @@ -74,13 +74,16 @@ func (a *AggFuncDesc) Hash64(h base.Hasher) {

// Equals checks whether two aggregation function signatures are equal.
func (a *AggFuncDesc) Equals(other any) bool {
if other == nil {
return false
}
otherAgg, ok := other.(*AggFuncDesc)
if !ok {
return false
}
if a == nil {
return otherAgg == nil
}
if otherAgg == nil {
return false
}
if a.Mode != otherAgg.Mode || a.HasDistinct != otherAgg.HasDistinct || len(a.OrderByItems) != len(otherAgg.OrderByItems) {
return false
}
Expand Down
30 changes: 13 additions & 17 deletions pkg/expression/column.go
Original file line number Diff line number Diff line change
Expand Up @@ -263,16 +263,14 @@ func (col *CorrelatedColumn) Hash64(h base.Hasher) {

// Equals implements HashEquals.<1st> interface.
func (col *CorrelatedColumn) Equals(other any) bool {
if other == nil {
col2, ok := other.(*CorrelatedColumn)
if !ok {
return false
}
var col2 *CorrelatedColumn
switch x := other.(type) {
case CorrelatedColumn:
col2 = &x
case *CorrelatedColumn:
col2 = x
default:
if col == nil {
return col2 == nil
}
if col2 == nil {
return false
}
return col.Column.Equals(&col2.Column)
Expand Down Expand Up @@ -511,21 +509,19 @@ func (col *Column) Hash64(h base.Hasher) {

// Equals implements HashEquals.<1st> interface.
func (col *Column) Equals(other any) bool {
if other == nil {
col2, ok := other.(*Column)
if !ok {
return false
}
var col2 *Column
switch x := other.(type) {
case Column:
col2 = &x
case *Column:
col2 = x
default:
if col == nil {
return col2 == nil
}
if col2 == nil {
return false
}
// when step into here, we could ensure that col1.RetType and col2.RetType are same type.
// and we should ensure col1.RetType and col2.RetType is not nil ourselves.
ok := col.RetType == nil && col2.RetType == nil || col.RetType != nil && col2.RetType != nil && col.RetType.Equal(col2.RetType)
ok = col.RetType == nil && col2.RetType == nil || col.RetType != nil && col2.RetType != nil && col.RetType.Equal(col2.RetType)
ok = ok && (col.VirtualExpr == nil && col2.VirtualExpr == nil || col.VirtualExpr != nil && col2.VirtualExpr != nil && col.VirtualExpr.Equals(col2.VirtualExpr))
return ok &&
col.ID == col2.ID &&
Expand Down
19 changes: 19 additions & 0 deletions pkg/expression/column_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -493,3 +493,22 @@ func TestColumnHashEuqals4VirtualExpr(t *testing.T) {
require.Equal(t, hasher1.Sum64(), hasher2.Sum64())
require.True(t, col1.Equals(col2))
}

func TestColumnEqualsWithNilWrappedInAny(t *testing.T) {
col1 := &Column{UniqueID: 1}
// Test: other is nil
require.False(t, col1.Equals(nil))
// Test: other is *Column(nil) wrapped in any
var col2 *Column = nil
var col2AsAny any = col2
require.False(t, col1.Equals(col2AsAny))
// Test: both Columns are nil
var col3 *Column = nil
require.True(t, col3.Equals(col2AsAny))
// Test: two Columns with the same values
col4 := &Column{UniqueID: 1}
require.True(t, col1.Equals(col4))
// Test: two Columns with different values
col5 := &Column{UniqueID: 2}
require.False(t, col1.Equals(col5))
}
16 changes: 7 additions & 9 deletions pkg/expression/constant.go
Original file line number Diff line number Diff line change
Expand Up @@ -560,19 +560,17 @@ func (c *Constant) Hash64(h base.Hasher) {

// Equals implements HashEquals.<1st> interface.
func (c *Constant) Equals(other any) bool {
if other == nil {
c2, ok := other.(*Constant)
if !ok {
return false
}
var c2 *Constant
switch x := other.(type) {
case *Constant:
c2 = x
case Constant:
c2 = &x
default:
if c == nil {
return c2 == nil
}
if c2 == nil {
return false
}
ok := c.RetType == nil && c2.RetType == nil || c.RetType != nil && c2.RetType != nil && c.RetType.Equals(c2.RetType)
ok = c.RetType == nil && c2.RetType == nil || c.RetType != nil && c2.RetType != nil && c.RetType.Equals(c2.RetType)
ok = ok && c.collationInfo.Equals(c2.collationInfo)
ok = ok && (c.DeferredExpr == nil && c2.DeferredExpr == nil || c.DeferredExpr != nil && c2.DeferredExpr != nil && c.DeferredExpr.Equals(c2.DeferredExpr))
ok = ok && (c.ParamMarker == nil && c2.ParamMarker == nil || c.ParamMarker != nil && c2.ParamMarker != nil && c.ParamMarker.order == c2.ParamMarker.order)
Expand Down
16 changes: 7 additions & 9 deletions pkg/expression/scalar_function.go
Original file line number Diff line number Diff line change
Expand Up @@ -700,19 +700,17 @@ func (sf *ScalarFunction) Hash64(h base.Hasher) {

// Equals implements HashEquals.<1th> interface.
func (sf *ScalarFunction) Equals(other any) bool {
if other == nil {
sf2, ok := other.(*ScalarFunction)
if !ok {
return false
}
var sf2 *ScalarFunction
switch x := other.(type) {
case *ScalarFunction:
sf2 = x
case ScalarFunction:
sf2 = &x
default:
if sf == nil {
return sf2 == nil
}
if sf2 == nil {
return false
}
ok := sf.FuncName.L == sf2.FuncName.L
ok = sf.FuncName.L == sf2.FuncName.L
ok = ok && (sf.RetType == nil && sf2.RetType == nil || sf.RetType != nil && sf2.RetType != nil && sf.RetType.Equals(sf2.RetType))
if len(sf.GetArgs()) != len(sf2.GetArgs()) {
return false
Expand Down
14 changes: 7 additions & 7 deletions pkg/planner/cascades/memo/group.go
Original file line number Diff line number Diff line change
Expand Up @@ -59,17 +59,17 @@ func (g *Group) Hash64(h base.Hasher) {

// Equals implements the HashEquals.<1st> interface.
func (g *Group) Equals(other any) bool {
if other == nil {
g2, ok := other.(*Group)
if !ok {
return false
}
switch x := other.(type) {
case *Group:
return g.groupID == x.groupID
case Group:
return g.groupID == x.groupID
default:
if g == nil {
return g2 == nil
}
if g2 == nil {
return false
}
return g.groupID == g2.groupID
}

// ******************************************* end of HashEqual methods *******************************************
Expand Down
9 changes: 5 additions & 4 deletions pkg/planner/cascades/memo/group_and_expr_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,16 +31,17 @@ func TestGroupHashEquals(t *testing.T) {
a.Hash64(hasher1)
b.Hash64(hasher2)
require.Equal(t, hasher1.Sum64(), hasher2.Sum64())
require.True(t, a.Equals(b))
require.True(t, a.Equals(&b))

require.True(t, (&a).Equals(&b))
require.False(t, a.Equals(b))
require.False(t, (&a).Equals(b))
// change the id.
b.groupID = 2
hasher2.Reset()
b.Hash64(hasher2)
require.NotEqual(t, hasher1.Sum64(), hasher2.Sum64())
require.False(t, a.Equals(b))
require.False(t, a.Equals(&b))
require.False(t, (&a).Equals(&b))
}

func TestGroupExpressionHashEquals(t *testing.T) {
Expand All @@ -62,7 +63,7 @@ func TestGroupExpressionHashEquals(t *testing.T) {
a.Hash64(hasher1)
b.Hash64(hasher2)
require.Equal(t, hasher1.Sum64(), hasher2.Sum64())
require.True(t, a.Equals(b))
require.False(t, a.Equals(b))
require.True(t, a.Equals(&b))

// change the children order, like join commutative.
Expand Down
14 changes: 6 additions & 8 deletions pkg/planner/cascades/memo/group_expr.go
Original file line number Diff line number Diff line change
Expand Up @@ -82,16 +82,14 @@ func (e *GroupExpression) Hash64(h base2.Hasher) {

// Equals implements the Equals interface.
func (e *GroupExpression) Equals(other any) bool {
if other == nil {
e2, ok := other.(*GroupExpression)
if !ok {
return false
}
var e2 *GroupExpression
switch x := other.(type) {
case *GroupExpression:
e2 = x
case GroupExpression:
e2 = &x
default:
if e == nil {
return e2 == nil
}
if e2 == nil {
return false
}
if len(e.Inputs) != len(e2.Inputs) {
Expand Down
14 changes: 6 additions & 8 deletions pkg/planner/core/scalar_subq_expression.go
Original file line number Diff line number Diff line change
Expand Up @@ -233,16 +233,14 @@ func (s *ScalarSubQueryExpr) Hash64(h base2.Hasher) {

// Equals implements the HashEquals.<1st> interface.
func (s *ScalarSubQueryExpr) Equals(other any) bool {
if other == nil {
s2, ok := other.(*ScalarSubQueryExpr)
if !ok {
return false
}
var s2 *ScalarSubQueryExpr
switch x := other.(type) {
case *ScalarSubQueryExpr:
s2 = x
case ScalarSubQueryExpr:
s2 = &x
default:
if s == nil {
return s2 == nil
}
if s2 == nil {
return false
}
return s.scalarSubqueryColID == s2.scalarSubqueryColID
Expand Down
14 changes: 6 additions & 8 deletions pkg/planner/property/physical_property.go
Original file line number Diff line number Diff line change
Expand Up @@ -51,16 +51,14 @@ func (s *SortItem) Hash64(h base.Hasher) {

// Equals implements the HashEquals interface.
func (s *SortItem) Equals(other any) bool {
if other == nil {
s2, ok := other.(*SortItem)
if !ok {
return false
}
var s2 *SortItem
switch x := other.(type) {
case *SortItem:
s2 = x
case SortItem:
s2 = &x
default:
if s == nil {
return s2 == nil
}
if s2 == nil {
return false
}
return s.Col.Equals(s2.Col) && s.Desc == s2.Desc
Expand Down
33 changes: 18 additions & 15 deletions pkg/util/codec/codec_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1297,23 +1297,26 @@ func TestHashChunkColumns(t *testing.T) {

func TestDatumHashEquals(t *testing.T) {
now := time.Now()
newDatumPtr := func(d types.Datum) *types.Datum {
return &d
}
tests := []struct {
d1 types.Datum
d2 types.Datum
d1 *types.Datum
d2 *types.Datum
}{
{types.NewIntDatum(1), types.NewIntDatum(1)},
{types.NewUintDatum(1), types.NewUintDatum(1)},
{types.NewFloat64Datum(1.1), types.NewFloat64Datum(1.1)},
{types.NewStringDatum("abc"), types.NewStringDatum("abc")},
{types.NewBytesDatum([]byte("abc")), types.NewBytesDatum([]byte("abc"))},
{types.NewMysqlEnumDatum(types.Enum{Name: "a", Value: 1}), types.NewMysqlEnumDatum(types.Enum{Name: "a", Value: 1})},
{types.NewMysqlSetDatum(types.Set{Name: "a", Value: 1}, "a"), types.NewMysqlSetDatum(types.Set{Name: "a", Value: 1}, "a")},
{types.NewBinaryLiteralDatum([]byte{0x01}), types.NewBinaryLiteralDatum([]byte{0x01})},
{types.NewMysqlBitDatum(types.NewBinaryLiteralFromUint(1, -1)), types.NewMysqlBitDatum(types.NewBinaryLiteralFromUint(1, -1))},
{types.NewTimeDatum(types.NewTime(types.FromGoTime(now), mysql.TypeDatetime, 6)), types.NewTimeDatum(types.NewTime(types.FromGoTime(now), mysql.TypeDatetime, 6))},
{types.NewDurationDatum(types.Duration{Duration: time.Second}), types.NewDurationDatum(types.Duration{Duration: time.Second})},
{types.NewJSONDatum(types.CreateBinaryJSON("a")), types.NewJSONDatum(types.CreateBinaryJSON("a"))},
{types.NewTimeDatum(types.NewTime(types.FromGoTime(now), mysql.TypeDatetime, 6)), types.NewTimeDatum(types.NewTime(types.FromGoTime(time.Now()), mysql.TypeDatetime, 6))},
{newDatumPtr(types.NewIntDatum(1)), newDatumPtr(types.NewIntDatum(1))},
{newDatumPtr(types.NewUintDatum(1)), newDatumPtr(types.NewUintDatum(1))},
{newDatumPtr(types.NewFloat64Datum(1.1)), newDatumPtr(types.NewFloat64Datum(1.1))},
{newDatumPtr(types.NewStringDatum("abc")), newDatumPtr(types.NewStringDatum("abc"))},
{newDatumPtr(types.NewBytesDatum([]byte("abc"))), newDatumPtr(types.NewBytesDatum([]byte("abc")))},
{newDatumPtr(types.NewMysqlEnumDatum(types.Enum{Name: "a", Value: 1})), newDatumPtr(types.NewMysqlEnumDatum(types.Enum{Name: "a", Value: 1}))},
{newDatumPtr(types.NewMysqlSetDatum(types.Set{Name: "a", Value: 1}, "a")), newDatumPtr(types.NewMysqlSetDatum(types.Set{Name: "a", Value: 1}, "a"))},
{newDatumPtr(types.NewBinaryLiteralDatum([]byte{0x01})), newDatumPtr(types.NewBinaryLiteralDatum([]byte{0x01}))},
{newDatumPtr(types.NewMysqlBitDatum(types.NewBinaryLiteralFromUint(1, -1))), newDatumPtr(types.NewMysqlBitDatum(types.NewBinaryLiteralFromUint(1, -1)))},
{newDatumPtr(types.NewTimeDatum(types.NewTime(types.FromGoTime(now), mysql.TypeDatetime, 6))), newDatumPtr(types.NewTimeDatum(types.NewTime(types.FromGoTime(now), mysql.TypeDatetime, 6)))},
{newDatumPtr(types.NewDurationDatum(types.Duration{Duration: time.Second})), newDatumPtr(types.NewDurationDatum(types.Duration{Duration: time.Second}))},
{newDatumPtr(types.NewJSONDatum(types.CreateBinaryJSON("a"))), newDatumPtr(types.NewJSONDatum(types.CreateBinaryJSON("a")))},
{newDatumPtr(types.NewTimeDatum(types.NewTime(types.FromGoTime(now), mysql.TypeDatetime, 6))), newDatumPtr(types.NewTimeDatum(types.NewTime(types.FromGoTime(time.Now()), mysql.TypeDatetime, 6)))},
}
hasher1 := base.NewHashEqualer()
hasher2 := base.NewHashEqualer()
Expand Down

0 comments on commit c5e9cc7

Please sign in to comment.