Skip to content

Commit

Permalink
cherry pick pingcap#14776 to release-3.0
Browse files Browse the repository at this point in the history
Signed-off-by: sre-bot <[email protected]>
  • Loading branch information
zimulala authored and sre-bot committed Apr 23, 2020
1 parent 445e69b commit 6a1bd9f
Show file tree
Hide file tree
Showing 3 changed files with 117 additions and 31 deletions.
40 changes: 27 additions & 13 deletions executor/point_get_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -137,10 +137,10 @@ func (s *testPointGetSuite) TestPointGetCharPK(c *C) {
tk.MustExec(`set @@sql_mode="";`)
tk.MustPointGet(`select * from t where a = "a";`).Check(testkit.Rows(`a b`))
tk.MustPointGet(`select * from t where a = "a ";`).Check(testkit.Rows())
tk.MustPointGet(`select * from t where a = "a ";`).Check(testkit.Rows())
tk.MustTableDual(`select * from t where a = "a ";`).Check(testkit.Rows())
tk.MustPointGet(`select * from t where a = "";`).Check(testkit.Rows(` `))
tk.MustPointGet(`select * from t where a = " ";`).Check(testkit.Rows())
tk.MustPointGet(`select * from t where a = " ";`).Check(testkit.Rows())
tk.MustTableDual(`select * from t where a = " ";`).Check(testkit.Rows())

}

