Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

*: fix the problem that PointGet returns wrong results in the case of overflow #14776

Merged
merged 13 commits into from
Apr 23, 2020
51 changes: 33 additions & 18 deletions executor/point_get_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ import (
"github.com/pingcap/tidb/session"
"github.com/pingcap/tidb/store/mockstore"
"github.com/pingcap/tidb/store/tikv"
"github.com/pingcap/tidb/util/logutil"
"github.com/pingcap/tidb/util/testkit"
)

Expand Down Expand Up @@ -137,10 +138,10 @@ func (s *testPointGetSuite) TestPointGetCharPK(c *C) {
tk.MustExec(`set @@sql_mode="";`)
tk.MustPointGet(`select * from t where a = "a";`).Check(testkit.Rows(`a b`))
tk.MustPointGet(`select * from t where a = "a ";`).Check(testkit.Rows())
tk.MustPointGet(`select * from t where a = "a ";`).Check(testkit.Rows())
tk.MustTableDual(`select * from t where a = "a ";`).Check(testkit.Rows())
tk.MustPointGet(`select * from t where a = "";`).Check(testkit.Rows(` `))
tk.MustPointGet(`select * from t where a = " ";`).Check(testkit.Rows())
tk.MustPointGet(`select * from t where a = " ";`).Check(testkit.Rows())
tk.MustTableDual(`select * from t where a = " ";`).Check(testkit.Rows())

}

