Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

*: fix the problem that PointGet returns wrong results in the case of overflow (#14776) #16753

Merged
merged 3 commits into from
Apr 30, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 27 additions & 13 deletions executor/point_get_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -137,10 +137,10 @@ func (s *testPointGetSuite) TestPointGetCharPK(c *C) {
tk.MustExec(`set @@sql_mode="";`)
tk.MustPointGet(`select * from t where a = "a";`).Check(testkit.Rows(`a b`))
tk.MustPointGet(`select * from t where a = "a ";`).Check(testkit.Rows())
tk.MustPointGet(`select * from t where a = "a ";`).Check(testkit.Rows())
tk.MustTableDual(`select * from t where a = "a ";`).Check(testkit.Rows())
tk.MustPointGet(`select * from t where a = "";`).Check(testkit.Rows(` `))
tk.MustPointGet(`select * from t where a = " ";`).Check(testkit.Rows())
tk.MustPointGet(`select * from t where a = " ";`).Check(testkit.Rows())
tk.MustTableDual(`select * from t where a = " ";`).Check(testkit.Rows())

}

Expand All @@ -153,15 +153,15 @@ func (s *testPointGetSuite) TestPointGetAliasTableCharPK(c *C) {

tk.MustExec(`set @@sql_mode="";`)
tk.MustPointGet(`select * from t tmp where a = "aa";`).Check(testkit.Rows(`aa bb`))
tk.MustPointGet(`select * from t tmp where a = "aab";`).Check(testkit.Rows())
tk.MustTableDual(`select * from t tmp where a = "aab";`).Check(testkit.Rows())

tk.MustExec(`truncate table t;`)
tk.MustExec(`insert into t values("a ", "b ");`)

tk.MustExec(`set @@sql_mode="";`)
tk.MustPointGet(`select * from t tmp where a = "a";`).Check(testkit.Rows(`a b`))
tk.MustPointGet(`select * from t tmp where a = "a ";`).Check(testkit.Rows())
tk.MustPointGet(`select * from t tmp where a = "a ";`).Check(testkit.Rows())
tk.MustTableDual(`select * from t tmp where a = "a ";`).Check(testkit.Rows())

// Test CHAR BINARY.
tk.MustExec(`drop table if exists t;`)
Expand All @@ -172,10 +172,10 @@ func (s *testPointGetSuite) TestPointGetAliasTableCharPK(c *C) {
tk.MustExec(`set @@sql_mode="";`)
tk.MustPointGet(`select * from t tmp where a = "a";`).Check(testkit.Rows(`a b`))
tk.MustPointGet(`select * from t tmp where a = "a ";`).Check(testkit.Rows())
tk.MustPointGet(`select * from t tmp where a = "a ";`).Check(testkit.Rows())
tk.MustTableDual(`select * from t tmp where a = "a ";`).Check(testkit.Rows())
tk.MustPointGet(`select * from t tmp where a = "";`).Check(testkit.Rows(` `))
tk.MustPointGet(`select * from t tmp where a = " ";`).Check(testkit.Rows())
tk.MustPointGet(`select * from t tmp where a = " ";`).Check(testkit.Rows())
tk.MustTableDual(`select * from t tmp where a = " ";`).Check(testkit.Rows())

// Test both wildcard and column name exist in select field list
tk.MustExec(`set @@sql_mode="";`)
Expand All @@ -188,9 +188,9 @@ func (s *testPointGetSuite) TestPointGetAliasTableCharPK(c *C) {
tk.MustPointGet(`select tmp.* from t tmp where a = "aa";`).Check(testkit.Rows(`aa bb`))
tk.MustPointGet(`select tmp.a, tmp.b from t tmp where a = "aa";`).Check(testkit.Rows(`aa bb`))
tk.MustPointGet(`select tmp.*, tmp.a, tmp.b from t tmp where a = "aa";`).Check(testkit.Rows(`aa bb aa bb`))
tk.MustPointGet(`select tmp.* from t tmp where a = "aab";`).Check(testkit.Rows())
tk.MustPointGet(`select tmp.a, tmp.b from t tmp where a = "aab";`).Check(testkit.Rows())
tk.MustPointGet(`select tmp.*, tmp.a, tmp.b from t tmp where a = "aab";`).Check(testkit.Rows())
tk.MustTableDual(`select tmp.* from t tmp where a = "aab";`).Check(testkit.Rows())
tk.MustTableDual(`select tmp.a, tmp.b from t tmp where a = "aab";`).Check(testkit.Rows())
tk.MustTableDual(`select tmp.*, tmp.a, tmp.b from t tmp where a = "aab";`).Check(testkit.Rows())

// Test using table alias in where clause
tk.MustPointGet(`select * from t tmp where tmp.a = "aa";`).Check(testkit.Rows(`aa bb`))
Expand Down Expand Up @@ -261,15 +261,15 @@ func (s *testPointGetSuite) TestPointGetVarcharPK(c *C) {

tk.MustExec(`set @@sql_mode="";`)
tk.MustPointGet(`select * from t where a = "aa";`).Check(testkit.Rows(`aa bb`))
tk.MustPointGet(`select * from t where a = "aab";`).Check(testkit.Rows())
tk.MustTableDual(`select * from t where a = "aab";`).Check(testkit.Rows())

tk.MustExec(`truncate table t;`)
tk.MustExec(`insert into t values("a ", "b ");`)

tk.MustExec(`set @@sql_mode="";`)
tk.MustPointGet(`select * from t where a = "a";`).Check(testkit.Rows())
tk.MustPointGet(`select * from t where a = "a ";`).Check(testkit.Rows(`a b `))
tk.MustPointGet(`select * from t where a = "a ";`).Check(testkit.Rows())
tk.MustTableDual(`select * from t where a = "a ";`).Check(testkit.Rows())

// // Test VARCHAR BINARY.
tk.MustExec(`drop table if exists t;`)
Expand All @@ -280,10 +280,10 @@ func (s *testPointGetSuite) TestPointGetVarcharPK(c *C) {
tk.MustExec(`set @@sql_mode="";`)
tk.MustPointGet(`select * from t where a = "a";`).Check(testkit.Rows())
tk.MustPointGet(`select * from t where a = "a ";`).Check(testkit.Rows(`a b `))
tk.MustPointGet(`select * from t where a = "a ";`).Check(testkit.Rows())
tk.MustTableDual(`select * from t where a = "a ";`).Check(testkit.Rows())
tk.MustPointGet(`select * from t where a = " ";`).Check(testkit.Rows())
tk.MustPointGet(`select * from t where a = " ";`).Check(testkit.Rows(` `))
tk.MustPointGet(`select * from t where a = " ";`).Check(testkit.Rows())
tk.MustTableDual(`select * from t where a = " ";`).Check(testkit.Rows())

}

Expand Down Expand Up @@ -364,6 +364,20 @@ func (s *testPointGetSuite) TestIndexLookupBinary(c *C) {

}

func (s *testPointGetSuite) TestOverflowOrTruncated(c *C) {
tk := testkit.NewTestKitWithInit(c, s.store)
tk.MustExec("create table t6 (id bigint, a bigint, primary key(id), unique key(a));")
tk.MustExec("insert into t6 values(9223372036854775807, 9223372036854775807);")
tk.MustExec("insert into t6 values(1, 1);")
var nilVal []string
// for unique key
tk.MustQuery("select * from t6 where a = 9223372036854775808").Check(testkit.Rows(nilVal...))
tk.MustQuery("select * from t6 where a = '1.123'").Check(testkit.Rows(nilVal...))
// for primary key
tk.MustQuery("select * from t6 where id = 9223372036854775808").Check(testkit.Rows(nilVal...))
tk.MustQuery("select * from t6 where id = '1.123'").Check(testkit.Rows(nilVal...))
}

func (s *testPointGetSuite) TestIssue10448(c *C) {
tk := testkit.NewTestKit(c, s.store)
tk.MustExec("use test")
Expand Down
115 changes: 66 additions & 49 deletions planner/core/point_get_plan.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ import (

"github.com/pingcap/errors"
"github.com/pingcap/parser/ast"
"github.com/pingcap/parser/charset"
"github.com/pingcap/parser/model"
"github.com/pingcap/parser/mysql"
"github.com/pingcap/parser/opcode"
Expand Down Expand Up @@ -232,41 +233,31 @@ func tryPointGetPlan(ctx sessionctx.Context, selStmt *ast.SelectStmt) *PointGetP
return nil
}
}
schema := buildSchemaFromFields(ctx, tblName.Schema, tbl, tblAlias, selStmt.Fields.Fields)
if schema == nil {
return nil
}
dbName := tblName.Schema.L
if dbName == "" {
dbName = ctx.GetSessionVars().CurrentDB
}

pairs := make([]nameValuePair, 0, 4)
pairs = getNameValuePairs(pairs, tblAlias, selStmt.Where)
if pairs == nil {
pairs, isTableDual := getNameValuePairs(ctx.GetSessionVars().StmtCtx, tbl, tblAlias, pairs, selStmt.Where)
if pairs == nil && !isTableDual {
return nil
}

handlePair, fieldType := findPKHandle(tbl, pairs)
if handlePair.value.Kind() != types.KindNull && len(pairs) == 1 {
schema := buildSchemaFromFields(ctx, tblName.Schema, tbl, tblAlias, selStmt.Fields.Fields)
if schema == nil {
return nil
}
dbName := tblName.Schema.L
if dbName == "" {
dbName = ctx.GetSessionVars().CurrentDB
}
p := newPointGetPlan(ctx, dbName, schema, tbl)
intDatum, err := handlePair.value.ConvertTo(ctx.GetSessionVars().StmtCtx, fieldType)
if err != nil {
if terror.ErrorEqual(types.ErrOverflow, err) {
p.IsTableDual = true
return p
}
// some scenarios cast to int with error, but we may use this value in point get
if !terror.ErrorEqual(types.ErrTruncatedWrongVal, err) {
return nil
}
}
cmp, err := intDatum.CompareDatum(ctx.GetSessionVars().StmtCtx, &handlePair.value)
if err != nil {
return nil
} else if cmp != 0 {
if isTableDual {
p := newPointGetPlan(ctx, tblName.Schema.O, schema, tbl)
p.IsTableDual = true
return p
}
p.Handle = intDatum.GetInt64()

p := newPointGetPlan(ctx, dbName, schema, tbl)
p.Handle = handlePair.value.GetInt64()
p.UnsignedHandle = mysql.HasUnsignedFlag(fieldType.Flag)
p.HandleParam = handlePair.param
return p
Expand All @@ -279,18 +270,16 @@ func tryPointGetPlan(ctx sessionctx.Context, selStmt *ast.SelectStmt) *PointGetP
if idxInfo.State != model.StatePublic {
continue
}
if isTableDual {
p := newPointGetPlan(ctx, tblName.Schema.O, schema, tbl)
p.IsTableDual = true
return p
}

idxValues, idxValueParams := getIndexValues(idxInfo, pairs)
if idxValues == nil {
continue
}
schema := buildSchemaFromFields(ctx, tblName.Schema, tbl, tblAlias, selStmt.Fields.Fields)
if schema == nil {
return nil
}
dbName := tblName.Schema.L
if dbName == "" {
dbName = ctx.GetSessionVars().CurrentDB
}
p := newPointGetPlan(ctx, dbName, schema, tbl)
p.IndexInfo = idxInfo
p.IndexValues = idxValues
Expand Down Expand Up @@ -402,21 +391,22 @@ func getSingleTableNameAndAlias(tableRefs *ast.TableRefsClause) (tblName *ast.Ta
}

// getNameValuePairs extracts `column = constant/paramMarker` conditions from expr as name value pairs.
func getNameValuePairs(nvPairs []nameValuePair, tblName model.CIStr, expr ast.ExprNode) []nameValuePair {
func getNameValuePairs(stmtCtx *stmtctx.StatementContext, tbl *model.TableInfo, tblName model.CIStr, nvPairs []nameValuePair, expr ast.ExprNode) (
pairs []nameValuePair, isTableDual bool) {
binOp, ok := expr.(*ast.BinaryOperationExpr)
if !ok {
return nil
return nil, false
}
if binOp.Op == opcode.LogicAnd {
nvPairs = getNameValuePairs(nvPairs, tblName, binOp.L)
if nvPairs == nil {
return nil
nvPairs, isTableDual = getNameValuePairs(stmtCtx, tbl, tblName, nvPairs, binOp.L)
if nvPairs == nil || isTableDual {
return nil, isTableDual
}
nvPairs = getNameValuePairs(nvPairs, tblName, binOp.R)
if nvPairs == nil {
return nil
nvPairs, isTableDual = getNameValuePairs(stmtCtx, tbl, tblName, nvPairs, binOp.R)
if nvPairs == nil || isTableDual {
return nil, isTableDual
}
return nvPairs
return nvPairs, isTableDual
} else if binOp.Op == opcode.EQ {
var d types.Datum
var colName *ast.ColumnNameExpr
Expand All @@ -439,17 +429,44 @@ func getNameValuePairs(nvPairs []nameValuePair, tblName model.CIStr, expr ast.Ex
param = x
}
} else {
return nil
return nil, false
}
if d.IsNull() {
return nil
return nil, false
}
// Views' columns have no FieldType.
if tbl.IsView() {
return nil, false
}
if colName.Name.Table.L != "" && colName.Name.Table.L != tblName.L {
return nil
return nil, false
}
col := model.FindColumnInfo(tbl.Cols(), colName.Name.Name.L)
if col == nil || // Handling the case when the column is _tidb_rowid.
(col.Tp == mysql.TypeString && col.Collate == charset.CollationBin) { // This type we needn't to pad `\0` in here.
return append(nvPairs, nameValuePair{colName: colName.Name.Name.L, value: d, param: param}), false
}
return append(nvPairs, nameValuePair{colName: colName.Name.Name.L, value: d, param: param})
dVal, err := d.ConvertTo(stmtCtx, &col.FieldType)
if err != nil {
if terror.ErrorEqual(types.ErrOverflow, err) {
return append(nvPairs, nameValuePair{colName: colName.Name.Name.L, value: d, param: param}), true
}
// Some scenarios cast to int with error, but we may use this value in point get.
if !terror.ErrorEqual(types.ErrTruncatedWrongVal, err) {
return nil, false
}
}
// The converted result must be same as original datum.
cmp, err := d.CompareDatum(stmtCtx, &dVal)
if err != nil {
return nil, false
} else if cmp != 0 {
return append(nvPairs, nameValuePair{colName: colName.Name.Name.L, value: dVal, param: param}), true
}

return append(nvPairs, nameValuePair{colName: colName.Name.Name.L, value: dVal, param: param}), false
}
return nil
return nil, false
}

func findPKHandle(tblInfo *model.TableInfo, pairs []nameValuePair) (handlePair nameValuePair, fieldType *types.FieldType) {
Expand Down
2 changes: 1 addition & 1 deletion util/testkit/testkit.go
Original file line number Diff line number Diff line change
Expand Up @@ -220,7 +220,7 @@ func (tk *TestKit) MustTableDual(sql string, args ...interface{}) *Result {
func (tk *TestKit) MustPointGet(sql string, args ...interface{}) *Result {
rs := tk.MustQuery("explain "+sql, args...)
tk.c.Assert(len(rs.rows), check.Equals, 1)
tk.c.Assert(strings.Contains(rs.rows[0][0], "Point_Get"), check.IsTrue)
tk.c.Assert(strings.Contains(rs.rows[0][0], "Point_Get"), check.IsTrue, check.Commentf("plan %v", rs.rows[0][0]))
return tk.MustQuery(sql, args...)
}

Expand Down