Skip to content

Commit

Permalink
*: replace compareDatum by compare and fix compare (#30090)
Browse files Browse the repository at this point in the history
  • Loading branch information
wjhuang2016 authored Nov 24, 2021
1 parent 3d26751 commit b5ded48
Show file tree
Hide file tree
Showing 11 changed files with 39 additions and 35 deletions.
9 changes: 4 additions & 5 deletions executor/aggfuncs/window_func_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,17 +18,16 @@ import (
"testing"
"time"

"github.com/pingcap/tidb/parser/ast"
"github.com/pingcap/tidb/parser/mysql"

"github.com/pingcap/tidb/executor/aggfuncs"
"github.com/pingcap/tidb/expression"
"github.com/pingcap/tidb/expression/aggregation"
"github.com/pingcap/tidb/parser/ast"
"github.com/pingcap/tidb/parser/mysql"
"github.com/pingcap/tidb/types"
"github.com/pingcap/tidb/types/json"
"github.com/pingcap/tidb/util/chunk"
"github.com/pingcap/tidb/util/collate"
"github.com/pingcap/tidb/util/mock"

"github.com/stretchr/testify/require"
)

Expand Down Expand Up @@ -78,7 +77,7 @@ func testWindowFunc(t *testing.T, p windowTest) {
err = finalFunc.AppendFinalResult2Chunk(ctx, finalPr, resultChk)
require.NoError(t, err)
dt := resultChk.GetRow(0).GetDatum(0, desc.RetTp)
result, err := dt.CompareDatum(ctx.GetSessionVars().StmtCtx, &p.results[i])
result, err := dt.Compare(ctx.GetSessionVars().StmtCtx, &p.results[i], collate.GetCollator(desc.RetTp.Collate))
require.NoError(t, err)
require.Equal(t, 0, result)
resultChk.Reset()
Expand Down
7 changes: 2 additions & 5 deletions executor/write.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,13 +25,13 @@ import (
"github.com/pingcap/tidb/kv"
"github.com/pingcap/tidb/meta/autoid"
"github.com/pingcap/tidb/parser/ast"
"github.com/pingcap/tidb/parser/charset"
"github.com/pingcap/tidb/parser/mysql"
"github.com/pingcap/tidb/parser/terror"
"github.com/pingcap/tidb/sessionctx"
"github.com/pingcap/tidb/table"
"github.com/pingcap/tidb/tablecodec"
"github.com/pingcap/tidb/types"
"github.com/pingcap/tidb/util/collate"
"github.com/pingcap/tidb/util/memory"
)

Expand Down Expand Up @@ -94,11 +94,8 @@ func updateRecord(ctx context.Context, sctx sessionctx.Context, h kv.Handle, old

// 3. Compare datum, then handle some flags.
for i, col := range t.Cols() {
collation := newData[i].Collation()
// We should use binary collation to compare datum, otherwise the result will be incorrect.
newData[i].SetCollation(charset.CollationBin)
cmp, err := newData[i].CompareDatum(sc, &oldData[i])
newData[i].SetCollation(collation)
cmp, err := newData[i].Compare(sc, &oldData[i], collate.GetBinaryCollator())
if err != nil {
return false, err
}
Expand Down
3 changes: 2 additions & 1 deletion expression/evaluator_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ import (
"github.com/pingcap/tidb/sessionctx"
"github.com/pingcap/tidb/types"
"github.com/pingcap/tidb/util/chunk"
"github.com/pingcap/tidb/util/collate"
"github.com/stretchr/testify/require"
)

Expand Down Expand Up @@ -574,7 +575,7 @@ func TestUnaryOp(t *testing.T) {
require.NoError(t, err)

expect := types.NewDatum(tt.result)
ret, err := result.CompareDatum(ctx.GetSessionVars().StmtCtx, &expect)
ret, err := result.Compare(ctx.GetSessionVars().StmtCtx, &expect, collate.GetBinaryCollator())
require.NoError(t, err)
require.Equal(t, 0, ret, Commentf("%v %s", tt.arg, tt.op))
}
Expand Down
5 changes: 3 additions & 2 deletions statistics/statistics_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ import (
"github.com/pingcap/tidb/sessionctx/stmtctx"
"github.com/pingcap/tidb/types"
"github.com/pingcap/tidb/util/chunk"
"github.com/pingcap/tidb/util/collate"
"github.com/pingcap/tidb/util/mock"
"github.com/pingcap/tidb/util/ranger"
"github.com/pingcap/tidb/util/sqlexec"
Expand Down Expand Up @@ -178,11 +179,11 @@ func TestMergeHistogram(t *testing.T) {
require.Equal(t, tt.bucketNum, h.Len())
require.Equal(t, tt.leftNum+tt.rightNum, int64(h.TotalRowCount()))
expectLower := types.NewIntDatum(tt.leftLower)
cmp, err := h.GetLower(0).CompareDatum(sc, &expectLower)
cmp, err := h.GetLower(0).Compare(sc, &expectLower, collate.GetBinaryCollator())
require.NoError(t, err)
require.Equal(t, 0, cmp)
expectUpper := types.NewIntDatum(tt.rightLower + tt.rightNum - 1)
cmp, err = h.GetUpper(h.Len()-1).CompareDatum(sc, &expectUpper)
cmp, err = h.GetUpper(h.Len()-1).Compare(sc, &expectUpper, collate.GetBinaryCollator())
require.NoError(t, err)
require.Equal(t, 0, cmp)
}
Expand Down
21 changes: 11 additions & 10 deletions tablecodec/tablecodec_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ import (
"github.com/pingcap/tidb/sessionctx/stmtctx"
"github.com/pingcap/tidb/types"
"github.com/pingcap/tidb/util/codec"
"github.com/pingcap/tidb/util/collate"
"github.com/pingcap/tidb/util/rowcodec"
"github.com/stretchr/testify/require"
)
Expand Down Expand Up @@ -107,7 +108,7 @@ func TestRowCodec(t *testing.T) {
for i, col := range cols {
v, ok := r[col.id]
require.True(t, ok)
equal, err1 := v.CompareDatum(sc, &row[i])
equal, err1 := v.Compare(sc, &row[i], collate.GetBinaryCollator())
require.NoError(t, err1)
require.Equalf(t, 0, equal, "expect: %v, got %v", row[i], v)
}
Expand All @@ -121,7 +122,7 @@ func TestRowCodec(t *testing.T) {
for i, col := range cols {
v, ok := r[col.id]
require.True(t, ok)
equal, err1 := v.CompareDatum(sc, &row[i])
equal, err1 := v.Compare(sc, &row[i], collate.GetBinaryCollator())
require.NoError(t, err1)
require.Equal(t, 0, equal)
}
Expand All @@ -139,7 +140,7 @@ func TestRowCodec(t *testing.T) {
}
v, ok := r[col.id]
require.True(t, ok)
equal, err1 := v.CompareDatum(sc, &row[i])
equal, err1 := v.Compare(sc, &row[i], collate.GetBinaryCollator())
require.NoError(t, err1)
require.Equal(t, 0, equal)
}
Expand Down Expand Up @@ -168,7 +169,7 @@ func TestDecodeColumnValue(t *testing.T) {
tp := types.NewFieldType(mysql.TypeTimestamp)
d1, err := DecodeColumnValue(bs, tp, sc.TimeZone)
require.NoError(t, err)
cmp, err := d1.CompareDatum(sc, &d)
cmp, err := d1.Compare(sc, &d, collate.GetBinaryCollator())
require.NoError(t, err)
require.Equal(t, 0, cmp)

Expand All @@ -185,7 +186,7 @@ func TestDecodeColumnValue(t *testing.T) {
tp.Elems = elems
d1, err = DecodeColumnValue(bs, tp, sc.TimeZone)
require.NoError(t, err)
cmp, err = d1.CompareDatum(sc, &d)
cmp, err = d1.Compare(sc, &d, collate.GetCollator(tp.Collate))
require.NoError(t, err)
require.Equal(t, 0, cmp)

Expand All @@ -200,7 +201,7 @@ func TestDecodeColumnValue(t *testing.T) {
tp.Flen = 24
d1, err = DecodeColumnValue(bs, tp, sc.TimeZone)
require.NoError(t, err)
cmp, err = d1.CompareDatum(sc, &d)
cmp, err = d1.Compare(sc, &d, collate.GetBinaryCollator())
require.NoError(t, err)
require.Equal(t, 0, cmp)

Expand All @@ -214,7 +215,7 @@ func TestDecodeColumnValue(t *testing.T) {
tp = types.NewFieldType(mysql.TypeEnum)
d1, err = DecodeColumnValue(bs, tp, sc.TimeZone)
require.NoError(t, err)
cmp, err = d1.CompareDatum(sc, &d)
cmp, err = d1.Compare(sc, &d, collate.GetCollator(tp.Collate))
require.NoError(t, err)
require.Equal(t, 0, cmp)
}
Expand All @@ -226,7 +227,7 @@ func TestUnflattenDatums(t *testing.T) {
tps := []*types.FieldType{types.NewFieldType(mysql.TypeLonglong)}
output, err := UnflattenDatums(input, tps, sc.TimeZone)
require.NoError(t, err)
cmp, err := input[0].CompareDatum(sc, &output[0])
cmp, err := input[0].Compare(sc, &output[0], collate.GetBinaryCollator())
require.NoError(t, err)
require.Equal(t, 0, cmp)

Expand All @@ -235,7 +236,7 @@ func TestUnflattenDatums(t *testing.T) {
tps[0].Collate = "utf8mb4_unicode_ci"
output, err = UnflattenDatums(input, tps, sc.TimeZone)
require.NoError(t, err)
cmp, err = input[0].CompareDatum(sc, &output[0])
cmp, err = input[0].Compare(sc, &output[0], collate.GetBinaryCollator())
require.NoError(t, err)
require.Equal(t, 0, cmp)
require.Equal(t, "utf8mb4_unicode_ci", output[0].Collation())
Expand Down Expand Up @@ -285,7 +286,7 @@ func TestTimeCodec(t *testing.T) {
for i, col := range cols {
v, ok := r[col.id]
require.True(t, ok)
equal, err1 := v.CompareDatum(sc, &row[i])
equal, err1 := v.Compare(sc, &row[i], collate.GetBinaryCollator())
require.Nil(t, err1)
require.Equal(t, 0, equal)
}
Expand Down
7 changes: 4 additions & 3 deletions types/compare_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ import (

"github.com/pingcap/tidb/parser/mysql"
"github.com/pingcap/tidb/sessionctx/stmtctx"
"github.com/pingcap/tidb/util/collate"
"github.com/stretchr/testify/require"
)

Expand Down Expand Up @@ -151,7 +152,7 @@ func compareForTest(a, b interface{}) (int, error) {
sc.IgnoreTruncate = true
aDatum := NewDatum(a)
bDatum := NewDatum(b)
return aDatum.CompareDatum(sc, &bDatum)
return aDatum.Compare(sc, &bDatum, collate.GetBinaryCollator())
}

func TestCompareDatum(t *testing.T) {
Expand All @@ -174,11 +175,11 @@ func TestCompareDatum(t *testing.T) {
sc := new(stmtctx.StatementContext)
sc.IgnoreTruncate = true
for i, tt := range cmpTbl {
ret, err := tt.lhs.CompareDatum(sc, &tt.rhs)
ret, err := tt.lhs.Compare(sc, &tt.rhs, collate.GetBinaryCollator())
require.NoError(t, err)
require.Equal(t, tt.ret, ret, "%d %v %v", i, tt.lhs, tt.rhs)

ret, err = tt.rhs.CompareDatum(sc, &tt.lhs)
ret, err = tt.rhs.Compare(sc, &tt.lhs, collate.GetBinaryCollator())
require.NoError(t, err)
require.Equal(t, -tt.ret, ret, "%d %v %v", i, tt.lhs, tt.rhs)
}
Expand Down
4 changes: 2 additions & 2 deletions types/datum.go
Original file line number Diff line number Diff line change
Expand Up @@ -585,7 +585,7 @@ func (d *Datum) Compare(sc *stmtctx.StatementContext, ad *Datum, comparer collat
case KindString:
return d.compareStringNew(sc, ad.GetString(), comparer)
case KindBytes:
return comparer.Compare(d.GetString(), ad.GetString()), nil
return d.compareStringNew(sc, ad.GetString(), comparer)
case KindMysqlDecimal:
return d.compareMysqlDecimal(sc, ad.GetMysqlDecimal())
case KindMysqlDuration:
Expand Down Expand Up @@ -754,7 +754,7 @@ func (d *Datum) compareStringNew(sc *stmtctx.StatementContext, s string, compare
case KindMysqlEnum:
return comparer.Compare(d.GetMysqlEnum().String(), s), nil
case KindBinaryLiteral, KindMysqlBit:
return comparer.Compare(d.GetBinaryLiteral4Cmp().String(), s), nil
return comparer.Compare(d.GetBinaryLiteral4Cmp().ToString(), s), nil
default:
fVal, err := StrToFloat(sc, s, false)
if err != nil {
Expand Down
9 changes: 5 additions & 4 deletions types/datum_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ import (
"github.com/pingcap/tidb/parser/mysql"
"github.com/pingcap/tidb/sessionctx/stmtctx"
"github.com/pingcap/tidb/types/json"
"github.com/pingcap/tidb/util/collate"
"github.com/pingcap/tidb/util/hack"
"github.com/stretchr/testify/require"
)
Expand Down Expand Up @@ -248,7 +249,7 @@ func TestToJSON(t *testing.T) {
require.NoError(t, err)

var cmp int
cmp, err = obtain.CompareDatum(sc, &expected)
cmp, err = obtain.Compare(sc, &expected, collate.GetBinaryCollator())
require.NoError(t, err)
require.Equal(t, 0, cmp)
} else {
Expand Down Expand Up @@ -324,7 +325,7 @@ func TestComputePlusAndMinus(t *testing.T) {
for ith, tt := range tests {
got, err := ComputePlus(tt.a, tt.b)
require.Equal(t, tt.hasErr, err != nil)
v, err := got.CompareDatum(sc, &tt.plus)
v, err := got.Compare(sc, &tt.plus, collate.GetBinaryCollator())
require.NoError(t, err)
require.Equalf(t, 0, v, "%dth got:%#v, %#v, expect:%#v, %#v", ith, got, got.x, tt.plus, tt.plus.x)
}
Expand All @@ -347,7 +348,7 @@ func TestCloneDatum(t *testing.T) {
sc.IgnoreTruncate = true
for _, tt := range tests {
tt1 := *tt.Clone()
res, err := tt.CompareDatum(sc, &tt1)
res, err := tt.Compare(sc, &tt1, collate.GetBinaryCollator())
require.NoError(t, err)
require.Equal(t, 0, res)
if tt.b != nil {
Expand Down Expand Up @@ -542,7 +543,7 @@ func BenchmarkCompareDatum(b *testing.B) {
b.ResetTimer()
for i := 0; i < b.N; i++ {
for j, v := range vals {
_, err := v.CompareDatum(sc, &vals1[j])
_, err := v.Compare(sc, &vals1[j], collate.GetBinaryCollator())
if err != nil {
b.Fatal(err)
}
Expand Down
3 changes: 2 additions & 1 deletion util/chunk/mutrow_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ import (
"github.com/pingcap/tidb/sessionctx/stmtctx"
"github.com/pingcap/tidb/types"
"github.com/pingcap/tidb/types/json"
"github.com/pingcap/tidb/util/collate"
"github.com/stretchr/testify/require"
)

Expand All @@ -35,7 +36,7 @@ func TestMutRow(t *testing.T) {
val := zeroValForType(allTypes[i])
d := row.GetDatum(i, allTypes[i])
d2 := types.NewDatum(val)
cmp, err := d.CompareDatum(sc, &d2)
cmp, err := d.Compare(sc, &d2, collate.GetCollator(allTypes[i].Collate))
require.Nil(t, err)
require.Equal(t, 0, cmp)
}
Expand Down
3 changes: 2 additions & 1 deletion util/rowDecoder/decoder_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ import (
"github.com/pingcap/tidb/tablecodec"
"github.com/pingcap/tidb/testkit"
"github.com/pingcap/tidb/types"
"github.com/pingcap/tidb/util/collate"
"github.com/pingcap/tidb/util/mock"
decoder "github.com/pingcap/tidb/util/rowDecoder"
"github.com/pingcap/tidb/util/rowcodec"
Expand Down Expand Up @@ -204,7 +205,7 @@ func TestClusterIndexRowDecoder(t *testing.T) {
for i, col := range cols {
v, ok := r[col.ID]
require.True(t, ok)
equal, err1 := v.CompareDatum(sc, &row.output[i])
equal, err1 := v.Compare(sc, &row.output[i], collate.GetBinaryCollator())
require.Nil(t, err1)
require.Equal(t, 0, equal)
}
Expand Down
3 changes: 2 additions & 1 deletion util/testutil/testutil.go
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ import (
"github.com/pingcap/tidb/testkit/testdata"
"github.com/pingcap/tidb/types"
"github.com/pingcap/tidb/util/codec"
"github.com/pingcap/tidb/util/collate"
"github.com/pingcap/tidb/util/logutil"
"go.uber.org/zap"
)
Expand Down Expand Up @@ -111,7 +112,7 @@ func (checker *datumEqualsChecker) Check(params []interface{}, names []string) (
panic("the second param should be datum")
}
sc := new(stmtctx.StatementContext)
res, err := paramFirst.CompareDatum(sc, &paramSecond)
res, err := paramFirst.Compare(sc, &paramSecond, collate.GetBinaryCollator())
if err != nil {
panic(err)
}
Expand Down

0 comments on commit b5ded48

Please sign in to comment.