From e85f74c6245bbf36ff705e41d1414dd5847bc5c3 Mon Sep 17 00:00:00 2001 From: Yuanjia Zhang Date: Mon, 1 Apr 2019 11:13:01 +0800 Subject: [PATCH] expression: fix issue that `date_add` and `date_sub` is incompatible with MySQL (#9966) --- expression/builtin_time.go | 14 ++++++++ expression/errors.go | 3 +- expression/integration_test.go | 34 +++++++++++++++++++ types/time.go | 61 ++++++++++++++++++++++++++++++---- 4 files changed, 104 insertions(+), 8 deletions(-) diff --git a/expression/builtin_time.go b/expression/builtin_time.go index 46acaba245637..859c81066af97 100644 --- a/expression/builtin_time.go +++ b/expression/builtin_time.go @@ -2651,6 +2651,13 @@ func (du *baseDateArithmitical) add(ctx sessionctx.Context, date types.Time, int } date.Time = types.FromGoTime(goTime) + overflow, err := types.DateTimeIsOverflow(ctx.GetSessionVars().StmtCtx, date) + if err != nil { + return types.Time{}, true, err + } + if overflow { + return types.Time{}, true, handleInvalidTimeError(ctx, types.ErrDatetimeFunctionOverflow.GenWithStackByArgs("datetime")) + } return date, false, nil } @@ -2677,6 +2684,13 @@ func (du *baseDateArithmitical) sub(ctx sessionctx.Context, date types.Time, int } date.Time = types.FromGoTime(goTime) + overflow, err := types.DateTimeIsOverflow(ctx.GetSessionVars().StmtCtx, date) + if err != nil { + return types.Time{}, true, err + } + if overflow { + return types.Time{}, true, handleInvalidTimeError(ctx, types.ErrDatetimeFunctionOverflow.GenWithStackByArgs("datetime")) + } return date, false, nil } diff --git a/expression/errors.go b/expression/errors.go index 4b60261337cd8..d56f45ab209e9 100644 --- a/expression/errors.go +++ b/expression/errors.go @@ -69,7 +69,8 @@ func init() { // handleInvalidTimeError reports error or warning depend on the context. func handleInvalidTimeError(ctx sessionctx.Context, err error) error { if err == nil || !(terror.ErrorEqual(err, types.ErrInvalidTimeFormat) || types.ErrIncorrectDatetimeValue.Equal(err) || - types.ErrTruncatedWrongValue.Equal(err) || types.ErrInvalidWeekModeFormat.Equal(err)) { + types.ErrTruncatedWrongValue.Equal(err) || types.ErrInvalidWeekModeFormat.Equal(err) || + types.ErrDatetimeFunctionOverflow.Equal(err)) { return err } sc := ctx.GetSessionVars().StmtCtx diff --git a/expression/integration_test.go b/expression/integration_test.go index 5b0162ced2be5..4d30385860b6d 100644 --- a/expression/integration_test.go +++ b/expression/integration_test.go @@ -1881,6 +1881,40 @@ func (s *testIntegrationSuite) TestOpBuiltin(c *C) { result.Check(testkit.Rows("1 0 -9 -0.001 0.999 aaa")) } +func (s *testIntegrationSuite) TestDatetimeOverflow(c *C) { + defer s.cleanEnv(c) + tk := testkit.NewTestKit(c, s.store) + tk.MustExec("use test") + + tk.MustExec("create table t1 (d date)") + tk.MustExec("set sql_mode='traditional'") + overflowSQLs := []string{ + "insert into t1 (d) select date_add('2000-01-01',interval 8000 year)", + "insert into t1 (d) select date_sub('2000-01-01', INTERVAL 2001 YEAR)", + "insert into t1 (d) select date_add('9999-12-31',interval 1 year)", + "insert into t1 (d) select date_sub('1000-01-01', INTERVAL 1 YEAR)", + "insert into t1 (d) select date_add('9999-12-31',interval 1 day)", + "insert into t1 (d) select date_sub('1000-01-01', INTERVAL 1 day)", + "insert into t1 (d) select date_sub('1000-01-01', INTERVAL 1 second)", + } + + for _, sql := range overflowSQLs { + _, err := tk.Exec(sql) + c.Assert(err.Error(), Equals, "[types:1441]Datetime function: datetime field overflow") + } + + tk.MustExec("set sql_mode=''") + for _, sql := range overflowSQLs { + tk.MustExec(sql) + } + + rows := make([]string, 0, len(overflowSQLs)) + for range overflowSQLs { + rows = append(rows, "") + } + tk.MustQuery("select * from t1").Check(testkit.Rows(rows...)) +} + func (s *testIntegrationSuite) TestBuiltin(c *C) { defer s.cleanEnv(c) tk := testkit.NewTestKit(c, s.store) diff --git a/types/time.go b/types/time.go index f9e49a238d929..d626c08018fc3 100644 --- a/types/time.go +++ b/types/time.go @@ -32,13 +32,14 @@ import ( // Portable analogs of some common call errors. var ( - ErrInvalidTimeFormat = terror.ClassTypes.New(mysql.ErrTruncatedWrongValue, "invalid time format: '%v'") - ErrInvalidWeekModeFormat = terror.ClassTypes.New(mysql.ErrTruncatedWrongValue, "invalid week mode format: '%v'") - ErrInvalidYearFormat = errors.New("invalid year format") - ErrInvalidYear = errors.New("invalid year") - ErrZeroDate = errors.New("datetime zero in date") - ErrIncorrectDatetimeValue = terror.ClassTypes.New(mysql.ErrTruncatedWrongValue, "Incorrect datetime value: '%s'") - ErrTruncatedWrongValue = terror.ClassTypes.New(mysql.ErrTruncatedWrongValue, mysql.MySQLErrName[mysql.ErrTruncatedWrongValue]) + ErrInvalidTimeFormat = terror.ClassTypes.New(mysql.ErrTruncatedWrongValue, "invalid time format: '%v'") + ErrInvalidWeekModeFormat = terror.ClassTypes.New(mysql.ErrTruncatedWrongValue, "invalid week mode format: '%v'") + ErrInvalidYearFormat = errors.New("invalid year format") + ErrInvalidYear = errors.New("invalid year") + ErrZeroDate = errors.New("datetime zero in date") + ErrIncorrectDatetimeValue = terror.ClassTypes.New(mysql.ErrTruncatedWrongValue, "Incorrect datetime value: '%s'") + ErrDatetimeFunctionOverflow = terror.ClassTypes.New(mysql.ErrDatetimeFunctionOverflow, mysql.MySQLErrName[mysql.ErrDatetimeFunctionOverflow]) + ErrTruncatedWrongValue = terror.ClassTypes.New(mysql.ErrTruncatedWrongValue, mysql.MySQLErrName[mysql.ErrTruncatedWrongValue]) ) // Time format without fractional seconds precision. @@ -2450,3 +2451,49 @@ func DateFSP(date string) (fsp int) { } return } + +// DateTimeIsOverflow return if this date is overflow. +// See: https://dev.mysql.com/doc/refman/8.0/en/datetime.html +func DateTimeIsOverflow(sc *stmtctx.StatementContext, date Time) (bool, error) { + tz := sc.TimeZone + if tz == nil { + tz = gotime.Local + } + + var err error + var b, e, t gotime.Time + switch date.Type { + case mysql.TypeDate, mysql.TypeDatetime: + if b, err = MinDatetime.GoTime(tz); err != nil { + return false, err + } + if e, err = MaxDatetime.GoTime(tz); err != nil { + return false, err + } + case mysql.TypeTimestamp: + minTS, maxTS := MinTimestamp, MaxTimestamp + if tz != gotime.UTC { + if err = minTS.ConvertTimeZone(gotime.UTC, tz); err != nil { + return false, err + } + if err = maxTS.ConvertTimeZone(gotime.UTC, tz); err != nil { + return false, err + } + } + if b, err = minTS.Time.GoTime(tz); err != nil { + return false, err + } + if e, err = maxTS.Time.GoTime(tz); err != nil { + return false, err + } + default: + return false, nil + } + + if t, err = date.Time.GoTime(tz); err != nil { + return false, err + } + + inRange := (t.After(b) || t.Equal(b)) && (t.Before(e) || t.Equal(e)) + return !inRange, nil +}