diff --git a/executor/aggfuncs/window_func_test.go b/executor/aggfuncs/window_func_test.go index 5468ed2c779b7..a063951dd44e6 100644 --- a/executor/aggfuncs/window_func_test.go +++ b/executor/aggfuncs/window_func_test.go @@ -18,17 +18,16 @@ import ( "testing" "time" - "github.com/pingcap/tidb/parser/ast" - "github.com/pingcap/tidb/parser/mysql" - "github.com/pingcap/tidb/executor/aggfuncs" "github.com/pingcap/tidb/expression" "github.com/pingcap/tidb/expression/aggregation" + "github.com/pingcap/tidb/parser/ast" + "github.com/pingcap/tidb/parser/mysql" "github.com/pingcap/tidb/types" "github.com/pingcap/tidb/types/json" "github.com/pingcap/tidb/util/chunk" + "github.com/pingcap/tidb/util/collate" "github.com/pingcap/tidb/util/mock" - "github.com/stretchr/testify/require" ) @@ -78,7 +77,7 @@ func testWindowFunc(t *testing.T, p windowTest) { err = finalFunc.AppendFinalResult2Chunk(ctx, finalPr, resultChk) require.NoError(t, err) dt := resultChk.GetRow(0).GetDatum(0, desc.RetTp) - result, err := dt.CompareDatum(ctx.GetSessionVars().StmtCtx, &p.results[i]) + result, err := dt.Compare(ctx.GetSessionVars().StmtCtx, &p.results[i], collate.GetCollator(desc.RetTp.Collate)) require.NoError(t, err) require.Equal(t, 0, result) resultChk.Reset() diff --git a/executor/write.go b/executor/write.go index f8e0c2c19c1b3..9b690d1141382 100644 --- a/executor/write.go +++ b/executor/write.go @@ -25,13 +25,13 @@ import ( "github.com/pingcap/tidb/kv" "github.com/pingcap/tidb/meta/autoid" "github.com/pingcap/tidb/parser/ast" - "github.com/pingcap/tidb/parser/charset" "github.com/pingcap/tidb/parser/mysql" "github.com/pingcap/tidb/parser/terror" "github.com/pingcap/tidb/sessionctx" "github.com/pingcap/tidb/table" "github.com/pingcap/tidb/tablecodec" "github.com/pingcap/tidb/types" + "github.com/pingcap/tidb/util/collate" "github.com/pingcap/tidb/util/memory" ) @@ -94,11 +94,8 @@ func updateRecord(ctx context.Context, sctx sessionctx.Context, h kv.Handle, old // 3. Compare datum, then handle some flags. for i, col := range t.Cols() { - collation := newData[i].Collation() // We should use binary collation to compare datum, otherwise the result will be incorrect. - newData[i].SetCollation(charset.CollationBin) - cmp, err := newData[i].CompareDatum(sc, &oldData[i]) - newData[i].SetCollation(collation) + cmp, err := newData[i].Compare(sc, &oldData[i], collate.GetBinaryCollator()) if err != nil { return false, err } diff --git a/expression/evaluator_test.go b/expression/evaluator_test.go index 26deecca77c0e..c808d604e4724 100644 --- a/expression/evaluator_test.go +++ b/expression/evaluator_test.go @@ -26,6 +26,7 @@ import ( "github.com/pingcap/tidb/sessionctx" "github.com/pingcap/tidb/types" "github.com/pingcap/tidb/util/chunk" + "github.com/pingcap/tidb/util/collate" "github.com/stretchr/testify/require" ) @@ -574,7 +575,7 @@ func TestUnaryOp(t *testing.T) { require.NoError(t, err) expect := types.NewDatum(tt.result) - ret, err := result.CompareDatum(ctx.GetSessionVars().StmtCtx, &expect) + ret, err := result.Compare(ctx.GetSessionVars().StmtCtx, &expect, collate.GetBinaryCollator()) require.NoError(t, err) require.Equal(t, 0, ret, Commentf("%v %s", tt.arg, tt.op)) } diff --git a/statistics/statistics_test.go b/statistics/statistics_test.go index 581705108b9f0..53df0e04bd7be 100644 --- a/statistics/statistics_test.go +++ b/statistics/statistics_test.go @@ -27,6 +27,7 @@ import ( "github.com/pingcap/tidb/sessionctx/stmtctx" "github.com/pingcap/tidb/types" "github.com/pingcap/tidb/util/chunk" + "github.com/pingcap/tidb/util/collate" "github.com/pingcap/tidb/util/mock" "github.com/pingcap/tidb/util/ranger" "github.com/pingcap/tidb/util/sqlexec" @@ -178,11 +179,11 @@ func TestMergeHistogram(t *testing.T) { require.Equal(t, tt.bucketNum, h.Len()) require.Equal(t, tt.leftNum+tt.rightNum, int64(h.TotalRowCount())) expectLower := types.NewIntDatum(tt.leftLower) - cmp, err := h.GetLower(0).CompareDatum(sc, &expectLower) + cmp, err := h.GetLower(0).Compare(sc, &expectLower, collate.GetBinaryCollator()) require.NoError(t, err) require.Equal(t, 0, cmp) expectUpper := types.NewIntDatum(tt.rightLower + tt.rightNum - 1) - cmp, err = h.GetUpper(h.Len()-1).CompareDatum(sc, &expectUpper) + cmp, err = h.GetUpper(h.Len()-1).Compare(sc, &expectUpper, collate.GetBinaryCollator()) require.NoError(t, err) require.Equal(t, 0, cmp) } diff --git a/tablecodec/tablecodec_test.go b/tablecodec/tablecodec_test.go index e373eadba33b9..83839d4463a89 100644 --- a/tablecodec/tablecodec_test.go +++ b/tablecodec/tablecodec_test.go @@ -27,6 +27,7 @@ import ( "github.com/pingcap/tidb/sessionctx/stmtctx" "github.com/pingcap/tidb/types" "github.com/pingcap/tidb/util/codec" + "github.com/pingcap/tidb/util/collate" "github.com/pingcap/tidb/util/rowcodec" "github.com/stretchr/testify/require" ) @@ -107,7 +108,7 @@ func TestRowCodec(t *testing.T) { for i, col := range cols { v, ok := r[col.id] require.True(t, ok) - equal, err1 := v.CompareDatum(sc, &row[i]) + equal, err1 := v.Compare(sc, &row[i], collate.GetBinaryCollator()) require.NoError(t, err1) require.Equalf(t, 0, equal, "expect: %v, got %v", row[i], v) } @@ -121,7 +122,7 @@ func TestRowCodec(t *testing.T) { for i, col := range cols { v, ok := r[col.id] require.True(t, ok) - equal, err1 := v.CompareDatum(sc, &row[i]) + equal, err1 := v.Compare(sc, &row[i], collate.GetBinaryCollator()) require.NoError(t, err1) require.Equal(t, 0, equal) } @@ -139,7 +140,7 @@ func TestRowCodec(t *testing.T) { } v, ok := r[col.id] require.True(t, ok) - equal, err1 := v.CompareDatum(sc, &row[i]) + equal, err1 := v.Compare(sc, &row[i], collate.GetBinaryCollator()) require.NoError(t, err1) require.Equal(t, 0, equal) } @@ -168,7 +169,7 @@ func TestDecodeColumnValue(t *testing.T) { tp := types.NewFieldType(mysql.TypeTimestamp) d1, err := DecodeColumnValue(bs, tp, sc.TimeZone) require.NoError(t, err) - cmp, err := d1.CompareDatum(sc, &d) + cmp, err := d1.Compare(sc, &d, collate.GetBinaryCollator()) require.NoError(t, err) require.Equal(t, 0, cmp) @@ -185,7 +186,7 @@ func TestDecodeColumnValue(t *testing.T) { tp.Elems = elems d1, err = DecodeColumnValue(bs, tp, sc.TimeZone) require.NoError(t, err) - cmp, err = d1.CompareDatum(sc, &d) + cmp, err = d1.Compare(sc, &d, collate.GetCollator(tp.Collate)) require.NoError(t, err) require.Equal(t, 0, cmp) @@ -200,7 +201,7 @@ func TestDecodeColumnValue(t *testing.T) { tp.Flen = 24 d1, err = DecodeColumnValue(bs, tp, sc.TimeZone) require.NoError(t, err) - cmp, err = d1.CompareDatum(sc, &d) + cmp, err = d1.Compare(sc, &d, collate.GetBinaryCollator()) require.NoError(t, err) require.Equal(t, 0, cmp) @@ -214,7 +215,7 @@ func TestDecodeColumnValue(t *testing.T) { tp = types.NewFieldType(mysql.TypeEnum) d1, err = DecodeColumnValue(bs, tp, sc.TimeZone) require.NoError(t, err) - cmp, err = d1.CompareDatum(sc, &d) + cmp, err = d1.Compare(sc, &d, collate.GetCollator(tp.Collate)) require.NoError(t, err) require.Equal(t, 0, cmp) } @@ -226,7 +227,7 @@ func TestUnflattenDatums(t *testing.T) { tps := []*types.FieldType{types.NewFieldType(mysql.TypeLonglong)} output, err := UnflattenDatums(input, tps, sc.TimeZone) require.NoError(t, err) - cmp, err := input[0].CompareDatum(sc, &output[0]) + cmp, err := input[0].Compare(sc, &output[0], collate.GetBinaryCollator()) require.NoError(t, err) require.Equal(t, 0, cmp) @@ -235,7 +236,7 @@ func TestUnflattenDatums(t *testing.T) { tps[0].Collate = "utf8mb4_unicode_ci" output, err = UnflattenDatums(input, tps, sc.TimeZone) require.NoError(t, err) - cmp, err = input[0].CompareDatum(sc, &output[0]) + cmp, err = input[0].Compare(sc, &output[0], collate.GetBinaryCollator()) require.NoError(t, err) require.Equal(t, 0, cmp) require.Equal(t, "utf8mb4_unicode_ci", output[0].Collation()) @@ -285,7 +286,7 @@ func TestTimeCodec(t *testing.T) { for i, col := range cols { v, ok := r[col.id] require.True(t, ok) - equal, err1 := v.CompareDatum(sc, &row[i]) + equal, err1 := v.Compare(sc, &row[i], collate.GetBinaryCollator()) require.Nil(t, err1) require.Equal(t, 0, equal) } diff --git a/types/compare_test.go b/types/compare_test.go index ffd9a8895efb8..ae701553a0e5a 100644 --- a/types/compare_test.go +++ b/types/compare_test.go @@ -21,6 +21,7 @@ import ( "github.com/pingcap/tidb/parser/mysql" "github.com/pingcap/tidb/sessionctx/stmtctx" + "github.com/pingcap/tidb/util/collate" "github.com/stretchr/testify/require" ) @@ -151,7 +152,7 @@ func compareForTest(a, b interface{}) (int, error) { sc.IgnoreTruncate = true aDatum := NewDatum(a) bDatum := NewDatum(b) - return aDatum.CompareDatum(sc, &bDatum) + return aDatum.Compare(sc, &bDatum, collate.GetBinaryCollator()) } func TestCompareDatum(t *testing.T) { @@ -174,11 +175,11 @@ func TestCompareDatum(t *testing.T) { sc := new(stmtctx.StatementContext) sc.IgnoreTruncate = true for i, tt := range cmpTbl { - ret, err := tt.lhs.CompareDatum(sc, &tt.rhs) + ret, err := tt.lhs.Compare(sc, &tt.rhs, collate.GetBinaryCollator()) require.NoError(t, err) require.Equal(t, tt.ret, ret, "%d %v %v", i, tt.lhs, tt.rhs) - ret, err = tt.rhs.CompareDatum(sc, &tt.lhs) + ret, err = tt.rhs.Compare(sc, &tt.lhs, collate.GetBinaryCollator()) require.NoError(t, err) require.Equal(t, -tt.ret, ret, "%d %v %v", i, tt.lhs, tt.rhs) } diff --git a/types/datum.go b/types/datum.go index 03874e1f02e35..cd8db788d2d3c 100644 --- a/types/datum.go +++ b/types/datum.go @@ -585,7 +585,7 @@ func (d *Datum) Compare(sc *stmtctx.StatementContext, ad *Datum, comparer collat case KindString: return d.compareStringNew(sc, ad.GetString(), comparer) case KindBytes: - return comparer.Compare(d.GetString(), ad.GetString()), nil + return d.compareStringNew(sc, ad.GetString(), comparer) case KindMysqlDecimal: return d.compareMysqlDecimal(sc, ad.GetMysqlDecimal()) case KindMysqlDuration: @@ -754,7 +754,7 @@ func (d *Datum) compareStringNew(sc *stmtctx.StatementContext, s string, compare case KindMysqlEnum: return comparer.Compare(d.GetMysqlEnum().String(), s), nil case KindBinaryLiteral, KindMysqlBit: - return comparer.Compare(d.GetBinaryLiteral4Cmp().String(), s), nil + return comparer.Compare(d.GetBinaryLiteral4Cmp().ToString(), s), nil default: fVal, err := StrToFloat(sc, s, false) if err != nil { diff --git a/types/datum_test.go b/types/datum_test.go index 091b92c752b23..36f6776282e72 100644 --- a/types/datum_test.go +++ b/types/datum_test.go @@ -25,6 +25,7 @@ import ( "github.com/pingcap/tidb/parser/mysql" "github.com/pingcap/tidb/sessionctx/stmtctx" "github.com/pingcap/tidb/types/json" + "github.com/pingcap/tidb/util/collate" "github.com/pingcap/tidb/util/hack" "github.com/stretchr/testify/require" ) @@ -248,7 +249,7 @@ func TestToJSON(t *testing.T) { require.NoError(t, err) var cmp int - cmp, err = obtain.CompareDatum(sc, &expected) + cmp, err = obtain.Compare(sc, &expected, collate.GetBinaryCollator()) require.NoError(t, err) require.Equal(t, 0, cmp) } else { @@ -324,7 +325,7 @@ func TestComputePlusAndMinus(t *testing.T) { for ith, tt := range tests { got, err := ComputePlus(tt.a, tt.b) require.Equal(t, tt.hasErr, err != nil) - v, err := got.CompareDatum(sc, &tt.plus) + v, err := got.Compare(sc, &tt.plus, collate.GetBinaryCollator()) require.NoError(t, err) require.Equalf(t, 0, v, "%dth got:%#v, %#v, expect:%#v, %#v", ith, got, got.x, tt.plus, tt.plus.x) } @@ -347,7 +348,7 @@ func TestCloneDatum(t *testing.T) { sc.IgnoreTruncate = true for _, tt := range tests { tt1 := *tt.Clone() - res, err := tt.CompareDatum(sc, &tt1) + res, err := tt.Compare(sc, &tt1, collate.GetBinaryCollator()) require.NoError(t, err) require.Equal(t, 0, res) if tt.b != nil { @@ -542,7 +543,7 @@ func BenchmarkCompareDatum(b *testing.B) { b.ResetTimer() for i := 0; i < b.N; i++ { for j, v := range vals { - _, err := v.CompareDatum(sc, &vals1[j]) + _, err := v.Compare(sc, &vals1[j], collate.GetBinaryCollator()) if err != nil { b.Fatal(err) } diff --git a/util/chunk/mutrow_test.go b/util/chunk/mutrow_test.go index 4abd9ad217357..a2a650d9bdc95 100644 --- a/util/chunk/mutrow_test.go +++ b/util/chunk/mutrow_test.go @@ -22,6 +22,7 @@ import ( "github.com/pingcap/tidb/sessionctx/stmtctx" "github.com/pingcap/tidb/types" "github.com/pingcap/tidb/types/json" + "github.com/pingcap/tidb/util/collate" "github.com/stretchr/testify/require" ) @@ -35,7 +36,7 @@ func TestMutRow(t *testing.T) { val := zeroValForType(allTypes[i]) d := row.GetDatum(i, allTypes[i]) d2 := types.NewDatum(val) - cmp, err := d.CompareDatum(sc, &d2) + cmp, err := d.Compare(sc, &d2, collate.GetCollator(allTypes[i].Collate)) require.Nil(t, err) require.Equal(t, 0, cmp) } diff --git a/util/rowDecoder/decoder_test.go b/util/rowDecoder/decoder_test.go index 4aab176723536..cb743418d1692 100644 --- a/util/rowDecoder/decoder_test.go +++ b/util/rowDecoder/decoder_test.go @@ -28,6 +28,7 @@ import ( "github.com/pingcap/tidb/tablecodec" "github.com/pingcap/tidb/testkit" "github.com/pingcap/tidb/types" + "github.com/pingcap/tidb/util/collate" "github.com/pingcap/tidb/util/mock" decoder "github.com/pingcap/tidb/util/rowDecoder" "github.com/pingcap/tidb/util/rowcodec" @@ -204,7 +205,7 @@ func TestClusterIndexRowDecoder(t *testing.T) { for i, col := range cols { v, ok := r[col.ID] require.True(t, ok) - equal, err1 := v.CompareDatum(sc, &row.output[i]) + equal, err1 := v.Compare(sc, &row.output[i], collate.GetBinaryCollator()) require.Nil(t, err1) require.Equal(t, 0, equal) } diff --git a/util/testutil/testutil.go b/util/testutil/testutil.go index 270e54241b51b..4a49636743e11 100644 --- a/util/testutil/testutil.go +++ b/util/testutil/testutil.go @@ -38,6 +38,7 @@ import ( "github.com/pingcap/tidb/testkit/testdata" "github.com/pingcap/tidb/types" "github.com/pingcap/tidb/util/codec" + "github.com/pingcap/tidb/util/collate" "github.com/pingcap/tidb/util/logutil" "go.uber.org/zap" ) @@ -111,7 +112,7 @@ func (checker *datumEqualsChecker) Check(params []interface{}, names []string) ( panic("the second param should be datum") } sc := new(stmtctx.StatementContext) - res, err := paramFirst.CompareDatum(sc, ¶mSecond) + res, err := paramFirst.Compare(sc, ¶mSecond, collate.GetBinaryCollator()) if err != nil { panic(err) }