Skip to content

Commit

Permalink
expression: fix wrong result of greatest/least(mixed unsigned/signed …
Browse files Browse the repository at this point in the history
…int) (#30121) (#30791)

close #30101
  • Loading branch information
ti-srebot authored Dec 20, 2021
1 parent b97b00a commit ff8d96a
Show file tree
Hide file tree
Showing 4 changed files with 54 additions and 0 deletions.
28 changes: 28 additions & 0 deletions expression/builtin_compare.go
Original file line number Diff line number Diff line change
Expand Up @@ -469,6 +469,14 @@ func (c *greatestFunctionClass) getFunction(ctx sessionctx.Context, args []Expre
}
switch tp {
case types.ETInt:
// adjust unsigned flag
greastInitUnsignedFlag := false
if isEqualsInitUnsignedFlag(greastInitUnsignedFlag, args) {
bf.tp.Flag &= ^mysql.UnsignedFlag
} else {
bf.tp.Flag |= mysql.UnsignedFlag
}

sig = &builtinGreatestIntSig{bf}
sig.setPbCode(tipb.ScalarFuncSig_GreatestInt)
case types.ETReal:
Expand Down Expand Up @@ -736,6 +744,14 @@ func (c *leastFunctionClass) getFunction(ctx sessionctx.Context, args []Expressi
}
switch tp {
case types.ETInt:
// adjust unsigned flag
leastInitUnsignedFlag := true
if isEqualsInitUnsignedFlag(leastInitUnsignedFlag, args) {
bf.tp.Flag |= mysql.UnsignedFlag
} else {
bf.tp.Flag &= ^mysql.UnsignedFlag
}

sig = &builtinLeastIntSig{bf}
sig.setPbCode(tipb.ScalarFuncSig_LeastInt)
case types.ETReal:
Expand Down Expand Up @@ -2846,3 +2862,15 @@ func CompareJSON(sctx sessionctx.Context, lhsArg, rhsArg Expression, lhsRow, rhs
}
return int64(json.CompareBinary(arg0, arg1)), false, nil
}

// isEqualsInitUnsignedFlag can adjust unsigned flag for greatest/least function.
// For greatest, returns unsigned result if there is at least one argument is unsigned.
// For least, returns signed result if there is at least one argument is signed.
func isEqualsInitUnsignedFlag(initUnsigned bool, args []Expression) bool {
for _, arg := range args {
if initUnsigned != mysql.HasUnsignedFlag(arg.GetType().Flag) {
return false
}
}
return true
}
10 changes: 10 additions & 0 deletions expression/builtin_compare_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -263,6 +263,8 @@ func (s *testEvaluatorSuite) TestGreatestLeastFunc(c *C) {
sc := s.ctx.GetSessionVars().StmtCtx
originIgnoreTruncate := sc.IgnoreTruncate
sc.IgnoreTruncate = true
decG := &types.MyDecimal{}
decL := &types.MyDecimal{}
defer func() {
sc.IgnoreTruncate = originIgnoreTruncate
}()
Expand All @@ -274,6 +276,14 @@ func (s *testEvaluatorSuite) TestGreatestLeastFunc(c *C) {
isNil bool
getErr bool
}{
{
[]interface{}{int64(-9223372036854775808), uint64(9223372036854775809)},
decG.FromUint(9223372036854775809), decL.FromInt(-9223372036854775808), false, false,
},
{
[]interface{}{uint64(9223372036854775808), uint64(9223372036854775809)},
uint64(9223372036854775809), uint64(9223372036854775808), false, false,
},
{
[]interface{}{1, 2, 3, 4},
int64(4), int64(1), false, false,
Expand Down
9 changes: 9 additions & 0 deletions expression/integration_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -9367,6 +9367,15 @@ func (s *testIntegrationSuite) TestConstPropNullFunctions(c *C) {
tk.MustQuery("select * from t2 where t2.i2=((select count(1) from t1 where t1.i1=t2.i2))").Check(testkit.Rows("1 <nil> 0.1"))
}

func (s *testIntegrationSuite) TestIssue30101(c *C) {
tk := testkit.NewTestKit(c, s.store)
tk.MustExec("use test")
tk.MustExec("drop table if exists t1;")
tk.MustExec("create table t1(c1 bigint unsigned, c2 bigint unsigned);")
tk.MustExec("insert into t1 values(9223372036854775808, 9223372036854775809);")
tk.MustQuery("select greatest(c1, c2) from t1;").Sort().Check(testkit.Rows("9223372036854775809"))
}

func (s *testIntegrationSuite) TestIssue28643(c *C) {
tk := testkit.NewTestKit(c, s.store)
tk.MustExec("use test")
Expand Down
7 changes: 7 additions & 0 deletions expression/typeinfer_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1035,6 +1035,13 @@ func (s *testInferTypeSuite) createTestCase4CompareFuncs() []typeInferTestCase {

{"interval(c_int_d, c_int_d, c_int_d)", mysql.TypeLonglong, charset.CharsetBin, mysql.BinaryFlag, mysql.MaxIntWidth, 0},
{"interval(c_int_d, c_float_d, c_double_d)", mysql.TypeLonglong, charset.CharsetBin, mysql.BinaryFlag, mysql.MaxIntWidth, 0},

{"greatest(c_bigint_d, c_ubigint_d, c_int_d)", mysql.TypeNewDecimal, charset.CharsetBin, mysql.BinaryFlag, mysql.MaxIntWidth, 0},
{"greatest(c_ubigint_d, c_ubigint_d, c_uint_d)", mysql.TypeLonglong, charset.CharsetBin, mysql.BinaryFlag | mysql.UnsignedFlag, mysql.MaxIntWidth, 0},
{"greatest(c_uint_d, c_int_d)", mysql.TypeLonglong, charset.CharsetBin, mysql.BinaryFlag | mysql.UnsignedFlag, 11, 0},
{"least(c_bigint_d, c_ubigint_d, c_int_d)", mysql.TypeNewDecimal, charset.CharsetBin, mysql.BinaryFlag, mysql.MaxIntWidth, 0},
{"least(c_ubigint_d, c_ubigint_d, c_uint_d)", mysql.TypeLonglong, charset.CharsetBin, mysql.BinaryFlag | mysql.UnsignedFlag, mysql.MaxIntWidth, 0},
{"least(c_uint_d, c_int_d)", mysql.TypeLonglong, charset.CharsetBin, mysql.BinaryFlag, 11, 0},
}
}

Expand Down

0 comments on commit ff8d96a

Please sign in to comment.