Skip to content

Commit

Permalink
types: fix string to integer cast (#11295) (#11469)
Browse files Browse the repository at this point in the history
  • Loading branch information
amyangfei authored and zz-jason committed Jul 29, 2019
1 parent 6766735 commit 384be78
Show file tree
Hide file tree
Showing 9 changed files with 165 additions and 9 deletions.
2 changes: 2 additions & 0 deletions executor/executor.go
Original file line number Diff line number Diff line change
Expand Up @@ -1395,8 +1395,10 @@ func ResetContextOfStmt(ctx sessionctx.Context, s ast.StmtNode) (err error) {
sc.NotFillCache = !opts.SQLCache
}
sc.PadCharToFullLength = ctx.GetSessionVars().SQLMode.HasPadCharToFullLengthMode()
sc.CastStrToIntStrict = true
case *ast.ExplainStmt:
sc.InExplainStmt = true
sc.CastStrToIntStrict = true
case *ast.ShowStmt:
sc.IgnoreTruncate = true
sc.IgnoreZeroInDate = true
Expand Down
1 change: 1 addition & 0 deletions executor/executor_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -363,6 +363,7 @@ func checkCases(tests []testCase, ld *executor.LoadDataInfo,
ctx.GetSessionVars().StmtCtx.DupKeyAsWarning = true
ctx.GetSessionVars().StmtCtx.BadNullAsWarning = true
ctx.GetSessionVars().StmtCtx.InLoadDataStmt = true
ctx.GetSessionVars().StmtCtx.InDeleteStmt = false
data, reachLimit, err1 := ld.InsertData(context.Background(), tt.data1, tt.data2)
c.Assert(err1, IsNil)
c.Assert(reachLimit, IsFalse)
Expand Down
2 changes: 2 additions & 0 deletions executor/point_get_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -399,4 +399,6 @@ func (s *testPointGetSuite) TestIssue10677(c *C) {
tk.MustQuery("select * from t where pk = 1").Check(testkit.Rows("1"))
tk.MustQuery("desc select * from t where pk = '1'").Check(testkit.Rows("Point_Get_1 1.00 root table:t, handle:1"))
tk.MustQuery("select * from t where pk = '1'").Check(testkit.Rows("1"))
tk.MustQuery("desc select * from t where pk = '1.0'").Check(testkit.Rows("Point_Get_1 1.00 root table:t, handle:1"))
tk.MustQuery("select * from t where pk = '1.0'").Check(testkit.Rows("1"))
}
4 changes: 4 additions & 0 deletions expression/builtin_cast_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,11 @@ func (s *testEvaluatorSuite) TestCast(c *C) {
c.Assert(terror.ErrorEqual(errWarnAllowedPacketOverflowed, lastWarn.Err), IsTrue, Commentf("err %v", lastWarn.Err))

origSc := sc
oldInSelectStmt := sc.InSelectStmt
sc.InSelectStmt = true
defer func() {
sc.InSelectStmt = oldInSelectStmt
}()
sc.OverflowAsWarning = true

// cast('18446744073709551616' as unsigned);
Expand Down
5 changes: 4 additions & 1 deletion planner/core/point_get_plan.go
Original file line number Diff line number Diff line change
Expand Up @@ -219,7 +219,10 @@ func tryPointGetPlan(ctx sessionctx.Context, selStmt *ast.SelectStmt) *PointGetP
p.IsTableDual = true
return p
}
return nil
// some scenarios cast to int with error, but we may use this value in point get
if !terror.ErrorEqual(types.ErrTruncatedWrongVal, err) {
return nil
}
}
cmp, err := intDatum.CompareDatum(ctx.GetSessionVars().StmtCtx, &handlePair.value)
if err != nil {
Expand Down
5 changes: 5 additions & 0 deletions sessionctx/stmtctx/stmtctx.go
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,11 @@ type StatementContext struct {
BatchCheck bool
InNullRejectCheck bool
AllowInvalidDate bool
// CastStrToIntStrict is used to control the way we cast float format string to int.
// If ConvertStrToIntStrict is false, we convert it to a valid float string first,
// then cast the float string to int string. Otherwise, we cast string to integer
// prefix in a strict way, only extract 0-9 and (+ or - in first bit).
CastStrToIntStrict bool

// mu struct holds variables that change during execution.
mu struct {
Expand Down
39 changes: 34 additions & 5 deletions types/convert.go
Original file line number Diff line number Diff line change
Expand Up @@ -362,11 +362,37 @@ func NumberToDuration(number int64, fsp int) (Duration, error) {

// getValidIntPrefix gets prefix of the string which can be successfully parsed as int.
func getValidIntPrefix(sc *stmtctx.StatementContext, str string) (string, error) {
floatPrefix, err := getValidFloatPrefix(sc, str)
if err != nil {
return floatPrefix, errors.Trace(err)
if !sc.CastStrToIntStrict {
floatPrefix, err := getValidFloatPrefix(sc, str)
if err != nil {
return floatPrefix, errors.Trace(err)
}
return floatStrToIntStr(sc, floatPrefix, str)
}

validLen := 0

for i := 0; i < len(str); i++ {
c := str[i]
if (c == '+' || c == '-') && i == 0 {
continue
}

if c >= '0' && c <= '9' {
validLen = i + 1
continue
}

break
}
return floatStrToIntStr(sc, floatPrefix, str)
valid := str[:validLen]
if valid == "" {
valid = "0"
}
if validLen == 0 || validLen != len(str) {
return valid, errors.Trace(handleTruncateError(sc, ErrTruncatedWrongVal.GenWithStackByArgs("INTEGER", str)))
}
return valid, nil
}

// roundIntStr is to round a **valid int string** base on the number following dot.
Expand Down Expand Up @@ -587,6 +613,9 @@ func ConvertJSONToDecimal(sc *stmtctx.StatementContext, j json.BinaryJSON) (*MyD

// getValidFloatPrefix gets prefix of string which can be successfully parsed as float.
func getValidFloatPrefix(sc *stmtctx.StatementContext, s string) (valid string, err error) {
if (sc.InDeleteStmt || sc.InSelectStmt || sc.InUpdateStmt) && s == "" {
return "0", nil
}
var (
sawDot bool
sawDigit bool
Expand Down Expand Up @@ -627,7 +656,7 @@ func getValidFloatPrefix(sc *stmtctx.StatementContext, s string) (valid string,
valid = "0"
}
if validLen == 0 || validLen != len(s) {
err = errors.Trace(handleTruncateError(sc))
err = errors.Trace(handleTruncateError(sc, ErrTruncated))
}
return valid, err
}
Expand Down
110 changes: 110 additions & 0 deletions types/convert_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -461,6 +461,42 @@ func (s *testTypeConvertSuite) TestStrToNum(c *C) {
testStrToFloat(c, "1e649", math.MaxFloat64, false, nil)
testStrToFloat(c, "-1e649", -math.MaxFloat64, true, ErrTruncatedWrongVal)
testStrToFloat(c, "-1e649", -math.MaxFloat64, false, nil)

// for issue #10806, #11179
testSelectUpdateDeleteEmptyStringError(c)
}

func testSelectUpdateDeleteEmptyStringError(c *C) {
testCases := []struct {
inSelect bool
inUpdate bool
inDelete bool
}{
{true, false, false},
{false, true, false},
{false, false, true},
}
sc := new(stmtctx.StatementContext)
for _, tc := range testCases {
sc.InSelectStmt = tc.inSelect
sc.InUpdateStmt = tc.inUpdate
sc.InDeleteStmt = tc.inDelete

str := ""
expect := 0

val, err := StrToInt(sc, str)
c.Assert(err, IsNil)
c.Assert(val, Equals, int64(expect))

val1, err := StrToUint(sc, str)
c.Assert(err, IsNil)
c.Assert(val1, Equals, uint64(expect))

val2, err := StrToFloat(sc, str)
c.Assert(err, IsNil)
c.Assert(val2, Equals, float64(expect))
}
}

func (s *testTypeConvertSuite) TestFieldTypeToStr(c *C) {
Expand Down Expand Up @@ -666,6 +702,80 @@ func (s *testTypeConvertSuite) TestConvert(c *C) {
signedAccept(c, mysql.TypeNewDecimal, dec, "-0.00123")
}

func (s *testTypeConvertSuite) TestGetValidInt(c *C) {
tests := []struct {
origin string
valid string
warning bool
}{
{"100", "100", false},
{"-100", "-100", false},
{"1abc", "1", true},
{"-1-1", "-1", true},
{"+1+1", "+1", true},
{"123..34", "123", true},
{"123.23E-10", "123", true},
{"1.1e1.3", "1", true},
{"11e1.3", "11", true},
{"1.", "1", true},
{".1", "0", true},
{"", "0", true},
{"123e+", "123", true},
{"123de", "123", true},
}
sc := new(stmtctx.StatementContext)
sc.TruncateAsWarning = true
sc.CastStrToIntStrict = true
warningCount := 0
for _, tt := range tests {
prefix, err := getValidIntPrefix(sc, tt.origin)
c.Assert(err, IsNil)
c.Assert(prefix, Equals, tt.valid)
_, err = strconv.ParseInt(prefix, 10, 64)
c.Assert(err, IsNil)
warnings := sc.GetWarnings()
if tt.warning {
c.Assert(warnings, HasLen, warningCount+1)
c.Assert(terror.ErrorEqual(warnings[len(warnings)-1].Err, ErrTruncatedWrongVal), IsTrue)
warningCount += 1
} else {
c.Assert(warnings, HasLen, warningCount)
}
}

tests2 := []struct {
origin string
valid string
warning bool
}{
{"100", "100", false},
{"-100", "-100", false},
{"1abc", "1", true},
{"-1-1", "-1", true},
{"+1+1", "+1", true},
{"123..34", "123.", true},
{"123.23E-10", "0", false},
{"1.1e1.3", "1.1e1", true},
{"11e1.3", "11e1", true},
{"1.", "1", false},
{".1", "0", false},
{"", "0", true},
{"123e+", "123", true},
{"123de", "123", true},
}
sc.TruncateAsWarning = false
sc.CastStrToIntStrict = false
for _, tt := range tests2 {
prefix, err := getValidIntPrefix(sc, tt.origin)
if tt.warning {
c.Assert(terror.ErrorEqual(err, ErrTruncated), IsTrue)
} else {
c.Assert(err, IsNil)
}
c.Assert(prefix, Equals, tt.valid)
}
}

func (s *testTypeConvertSuite) TestRoundIntStr(c *C) {
cases := []struct {
a string
Expand Down
6 changes: 3 additions & 3 deletions types/datum.go
Original file line number Diff line number Diff line change
Expand Up @@ -1781,14 +1781,14 @@ func (ds *datumsSorter) Swap(i, j int) {
ds.datums[i], ds.datums[j] = ds.datums[j], ds.datums[i]
}

func handleTruncateError(sc *stmtctx.StatementContext) error {
func handleTruncateError(sc *stmtctx.StatementContext, err error) error {
if sc.IgnoreTruncate {
return nil
}
if !sc.TruncateAsWarning {
return ErrTruncated
return err
}
sc.AppendWarning(ErrTruncated)
sc.AppendWarning(err)
return nil
}

Expand Down

0 comments on commit 384be78

Please sign in to comment.