diff --git a/executor/statement_context_test.go b/executor/statement_context_test.go new file mode 100644 index 0000000000000..6f6f2f44068d1 --- /dev/null +++ b/executor/statement_context_test.go @@ -0,0 +1,34 @@ +// Copyright 2016 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// See the License for the specific language governing permissions and +// limitations under the License. + +package executor_test + +import ( + . "github.com/pingcap/check" + "github.com/pingcap/tidb/terror" + "github.com/pingcap/tidb/util/testkit" + "github.com/pingcap/tidb/util/types" +) + +func (s *testSuite) TestStatementContext(c *C) { + tk := testkit.NewTestKit(c, s.store) + tk.MustExec("use test") + tk.MustExec("create table sc (a int)") + tk.MustExec("insert sc values (1), (2)") + tk.MustQuery("select * from sc where a > cast(1.1 as decimal)").Check(testkit.Rows("2")) + _, err := tk.Exec("update sc set a = 4 where a > cast(1.1 as decimal)") + c.Check(terror.ErrorEqual(err, types.ErrTruncated), IsTrue) + tk.MustExec("set sql_mode = 0") + tk.MustExec("update sc set a = 3 where a > cast(1.1 as decimal)") + tk.MustQuery("select * from sc").Check(testkit.Rows("1", "3")) +} diff --git a/util/types/convert_test.go b/util/types/convert_test.go index 97a2e8d77effa..ad1f370f76cec 100644 --- a/util/types/convert_test.go +++ b/util/types/convert_test.go @@ -39,6 +39,7 @@ type invalidMockType struct { func Convert(val interface{}, target *FieldType) (v interface{}, err error) { d := NewDatum(val) sc := new(variable.StatementContext) + sc.TruncateAsError = true ret, err := d.ConvertTo(sc, target) if err != nil { return ret.GetValue(), errors.Trace(err) diff --git a/util/types/datum.go b/util/types/datum.go index 1372b1c2425d0..7548b50715dcc 100644 --- a/util/types/datum.go +++ b/util/types/datum.go @@ -664,7 +664,7 @@ func (d *Datum) ConvertTo(sc *variable.StatementContext, target *FieldType) (Dat case mysql.TypeBit: return d.convertToMysqlBit(target) case mysql.TypeDecimal, mysql.TypeNewDecimal: - return d.convertToMysqlDecimal(target) + return d.convertToMysqlDecimal(sc, target) case mysql.TypeYear: return d.convertToMysqlYear(target) case mysql.TypeEnum: @@ -943,7 +943,7 @@ func (d *Datum) convertToMysqlDuration(target *FieldType) (Datum, error) { return ret, nil } -func (d *Datum) convertToMysqlDecimal(target *FieldType) (Datum, error) { +func (d *Datum) convertToMysqlDecimal(sc *variable.StatementContext, target *FieldType) (Datum, error) { var ret Datum ret.SetLength(target.Flen) ret.SetFrac(target.Decimal) @@ -983,7 +983,9 @@ func (d *Datum) convertToMysqlDecimal(target *FieldType) (Datum, error) { } else if frac != target.Decimal { dec.Round(dec, target.Decimal) if frac > target.Decimal { - err = errors.Trace(ErrTruncated) + if sc.TruncateAsError { + err = errors.Trace(ErrTruncated) + } } } }