From 384be78d73aaeb9b2f4dbfba46d24d99c2910cc1 Mon Sep 17 00:00:00 2001 From: amyangfei Date: Mon, 29 Jul 2019 10:14:52 +0800 Subject: [PATCH] types: fix string to integer cast (#11295) (#11469) --- executor/executor.go | 2 + executor/executor_test.go | 1 + executor/point_get_test.go | 2 + expression/builtin_cast_test.go | 4 ++ planner/core/point_get_plan.go | 5 +- sessionctx/stmtctx/stmtctx.go | 5 ++ types/convert.go | 39 +++++++++-- types/convert_test.go | 110 ++++++++++++++++++++++++++++++++ types/datum.go | 6 +- 9 files changed, 165 insertions(+), 9 deletions(-) diff --git a/executor/executor.go b/executor/executor.go index 02abdf1079843..bbe9e99d8d0fa 100644 --- a/executor/executor.go +++ b/executor/executor.go @@ -1395,8 +1395,10 @@ func ResetContextOfStmt(ctx sessionctx.Context, s ast.StmtNode) (err error) { sc.NotFillCache = !opts.SQLCache } sc.PadCharToFullLength = ctx.GetSessionVars().SQLMode.HasPadCharToFullLengthMode() + sc.CastStrToIntStrict = true case *ast.ExplainStmt: sc.InExplainStmt = true + sc.CastStrToIntStrict = true case *ast.ShowStmt: sc.IgnoreTruncate = true sc.IgnoreZeroInDate = true diff --git a/executor/executor_test.go b/executor/executor_test.go index b08f4cd093581..36636c431832b 100644 --- a/executor/executor_test.go +++ b/executor/executor_test.go @@ -363,6 +363,7 @@ func checkCases(tests []testCase, ld *executor.LoadDataInfo, ctx.GetSessionVars().StmtCtx.DupKeyAsWarning = true ctx.GetSessionVars().StmtCtx.BadNullAsWarning = true ctx.GetSessionVars().StmtCtx.InLoadDataStmt = true + ctx.GetSessionVars().StmtCtx.InDeleteStmt = false data, reachLimit, err1 := ld.InsertData(context.Background(), tt.data1, tt.data2) c.Assert(err1, IsNil) c.Assert(reachLimit, IsFalse) diff --git a/executor/point_get_test.go b/executor/point_get_test.go index b14a7de0cadaf..e4df3c7a2c1bc 100644 --- a/executor/point_get_test.go +++ b/executor/point_get_test.go @@ -399,4 +399,6 @@ func (s *testPointGetSuite) TestIssue10677(c *C) { tk.MustQuery("select * from t where pk = 1").Check(testkit.Rows("1")) tk.MustQuery("desc select * from t where pk = '1'").Check(testkit.Rows("Point_Get_1 1.00 root table:t, handle:1")) tk.MustQuery("select * from t where pk = '1'").Check(testkit.Rows("1")) + tk.MustQuery("desc select * from t where pk = '1.0'").Check(testkit.Rows("Point_Get_1 1.00 root table:t, handle:1")) + tk.MustQuery("select * from t where pk = '1.0'").Check(testkit.Rows("1")) } diff --git a/expression/builtin_cast_test.go b/expression/builtin_cast_test.go index 05207acfde41a..c0636e4b34a04 100644 --- a/expression/builtin_cast_test.go +++ b/expression/builtin_cast_test.go @@ -88,7 +88,11 @@ func (s *testEvaluatorSuite) TestCast(c *C) { c.Assert(terror.ErrorEqual(errWarnAllowedPacketOverflowed, lastWarn.Err), IsTrue, Commentf("err %v", lastWarn.Err)) origSc := sc + oldInSelectStmt := sc.InSelectStmt sc.InSelectStmt = true + defer func() { + sc.InSelectStmt = oldInSelectStmt + }() sc.OverflowAsWarning = true // cast('18446744073709551616' as unsigned); diff --git a/planner/core/point_get_plan.go b/planner/core/point_get_plan.go index 7be3ba1e49b07..ae2cfbb733be2 100644 --- a/planner/core/point_get_plan.go +++ b/planner/core/point_get_plan.go @@ -219,7 +219,10 @@ func tryPointGetPlan(ctx sessionctx.Context, selStmt *ast.SelectStmt) *PointGetP p.IsTableDual = true return p } - return nil + // some scenarios cast to int with error, but we may use this value in point get + if !terror.ErrorEqual(types.ErrTruncatedWrongVal, err) { + return nil + } } cmp, err := intDatum.CompareDatum(ctx.GetSessionVars().StmtCtx, &handlePair.value) if err != nil { diff --git a/sessionctx/stmtctx/stmtctx.go b/sessionctx/stmtctx/stmtctx.go index 9be31482c77c6..8e08cb30a0a43 100644 --- a/sessionctx/stmtctx/stmtctx.go +++ b/sessionctx/stmtctx/stmtctx.go @@ -69,6 +69,11 @@ type StatementContext struct { BatchCheck bool InNullRejectCheck bool AllowInvalidDate bool + // CastStrToIntStrict is used to control the way we cast float format string to int. + // If ConvertStrToIntStrict is false, we convert it to a valid float string first, + // then cast the float string to int string. Otherwise, we cast string to integer + // prefix in a strict way, only extract 0-9 and (+ or - in first bit). + CastStrToIntStrict bool // mu struct holds variables that change during execution. mu struct { diff --git a/types/convert.go b/types/convert.go index c86b342108704..c42ebad12229d 100644 --- a/types/convert.go +++ b/types/convert.go @@ -362,11 +362,37 @@ func NumberToDuration(number int64, fsp int) (Duration, error) { // getValidIntPrefix gets prefix of the string which can be successfully parsed as int. func getValidIntPrefix(sc *stmtctx.StatementContext, str string) (string, error) { - floatPrefix, err := getValidFloatPrefix(sc, str) - if err != nil { - return floatPrefix, errors.Trace(err) + if !sc.CastStrToIntStrict { + floatPrefix, err := getValidFloatPrefix(sc, str) + if err != nil { + return floatPrefix, errors.Trace(err) + } + return floatStrToIntStr(sc, floatPrefix, str) + } + + validLen := 0 + + for i := 0; i < len(str); i++ { + c := str[i] + if (c == '+' || c == '-') && i == 0 { + continue + } + + if c >= '0' && c <= '9' { + validLen = i + 1 + continue + } + + break } - return floatStrToIntStr(sc, floatPrefix, str) + valid := str[:validLen] + if valid == "" { + valid = "0" + } + if validLen == 0 || validLen != len(str) { + return valid, errors.Trace(handleTruncateError(sc, ErrTruncatedWrongVal.GenWithStackByArgs("INTEGER", str))) + } + return valid, nil } // roundIntStr is to round a **valid int string** base on the number following dot. @@ -587,6 +613,9 @@ func ConvertJSONToDecimal(sc *stmtctx.StatementContext, j json.BinaryJSON) (*MyD // getValidFloatPrefix gets prefix of string which can be successfully parsed as float. func getValidFloatPrefix(sc *stmtctx.StatementContext, s string) (valid string, err error) { + if (sc.InDeleteStmt || sc.InSelectStmt || sc.InUpdateStmt) && s == "" { + return "0", nil + } var ( sawDot bool sawDigit bool @@ -627,7 +656,7 @@ func getValidFloatPrefix(sc *stmtctx.StatementContext, s string) (valid string, valid = "0" } if validLen == 0 || validLen != len(s) { - err = errors.Trace(handleTruncateError(sc)) + err = errors.Trace(handleTruncateError(sc, ErrTruncated)) } return valid, err } diff --git a/types/convert_test.go b/types/convert_test.go index a14371b88e7a5..9d292632e518e 100644 --- a/types/convert_test.go +++ b/types/convert_test.go @@ -461,6 +461,42 @@ func (s *testTypeConvertSuite) TestStrToNum(c *C) { testStrToFloat(c, "1e649", math.MaxFloat64, false, nil) testStrToFloat(c, "-1e649", -math.MaxFloat64, true, ErrTruncatedWrongVal) testStrToFloat(c, "-1e649", -math.MaxFloat64, false, nil) + + // for issue #10806, #11179 + testSelectUpdateDeleteEmptyStringError(c) +} + +func testSelectUpdateDeleteEmptyStringError(c *C) { + testCases := []struct { + inSelect bool + inUpdate bool + inDelete bool + }{ + {true, false, false}, + {false, true, false}, + {false, false, true}, + } + sc := new(stmtctx.StatementContext) + for _, tc := range testCases { + sc.InSelectStmt = tc.inSelect + sc.InUpdateStmt = tc.inUpdate + sc.InDeleteStmt = tc.inDelete + + str := "" + expect := 0 + + val, err := StrToInt(sc, str) + c.Assert(err, IsNil) + c.Assert(val, Equals, int64(expect)) + + val1, err := StrToUint(sc, str) + c.Assert(err, IsNil) + c.Assert(val1, Equals, uint64(expect)) + + val2, err := StrToFloat(sc, str) + c.Assert(err, IsNil) + c.Assert(val2, Equals, float64(expect)) + } } func (s *testTypeConvertSuite) TestFieldTypeToStr(c *C) { @@ -666,6 +702,80 @@ func (s *testTypeConvertSuite) TestConvert(c *C) { signedAccept(c, mysql.TypeNewDecimal, dec, "-0.00123") } +func (s *testTypeConvertSuite) TestGetValidInt(c *C) { + tests := []struct { + origin string + valid string + warning bool + }{ + {"100", "100", false}, + {"-100", "-100", false}, + {"1abc", "1", true}, + {"-1-1", "-1", true}, + {"+1+1", "+1", true}, + {"123..34", "123", true}, + {"123.23E-10", "123", true}, + {"1.1e1.3", "1", true}, + {"11e1.3", "11", true}, + {"1.", "1", true}, + {".1", "0", true}, + {"", "0", true}, + {"123e+", "123", true}, + {"123de", "123", true}, + } + sc := new(stmtctx.StatementContext) + sc.TruncateAsWarning = true + sc.CastStrToIntStrict = true + warningCount := 0 + for _, tt := range tests { + prefix, err := getValidIntPrefix(sc, tt.origin) + c.Assert(err, IsNil) + c.Assert(prefix, Equals, tt.valid) + _, err = strconv.ParseInt(prefix, 10, 64) + c.Assert(err, IsNil) + warnings := sc.GetWarnings() + if tt.warning { + c.Assert(warnings, HasLen, warningCount+1) + c.Assert(terror.ErrorEqual(warnings[len(warnings)-1].Err, ErrTruncatedWrongVal), IsTrue) + warningCount += 1 + } else { + c.Assert(warnings, HasLen, warningCount) + } + } + + tests2 := []struct { + origin string + valid string + warning bool + }{ + {"100", "100", false}, + {"-100", "-100", false}, + {"1abc", "1", true}, + {"-1-1", "-1", true}, + {"+1+1", "+1", true}, + {"123..34", "123.", true}, + {"123.23E-10", "0", false}, + {"1.1e1.3", "1.1e1", true}, + {"11e1.3", "11e1", true}, + {"1.", "1", false}, + {".1", "0", false}, + {"", "0", true}, + {"123e+", "123", true}, + {"123de", "123", true}, + } + sc.TruncateAsWarning = false + sc.CastStrToIntStrict = false + for _, tt := range tests2 { + prefix, err := getValidIntPrefix(sc, tt.origin) + if tt.warning { + c.Assert(terror.ErrorEqual(err, ErrTruncated), IsTrue) + } else { + c.Assert(err, IsNil) + } + c.Assert(prefix, Equals, tt.valid) + } +} + func (s *testTypeConvertSuite) TestRoundIntStr(c *C) { cases := []struct { a string diff --git a/types/datum.go b/types/datum.go index 9819a59a1f442..0d2053f06e954 100644 --- a/types/datum.go +++ b/types/datum.go @@ -1781,14 +1781,14 @@ func (ds *datumsSorter) Swap(i, j int) { ds.datums[i], ds.datums[j] = ds.datums[j], ds.datums[i] } -func handleTruncateError(sc *stmtctx.StatementContext) error { +func handleTruncateError(sc *stmtctx.StatementContext, err error) error { if sc.IgnoreTruncate { return nil } if !sc.TruncateAsWarning { - return ErrTruncated + return err } - sc.AppendWarning(ErrTruncated) + sc.AppendWarning(err) return nil }