From 0be031a5005fd1ce95d1240b13ad121a77814c71 Mon Sep 17 00:00:00 2001 From: pingcap-github-bot Date: Thu, 30 Apr 2020 14:56:56 +0800 Subject: [PATCH] *: fix the problem that PointGet returns wrong results in the case of overflow (#14776) (#16753) --- executor/point_get_test.go | 40 ++++++++---- planner/core/point_get_plan.go | 115 +++++++++++++++++++-------------- util/testkit/testkit.go | 2 +- 3 files changed, 94 insertions(+), 63 deletions(-) diff --git a/executor/point_get_test.go b/executor/point_get_test.go index 8b5cf8bcdc5d8..f999981175e04 100644 --- a/executor/point_get_test.go +++ b/executor/point_get_test.go @@ -137,10 +137,10 @@ func (s *testPointGetSuite) TestPointGetCharPK(c *C) { tk.MustExec(`set @@sql_mode="";`) tk.MustPointGet(`select * from t where a = "a";`).Check(testkit.Rows(`a b`)) tk.MustPointGet(`select * from t where a = "a ";`).Check(testkit.Rows()) - tk.MustPointGet(`select * from t where a = "a ";`).Check(testkit.Rows()) + tk.MustTableDual(`select * from t where a = "a ";`).Check(testkit.Rows()) tk.MustPointGet(`select * from t where a = "";`).Check(testkit.Rows(` `)) tk.MustPointGet(`select * from t where a = " ";`).Check(testkit.Rows()) - tk.MustPointGet(`select * from t where a = " ";`).Check(testkit.Rows()) + tk.MustTableDual(`select * from t where a = " ";`).Check(testkit.Rows()) } @@ -153,7 +153,7 @@ func (s *testPointGetSuite) TestPointGetAliasTableCharPK(c *C) { tk.MustExec(`set @@sql_mode="";`) tk.MustPointGet(`select * from t tmp where a = "aa";`).Check(testkit.Rows(`aa bb`)) - tk.MustPointGet(`select * from t tmp where a = "aab";`).Check(testkit.Rows()) + tk.MustTableDual(`select * from t tmp where a = "aab";`).Check(testkit.Rows()) tk.MustExec(`truncate table t;`) tk.MustExec(`insert into t values("a ", "b ");`) @@ -161,7 +161,7 @@ func (s *testPointGetSuite) TestPointGetAliasTableCharPK(c *C) { tk.MustExec(`set @@sql_mode="";`) tk.MustPointGet(`select * from t tmp where a = "a";`).Check(testkit.Rows(`a b`)) tk.MustPointGet(`select * from t tmp where a = "a ";`).Check(testkit.Rows()) - tk.MustPointGet(`select * from t tmp where a = "a ";`).Check(testkit.Rows()) + tk.MustTableDual(`select * from t tmp where a = "a ";`).Check(testkit.Rows()) // Test CHAR BINARY. tk.MustExec(`drop table if exists t;`) @@ -172,10 +172,10 @@ func (s *testPointGetSuite) TestPointGetAliasTableCharPK(c *C) { tk.MustExec(`set @@sql_mode="";`) tk.MustPointGet(`select * from t tmp where a = "a";`).Check(testkit.Rows(`a b`)) tk.MustPointGet(`select * from t tmp where a = "a ";`).Check(testkit.Rows()) - tk.MustPointGet(`select * from t tmp where a = "a ";`).Check(testkit.Rows()) + tk.MustTableDual(`select * from t tmp where a = "a ";`).Check(testkit.Rows()) tk.MustPointGet(`select * from t tmp where a = "";`).Check(testkit.Rows(` `)) tk.MustPointGet(`select * from t tmp where a = " ";`).Check(testkit.Rows()) - tk.MustPointGet(`select * from t tmp where a = " ";`).Check(testkit.Rows()) + tk.MustTableDual(`select * from t tmp where a = " ";`).Check(testkit.Rows()) // Test both wildcard and column name exist in select field list tk.MustExec(`set @@sql_mode="";`) @@ -188,9 +188,9 @@ func (s *testPointGetSuite) TestPointGetAliasTableCharPK(c *C) { tk.MustPointGet(`select tmp.* from t tmp where a = "aa";`).Check(testkit.Rows(`aa bb`)) tk.MustPointGet(`select tmp.a, tmp.b from t tmp where a = "aa";`).Check(testkit.Rows(`aa bb`)) tk.MustPointGet(`select tmp.*, tmp.a, tmp.b from t tmp where a = "aa";`).Check(testkit.Rows(`aa bb aa bb`)) - tk.MustPointGet(`select tmp.* from t tmp where a = "aab";`).Check(testkit.Rows()) - tk.MustPointGet(`select tmp.a, tmp.b from t tmp where a = "aab";`).Check(testkit.Rows()) - tk.MustPointGet(`select tmp.*, tmp.a, tmp.b from t tmp where a = "aab";`).Check(testkit.Rows()) + tk.MustTableDual(`select tmp.* from t tmp where a = "aab";`).Check(testkit.Rows()) + tk.MustTableDual(`select tmp.a, tmp.b from t tmp where a = "aab";`).Check(testkit.Rows()) + tk.MustTableDual(`select tmp.*, tmp.a, tmp.b from t tmp where a = "aab";`).Check(testkit.Rows()) // Test using table alias in where clause tk.MustPointGet(`select * from t tmp where tmp.a = "aa";`).Check(testkit.Rows(`aa bb`)) @@ -261,7 +261,7 @@ func (s *testPointGetSuite) TestPointGetVarcharPK(c *C) { tk.MustExec(`set @@sql_mode="";`) tk.MustPointGet(`select * from t where a = "aa";`).Check(testkit.Rows(`aa bb`)) - tk.MustPointGet(`select * from t where a = "aab";`).Check(testkit.Rows()) + tk.MustTableDual(`select * from t where a = "aab";`).Check(testkit.Rows()) tk.MustExec(`truncate table t;`) tk.MustExec(`insert into t values("a ", "b ");`) @@ -269,7 +269,7 @@ func (s *testPointGetSuite) TestPointGetVarcharPK(c *C) { tk.MustExec(`set @@sql_mode="";`) tk.MustPointGet(`select * from t where a = "a";`).Check(testkit.Rows()) tk.MustPointGet(`select * from t where a = "a ";`).Check(testkit.Rows(`a b `)) - tk.MustPointGet(`select * from t where a = "a ";`).Check(testkit.Rows()) + tk.MustTableDual(`select * from t where a = "a ";`).Check(testkit.Rows()) // // Test VARCHAR BINARY. tk.MustExec(`drop table if exists t;`) @@ -280,10 +280,10 @@ func (s *testPointGetSuite) TestPointGetVarcharPK(c *C) { tk.MustExec(`set @@sql_mode="";`) tk.MustPointGet(`select * from t where a = "a";`).Check(testkit.Rows()) tk.MustPointGet(`select * from t where a = "a ";`).Check(testkit.Rows(`a b `)) - tk.MustPointGet(`select * from t where a = "a ";`).Check(testkit.Rows()) + tk.MustTableDual(`select * from t where a = "a ";`).Check(testkit.Rows()) tk.MustPointGet(`select * from t where a = " ";`).Check(testkit.Rows()) tk.MustPointGet(`select * from t where a = " ";`).Check(testkit.Rows(` `)) - tk.MustPointGet(`select * from t where a = " ";`).Check(testkit.Rows()) + tk.MustTableDual(`select * from t where a = " ";`).Check(testkit.Rows()) } @@ -364,6 +364,20 @@ func (s *testPointGetSuite) TestIndexLookupBinary(c *C) { } +func (s *testPointGetSuite) TestOverflowOrTruncated(c *C) { + tk := testkit.NewTestKitWithInit(c, s.store) + tk.MustExec("create table t6 (id bigint, a bigint, primary key(id), unique key(a));") + tk.MustExec("insert into t6 values(9223372036854775807, 9223372036854775807);") + tk.MustExec("insert into t6 values(1, 1);") + var nilVal []string + // for unique key + tk.MustQuery("select * from t6 where a = 9223372036854775808").Check(testkit.Rows(nilVal...)) + tk.MustQuery("select * from t6 where a = '1.123'").Check(testkit.Rows(nilVal...)) + // for primary key + tk.MustQuery("select * from t6 where id = 9223372036854775808").Check(testkit.Rows(nilVal...)) + tk.MustQuery("select * from t6 where id = '1.123'").Check(testkit.Rows(nilVal...)) +} + func (s *testPointGetSuite) TestIssue10448(c *C) { tk := testkit.NewTestKit(c, s.store) tk.MustExec("use test") diff --git a/planner/core/point_get_plan.go b/planner/core/point_get_plan.go index 94d5fca0997ea..cd2228a962a98 100644 --- a/planner/core/point_get_plan.go +++ b/planner/core/point_get_plan.go @@ -19,6 +19,7 @@ import ( "github.com/pingcap/errors" "github.com/pingcap/parser/ast" + "github.com/pingcap/parser/charset" "github.com/pingcap/parser/model" "github.com/pingcap/parser/mysql" "github.com/pingcap/parser/opcode" @@ -232,41 +233,31 @@ func tryPointGetPlan(ctx sessionctx.Context, selStmt *ast.SelectStmt) *PointGetP return nil } } + schema := buildSchemaFromFields(ctx, tblName.Schema, tbl, tblAlias, selStmt.Fields.Fields) + if schema == nil { + return nil + } + dbName := tblName.Schema.L + if dbName == "" { + dbName = ctx.GetSessionVars().CurrentDB + } + pairs := make([]nameValuePair, 0, 4) - pairs = getNameValuePairs(pairs, tblAlias, selStmt.Where) - if pairs == nil { + pairs, isTableDual := getNameValuePairs(ctx.GetSessionVars().StmtCtx, tbl, tblAlias, pairs, selStmt.Where) + if pairs == nil && !isTableDual { return nil } + handlePair, fieldType := findPKHandle(tbl, pairs) if handlePair.value.Kind() != types.KindNull && len(pairs) == 1 { - schema := buildSchemaFromFields(ctx, tblName.Schema, tbl, tblAlias, selStmt.Fields.Fields) - if schema == nil { - return nil - } - dbName := tblName.Schema.L - if dbName == "" { - dbName = ctx.GetSessionVars().CurrentDB - } - p := newPointGetPlan(ctx, dbName, schema, tbl) - intDatum, err := handlePair.value.ConvertTo(ctx.GetSessionVars().StmtCtx, fieldType) - if err != nil { - if terror.ErrorEqual(types.ErrOverflow, err) { - p.IsTableDual = true - return p - } - // some scenarios cast to int with error, but we may use this value in point get - if !terror.ErrorEqual(types.ErrTruncatedWrongVal, err) { - return nil - } - } - cmp, err := intDatum.CompareDatum(ctx.GetSessionVars().StmtCtx, &handlePair.value) - if err != nil { - return nil - } else if cmp != 0 { + if isTableDual { + p := newPointGetPlan(ctx, tblName.Schema.O, schema, tbl) p.IsTableDual = true return p } - p.Handle = intDatum.GetInt64() + + p := newPointGetPlan(ctx, dbName, schema, tbl) + p.Handle = handlePair.value.GetInt64() p.UnsignedHandle = mysql.HasUnsignedFlag(fieldType.Flag) p.HandleParam = handlePair.param return p @@ -279,18 +270,16 @@ func tryPointGetPlan(ctx sessionctx.Context, selStmt *ast.SelectStmt) *PointGetP if idxInfo.State != model.StatePublic { continue } + if isTableDual { + p := newPointGetPlan(ctx, tblName.Schema.O, schema, tbl) + p.IsTableDual = true + return p + } + idxValues, idxValueParams := getIndexValues(idxInfo, pairs) if idxValues == nil { continue } - schema := buildSchemaFromFields(ctx, tblName.Schema, tbl, tblAlias, selStmt.Fields.Fields) - if schema == nil { - return nil - } - dbName := tblName.Schema.L - if dbName == "" { - dbName = ctx.GetSessionVars().CurrentDB - } p := newPointGetPlan(ctx, dbName, schema, tbl) p.IndexInfo = idxInfo p.IndexValues = idxValues @@ -402,21 +391,22 @@ func getSingleTableNameAndAlias(tableRefs *ast.TableRefsClause) (tblName *ast.Ta } // getNameValuePairs extracts `column = constant/paramMarker` conditions from expr as name value pairs. -func getNameValuePairs(nvPairs []nameValuePair, tblName model.CIStr, expr ast.ExprNode) []nameValuePair { +func getNameValuePairs(stmtCtx *stmtctx.StatementContext, tbl *model.TableInfo, tblName model.CIStr, nvPairs []nameValuePair, expr ast.ExprNode) ( + pairs []nameValuePair, isTableDual bool) { binOp, ok := expr.(*ast.BinaryOperationExpr) if !ok { - return nil + return nil, false } if binOp.Op == opcode.LogicAnd { - nvPairs = getNameValuePairs(nvPairs, tblName, binOp.L) - if nvPairs == nil { - return nil + nvPairs, isTableDual = getNameValuePairs(stmtCtx, tbl, tblName, nvPairs, binOp.L) + if nvPairs == nil || isTableDual { + return nil, isTableDual } - nvPairs = getNameValuePairs(nvPairs, tblName, binOp.R) - if nvPairs == nil { - return nil + nvPairs, isTableDual = getNameValuePairs(stmtCtx, tbl, tblName, nvPairs, binOp.R) + if nvPairs == nil || isTableDual { + return nil, isTableDual } - return nvPairs + return nvPairs, isTableDual } else if binOp.Op == opcode.EQ { var d types.Datum var colName *ast.ColumnNameExpr @@ -439,17 +429,44 @@ func getNameValuePairs(nvPairs []nameValuePair, tblName model.CIStr, expr ast.Ex param = x } } else { - return nil + return nil, false } if d.IsNull() { - return nil + return nil, false + } + // Views' columns have no FieldType. + if tbl.IsView() { + return nil, false } if colName.Name.Table.L != "" && colName.Name.Table.L != tblName.L { - return nil + return nil, false + } + col := model.FindColumnInfo(tbl.Cols(), colName.Name.Name.L) + if col == nil || // Handling the case when the column is _tidb_rowid. + (col.Tp == mysql.TypeString && col.Collate == charset.CollationBin) { // This type we needn't to pad `\0` in here. + return append(nvPairs, nameValuePair{colName: colName.Name.Name.L, value: d, param: param}), false } - return append(nvPairs, nameValuePair{colName: colName.Name.Name.L, value: d, param: param}) + dVal, err := d.ConvertTo(stmtCtx, &col.FieldType) + if err != nil { + if terror.ErrorEqual(types.ErrOverflow, err) { + return append(nvPairs, nameValuePair{colName: colName.Name.Name.L, value: d, param: param}), true + } + // Some scenarios cast to int with error, but we may use this value in point get. + if !terror.ErrorEqual(types.ErrTruncatedWrongVal, err) { + return nil, false + } + } + // The converted result must be same as original datum. + cmp, err := d.CompareDatum(stmtCtx, &dVal) + if err != nil { + return nil, false + } else if cmp != 0 { + return append(nvPairs, nameValuePair{colName: colName.Name.Name.L, value: dVal, param: param}), true + } + + return append(nvPairs, nameValuePair{colName: colName.Name.Name.L, value: dVal, param: param}), false } - return nil + return nil, false } func findPKHandle(tblInfo *model.TableInfo, pairs []nameValuePair) (handlePair nameValuePair, fieldType *types.FieldType) { diff --git a/util/testkit/testkit.go b/util/testkit/testkit.go index a02463450180b..69af93c50d7e4 100644 --- a/util/testkit/testkit.go +++ b/util/testkit/testkit.go @@ -220,7 +220,7 @@ func (tk *TestKit) MustTableDual(sql string, args ...interface{}) *Result { func (tk *TestKit) MustPointGet(sql string, args ...interface{}) *Result { rs := tk.MustQuery("explain "+sql, args...) tk.c.Assert(len(rs.rows), check.Equals, 1) - tk.c.Assert(strings.Contains(rs.rows[0][0], "Point_Get"), check.IsTrue) + tk.c.Assert(strings.Contains(rs.rows[0][0], "Point_Get"), check.IsTrue, check.Commentf("plan %v", rs.rows[0][0])) return tk.MustQuery(sql, args...) }