Skip to content

Commit

Permalink
*: fix the problem that PointGet returns wrong results in the case of…
Browse files Browse the repository at this point in the history
… overflow (pingcap#14776)
  • Loading branch information
zimulala authored and sre-bot committed Apr 23, 2020
1 parent 586915e commit 49a31f1
Show file tree
Hide file tree
Showing 3 changed files with 97 additions and 66 deletions.
40 changes: 27 additions & 13 deletions executor/point_get_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -141,10 +141,10 @@ func (s *testPointGetSuite) TestPointGetCharPK(c *C) {
tk.MustExec(`set @@sql_mode="";`)
tk.MustPointGet(`select * from t where a = "a";`).Check(testkit.Rows(`a b`))
tk.MustPointGet(`select * from t where a = "a ";`).Check(testkit.Rows())
tk.MustPointGet(`select * from t where a = "a ";`).Check(testkit.Rows())
tk.MustTableDual(`select * from t where a = "a ";`).Check(testkit.Rows())
tk.MustPointGet(`select * from t where a = "";`).Check(testkit.Rows(` `))
tk.MustPointGet(`select * from t where a = " ";`).Check(testkit.Rows())
tk.MustPointGet(`select * from t where a = " ";`).Check(testkit.Rows())
tk.MustTableDual(`select * from t where a = " ";`).Check(testkit.Rows())

}

Expand All @@ -157,15 +157,15 @@ func (s *testPointGetSuite) TestPointGetAliasTableCharPK(c *C) {

tk.MustExec(`set @@sql_mode="";`)
tk.MustPointGet(`select * from t tmp where a = "aa";`).Check(testkit.Rows(`aa bb`))
tk.MustPointGet(`select * from t tmp where a = "aab";`).Check(testkit.Rows())
tk.MustTableDual(`select * from t tmp where a = "aab";`).Check(testkit.Rows())

tk.MustExec(`truncate table t;`)
tk.MustExec(`insert into t values("a ", "b ");`)

tk.MustExec(`set @@sql_mode="";`)
tk.MustPointGet(`select * from t tmp where a = "a";`).Check(testkit.Rows(`a b`))
tk.MustPointGet(`select * from t tmp where a = "a ";`).Check(testkit.Rows())
tk.MustPointGet(`select * from t tmp where a = "a ";`).Check(testkit.Rows())
tk.MustTableDual(`select * from t tmp where a = "a ";`).Check(testkit.Rows())

// Test CHAR BINARY.
tk.MustExec(`drop table if exists t;`)
Expand All @@ -176,10 +176,10 @@ func (s *testPointGetSuite) TestPointGetAliasTableCharPK(c *C) {
tk.MustExec(`set @@sql_mode="";`)
tk.MustPointGet(`select * from t tmp where a = "a";`).Check(testkit.Rows(`a b`))
tk.MustPointGet(`select * from t tmp where a = "a ";`).Check(testkit.Rows())
tk.MustPointGet(`select * from t tmp where a = "a ";`).Check(testkit.Rows())
tk.MustTableDual(`select * from t tmp where a = "a ";`).Check(testkit.Rows())
tk.MustPointGet(`select * from t tmp where a = "";`).Check(testkit.Rows(` `))
tk.MustPointGet(`select * from t tmp where a = " ";`).Check(testkit.Rows())
tk.MustPointGet(`select * from t tmp where a = " ";`).Check(testkit.Rows())
tk.MustTableDual(`select * from t tmp where a = " ";`).Check(testkit.Rows())

// Test both wildcard and column name exist in select field list
tk.MustExec(`set @@sql_mode="";`)
Expand All @@ -192,9 +192,9 @@ func (s *testPointGetSuite) TestPointGetAliasTableCharPK(c *C) {
tk.MustPointGet(`select tmp.* from t tmp where a = "aa";`).Check(testkit.Rows(`aa bb`))
tk.MustPointGet(`select tmp.a, tmp.b from t tmp where a = "aa";`).Check(testkit.Rows(`aa bb`))
tk.MustPointGet(`select tmp.*, tmp.a, tmp.b from t tmp where a = "aa";`).Check(testkit.Rows(`aa bb aa bb`))
tk.MustPointGet(`select tmp.* from t tmp where a = "aab";`).Check(testkit.Rows())
tk.MustPointGet(`select tmp.a, tmp.b from t tmp where a = "aab";`).Check(testkit.Rows())
tk.MustPointGet(`select tmp.*, tmp.a, tmp.b from t tmp where a = "aab";`).Check(testkit.Rows())
tk.MustTableDual(`select tmp.* from t tmp where a = "aab";`).Check(testkit.Rows())
tk.MustTableDual(`select tmp.a, tmp.b from t tmp where a = "aab";`).Check(testkit.Rows())
tk.MustTableDual(`select tmp.*, tmp.a, tmp.b from t tmp where a = "aab";`).Check(testkit.Rows())

// Test using table alias in where clause
tk.MustPointGet(`select * from t tmp where tmp.a = "aa";`).Check(testkit.Rows(`aa bb`))
Expand Down Expand Up @@ -265,15 +265,15 @@ func (s *testPointGetSuite) TestPointGetVarcharPK(c *C) {

tk.MustExec(`set @@sql_mode="";`)
tk.MustPointGet(`select * from t where a = "aa";`).Check(testkit.Rows(`aa bb`))
tk.MustPointGet(`select * from t where a = "aab";`).Check(testkit.Rows())
tk.MustTableDual(`select * from t where a = "aab";`).Check(testkit.Rows())

tk.MustExec(`truncate table t;`)
tk.MustExec(`insert into t values("a ", "b ");`)

tk.MustExec(`set @@sql_mode="";`)
tk.MustPointGet(`select * from t where a = "a";`).Check(testkit.Rows())
tk.MustPointGet(`select * from t where a = "a ";`).Check(testkit.Rows(`a b `))
tk.MustPointGet(`select * from t where a = "a ";`).Check(testkit.Rows())
tk.MustTableDual(`select * from t where a = "a ";`).Check(testkit.Rows())

// // Test VARCHAR BINARY.
tk.MustExec(`drop table if exists t;`)
Expand All @@ -284,10 +284,10 @@ func (s *testPointGetSuite) TestPointGetVarcharPK(c *C) {
tk.MustExec(`set @@sql_mode="";`)
tk.MustPointGet(`select * from t where a = "a";`).Check(testkit.Rows())
tk.MustPointGet(`select * from t where a = "a ";`).Check(testkit.Rows(`a b `))
tk.MustPointGet(`select * from t where a = "a ";`).Check(testkit.Rows())
tk.MustTableDual(`select * from t where a = "a ";`).Check(testkit.Rows())
tk.MustPointGet(`select * from t where a = " ";`).Check(testkit.Rows())
tk.MustPointGet(`select * from t where a = " ";`).Check(testkit.Rows(` `))
tk.MustPointGet(`select * from t where a = " ";`).Check(testkit.Rows())
tk.MustTableDual(`select * from t where a = " ";`).Check(testkit.Rows())

}

Expand Down Expand Up @@ -368,6 +368,20 @@ func (s *testPointGetSuite) TestIndexLookupBinary(c *C) {

}

func (s *testPointGetSuite) TestOverflowOrTruncated(c *C) {
tk := testkit.NewTestKitWithInit(c, s.store)
tk.MustExec("create table t6 (id bigint, a bigint, primary key(id), unique key(a));")
tk.MustExec("insert into t6 values(9223372036854775807, 9223372036854775807);")
tk.MustExec("insert into t6 values(1, 1);")
var nilVal []string
// for unique key
tk.MustQuery("select * from t6 where a = 9223372036854775808").Check(testkit.Rows(nilVal...))
tk.MustQuery("select * from t6 where a = '1.123'").Check(testkit.Rows(nilVal...))
// for primary key
tk.MustQuery("select * from t6 where id = 9223372036854775808").Check(testkit.Rows(nilVal...))
tk.MustQuery("select * from t6 where id = '1.123'").Check(testkit.Rows(nilVal...))
}

func (s *testPointGetSuite) TestIssue10448(c *C) {
tk := testkit.NewTestKit(c, s.store)
tk.MustExec("use test")
Expand Down
121 changes: 69 additions & 52 deletions planner/core/point_get_plan.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ import (

"github.com/pingcap/errors"
"github.com/pingcap/parser/ast"
"github.com/pingcap/parser/charset"
"github.com/pingcap/parser/model"
"github.com/pingcap/parser/mysql"
"github.com/pingcap/parser/opcode"
Expand Down Expand Up @@ -652,6 +653,9 @@ func tryPointGetPlan(ctx sessionctx.Context, selStmt *ast.SelectStmt) *PointGetP
return nil
}
pi := tbl.GetPartitionInfo()
if pi != nil && pi.Type != model.PartitionTypeHash {
return nil
}
for _, col := range tbl.Columns {
// Do not handle generated columns.
if col.IsGenerated() {
Expand All @@ -662,53 +666,40 @@ func tryPointGetPlan(ctx sessionctx.Context, selStmt *ast.SelectStmt) *PointGetP
return nil
}
}
schema, names := buildSchemaFromFields(tblName.Schema, tbl, tblAlias, selStmt.Fields.Fields)
if schema == nil {
return nil
}
dbName := tblName.Schema.L
if dbName == "" {
dbName = ctx.GetSessionVars().CurrentDB
}

pairs := make([]nameValuePair, 0, 4)
pairs = getNameValuePairs(pairs, tblAlias, selStmt.Where)
if pairs == nil {
pairs, isTableDual := getNameValuePairs(ctx.GetSessionVars().StmtCtx, tbl, tblAlias, pairs, selStmt.Where)
if pairs == nil && !isTableDual {
return nil
}

var partitionInfo *model.PartitionDefinition
var pos int
if pi != nil {
if pi.Type != model.PartitionTypeHash {
return nil
}
partitionInfo, pos = getPartitionInfo(ctx, tbl, pairs)
if partitionInfo == nil {
return nil
}
}

handlePair, fieldType := findPKHandle(tbl, pairs)
if handlePair.value.Kind() != types.KindNull && len(pairs) == 1 {
schema, names := buildSchemaFromFields(tblName.Schema, tbl, tblAlias, selStmt.Fields.Fields)
if schema == nil {
return nil
}
dbName := tblName.Schema.L
if dbName == "" {
dbName = ctx.GetSessionVars().CurrentDB
}
p := newPointGetPlan(ctx, dbName, schema, tbl, names)
intDatum, err := handlePair.value.ConvertTo(ctx.GetSessionVars().StmtCtx, fieldType)
if err != nil {
if terror.ErrorEqual(types.ErrOverflow, err) {
p.IsTableDual = true
return p
}
// some scenarios cast to int with error, but we may use this value in point get
if !terror.ErrorEqual(types.ErrTruncatedWrongVal, err) {
return nil
}
}
cmp, err := intDatum.CompareDatum(ctx.GetSessionVars().StmtCtx, &handlePair.value)
if err != nil {
return nil
} else if cmp != 0 {
if isTableDual {
p := newPointGetPlan(ctx, tblName.Schema.O, schema, tbl, names)
p.IsTableDual = true
return p
}
p.Handle = intDatum.GetInt64()

p := newPointGetPlan(ctx, dbName, schema, tbl, names)
p.Handle = handlePair.value.GetInt64()
p.UnsignedHandle = mysql.HasUnsignedFlag(fieldType.Flag)
p.HandleParam = handlePair.param
p.PartitionInfo = partitionInfo
Expand All @@ -722,18 +713,16 @@ func tryPointGetPlan(ctx sessionctx.Context, selStmt *ast.SelectStmt) *PointGetP
if idxInfo.State != model.StatePublic {
continue
}
if isTableDual {
p := newPointGetPlan(ctx, tblName.Schema.O, schema, tbl, names)
p.IsTableDual = true
return p
}

idxValues, idxValueParams := getIndexValues(idxInfo, pairs)
if idxValues == nil {
continue
}
schema, names := buildSchemaFromFields(tblName.Schema, tbl, tblAlias, selStmt.Fields.Fields)
if schema == nil {
return nil
}
dbName := tblName.Schema.L
if dbName == "" {
dbName = ctx.GetSessionVars().CurrentDB
}
p := newPointGetPlan(ctx, dbName, schema, tbl, names)
p.IndexInfo = idxInfo
p.IndexValues = idxValues
Expand Down Expand Up @@ -864,21 +853,22 @@ func getSingleTableNameAndAlias(tableRefs *ast.TableRefsClause) (tblName *ast.Ta
}

// getNameValuePairs extracts `column = constant/paramMarker` conditions from expr as name value pairs.
func getNameValuePairs(nvPairs []nameValuePair, tblName model.CIStr, expr ast.ExprNode) []nameValuePair {
func getNameValuePairs(stmtCtx *stmtctx.StatementContext, tbl *model.TableInfo, tblName model.CIStr, nvPairs []nameValuePair, expr ast.ExprNode) (
pairs []nameValuePair, isTableDual bool) {
binOp, ok := expr.(*ast.BinaryOperationExpr)
if !ok {
return nil
return nil, false
}
if binOp.Op == opcode.LogicAnd {
nvPairs = getNameValuePairs(nvPairs, tblName, binOp.L)
if nvPairs == nil {
return nil
nvPairs, isTableDual = getNameValuePairs(stmtCtx, tbl, tblName, nvPairs, binOp.L)
if nvPairs == nil || isTableDual {
return nil, isTableDual
}
nvPairs = getNameValuePairs(nvPairs, tblName, binOp.R)
if nvPairs == nil {
return nil
nvPairs, isTableDual = getNameValuePairs(stmtCtx, tbl, tblName, nvPairs, binOp.R)
if nvPairs == nil || isTableDual {
return nil, isTableDual
}
return nvPairs
return nvPairs, isTableDual
} else if binOp.Op == opcode.EQ {
var d types.Datum
var colName *ast.ColumnNameExpr
Expand All @@ -901,17 +891,44 @@ func getNameValuePairs(nvPairs []nameValuePair, tblName model.CIStr, expr ast.Ex
param = x
}
} else {
return nil
return nil, false
}
if d.IsNull() {
return nil
return nil, false
}
// Views' columns have no FieldType.
if tbl.IsView() {
return nil, false
}
if colName.Name.Table.L != "" && colName.Name.Table.L != tblName.L {
return nil
return nil, false
}
col := model.FindColumnInfo(tbl.Cols(), colName.Name.Name.L)
if col == nil || // Handling the case when the column is _tidb_rowid.
(col.Tp == mysql.TypeString && col.Collate == charset.CollationBin) { // This type we needn't to pad `\0` in here.
return append(nvPairs, nameValuePair{colName: colName.Name.Name.L, value: d, param: param}), false
}
return append(nvPairs, nameValuePair{colName: colName.Name.Name.L, value: d, param: param})
dVal, err := d.ConvertTo(stmtCtx, &col.FieldType)
if err != nil {
if terror.ErrorEqual(types.ErrOverflow, err) {
return append(nvPairs, nameValuePair{colName: colName.Name.Name.L, value: d, param: param}), true
}
// Some scenarios cast to int with error, but we may use this value in point get.
if !terror.ErrorEqual(types.ErrTruncatedWrongVal, err) {
return nil, false
}
}
// The converted result must be same as original datum.
cmp, err := d.CompareDatum(stmtCtx, &dVal)
if err != nil {
return nil, false
} else if cmp != 0 {
return append(nvPairs, nameValuePair{colName: colName.Name.Name.L, value: dVal, param: param}), true
}

return append(nvPairs, nameValuePair{colName: colName.Name.Name.L, value: dVal, param: param}), false
}
return nil
return nil, false
}

func findPKHandle(tblInfo *model.TableInfo, pairs []nameValuePair) (handlePair nameValuePair, fieldType *types.FieldType) {
Expand Down
2 changes: 1 addition & 1 deletion util/testkit/testkit.go
Original file line number Diff line number Diff line change
Expand Up @@ -224,7 +224,7 @@ func (tk *TestKit) MustTableDual(sql string, args ...interface{}) *Result {
func (tk *TestKit) MustPointGet(sql string, args ...interface{}) *Result {
rs := tk.MustQuery("explain "+sql, args...)
tk.c.Assert(len(rs.rows), check.Equals, 1)
tk.c.Assert(strings.Contains(rs.rows[0][0], "Point_Get"), check.IsTrue)
tk.c.Assert(strings.Contains(rs.rows[0][0], "Point_Get"), check.IsTrue, check.Commentf("plan %v", rs.rows[0][0]))
return tk.MustQuery(sql, args...)
}

Expand Down

0 comments on commit 49a31f1

Please sign in to comment.