Expand All @@ -153,15 +153,15 @@ func (s *testPointGetSuite) TestPointGetAliasTableCharPK(c *C) {

tk.MustExec(`set @@sql_mode="";`)
tk.MustPointGet(`select * from t tmp where a = "aa";`).Check(testkit.Rows(`aa bb`))
tk.MustPointGet(`select * from t tmp where a = "aab";`).Check(testkit.Rows())
tk.MustTableDual(`select * from t tmp where a = "aab";`).Check(testkit.Rows())

tk.MustExec(`truncate table t;`)
tk.MustExec(`insert into t values("a ", "b ");`)

tk.MustExec(`set @@sql_mode="";`)
tk.MustPointGet(`select * from t tmp where a = "a";`).Check(testkit.Rows(`a b`))
tk.MustPointGet(`select * from t tmp where a = "a ";`).Check(testkit.Rows())
tk.MustPointGet(`select * from t tmp where a = "a ";`).Check(testkit.Rows())
tk.MustTableDual(`select * from t tmp where a = "a ";`).Check(testkit.Rows())

// Test CHAR BINARY.
tk.MustExec(`drop table if exists t;`)
Expand All @@ -172,10 +172,10 @@ func (s *testPointGetSuite) TestPointGetAliasTableCharPK(c *C) {
tk.MustExec(`set @@sql_mode="";`)
tk.MustPointGet(`select * from t tmp where a = "a";`).Check(testkit.Rows(`a b`))
tk.MustPointGet(`select * from t tmp where a = "a ";`).Check(testkit.Rows())
tk.MustPointGet(`select * from t tmp where a = "a ";`).Check(testkit.Rows())
tk.MustTableDual(`select * from t tmp where a = "a ";`).Check(testkit.Rows())
tk.MustPointGet(`select * from t tmp where a = "";`).Check(testkit.Rows(` `))
tk.MustPointGet(`select * from t tmp where a = " ";`).Check(testkit.Rows())
tk.MustPointGet(`select * from t tmp where a = " ";`).Check(testkit.Rows())
tk.MustTableDual(`select * from t tmp where a = " ";`).Check(testkit.Rows())

// Test both wildcard and column name exist in select field list
tk.MustExec(`set @@sql_mode="";`)
Expand All @@ -188,9 +188,9 @@ func (s *testPointGetSuite) TestPointGetAliasTableCharPK(c *C) {
tk.MustPointGet(`select tmp.* from t tmp where a = "aa";`).Check(testkit.Rows(`aa bb`))
tk.MustPointGet(`select tmp.a, tmp.b from t tmp where a = "aa";`).Check(testkit.Rows(`aa bb`))
tk.MustPointGet(`select tmp.*, tmp.a, tmp.b from t tmp where a = "aa";`).Check(testkit.Rows(`aa bb aa bb`))
tk.MustPointGet(`select tmp.* from t tmp where a = "aab";`).Check(testkit.Rows())
tk.MustPointGet(`select tmp.a, tmp.b from t tmp where a = "aab";`).Check(testkit.Rows())
tk.MustPointGet(`select tmp.*, tmp.a, tmp.b from t tmp where a = "aab";`).Check(testkit.Rows())
tk.MustTableDual(`select tmp.* from t tmp where a = "aab";`).Check(testkit.Rows())
tk.MustTableDual(`select tmp.a, tmp.b from t tmp where a = "aab";`).Check(testkit.Rows())
tk.MustTableDual(`select tmp.*, tmp.a, tmp.b from t tmp where a = "aab";`).Check(testkit.Rows())

// Test using table alias in where clause
tk.MustPointGet(`select * from t tmp where tmp.a = "aa";`).Check(testkit.Rows(`aa bb`))
Expand Down Expand Up @@ -261,15 +261,15 @@ func (s *testPointGetSuite) TestPointGetVarcharPK(c *C) {

tk.MustExec(`set @@sql_mode="";`)
tk.MustPointGet(`select * from t where a = "aa";`).Check(testkit.Rows(`aa bb`))
tk.MustPointGet(`select * from t where a = "aab";`).Check(testkit.Rows())
tk.MustTableDual(`select * from t where a = "aab";`).Check(testkit.Rows())

tk.MustExec(`truncate table t;`)
tk.MustExec(`insert into t values("a ", "b ");`)

tk.MustExec(`set @@sql_mode="";`)
tk.MustPointGet(`select * from t where a = "a";`).Check(testkit.Rows())
tk.MustPointGet(`select * from t where a = "a ";`).Check(testkit.Rows(`a b `))
tk.MustPointGet(`select * from t where a = "a ";`).Check(testkit.Rows())
tk.MustTableDual(`select * from t where a = "a ";`).Check(testkit.Rows())

// // Test VARCHAR BINARY.
tk.MustExec(`drop table if exists t;`)
Expand All @@ -280,10 +280,10 @@ func (s *testPointGetSuite) TestPointGetVarcharPK(c *C) {
tk.MustExec(`set @@sql_mode="";`)
tk.MustPointGet(`select * from t where a = "a";`).Check(testkit.Rows())
tk.MustPointGet(`select * from t where a = "a ";`).Check(testkit.Rows(`a b `))
tk.MustPointGet(`select * from t where a = "a ";`).Check(testkit.Rows())
tk.MustTableDual(`select * from t where a = "a ";`).Check(testkit.Rows())
tk.MustPointGet(`select * from t where a = " ";`).Check(testkit.Rows())
tk.MustPointGet(`select * from t where a = " ";`).Check(testkit.Rows(` `))
tk.MustPointGet(`select * from t where a = " ";`).Check(testkit.Rows())
tk.MustTableDual(`select * from t where a = " ";`).Check(testkit.Rows())

}

Expand Down Expand Up @@ -364,6 +364,20 @@ func (s *testPointGetSuite) TestIndexLookupBinary(c *C) {

}

func (s *testPointGetSuite) TestOverflowOrTruncated(c *C) {
tk := testkit.NewTestKitWithInit(c, s.store)
tk.MustExec("create table t6 (id bigint, a bigint, primary key(id), unique key(a));")
tk.MustExec("insert into t6 values(9223372036854775807, 9223372036854775807);")
tk.MustExec("insert into t6 values(1, 1);")
var nilVal []string
// for unique key
tk.MustQuery("select * from t6 where a = 9223372036854775808").Check(testkit.Rows(nilVal...))
tk.MustQuery("select * from t6 where a = '1.123'").Check(testkit.Rows(nilVal...))
// for primary key
tk.MustQuery("select * from t6 where id = 9223372036854775808").Check(testkit.Rows(nilVal...))
tk.MustQuery("select * from t6 where id = '1.123'").Check(testkit.Rows(nilVal...))
}

func (s *testPointGetSuite) TestIssue10448(c *C) {
tk := testkit.NewTestKit(c, s.store)
tk.MustExec("use test")
Expand Down
106 changes: 89 additions & 17 deletions planner/core/point_get_plan.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ import (

"github.com/pingcap/errors"
"github.com/pingcap/parser/ast"
"github.com/pingcap/parser/charset"
"github.com/pingcap/parser/model"
"github.com/pingcap/parser/mysql"
"github.com/pingcap/parser/opcode"
Expand Down Expand Up @@ -215,11 +216,16 @@ func tryPointGetPlan(ctx sessionctx.Context, selStmt *ast.SelectStmt) *PointGetP
if tbl == nil {
return nil
}
<<<<<<< HEAD
// Do not handle partitioned table.
// Table partition implementation translates LogicalPlan from `DataSource` to
// `Union -> DataSource` in the logical plan optimization pass, since PointGetPlan
// bypass the logical plan optimization, it can't support partitioned table.
if tbl.GetPartitionInfo() != nil {
=======
pi := tbl.GetPartitionInfo()
if pi != nil && pi.Type != model.PartitionTypeHash {
>>>>>>> e521ea9... *: fix the problem that PointGet returns wrong results in the case of overflow (#14776)
return nil
}
for _, col := range tbl.Columns {
Expand All @@ -232,11 +238,21 @@ func tryPointGetPlan(ctx sessionctx.Context, selStmt *ast.SelectStmt) *PointGetP
return nil
}
}
schema, names := buildSchemaFromFields(tblName.Schema, tbl, tblAlias, selStmt.Fields.Fields)
if schema == nil {
return nil
}
dbName := tblName.Schema.L
if dbName == "" {
dbName = ctx.GetSessionVars().CurrentDB
}

pairs := make([]nameValuePair, 0, 4)
pairs = getNameValuePairs(pairs, tblAlias, selStmt.Where)
if pairs == nil {
pairs, isTableDual := getNameValuePairs(ctx.GetSessionVars().StmtCtx, tbl, tblAlias, pairs, selStmt.Where)
if pairs == nil && !isTableDual {
return nil
}
<<<<<<< HEAD
handlePair, fieldType := findPKHandle(tbl, pairs)
if handlePair.value.Kind() != types.KindNull && len(pairs) == 1 {
schema := buildSchemaFromFields(ctx, tblName.Schema, tbl, tblAlias, selStmt.Fields.Fields)
Expand All @@ -263,10 +279,28 @@ func tryPointGetPlan(ctx sessionctx.Context, selStmt *ast.SelectStmt) *PointGetP
if err != nil {
return nil
} else if cmp != 0 {
=======

var partitionInfo *model.PartitionDefinition
var pos int
if pi != nil {
partitionInfo, pos = getPartitionInfo(ctx, tbl, pairs)
if partitionInfo == nil {
return nil
}
}

handlePair, fieldType := findPKHandle(tbl, pairs)
if handlePair.value.Kind() != types.KindNull && len(pairs) == 1 {
if isTableDual {
p := newPointGetPlan(ctx, tblName.Schema.O, schema, tbl, names)
>>>>>>> e521ea9... *: fix the problem that PointGet returns wrong results in the case of overflow (#14776)
p.IsTableDual = true
return p
}
p.Handle = intDatum.GetInt64()

p := newPointGetPlan(ctx, dbName, schema, tbl, names)
p.Handle = handlePair.value.GetInt64()
p.UnsignedHandle = mysql.HasUnsignedFlag(fieldType.Flag)
p.HandleParam = handlePair.param
return p
Expand All @@ -279,10 +313,17 @@ func tryPointGetPlan(ctx sessionctx.Context, selStmt *ast.SelectStmt) *PointGetP
if idxInfo.State != model.StatePublic {
continue
}
if isTableDual {
p := newPointGetPlan(ctx, tblName.Schema.O, schema, tbl, names)
p.IsTableDual = true
return p
}

idxValues, idxValueParams := getIndexValues(idxInfo, pairs)
if idxValues == nil {
continue
}
<<<<<<< HEAD
schema := buildSchemaFromFields(ctx, tblName.Schema, tbl, tblAlias, selStmt.Fields.Fields)
if schema == nil {
return nil
Expand All @@ -292,6 +333,9 @@ func tryPointGetPlan(ctx sessionctx.Context, selStmt *ast.SelectStmt) *PointGetP
dbName = ctx.GetSessionVars().CurrentDB
}
p := newPointGetPlan(ctx, dbName, schema, tbl)
=======
p := newPointGetPlan(ctx, dbName, schema, tbl, names)
>>>>>>> e521ea9... *: fix the problem that PointGet returns wrong results in the case of overflow (#14776)
p.IndexInfo = idxInfo
p.IndexValues = idxValues
p.IndexValueParams = idxValueParams
Expand Down Expand Up @@ -402,21 +446,22 @@ func getSingleTableNameAndAlias(tableRefs *ast.TableRefsClause) (tblName *ast.Ta
}

// getNameValuePairs extracts `column = constant/paramMarker` conditions from expr as name value pairs.
func getNameValuePairs(nvPairs []nameValuePair, tblName model.CIStr, expr ast.ExprNode) []nameValuePair {
func getNameValuePairs(stmtCtx *stmtctx.StatementContext, tbl *model.TableInfo, tblName model.CIStr, nvPairs []nameValuePair, expr ast.ExprNode) (
pairs []nameValuePair, isTableDual bool) {
binOp, ok := expr.(*ast.BinaryOperationExpr)
if !ok {
return nil
return nil, false
}
if binOp.Op == opcode.LogicAnd {
nvPairs = getNameValuePairs(nvPairs, tblName, binOp.L)
if nvPairs == nil {
return nil
nvPairs, isTableDual = getNameValuePairs(stmtCtx, tbl, tblName, nvPairs, binOp.L)
if nvPairs == nil || isTableDual {
return nil, isTableDual
}
nvPairs = getNameValuePairs(nvPairs, tblName, binOp.R)
if nvPairs == nil {
return nil
nvPairs, isTableDual = getNameValuePairs(stmtCtx, tbl, tblName, nvPairs, binOp.R)
if nvPairs == nil || isTableDual {
return nil, isTableDual
}
return nvPairs
return nvPairs, isTableDual
} else if binOp.Op == opcode.EQ {
var d types.Datum
var colName *ast.ColumnNameExpr
Expand All @@ -439,17 +484,44 @@ func getNameValuePairs(nvPairs []nameValuePair, tblName model.CIStr, expr ast.Ex
param = x
}
} else {
return nil
return nil, false
}
if d.IsNull() {
return nil
return nil, false
}
// Views' columns have no FieldType.
if tbl.IsView() {
return nil, false
}
if colName.Name.Table.L != "" && colName.Name.Table.L != tblName.L {
return nil
return nil, false
}
return append(nvPairs, nameValuePair{colName: colName.Name.Name.L, value: d, param: param})
col := model.FindColumnInfo(tbl.Cols(), colName.Name.Name.L)
if col == nil || // Handling the case when the column is _tidb_rowid.
(col.Tp == mysql.TypeString && col.Collate == charset.CollationBin) { // This type we needn't to pad `\0` in here.
return append(nvPairs, nameValuePair{colName: colName.Name.Name.L, value: d, param: param}), false
}
dVal, err := d.ConvertTo(stmtCtx, &col.FieldType)
if err != nil {
if terror.ErrorEqual(types.ErrOverflow, err) {
return append(nvPairs, nameValuePair{colName: colName.Name.Name.L, value: d, param: param}), true
}
// Some scenarios cast to int with error, but we may use this value in point get.
if !terror.ErrorEqual(types.ErrTruncatedWrongVal, err) {
return nil, false
}
}
// The converted result must be same as original datum.
cmp, err := d.CompareDatum(stmtCtx, &dVal)
if err != nil {
return nil, false
} else if cmp != 0 {
return append(nvPairs, nameValuePair{colName: colName.Name.Name.L, value: dVal, param: param}), true
}

return append(nvPairs, nameValuePair{colName: colName.Name.Name.L, value: dVal, param: param}), false
}
return nil
return nil, false
}

func findPKHandle(tblInfo *model.TableInfo, pairs []nameValuePair) (handlePair nameValuePair, fieldType *types.FieldType) {
Expand Down
2 changes: 1 addition & 1 deletion util/testkit/testkit.go
Original file line number Diff line number Diff line change
Expand Up @@ -220,7 +220,7 @@ func (tk *TestKit) MustTableDual(sql string, args ...interface{}) *Result {
func (tk *TestKit) MustPointGet(sql string, args ...interface{}) *Result {
rs := tk.MustQuery("explain "+sql, args...)
tk.c.Assert(len(rs.rows), check.Equals, 1)
tk.c.Assert(strings.Contains(rs.rows[0][0], "Point_Get"), check.IsTrue)
tk.c.Assert(strings.Contains(rs.rows[0][0], "Point_Get"), check.IsTrue, check.Commentf("plan %v", rs.rows[0][0]))
return tk.MustQuery(sql, args...)
}

Expand Down

0 comments on commit 6a1bd9f

Please sign in to comment.