Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

expression: check if period is valid in period_add #10430

Merged
merged 6 commits into from
May 15, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 10 additions & 4 deletions expression/builtin_time.go
Original file line number Diff line number Diff line change
Expand Up @@ -4887,6 +4887,11 @@ func (c *periodAddFunctionClass) getFunction(ctx sessionctx.Context, args []Expr
return sig, nil
}

// validPeriod checks if this period is valid, it comes from MySQL 8.0+.
func validPeriod(p int64) bool {
return !(p < 0 || p%100 == 0 || p%100 > 12)
}

// period2Month converts a period to months, in which period is represented in the format of YYMM or YYYYMM.
// Note that the period argument is not a date value.
func period2Month(period uint64) uint64 {
Expand Down Expand Up @@ -4938,15 +4943,16 @@ func (b *builtinPeriodAddSig) evalInt(row chunk.Row) (int64, bool, error) {
return 0, true, errors.Trace(err)
}

if p == 0 {
return 0, false, nil
}

n, isNull, err := b.args[1].EvalInt(b.ctx, row)
if isNull || err != nil {
return 0, true, errors.Trace(err)
}

// in MySQL, if p is invalid but n is NULL, the result is NULL, so we have to check if n is NULL first.
if !validPeriod(p) {
return 0, false, errIncorrectArgs.GenWithStackByArgs("period_add")
}

sumMonth := int64(period2Month(uint64(p))) + n
return int64(month2Period(uint64(sumMonth))), false, nil
}
Expand Down
4 changes: 2 additions & 2 deletions expression/builtin_time_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2146,8 +2146,8 @@ func (s *testEvaluatorSuite) TestPeriodAdd(c *C) {
{201611, -13, true, 201510},
{1611, 3, true, 201702},
{7011, 3, true, 197102},
{12323, 10, true, 12509},
{0, 3, true, 0},
{12323, 10, false, 0},
{0, 3, false, 0},
}

fc := funcs[ast.PeriodAdd]
Expand Down
14 changes: 10 additions & 4 deletions expression/integration_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1450,10 +1450,16 @@ func (s *testIntegrationSuite) TestTimeBuiltin(c *C) {
result.Check(testkit.Rows("123456 10 <nil> <nil>"))

// for period_add
result = tk.MustQuery(`SELECT period_add(191, 2), period_add(191, -2), period_add(0, 20), period_add(0, 0);`)
result.Check(testkit.Rows("200809 200805 0 0"))
result = tk.MustQuery(`SELECT period_add(NULL, 2), period_add(-191, NULL), period_add(NULL, NULL), period_add(12.09, -2), period_add("21aa", "11aa"), period_add("", "");`)
result.Check(testkit.Rows("<nil> <nil> <nil> 200010 200208 0"))
result = tk.MustQuery(`SELECT period_add(200807, 2), period_add(200807, -2);`)
result.Check(testkit.Rows("200809 200805"))
result = tk.MustQuery(`SELECT period_add(NULL, 2), period_add(-191, NULL), period_add(NULL, NULL), period_add(12.09, -2), period_add("200207aa", "1aa");`)
result.Check(testkit.Rows("<nil> <nil> <nil> 200010 200208"))
for _, errPeriod := range []string{
"period_add(0, 20)", "period_add(0, 0)", "period_add(-1, 1)", "period_add(200013, 1)", "period_add(-200012, 1)", "period_add('', '')",
} {
err := tk.QueryToErr(fmt.Sprintf("SELECT %v;", errPeriod))
c.Assert(err.Error(), Equals, "[expression:1210]Incorrect arguments to period_add")
}

// for period_diff
result = tk.MustQuery(`SELECT period_diff(191, 2), period_diff(191, -2), period_diff(0, 0), period_diff(191, 191);`)
Expand Down
11 changes: 11 additions & 0 deletions util/testkit/testkit.go
Original file line number Diff line number Diff line change
Expand Up @@ -203,6 +203,17 @@ func (tk *TestKit) MustQuery(sql string, args ...interface{}) *Result {
return tk.ResultSetToResult(rs, comment)
}

// QueryToErr executes a sql statement and discard results.
func (tk *TestKit) QueryToErr(sql string, args ...interface{}) error {
comment := check.Commentf("sql:%s, args:%v", sql, args)
res, err := tk.Exec(sql, args...)
tk.c.Assert(errors.ErrorStack(err), check.Equals, "", comment)
tk.c.Assert(res, check.NotNil, comment)
_, resErr := session.GetRows4Test(context.Background(), tk.Se, res)
tk.c.Assert(res.Close(), check.IsNil)
return resErr
}

// ResultSetToResult converts sqlexec.RecordSet to testkit.Result.
// It is used to check results of execute statement in binary mode.
func (tk *TestKit) ResultSetToResult(rs sqlexec.RecordSet, comment check.CommentInterface) *Result {
Expand Down