diff --git a/expression/expressions/cast.go b/expression/expressions/cast.go index f2ecd1e8d0b54..8c29b6df39ed1 100644 --- a/expression/expressions/cast.go +++ b/expression/expressions/cast.go @@ -19,7 +19,6 @@ import ( "github.com/pingcap/tidb/context" "github.com/pingcap/tidb/expression" mysql "github.com/pingcap/tidb/mysqldef" - "github.com/pingcap/tidb/util/charset" "github.com/pingcap/tidb/util/types" ) @@ -76,30 +75,9 @@ func (f *FunctionCast) Eval(ctx context.Context, args map[interface{}]interface{ if value == nil { return nil, nil } - - // TODO: we need a better function convert between any two types according to FieldType. - // Not only check Type, but also consider Flen/Decimal/Charset and so on. nv, err := types.Convert(value, f.Tp) if err != nil { return nil, err } - if f.Tp.Tp == mysql.TypeString && f.Tp.Charset == charset.CharsetBin { - nv = []byte(nv.(string)) - } - if f.Tp.Flen != types.UnspecifiedLength { - switch f.Tp.Tp { - case mysql.TypeString: - v := nv.(string) - if len(v) > int(f.Tp.Flen) { - v = v[:f.Tp.Flen] - } - return v, nil - } - } - if f.Tp.Tp == mysql.TypeLonglong { - if mysql.HasUnsignedFlag(f.Tp.Flag) { - return uint64(nv.(int64)), nil - } - } return nv, nil } diff --git a/util/types/convert.go b/util/types/convert.go index d725015e1774b..0cf14d05c63d9 100644 --- a/util/types/convert.go +++ b/util/types/convert.go @@ -19,6 +19,7 @@ import ( "github.com/juju/errors" mysql "github.com/pingcap/tidb/mysqldef" + "github.com/pingcap/tidb/util/charset" ) // InvConv returns a failed convertion error. @@ -71,6 +72,10 @@ func Convert(val interface{}, target *FieldType) (v interface{}, err error) { // } // TODO: consider target.Charset/Collate x = truncateStr(x, target.Flen) + if target.Charset == charset.CharsetBin { + bx := []byte(x) + return bx, nil + } return x, nil case mysql.TypeBlob: x, err := ToString(val) @@ -78,6 +83,10 @@ func Convert(val interface{}, target *FieldType) (v interface{}, err error) { // return InvConv(val, tp) } x = truncateStr(x, target.Flen) + if target.Charset == charset.CharsetBin { + bx := []byte(x) + return bx, nil + } return x, nil case mysql.TypeDuration: fsp := mysql.DefaultFsp @@ -127,6 +136,10 @@ func Convert(val interface{}, target *FieldType) (v interface{}, err error) { // if err != nil { return InvConv(val, tp) } + // TODO: We should first convert to uint64 then check unsigned flag. + if mysql.HasUnsignedFlag(target.Flag) { + return uint64(x), nil + } return x, nil case mysql.TypeNewDecimal: x, err := ToDecimal(val) diff --git a/util/types/convert_test.go b/util/types/convert_test.go index d5c69830b28e9..2c7aed98a92d4 100644 --- a/util/types/convert_test.go +++ b/util/types/convert_test.go @@ -36,6 +36,12 @@ func (s *testTypeConvertSuite) TestConvertType(c *C) { v, err := Convert("123456", ft) c.Assert(err, IsNil) c.Assert(v, Equals, "1234") + ft = NewFieldType(mysql.TypeString) + ft.Flen = 4 + ft.Charset = charset.CharsetBin + v, err = Convert("12345", ft) + c.Assert(err, IsNil) + c.Assert(v, DeepEquals, []byte("1234")) ft = NewFieldType(mysql.TypeFloat) ft.Flen = 5 @@ -111,6 +117,12 @@ func (s *testTypeConvertSuite) TestConvertType(c *C) { v, err = Convert("12345", ft) c.Assert(err, IsNil) c.Assert(v, Equals, "123") + ft = NewFieldType(mysql.TypeString) + ft.Flen = 3 + ft.Charset = charset.CharsetBin + v, err = Convert("12345", ft) + c.Assert(err, IsNil) + c.Assert(v, DeepEquals, []byte("123")) // For TypeDuration ft = NewFieldType(mysql.TypeDuration) @@ -146,6 +158,11 @@ func (s *testTypeConvertSuite) TestConvertType(c *C) { v, err = Convert("100", ft) c.Assert(err, IsNil) c.Assert(v, Equals, int64(100)) + ft = NewFieldType(mysql.TypeLonglong) + ft.Flag |= mysql.UnsignedFlag + v, err = Convert("100", ft) + c.Assert(err, IsNil) + c.Assert(v, Equals, uint64(100)) // For TypeNewDecimal ft = NewFieldType(mysql.TypeNewDecimal)