diff --git a/expression/builtin_time_test.go b/expression/builtin_time_test.go index a142dcbed194f..b19a2a7ba5040 100644 --- a/expression/builtin_time_test.go +++ b/expression/builtin_time_test.go @@ -1615,27 +1615,37 @@ func (s *testEvaluatorSuite) TestDateArithFuncs(c *C) { fcAdd := funcs[ast.DateAdd] fcSub := funcs[ast.DateSub] - args := types.MakeDatums(date[0], 1, "DAY") - f, err := fcAdd.getFunction(s.ctx, s.datumsToConstants(args)) - c.Assert(err, IsNil) - c.Assert(f, NotNil) - v, err := evalBuiltinFunc(f, chunk.Row{}) - c.Assert(err, IsNil) - c.Assert(v.GetMysqlTime().String(), Equals, date[1]) + tests := []struct { + inputDate string + fc functionClass + inputDecimal float64 + expect string + }{ + {date[0], fcAdd, 1, date[1]}, + {date[1], fcAdd, -1, date[0]}, + {date[1], fcAdd, -0.5, date[0]}, + {date[1], fcAdd, -1.4, date[0]}, - args = types.MakeDatums(date[1], 1, "DAY") - f, err = fcSub.getFunction(s.ctx, s.datumsToConstants(args)) - c.Assert(err, IsNil) - c.Assert(f, NotNil) - v, err = evalBuiltinFunc(f, chunk.Row{}) - c.Assert(err, IsNil) - c.Assert(v.GetMysqlTime().String(), Equals, date[0]) + {date[1], fcSub, 1, date[0]}, + {date[0], fcSub, -1, date[1]}, + {date[0], fcSub, -0.5, date[1]}, + {date[0], fcSub, -1.4, date[1]}, + } + for _, test := range tests { + args := types.MakeDatums(test.inputDate, test.inputDecimal, "DAY") + f, err := test.fc.getFunction(s.ctx, s.datumsToConstants(args)) + c.Assert(err, IsNil) + c.Assert(f, NotNil) + v, err := evalBuiltinFunc(f, chunk.Row{}) + c.Assert(err, IsNil) + c.Assert(v.GetMysqlTime().String(), Equals, test.expect) + } - args = types.MakeDatums(date[0], nil, "DAY") - f, err = fcAdd.getFunction(s.ctx, s.datumsToConstants(args)) + args := types.MakeDatums(date[0], nil, "DAY") + f, err := fcAdd.getFunction(s.ctx, s.datumsToConstants(args)) c.Assert(err, IsNil) c.Assert(f, NotNil) - v, err = evalBuiltinFunc(f, chunk.Row{}) + v, err := evalBuiltinFunc(f, chunk.Row{}) c.Assert(err, IsNil) c.Assert(v.IsNull(), IsTrue) diff --git a/types/time.go b/types/time.go index 8a4ff10ebe0b4..73a20ea3040f0 100644 --- a/types/time.go +++ b/types/time.go @@ -1524,7 +1524,7 @@ func extractSingleTimeValue(unit string, format string) (int64, int64, int64, fl if err != nil { return 0, 0, 0, 0, ErrIncorrectDatetimeValue.GenWithStackByArgs(format) } - iv := int64(fv + 0.5) + iv := int64(math.Round(fv)) switch strings.ToUpper(unit) { case "MICROSECOND":