Skip to content

Commit

Permalink
types/util: use statement context to handle truncate error. (#2147)
Browse files Browse the repository at this point in the history
use statement context to handle truncate error in ConvertTo MyDecimal.
  • Loading branch information
coocood authored Dec 1, 2016
1 parent 3a721da commit e288b35
Show file tree
Hide file tree
Showing 3 changed files with 40 additions and 3 deletions.
34 changes: 34 additions & 0 deletions executor/statement_context_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
// Copyright 2016 PingCAP, Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// See the License for the specific language governing permissions and
// limitations under the License.

package executor_test

import (
. "github.com/pingcap/check"
"github.com/pingcap/tidb/terror"
"github.com/pingcap/tidb/util/testkit"
"github.com/pingcap/tidb/util/types"
)

func (s *testSuite) TestStatementContext(c *C) {
tk := testkit.NewTestKit(c, s.store)
tk.MustExec("use test")
tk.MustExec("create table sc (a int)")
tk.MustExec("insert sc values (1), (2)")
tk.MustQuery("select * from sc where a > cast(1.1 as decimal)").Check(testkit.Rows("2"))
_, err := tk.Exec("update sc set a = 4 where a > cast(1.1 as decimal)")
c.Check(terror.ErrorEqual(err, types.ErrTruncated), IsTrue)
tk.MustExec("set sql_mode = 0")
tk.MustExec("update sc set a = 3 where a > cast(1.1 as decimal)")
tk.MustQuery("select * from sc").Check(testkit.Rows("1", "3"))
}
1 change: 1 addition & 0 deletions util/types/convert_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ type invalidMockType struct {
func Convert(val interface{}, target *FieldType) (v interface{}, err error) {
d := NewDatum(val)
sc := new(variable.StatementContext)
sc.TruncateAsError = true
ret, err := d.ConvertTo(sc, target)
if err != nil {
return ret.GetValue(), errors.Trace(err)
Expand Down
8 changes: 5 additions & 3 deletions util/types/datum.go
Original file line number Diff line number Diff line change
Expand Up @@ -664,7 +664,7 @@ func (d *Datum) ConvertTo(sc *variable.StatementContext, target *FieldType) (Dat
case mysql.TypeBit:
return d.convertToMysqlBit(target)
case mysql.TypeDecimal, mysql.TypeNewDecimal:
return d.convertToMysqlDecimal(target)
return d.convertToMysqlDecimal(sc, target)
case mysql.TypeYear:
return d.convertToMysqlYear(target)
case mysql.TypeEnum:
Expand Down Expand Up @@ -943,7 +943,7 @@ func (d *Datum) convertToMysqlDuration(target *FieldType) (Datum, error) {
return ret, nil
}

func (d *Datum) convertToMysqlDecimal(target *FieldType) (Datum, error) {
func (d *Datum) convertToMysqlDecimal(sc *variable.StatementContext, target *FieldType) (Datum, error) {
var ret Datum
ret.SetLength(target.Flen)
ret.SetFrac(target.Decimal)
Expand Down Expand Up @@ -983,7 +983,9 @@ func (d *Datum) convertToMysqlDecimal(target *FieldType) (Datum, error) {
} else if frac != target.Decimal {
dec.Round(dec, target.Decimal)
if frac > target.Decimal {
err = errors.Trace(ErrTruncated)
if sc.TruncateAsError {
err = errors.Trace(ErrTruncated)
}
}
}
}
Expand Down

0 comments on commit e288b35

Please sign in to comment.