Expand All @@ -153,15 +154,15 @@ func (s *testPointGetSuite) TestPointGetAliasTableCharPK(c *C) {

tk.MustExec(`set @@sql_mode="";`)
tk.MustPointGet(`select * from t tmp where a = "aa";`).Check(testkit.Rows(`aa bb`))
tk.MustPointGet(`select * from t tmp where a = "aab";`).Check(testkit.Rows())
tk.MustTableDual(`select * from t tmp where a = "aab";`).Check(testkit.Rows())

tk.MustExec(`truncate table t;`)
tk.MustExec(`insert into t values("a ", "b ");`)

tk.MustExec(`set @@sql_mode="";`)
tk.MustPointGet(`select * from t tmp where a = "a";`).Check(testkit.Rows(`a b`))
tk.MustPointGet(`select * from t tmp where a = "a ";`).Check(testkit.Rows())
tk.MustPointGet(`select * from t tmp where a = "a ";`).Check(testkit.Rows())
tk.MustTableDual(`select * from t tmp where a = "a ";`).Check(testkit.Rows())

// Test CHAR BINARY.
tk.MustExec(`drop table if exists t;`)
Expand All @@ -172,10 +173,10 @@ func (s *testPointGetSuite) TestPointGetAliasTableCharPK(c *C) {
tk.MustExec(`set @@sql_mode="";`)
tk.MustPointGet(`select * from t tmp where a = "a";`).Check(testkit.Rows(`a b`))
tk.MustPointGet(`select * from t tmp where a = "a ";`).Check(testkit.Rows())
tk.MustPointGet(`select * from t tmp where a = "a ";`).Check(testkit.Rows())
tk.MustTableDual(`select * from t tmp where a = "a ";`).Check(testkit.Rows())
tk.MustPointGet(`select * from t tmp where a = "";`).Check(testkit.Rows(` `))
tk.MustPointGet(`select * from t tmp where a = " ";`).Check(testkit.Rows())
tk.MustPointGet(`select * from t tmp where a = " ";`).Check(testkit.Rows())
tk.MustTableDual(`select * from t tmp where a = " ";`).Check(testkit.Rows())

// Test both wildcard and column name exist in select field list
tk.MustExec(`set @@sql_mode="";`)
Expand All @@ -188,9 +189,9 @@ func (s *testPointGetSuite) TestPointGetAliasTableCharPK(c *C) {
tk.MustPointGet(`select tmp.* from t tmp where a = "aa";`).Check(testkit.Rows(`aa bb`))
tk.MustPointGet(`select tmp.a, tmp.b from t tmp where a = "aa";`).Check(testkit.Rows(`aa bb`))
tk.MustPointGet(`select tmp.*, tmp.a, tmp.b from t tmp where a = "aa";`).Check(testkit.Rows(`aa bb aa bb`))
tk.MustPointGet(`select tmp.* from t tmp where a = "aab";`).Check(testkit.Rows())
tk.MustPointGet(`select tmp.a, tmp.b from t tmp where a = "aab";`).Check(testkit.Rows())
tk.MustPointGet(`select tmp.*, tmp.a, tmp.b from t tmp where a = "aab";`).Check(testkit.Rows())
tk.MustTableDual(`select tmp.* from t tmp where a = "aab";`).Check(testkit.Rows())
tk.MustTableDual(`select tmp.a, tmp.b from t tmp where a = "aab";`).Check(testkit.Rows())
tk.MustTableDual(`select tmp.*, tmp.a, tmp.b from t tmp where a = "aab";`).Check(testkit.Rows())

// Test using table alias in where clause
tk.MustPointGet(`select * from t tmp where tmp.a = "aa";`).Check(testkit.Rows(`aa bb`))
Expand Down Expand Up @@ -221,19 +222,19 @@ func (s *testPointGetSuite) TestIndexLookupChar(c *C) {

tk.MustExec(`set @@sql_mode="";`)
tk.MustIndexLookup(`select * from t where a = "aa";`).Check(testkit.Rows(`aa bb`))
tk.MustIndexLookup(`select * from t where a = "aab";`).Check(testkit.Rows())
tk.MustTableDual(`select * from t where a = "aab";`).Check(testkit.Rows())

// Test query with table alias
tk.MustIndexLookup(`select * from t tmp where a = "aa";`).Check(testkit.Rows(`aa bb`))
tk.MustIndexLookup(`select * from t tmp where a = "aab";`).Check(testkit.Rows())
tk.MustTableDual(`select * from t tmp where a = "aab";`).Check(testkit.Rows())

tk.MustExec(`truncate table t;`)
tk.MustExec(`insert into t values("a ", "b ");`)

tk.MustExec(`set @@sql_mode="";`)
tk.MustIndexLookup(`select * from t where a = "a";`).Check(testkit.Rows(`a b`))
tk.MustIndexLookup(`select * from t where a = "a ";`).Check(testkit.Rows())
tk.MustIndexLookup(`select * from t where a = "a ";`).Check(testkit.Rows())
tk.MustTableDual(`select * from t where a = "a ";`).Check(testkit.Rows())

// Test CHAR BINARY.
tk.MustExec(`drop table if exists t;`)
Expand All @@ -244,11 +245,11 @@ func (s *testPointGetSuite) TestIndexLookupChar(c *C) {
tk.MustExec(`set @@sql_mode="";`)
tk.MustIndexLookup(`select * from t where a = "a";`).Check(testkit.Rows(`a b`))
tk.MustIndexLookup(`select * from t where a = "a ";`).Check(testkit.Rows())
tk.MustIndexLookup(`select * from t where a = "a ";`).Check(testkit.Rows())
tk.MustTableDual(`select * from t where a = "a ";`).Check(testkit.Rows())
tk.MustIndexLookup(`select * from t where a = "";`).Check(testkit.Rows(` `))
tk.MustIndexLookup(`select * from t where a = " ";`).Check(testkit.Rows())
tk.MustIndexLookup(`select * from t where a = " ";`).Check(testkit.Rows())
tk.MustIndexLookup(`select * from t where a = " ";`).Check(testkit.Rows())
tk.MustTableDual(`select * from t where a = " ";`).Check(testkit.Rows())

}

Expand All @@ -261,15 +262,15 @@ func (s *testPointGetSuite) TestPointGetVarcharPK(c *C) {

tk.MustExec(`set @@sql_mode="";`)
tk.MustPointGet(`select * from t where a = "aa";`).Check(testkit.Rows(`aa bb`))
tk.MustPointGet(`select * from t where a = "aab";`).Check(testkit.Rows())
tk.MustTableDual(`select * from t where a = "aab";`).Check(testkit.Rows())

tk.MustExec(`truncate table t;`)
tk.MustExec(`insert into t values("a ", "b ");`)

tk.MustExec(`set @@sql_mode="";`)
tk.MustPointGet(`select * from t where a = "a";`).Check(testkit.Rows())
tk.MustPointGet(`select * from t where a = "a ";`).Check(testkit.Rows(`a b `))
tk.MustPointGet(`select * from t where a = "a ";`).Check(testkit.Rows())
tk.MustTableDual(`select * from t where a = "a ";`).Check(testkit.Rows())

// // Test VARCHAR BINARY.
tk.MustExec(`drop table if exists t;`)
Expand All @@ -280,10 +281,10 @@ func (s *testPointGetSuite) TestPointGetVarcharPK(c *C) {
tk.MustExec(`set @@sql_mode="";`)
tk.MustPointGet(`select * from t where a = "a";`).Check(testkit.Rows())
tk.MustPointGet(`select * from t where a = "a ";`).Check(testkit.Rows(`a b `))
tk.MustPointGet(`select * from t where a = "a ";`).Check(testkit.Rows())
tk.MustTableDual(`select * from t where a = "a ";`).Check(testkit.Rows())
tk.MustPointGet(`select * from t where a = " ";`).Check(testkit.Rows())
tk.MustPointGet(`select * from t where a = " ";`).Check(testkit.Rows(` `))
tk.MustPointGet(`select * from t where a = " ";`).Check(testkit.Rows())
tk.MustTableDual(`select * from t where a = " ";`).Check(testkit.Rows())

}

Expand Down Expand Up @@ -364,6 +365,20 @@ func (s *testPointGetSuite) TestIndexLookupBinary(c *C) {

}

func (s *testPointGetSuite) TestOverflowOrTruncated(c *C) {
tk := testkit.NewTestKitWithInit(c, s.store)
tk.MustExec("create table t6 (id bigint, a bigint, primary key(id), unique key(a));")
tk.MustExec("insert into t6 values(9223372036854775807, 9223372036854775807);")
tk.MustExec("insert into t6 values(1, 1);")
var nilVal []string
// for unique key
tk.MustQuery("select * from t6 where a = 9223372036854775808").Check(testkit.Rows(nilVal...))
tk.MustQuery("select * from t6 where a = '1.123'").Check(testkit.Rows(nilVal...))
// for primary key
tk.MustQuery("select * from t6 where id = 9223372036854775808").Check(testkit.Rows(nilVal...))
tk.MustQuery("select * from t6 where id = '1.123'").Check(testkit.Rows(nilVal...))
}

func (s *testPointGetSuite) TestIssue10448(c *C) {
tk := testkit.NewTestKit(c, s.store)
tk.MustExec("use test")
Expand Down
113 changes: 60 additions & 53 deletions planner/core/point_get_plan.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ import (

"github.com/pingcap/errors"
"github.com/pingcap/parser/ast"
"github.com/pingcap/parser/charset"
"github.com/pingcap/parser/model"
"github.com/pingcap/parser/mysql"
"github.com/pingcap/parser/opcode"
Expand Down Expand Up @@ -556,11 +557,15 @@ func tryPointGetPlan(ctx sessionctx.Context, selStmt *ast.SelectStmt) *PointGetP
if tbl == nil {
return nil
}

// Do not handle partitioned table.
// Table partition implementation translates LogicalPlan from `DataSource` to
// `Union -> DataSource` in the logical plan optimization pass, since PointGetPlan
// bypass the logical plan optimization, it can't support partitioned table.
pi := tbl.GetPartitionInfo()
if pi != nil && pi.Type != model.PartitionTypeHash {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please add a comment why partitions rather than HashPartition are not supported.

PS: It is actually not related to this PR, right?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, it's the original code, I only move it.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Now, we only support for HashPartition.

return nil
}
for _, col := range tbl.Columns {
// Do not handle generated columns.
if col.IsGenerated() {
Expand All @@ -571,52 +576,38 @@ func tryPointGetPlan(ctx sessionctx.Context, selStmt *ast.SelectStmt) *PointGetP
return nil
}
}
schema, names := buildSchemaFromFields(tblName.Schema, tbl, tblAlias, selStmt.Fields.Fields)
if schema == nil {
return nil
}
dbName := tblName.Schema.L
if dbName == "" {
dbName = ctx.GetSessionVars().CurrentDB
}

pairs := make([]nameValuePair, 0, 4)
pairs = getNameValuePairs(pairs, tblAlias, selStmt.Where)
pairs, isTableDual := getNameValuePairs(ctx.GetSessionVars().StmtCtx, tbl, tblAlias, pairs, selStmt.Where)
if isTableDual {
p := newPointGetPlan(ctx, tblName.Schema.O, schema, tbl, names)
p.IsTableDual = true
return p
}
if pairs == nil {
return nil
}

var partitionInfo *model.PartitionDefinition
if pi != nil {
if pi.Type != model.PartitionTypeHash {
return nil
}
partitionInfo = getPartitionInfo(ctx, tbl, pairs)
if partitionInfo == nil {
return nil
}
}

handlePair, fieldType := findPKHandle(tbl, pairs)
if handlePair.value.Kind() != types.KindNull && len(pairs) == 1 {
schema, names := buildSchemaFromFields(tblName.Schema, tbl, tblAlias, selStmt.Fields.Fields)
if schema == nil {
return nil
}
dbName := tblName.Schema.L
if dbName == "" {
dbName = ctx.GetSessionVars().CurrentDB
}
p := newPointGetPlan(ctx, dbName, schema, tbl, names)
intDatum, err := handlePair.value.ConvertTo(ctx.GetSessionVars().StmtCtx, fieldType)
if err != nil {
if terror.ErrorEqual(types.ErrOverflow, err) {
p.IsTableDual = true
return p
}
// some scenarios cast to int with error, but we may use this value in point get
if !terror.ErrorEqual(types.ErrTruncatedWrongVal, err) {
return nil
}
}
cmp, err := intDatum.CompareDatum(ctx.GetSessionVars().StmtCtx, &handlePair.value)
if err != nil {
return nil
} else if cmp != 0 {
p.IsTableDual = true
return p
}
p.Handle = intDatum.GetInt64()
p.Handle = handlePair.value.GetInt64()
p.UnsignedHandle = mysql.HasUnsignedFlag(fieldType.Flag)
p.HandleParam = handlePair.param
p.PartitionInfo = partitionInfo
Expand All @@ -634,14 +625,6 @@ func tryPointGetPlan(ctx sessionctx.Context, selStmt *ast.SelectStmt) *PointGetP
if idxValues == nil {
continue
}
schema, names := buildSchemaFromFields(tblName.Schema, tbl, tblAlias, selStmt.Fields.Fields)
if schema == nil {
return nil
}
dbName := tblName.Schema.L
if dbName == "" {
dbName = ctx.GetSessionVars().CurrentDB
}
p := newPointGetPlan(ctx, dbName, schema, tbl, names)
p.IndexInfo = idxInfo
p.IndexValues = idxValues
Expand Down Expand Up @@ -769,21 +752,22 @@ func getSingleTableNameAndAlias(tableRefs *ast.TableRefsClause) (tblName *ast.Ta
}

// getNameValuePairs extracts `column = constant/paramMarker` conditions from expr as name value pairs.
func getNameValuePairs(nvPairs []nameValuePair, tblName model.CIStr, expr ast.ExprNode) []nameValuePair {
func getNameValuePairs(stmtCtx *stmtctx.StatementContext, tbl *model.TableInfo, tblName model.CIStr, nvPairs []nameValuePair, expr ast.ExprNode) (
pairs []nameValuePair, isTableDual bool) {
binOp, ok := expr.(*ast.BinaryOperationExpr)
if !ok {
return nil
return nil, false
}
if binOp.Op == opcode.LogicAnd {
nvPairs = getNameValuePairs(nvPairs, tblName, binOp.L)
if nvPairs == nil {
return nil
nvPairs, isTableDual = getNameValuePairs(stmtCtx, tbl, tblName, nvPairs, binOp.L)
if nvPairs == nil || isTableDual {
return nil, isTableDual
}
nvPairs = getNameValuePairs(nvPairs, tblName, binOp.R)
if nvPairs == nil {
return nil
nvPairs, isTableDual = getNameValuePairs(stmtCtx, tbl, tblName, nvPairs, binOp.R)
if nvPairs == nil || isTableDual {
return nil, isTableDual
}
return nvPairs
return nvPairs, isTableDual
} else if binOp.Op == opcode.EQ {
var d types.Datum
var colName *ast.ColumnNameExpr
Expand All @@ -806,17 +790,40 @@ func getNameValuePairs(nvPairs []nameValuePair, tblName model.CIStr, expr ast.Ex
param = x
}
} else {
return nil
return nil, false
}
if d.IsNull() {
return nil
return nil, false
}
if colName.Name.Table.L != "" && colName.Name.Table.L != tblName.L {
return nil
return nil, false
}
return append(nvPairs, nameValuePair{colName: colName.Name.Name.L, value: d, param: param})
col := model.FindColumnInfo(tbl.Cols(), colName.Name.Name.L)
if col == nil || // Handling the case when the column is _tidb_rowid.
(col.Tp == mysql.TypeString && col.Collate == charset.CollationBin) { // This type we needn't to pad `\0` in here.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Any test case for that? I don't know this is used for what.

Copy link
Contributor Author

@zimulala zimulala Apr 20, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes. These tests already exist. The test case for the table of create table t(a binary(2) primary key, b binary(2)); table, like select * from t where a = "a ";
@crazycs520

return append(nvPairs, nameValuePair{colName: colName.Name.Name.L, value: d, param: param}), false
}
dVal, err := d.ConvertTo(stmtCtx, &col.FieldType)
if err != nil {
if terror.ErrorEqual(types.ErrOverflow, err) {
return nil, true
}
// Some scenarios cast to int with error, but we may use this value in point get.
if !terror.ErrorEqual(types.ErrTruncatedWrongVal, err) {
return nil, false
}
}
// The converted result must be same as original datum.
cmp, err := d.CompareDatum(stmtCtx, &dVal)
if err != nil {
return nil, false
} else if cmp != 0 {
return nil, true
}

return append(nvPairs, nameValuePair{colName: colName.Name.Name.L, value: dVal, param: param}), false
}
return nil
return nil, false
}

func findPKHandle(tblInfo *model.TableInfo, pairs []nameValuePair) (handlePair nameValuePair, fieldType *types.FieldType) {
Expand Down
2 changes: 1 addition & 1 deletion util/testkit/testkit.go
Original file line number Diff line number Diff line change
Expand Up @@ -224,7 +224,7 @@ func (tk *TestKit) MustTableDual(sql string, args ...interface{}) *Result {
func (tk *TestKit) MustPointGet(sql string, args ...interface{}) *Result {
rs := tk.MustQuery("explain "+sql, args...)
tk.c.Assert(len(rs.rows), check.Equals, 1)
tk.c.Assert(strings.Contains(rs.rows[0][0], "Point_Get"), check.IsTrue)
tk.c.Assert(strings.Contains(rs.rows[0][0], "Point_Get"), check.IsTrue, check.Commentf("plan %v", rs.rows[0][0]))
return tk.MustQuery(sql, args...)
}

Expand Down