diff --git a/README.md b/README.md index 1bb34299bc0ae..d4f38ad22cdc3 100644 --- a/README.md +++ b/README.md @@ -8,6 +8,7 @@ [![GitHub release date](https://img.shields.io/github/release-date/pingcap/tidb.svg)](https://github.com/pingcap/tidb/releases) [![CircleCI Status](https://circleci.com/gh/pingcap/tidb.svg?style=shield)](https://circleci.com/gh/pingcap/tidb) [![Coverage Status](https://codecov.io/gh/pingcap/tidb/branch/master/graph/badge.svg)](https://codecov.io/gh/pingcap/tidb) +[![GoDoc](https://img.shields.io/badge/Godoc-reference-blue.svg)](https://godoc.org/github.com/pingcap/tidb) - [**Stack Overflow**](https://stackoverflow.com/questions/tagged/tidb) - Community [**Slack Channel**](https://pingcap.com/tidbslack/) diff --git a/config/config.toml.example b/config/config.toml.example index 6ff6fd9b0741a..15f079e1de12d 100644 --- a/config/config.toml.example +++ b/config/config.toml.example @@ -57,7 +57,7 @@ treat-old-version-utf8-as-utf8mb4 = true # enable-table-lock is used to control table lock feature. Default is false, indicate the table lock feature is disabled. enable-table-lock = false -# delay-clean-table-lock is used to control whether delayed-release the table lock in the abnormal situation. (Milliseconds) +# delay-clean-table-lock is used to control the time (Milliseconds) of delay before unlock the table in the abnormal situation. delay-clean-table-lock = 0 [log] diff --git a/executor/adapter.go b/executor/adapter.go index ee7bc38441d1a..a08c26ab11abf 100644 --- a/executor/adapter.go +++ b/executor/adapter.go @@ -608,7 +608,8 @@ func (a *ExecStmt) logAudit() { audit := plugin.DeclareAuditManifest(p.Manifest) if audit.OnGeneralEvent != nil { cmd := mysql.Command2Str[byte(atomic.LoadUint32(&a.Ctx.GetSessionVars().CommandValue))] - audit.OnGeneralEvent(context.Background(), sessVars, plugin.Log, cmd) + ctx := context.WithValue(context.Background(), plugin.ExecStartTimeCtxKey, a.StartTime) + audit.OnGeneralEvent(ctx, sessVars, plugin.Log, cmd) } return nil }) diff --git a/executor/analyze.go b/executor/analyze.go index d37a27a32fc5d..c159b82313206 100644 --- a/executor/analyze.go +++ b/executor/analyze.go @@ -620,12 +620,9 @@ func (e *AnalyzeFastExec) getSampRegionsRowCount(bo *tikv.Backoffer, needRebuild if !ok { return } - req := &tikvrpc.Request{ - Type: tikvrpc.CmdDebugGetRegionProperties, - DebugGetRegionProperties: &debugpb.GetRegionPropertiesRequest{ - RegionId: loc.Region.GetID(), - }, - } + req := tikvrpc.NewRequest(tikvrpc.CmdDebugGetRegionProperties, &debugpb.GetRegionPropertiesRequest{ + RegionId: loc.Region.GetID(), + }) var resp *tikvrpc.Response var rpcCtx *tikv.RPCContext rpcCtx, *err = e.cache.GetRPCContext(bo, loc.Region) @@ -638,11 +635,11 @@ func (e *AnalyzeFastExec) getSampRegionsRowCount(bo *tikv.Backoffer, needRebuild if *err != nil { return } - if resp.DebugGetRegionProperties == nil || len(resp.DebugGetRegionProperties.Props) == 0 { + if resp.Resp == nil || len(resp.Resp.(*debugpb.GetRegionPropertiesResponse).Props) == 0 { *needRebuild = true return } - for _, prop := range resp.DebugGetRegionProperties.Props { + for _, prop := range resp.Resp.(*debugpb.GetRegionPropertiesResponse).Props { if prop.Name == "mvcc.num_rows" { var cnt uint64 cnt, *err = strconv.ParseUint(prop.Value, 10, 64) @@ -1002,7 +999,7 @@ func (e *AnalyzeFastExec) buildColumnStats(ID int64, collector *statistics.Sampl collector.NullCount++ continue } - bytes, err := tablecodec.EncodeValue(e.ctx.GetSessionVars().StmtCtx, sample.Value) + bytes, err := tablecodec.EncodeValue(e.ctx.GetSessionVars().StmtCtx, nil, sample.Value) if err != nil { return nil, nil, err } diff --git a/executor/analyze_test.go b/executor/analyze_test.go index ec3d98e188d42..50577ca7c50ba 100644 --- a/executor/analyze_test.go +++ b/executor/analyze_test.go @@ -400,7 +400,7 @@ func (c *regionProperityClient) SendRequest(ctx context.Context, addr string, re defer c.mu.Unlock() c.mu.count++ // Mock failure once. - if req.DebugGetRegionProperties.RegionId == c.mu.regionID { + if req.DebugGetRegionProperties().RegionId == c.mu.regionID { c.mu.regionID = 0 return &tikvrpc.Response{}, nil } diff --git a/executor/checksum.go b/executor/checksum.go index 9ca8692088825..c84579fe85ee8 100644 --- a/executor/checksum.go +++ b/executor/checksum.go @@ -187,27 +187,47 @@ func newChecksumContext(db *model.DBInfo, table *model.TableInfo, startTs uint64 } func (c *checksumContext) BuildRequests(ctx sessionctx.Context) ([]*kv.Request, error) { - reqs := make([]*kv.Request, 0, len(c.TableInfo.Indices)+1) - req, err := c.buildTableRequest(ctx) - if err != nil { + var partDefs []model.PartitionDefinition + if part := c.TableInfo.Partition; part != nil { + partDefs = part.Definitions + } + + reqs := make([]*kv.Request, 0, (len(c.TableInfo.Indices)+1)*(len(partDefs)+1)) + if err := c.appendRequest(ctx, c.TableInfo.ID, &reqs); err != nil { return nil, err } - reqs = append(reqs, req) + + for _, partDef := range partDefs { + if err := c.appendRequest(ctx, partDef.ID, &reqs); err != nil { + return nil, err + } + } + + return reqs, nil +} + +func (c *checksumContext) appendRequest(ctx sessionctx.Context, tableID int64, reqs *[]*kv.Request) error { + req, err := c.buildTableRequest(ctx, tableID) + if err != nil { + return err + } + + *reqs = append(*reqs, req) for _, indexInfo := range c.TableInfo.Indices { if indexInfo.State != model.StatePublic { continue } - req, err = c.buildIndexRequest(ctx, indexInfo) + req, err = c.buildIndexRequest(ctx, tableID, indexInfo) if err != nil { - return nil, err + return err } - reqs = append(reqs, req) + *reqs = append(*reqs, req) } - return reqs, nil + return nil } -func (c *checksumContext) buildTableRequest(ctx sessionctx.Context) (*kv.Request, error) { +func (c *checksumContext) buildTableRequest(ctx sessionctx.Context, tableID int64) (*kv.Request, error) { checksum := &tipb.ChecksumRequest{ StartTs: c.StartTs, ScanOn: tipb.ChecksumScanOn_Table, @@ -217,13 +237,13 @@ func (c *checksumContext) buildTableRequest(ctx sessionctx.Context) (*kv.Request ranges := ranger.FullIntRange(false) var builder distsql.RequestBuilder - return builder.SetTableRanges(c.TableInfo.ID, ranges, nil). + return builder.SetTableRanges(tableID, ranges, nil). SetChecksumRequest(checksum). SetConcurrency(ctx.GetSessionVars().DistSQLScanConcurrency). Build() } -func (c *checksumContext) buildIndexRequest(ctx sessionctx.Context, indexInfo *model.IndexInfo) (*kv.Request, error) { +func (c *checksumContext) buildIndexRequest(ctx sessionctx.Context, tableID int64, indexInfo *model.IndexInfo) (*kv.Request, error) { checksum := &tipb.ChecksumRequest{ StartTs: c.StartTs, ScanOn: tipb.ChecksumScanOn_Index, @@ -233,7 +253,7 @@ func (c *checksumContext) buildIndexRequest(ctx sessionctx.Context, indexInfo *m ranges := ranger.FullRange() var builder distsql.RequestBuilder - return builder.SetIndexRanges(ctx.GetSessionVars().StmtCtx, c.TableInfo.ID, indexInfo.ID, ranges). + return builder.SetIndexRanges(ctx.GetSessionVars().StmtCtx, tableID, indexInfo.ID, ranges). SetChecksumRequest(checksum). SetConcurrency(ctx.GetSessionVars().DistSQLScanConcurrency). Build() diff --git a/executor/executor_test.go b/executor/executor_test.go index e860a6e200d52..3efde355754f3 100644 --- a/executor/executor_test.go +++ b/executor/executor_test.go @@ -87,7 +87,8 @@ func TestT(t *testing.T) { testleak.AfterTestT(t)() } -var _ = Suite(&testSuite{}) +var _ = Suite(&testSuite{&baseTestSuite{}}) +var _ = Suite(&testSuiteP1{&baseTestSuite{}}) var _ = Suite(&testSuite1{}) var _ = Suite(&testSuite2{}) var _ = Suite(&testSuite3{}) @@ -98,9 +99,11 @@ var _ = Suite(&testOOMSuite{}) var _ = Suite(&testPointGetSuite{}) var _ = Suite(&testRecoverTable{}) var _ = Suite(&testFlushSuite{}) -var _ = Suite(&testShowStatsSuite{}) -type testSuite struct { +type testSuite struct{ *baseTestSuite } +type testSuiteP1 struct{ *baseTestSuite } + +type baseTestSuite struct { cluster *mocktikv.Cluster mvccStore mocktikv.MVCCStore store kv.Storage @@ -111,7 +114,7 @@ type testSuite struct { var mockTikv = flag.Bool("mockTikv", true, "use mock tikv store in executor test") -func (s *testSuite) SetUpSuite(c *C) { +func (s *baseTestSuite) SetUpSuite(c *C) { s.Parser = parser.New() flag.Lookup("mockTikv") useMockTikv := *mockTikv @@ -134,7 +137,7 @@ func (s *testSuite) SetUpSuite(c *C) { s.domain = d } -func (s *testSuite) TearDownSuite(c *C) { +func (s *baseTestSuite) TearDownSuite(c *C) { s.domain.Close() s.store.Close() } @@ -145,7 +148,7 @@ func enablePessimisticTxn(enable bool) { config.StoreGlobalConfig(newConf) } -func (s *testSuite) TestPessimisticSelectForUpdate(c *C) { +func (s *testSuiteP1) TestPessimisticSelectForUpdate(c *C) { defer func() { enablePessimisticTxn(false) }() enablePessimisticTxn(true) tk := testkit.NewTestKit(c, s.store) @@ -170,7 +173,7 @@ func (s *testSuite) TearDownTest(c *C) { } } -func (s *testSuite) TestBind(c *C) { +func (s *testSuiteP1) TestBind(c *C) { tk := testkit.NewTestKit(c, s.store) tk.MustExec("use test") tk.MustExec("drop table if exists testbind") @@ -185,7 +188,7 @@ func (s *testSuite) TestBind(c *C) { tk.MustExec("drop session binding for select * from testbind") } -func (s *testSuite) TestChange(c *C) { +func (s *testSuiteP1) TestChange(c *C) { tk := testkit.NewTestKit(c, s.store) tk.MustExec("use test") tk.MustExec("drop table if exists t") @@ -195,18 +198,20 @@ func (s *testSuite) TestChange(c *C) { c.Assert(tk.ExecToErr("alter table t change c d varchar(100)"), NotNil) } -func (s *testSuite) TestLoadStats(c *C) { +func (s *testSuiteP1) TestLoadStats(c *C) { tk := testkit.NewTestKit(c, s.store) tk.MustExec("use test") c.Assert(tk.ExecToErr("load stats"), NotNil) c.Assert(tk.ExecToErr("load stats ./xxx.json"), NotNil) } -func (s *testSuite) TestShow(c *C) { +func (s *testSuiteP1) TestShow(c *C) { tk := testkit.NewTestKit(c, s.store) - tk.MustExec("use test") + tk.MustExec("create database test_show;") + tk.MustExec("use test_show") tk.MustQuery("show engines") + tk.MustExec("drop table if exists t") tk.MustExec("create table t(a int primary key)") c.Assert(len(tk.MustQuery("show index in t").Rows()), Equals, 1) c.Assert(len(tk.MustQuery("show index from t").Rows()), Equals, 1) @@ -218,7 +223,7 @@ func (s *testSuite) TestShow(c *C) { "latin1 Latin1 latin1_bin 1", "binary binary binary 1")) c.Assert(len(tk.MustQuery("show master status").Rows()), Equals, 1) - tk.MustQuery("show create database test").Check(testkit.Rows("test CREATE DATABASE `test` /*!40100 DEFAULT CHARACTER SET utf8mb4 */")) + tk.MustQuery("show create database test_show").Check(testkit.Rows("test_show CREATE DATABASE `test_show` /*!40100 DEFAULT CHARACTER SET utf8mb4 */")) tk.MustQuery("show privileges").Check(testkit.Rows("Alter Tables To alter the table", "Alter Tables To alter the table", "Alter routine Functions,Procedures To alter or drop stored functions/procedures", @@ -254,7 +259,7 @@ func (s *testSuite) TestShow(c *C) { c.Assert(len(tk.MustQuery("show table status").Rows()), Equals, 1) } -func (s *testSuite) TestAdmin(c *C) { +func (s *testSuiteP1) TestAdmin(c *C) { tk := testkit.NewTestKit(c, s.store) tk.MustExec("use test") tk.MustExec("drop table if exists admin_test") @@ -412,7 +417,18 @@ func (s *testSuite) TestAdmin(c *C) { c.Assert(historyJobs, DeepEquals, historyJobs2) } -func (s *testSuite) fillData(tk *testkit.TestKit, table string) { +func (s *testSuite) TestAdminChecksumOfPartitionedTable(c *C) { + tk := testkit.NewTestKit(c, s.store) + tk.MustExec("USE test;") + tk.MustExec("DROP TABLE IF EXISTS admin_checksum_partition_test;") + tk.MustExec("CREATE TABLE admin_checksum_partition_test (a INT) PARTITION BY HASH(a) PARTITIONS 4;") + tk.MustExec("INSERT INTO admin_checksum_partition_test VALUES (1), (2);") + + r := tk.MustQuery("ADMIN CHECKSUM TABLE admin_checksum_partition_test;") + r.Check(testkit.Rows("test admin_checksum_partition_test 1 5 5")) +} + +func (s *baseTestSuite) fillData(tk *testkit.TestKit, table string) { tk.MustExec("use test") tk.MustExec(fmt.Sprintf("create table %s(id int not null default 1, name varchar(255), PRIMARY KEY(id));", table)) @@ -440,6 +456,7 @@ func checkCases(tests []testCase, ld *executor.LoadDataInfo, ctx.GetSessionVars().StmtCtx.DupKeyAsWarning = true ctx.GetSessionVars().StmtCtx.BadNullAsWarning = true ctx.GetSessionVars().StmtCtx.InLoadDataStmt = true + ctx.GetSessionVars().StmtCtx.InDeleteStmt = false data, reachLimit, err1 := ld.InsertData(context.Background(), tt.data1, tt.data2) c.Assert(err1, IsNil) c.Assert(reachLimit, IsFalse) @@ -466,7 +483,7 @@ func checkCases(tests []testCase, ld *executor.LoadDataInfo, } } -func (s *testSuite) TestSelectWithoutFrom(c *C) { +func (s *testSuiteP1) TestSelectWithoutFrom(c *C) { tk := testkit.NewTestKit(c, s.store) tk.MustExec("use test") @@ -481,7 +498,7 @@ func (s *testSuite) TestSelectWithoutFrom(c *C) { } // TestSelectBackslashN Issue 3685. -func (s *testSuite) TestSelectBackslashN(c *C) { +func (s *testSuiteP1) TestSelectBackslashN(c *C) { tk := testkit.NewTestKit(c, s.store) sql := `select \N;` @@ -572,7 +589,7 @@ func (s *testSuite) TestSelectBackslashN(c *C) { } // TestSelectNull Issue #4053. -func (s *testSuite) TestSelectNull(c *C) { +func (s *testSuiteP1) TestSelectNull(c *C) { tk := testkit.NewTestKit(c, s.store) sql := `select nUll;` @@ -605,7 +622,7 @@ func (s *testSuite) TestSelectNull(c *C) { } // TestSelectStringLiteral Issue #3686. -func (s *testSuite) TestSelectStringLiteral(c *C) { +func (s *testSuiteP1) TestSelectStringLiteral(c *C) { tk := testkit.NewTestKit(c, s.store) sql := `select 'abc';` @@ -760,7 +777,7 @@ func (s *testSuite) TestSelectStringLiteral(c *C) { c.Check(fields[0].Column.Name.O, Equals, "ss") } -func (s *testSuite) TestSelectLimit(c *C) { +func (s *testSuiteP1) TestSelectLimit(c *C) { tk := testkit.NewTestKit(c, s.store) tk.MustExec("use test") s.fillData(tk, "select_limit") @@ -789,7 +806,7 @@ func (s *testSuite) TestSelectLimit(c *C) { c.Assert(err, NotNil) } -func (s *testSuite) TestSelectOrderBy(c *C) { +func (s *testSuiteP1) TestSelectOrderBy(c *C) { tk := testkit.NewTestKit(c, s.store) tk.MustExec("use test") s.fillData(tk, "select_order_test") @@ -886,7 +903,7 @@ func (s *testSuite) TestSelectOrderBy(c *C) { tk.MustQuery("select a from t use index(b) order by b").Check(testkit.Rows("9", "8", "7", "6", "5", "4", "3", "2", "1", "0")) } -func (s *testSuite) TestOrderBy(c *C) { +func (s *testSuiteP1) TestOrderBy(c *C) { tk := testkit.NewTestKitWithInit(c, s.store) tk.MustExec("drop table if exists t") tk.MustExec("create table t (c1 int, c2 int, c3 varchar(20))") @@ -909,7 +926,7 @@ func (s *testSuite) TestOrderBy(c *C) { tk.MustQuery("select c1, c2 from t order by binary c3").Check(testkit.Rows("1 2", "2 1")) } -func (s *testSuite) TestSelectErrorRow(c *C) { +func (s *testSuiteP1) TestSelectErrorRow(c *C) { tk := testkit.NewTestKit(c, s.store) tk.MustExec("use test") @@ -939,7 +956,7 @@ func (s *testSuite) TestSelectErrorRow(c *C) { } // TestIssue2612 is related with https://github.com/pingcap/tidb/issues/2612 -func (s *testSuite) TestIssue2612(c *C) { +func (s *testSuiteP1) TestIssue2612(c *C) { tk := testkit.NewTestKit(c, s.store) tk.MustExec("use test") tk.MustExec(`drop table if exists t`) @@ -957,7 +974,7 @@ func (s *testSuite) TestIssue2612(c *C) { } // TestIssue345 is related with https://github.com/pingcap/tidb/issues/345 -func (s *testSuite) TestIssue345(c *C) { +func (s *testSuiteP1) TestIssue345(c *C) { tk := testkit.NewTestKit(c, s.store) tk.MustExec("use test") tk.MustExec(`drop table if exists t1, t2`) @@ -988,7 +1005,7 @@ func (s *testSuite) TestIssue345(c *C) { c.Assert(err, NotNil) } -func (s *testSuite) TestIssue5055(c *C) { +func (s *testSuiteP1) TestIssue5055(c *C) { tk := testkit.NewTestKit(c, s.store) tk.MustExec("use test") tk.MustExec(`drop table if exists t1, t2`) @@ -1185,6 +1202,7 @@ func (s *testSuite) TestUnion(c *C) { tk.MustExec(`insert into t1 select * from t1;`) tk.MustExec(`insert into t2 values(1, 1);`) tk.MustExec(`set @@tidb_init_chunk_size=2;`) + tk.MustExec(`set @@sql_mode="";`) tk.MustQuery(`select count(*) from (select t1.a, t1.b from t1 left join t2 on t1.a=t2.a union all select t1.a, t1.a from t1 left join t2 on t1.a=t2.a) tmp;`).Check(testkit.Rows("128")) tk.MustQuery(`select tmp.a, count(*) from (select t1.a, t1.b from t1 left join t2 on t1.a=t2.a union all select t1.a, t1.a from t1 left join t2 on t1.a=t2.a) tmp;`).Check(testkit.Rows("1 128")) @@ -1239,7 +1257,7 @@ func (s *testSuite) TestUnion(c *C) { tk.MustQuery("select count(distinct a), sum(distinct a), avg(distinct a) from (select a from t union all select b from t) tmp;").Check(testkit.Rows("1 1.000 1.0000000")) } -func (s *testSuite) TestNeighbouringProj(c *C) { +func (s *testSuiteP1) TestNeighbouringProj(c *C) { tk := testkit.NewTestKit(c, s.store) tk.MustExec("use test") @@ -1257,7 +1275,7 @@ func (s *testSuite) TestNeighbouringProj(c *C) { rs.Check(testkit.Rows("1 1 1", "1 2 2", "1 3 3")) } -func (s *testSuite) TestIn(c *C) { +func (s *testSuiteP1) TestIn(c *C) { tk := testkit.NewTestKit(c, s.store) tk.MustExec("use test") tk.MustExec(`drop table if exists t`) @@ -1273,7 +1291,7 @@ func (s *testSuite) TestIn(c *C) { tk.MustQuery(queryStr).Check(testkit.Rows("7")) } -func (s *testSuite) TestTablePKisHandleScan(c *C) { +func (s *testSuiteP1) TestTablePKisHandleScan(c *C) { tk := testkit.NewTestKit(c, s.store) tk.MustExec("use test") tk.MustExec("drop table if exists t") @@ -1335,7 +1353,7 @@ func (s *testSuite) TestTablePKisHandleScan(c *C) { } } -func (s *testSuite) TestIndexScan(c *C) { +func (s *testSuiteP1) TestIndexScan(c *C) { tk := testkit.NewTestKit(c, s.store) tk.MustExec("use test") tk.MustExec("drop table if exists t") @@ -1418,7 +1436,7 @@ func (s *testSuite) TestIndexScan(c *C) { result.Check(testkit.Rows()) } -func (s *testSuite) TestIndexReverseOrder(c *C) { +func (s *testSuiteP1) TestIndexReverseOrder(c *C) { tk := testkit.NewTestKit(c, s.store) tk.MustExec("use test") tk.MustExec("drop table if exists t") @@ -1436,7 +1454,7 @@ func (s *testSuite) TestIndexReverseOrder(c *C) { result.Check(testkit.Rows("0 2", "0 1", "0 0", "1 2", "1 1", "1 0", "2 2", "2 1", "2 0")) } -func (s *testSuite) TestTableReverseOrder(c *C) { +func (s *testSuiteP1) TestTableReverseOrder(c *C) { tk := testkit.NewTestKit(c, s.store) tk.MustExec("use test") tk.MustExec("drop table if exists t") @@ -1448,7 +1466,7 @@ func (s *testSuite) TestTableReverseOrder(c *C) { result.Check(testkit.Rows("7", "6", "2", "1")) } -func (s *testSuite) TestDefaultNull(c *C) { +func (s *testSuiteP1) TestDefaultNull(c *C) { tk := testkit.NewTestKit(c, s.store) tk.MustExec("use test") tk.MustExec("drop table if exists t") @@ -1464,7 +1482,7 @@ func (s *testSuite) TestDefaultNull(c *C) { tk.MustQuery("select * from t").Check(testkit.Rows("1 1 ")) } -func (s *testSuite) TestUnsignedPKColumn(c *C) { +func (s *testSuiteP1) TestUnsignedPKColumn(c *C) { tk := testkit.NewTestKit(c, s.store) tk.MustExec("use test") tk.MustExec("drop table if exists t") @@ -1477,7 +1495,7 @@ func (s *testSuite) TestUnsignedPKColumn(c *C) { result.Check(testkit.Rows("1 1 2")) } -func (s *testSuite) TestJSON(c *C) { +func (s *testSuiteP1) TestJSON(c *C) { tk := testkit.NewTestKit(c, s.store) tk.MustExec("use test") @@ -1565,7 +1583,7 @@ func (s *testSuite) TestJSON(c *C) { "1234567890123456789012345678901234567890123456789012345.12")) } -func (s *testSuite) TestMultiUpdate(c *C) { +func (s *testSuiteP1) TestMultiUpdate(c *C) { tk := testkit.NewTestKit(c, s.store) tk.MustExec("use test") tk.MustExec(`CREATE TABLE test_mu (a int primary key, b int, c int)`) @@ -1594,7 +1612,7 @@ func (s *testSuite) TestMultiUpdate(c *C) { result.Check(testkit.Rows(`1 7 2`, `4 8 8`, `7 8 8`)) } -func (s *testSuite) TestGeneratedColumnWrite(c *C) { +func (s *testSuiteP1) TestGeneratedColumnWrite(c *C) { tk := testkit.NewTestKit(c, s.store) tk.MustExec("use test") _, err := tk.Exec(`CREATE TABLE test_gc_write (a int primary key auto_increment, b int, c int as (a+8) virtual)`) @@ -1648,7 +1666,7 @@ func (s *testSuite) TestGeneratedColumnWrite(c *C) { // TestGeneratedColumnRead tests select generated columns from table. // They should be calculated from their generation expressions. -func (s *testSuite) TestGeneratedColumnRead(c *C) { +func (s *testSuiteP1) TestGeneratedColumnRead(c *C) { tk := testkit.NewTestKit(c, s.store) tk.MustExec("use test") tk.MustExec(`CREATE TABLE test_gc_read(a int primary key, b int, c int as (a+b), d int as (a*b) stored, e int as (c*2))`) @@ -1810,7 +1828,7 @@ func (s *testSuite) TestGeneratedColumnRead(c *C) { } } -func (s *testSuite) TestToPBExpr(c *C) { +func (s *testSuiteP1) TestToPBExpr(c *C) { tk := testkit.NewTestKit(c, s.store) tk.MustExec("use test") tk.MustExec("drop table if exists t") @@ -1858,7 +1876,7 @@ func (s *testSuite) TestToPBExpr(c *C) { result.Check(testkit.Rows("1", "2")) } -func (s *testSuite) TestDatumXAPI(c *C) { +func (s *testSuiteP1) TestDatumXAPI(c *C) { tk := testkit.NewTestKit(c, s.store) tk.MustExec("use test") tk.MustExec("drop table if exists t") @@ -1883,7 +1901,7 @@ func (s *testSuite) TestDatumXAPI(c *C) { result.Check(testkit.Rows("11:11:12.000 11:11:12", "11:11:13.000 11:11:13")) } -func (s *testSuite) TestSQLMode(c *C) { +func (s *testSuiteP1) TestSQLMode(c *C) { tk := testkit.NewTestKit(c, s.store) tk.MustExec("use test") tk.MustExec("drop table if exists t") @@ -1921,6 +1939,7 @@ func (s *testSuite) TestSQLMode(c *C) { s.domain.GetGlobalVarsCache().Disable() tk2 := testkit.NewTestKit(c, s.store) tk2.MustExec("use test") + tk2.MustExec("drop table if exists t2") tk2.MustExec("create table t2 (a varchar(3))") tk2.MustExec("insert t2 values ('abcd')") tk2.MustQuery("select * from t2").Check(testkit.Rows("abc")) @@ -1932,7 +1951,7 @@ func (s *testSuite) TestSQLMode(c *C) { tk.MustExec("set @@global.sql_mode = 'STRICT_TRANS_TABLES'") } -func (s *testSuite) TestTableDual(c *C) { +func (s *testSuiteP1) TestTableDual(c *C) { tk := testkit.NewTestKit(c, s.store) tk.MustExec("use test") result := tk.MustQuery("Select 1") @@ -1944,11 +1963,12 @@ func (s *testSuite) TestTableDual(c *C) { result = tk.MustQuery("Select 1 from dual where 1") result.Check(testkit.Rows("1")) + tk.MustExec("drop table if exists t") tk.MustExec("create table t(a int primary key)") tk.MustQuery("select t1.* from t t1, t t2 where t1.a=t2.a and 1=0").Check(testkit.Rows()) } -func (s *testSuite) TestTableScan(c *C) { +func (s *testSuiteP1) TestTableScan(c *C) { tk := testkit.NewTestKit(c, s.store) tk.MustExec("use information_schema") result := tk.MustQuery("select * from schemata") @@ -1967,7 +1987,7 @@ func (s *testSuite) TestTableScan(c *C) { result.Check(testkit.Rows("1")) } -func (s *testSuite) TestAdapterStatement(c *C) { +func (s *testSuiteP1) TestAdapterStatement(c *C) { se, err := session.CreateSession4Test(s.store) c.Check(err, IsNil) se.GetSessionVars().TxnCtx.InfoSchema = domain.GetDomain(se).InfoSchema() @@ -1985,7 +2005,7 @@ func (s *testSuite) TestAdapterStatement(c *C) { c.Check(stmt.OriginText(), Equals, "create table test.t (a int)") } -func (s *testSuite) TestIsPointGet(c *C) { +func (s *testSuiteP1) TestIsPointGet(c *C) { tk := testkit.NewTestKit(c, s.store) tk.MustExec("use mysql") ctx := tk.Se.(sessionctx.Context) @@ -2010,7 +2030,7 @@ func (s *testSuite) TestIsPointGet(c *C) { } } -func (s *testSuite) TestPointGetRepeatableRead(c *C) { +func (s *testSuiteP1) TestPointGetRepeatableRead(c *C) { tk1 := testkit.NewTestKit(c, s.store) tk1.MustExec("use test") tk1.MustExec(`create table point_get (a int, b int, c int, @@ -2046,7 +2066,7 @@ func (s *testSuite) TestPointGetRepeatableRead(c *C) { c.Assert(failpoint.Disable(step2), IsNil) } -func (s *testSuite) TestSplitRegionTimeout(c *C) { +func (s *testSuite4) TestSplitRegionTimeout(c *C) { c.Assert(failpoint.Enable("github.com/pingcap/tidb/executor/mockSplitRegionTimeout", `return(true)`), IsNil) tk := testkit.NewTestKit(c, s.store) tk.MustExec("use test") @@ -2066,7 +2086,7 @@ func (s *testSuite) TestSplitRegionTimeout(c *C) { c.Assert(failpoint.Disable("github.com/pingcap/tidb/executor/mockScatterRegionTimeout"), IsNil) } -func (s *testSuite) TestRow(c *C) { +func (s *testSuiteP1) TestRow(c *C) { tk := testkit.NewTestKit(c, s.store) tk.MustExec("use test") tk.MustExec("drop table if exists t") @@ -2115,7 +2135,7 @@ func (s *testSuite) TestRow(c *C) { result.Check(testkit.Rows("1")) } -func (s *testSuite) TestColumnName(c *C) { +func (s *testSuiteP1) TestColumnName(c *C) { tk := testkit.NewTestKit(c, s.store) tk.MustExec("use test") tk.MustExec("drop table if exists t") @@ -2196,7 +2216,7 @@ func (s *testSuite) TestColumnName(c *C) { rs.Close() } -func (s *testSuite) TestSelectVar(c *C) { +func (s *testSuiteP1) TestSelectVar(c *C) { tk := testkit.NewTestKit(c, s.store) tk.MustExec("use test") tk.MustExec("drop table if exists t") @@ -2211,7 +2231,7 @@ func (s *testSuite) TestSelectVar(c *C) { tk.MustExec("select SQL_BUFFER_RESULT d from t group by d") } -func (s *testSuite) TestHistoryRead(c *C) { +func (s *testSuiteP1) TestHistoryRead(c *C) { tk := testkit.NewTestKit(c, s.store) tk.MustExec("use test") tk.MustExec("drop table if exists history_read") @@ -2275,7 +2295,7 @@ func (s *testSuite) TestHistoryRead(c *C) { tk.MustQuery("select * from history_read order by a").Check(testkit.Rows("2 ", "4 ", "8 8", "9 9")) } -func (s *testSuite) TestLowResolutionTSORead(c *C) { +func (s *testSuiteP1) TestLowResolutionTSORead(c *C) { tk := testkit.NewTestKit(c, s.store) tk.MustExec("set @@autocommit=1") tk.MustExec("use test") @@ -3882,7 +3902,7 @@ func (s *testSuite4) TearDownTest(c *C) { } } -func (s *testSuite) TestStrToDateBuiltin(c *C) { +func (s *testSuiteP1) TestStrToDateBuiltin(c *C) { tk := testkit.NewTestKit(c, s.store) tk.MustQuery(`select str_to_date('18/10/22','%y/%m/%d') from dual`).Check(testkit.Rows("2018-10-22")) tk.MustQuery(`select str_to_date('a18/10/22','%y/%m/%d') from dual`).Check(testkit.Rows("")) @@ -3922,7 +3942,7 @@ func (s *testSuite) TestStrToDateBuiltin(c *C) { tk.MustQuery(`select str_to_date('18_10_22','%y_%m_%d') from dual`).Check(testkit.Rows("2018-10-22")) } -func (s *testSuite) TestReadPartitionedTable(c *C) { +func (s *testSuiteP1) TestReadPartitionedTable(c *C) { // Test three reader on partitioned table. tk := testkit.NewTestKit(c, s.store) tk.MustExec("use test") @@ -3939,7 +3959,7 @@ func (s *testSuite) TestReadPartitionedTable(c *C) { tk.MustQuery("select a from pt where b = 3").Check(testkit.Rows("3")) } -func (s *testSuite) TestSplitRegion(c *C) { +func (s *testSuiteP1) TestSplitRegion(c *C) { tk := testkit.NewTestKit(c, s.store) tk.MustExec("use test") tk.MustExec("drop table if exists t") @@ -4139,9 +4159,10 @@ func testGetTableByName(c *C, ctx sessionctx.Context, db, table string) table.Ta return tbl } -func (s *testSuite) TestIssue10435(c *C) { +func (s *testSuiteP1) TestIssue10435(c *C) { tk := testkit.NewTestKit(c, s.store) tk.MustExec("use test") + tk.MustExec("drop table if exists t1") tk.MustExec("create table t1(i int, j int, k int)") tk.MustExec("insert into t1 VALUES (1,1,1),(2,2,2),(3,3,3),(4,4,4)") tk.MustExec("INSERT INTO t1 SELECT 10*i,j,5*j FROM t1 UNION SELECT 20*i,j,5*j FROM t1 UNION SELECT 30*i,j,5*j FROM t1") @@ -4152,7 +4173,7 @@ func (s *testSuite) TestIssue10435(c *C) { ) } -func (s *testSuite) TestUnsignedFeedback(c *C) { +func (s *testSuiteP1) TestUnsignedFeedback(c *C) { tk := testkit.NewTestKit(c, s.store) oriProbability := statistics.FeedbackProbability.Load() statistics.FeedbackProbability.Store(1.0) diff --git a/executor/set_test.go b/executor/set_test.go index 421b17071f293..fb370b1e855da 100644 --- a/executor/set_test.go +++ b/executor/set_test.go @@ -429,9 +429,25 @@ func (s *testSuite2) TestValidateSetVar(c *C) { c.Assert(terror.ErrorEqual(err, variable.ErrWrongValueForVar), IsTrue, Commentf("err %v", err)) tk.MustExec("set @@tidb_batch_delete='On';") + tk.MustQuery("select @@tidb_batch_delete;").Check(testkit.Rows("1")) tk.MustExec("set @@tidb_batch_delete='oFf';") + tk.MustQuery("select @@tidb_batch_delete;").Check(testkit.Rows("0")) tk.MustExec("set @@tidb_batch_delete=1;") + tk.MustQuery("select @@tidb_batch_delete;").Check(testkit.Rows("1")) tk.MustExec("set @@tidb_batch_delete=0;") + tk.MustQuery("select @@tidb_batch_delete;").Check(testkit.Rows("0")) + + tk.MustExec("set @@tidb_opt_agg_push_down=off;") + tk.MustQuery("select @@tidb_opt_agg_push_down;").Check(testkit.Rows("0")) + + tk.MustExec("set @@tidb_constraint_check_in_place=on;") + tk.MustQuery("select @@tidb_constraint_check_in_place;").Check(testkit.Rows("1")) + + tk.MustExec("set @@tidb_general_log=0;") + tk.MustQuery("select @@tidb_general_log;").Check(testkit.Rows("0")) + + tk.MustExec("set @@tidb_enable_streaming=1;") + tk.MustQuery("select @@tidb_enable_streaming;").Check(testkit.Rows("1")) _, err = tk.Exec("set @@tidb_batch_delete=3;") c.Assert(terror.ErrorEqual(err, variable.ErrWrongValueForVar), IsTrue, Commentf("err %v", err)) @@ -789,9 +805,9 @@ func (s *testSuite2) TestEnableNoopFunctionsVar(c *C) { _, err = tk.Exec(`set tidb_enable_noop_functions=11`) c.Assert(err, NotNil) tk.MustExec(`set tidb_enable_noop_functions="off";`) - tk.MustQuery(`select @@tidb_enable_noop_functions;`).Check(testkit.Rows("off")) + tk.MustQuery(`select @@tidb_enable_noop_functions;`).Check(testkit.Rows("0")) tk.MustExec(`set tidb_enable_noop_functions="on";`) - tk.MustQuery(`select @@tidb_enable_noop_functions;`).Check(testkit.Rows("on")) + tk.MustQuery(`select @@tidb_enable_noop_functions;`).Check(testkit.Rows("1")) tk.MustExec(`set tidb_enable_noop_functions=0;`) tk.MustQuery(`select @@tidb_enable_noop_functions;`).Check(testkit.Rows("0")) } diff --git a/executor/show_stats.go b/executor/show_stats.go index 9bbf358a26d39..690a95c1a73da 100644 --- a/executor/show_stats.go +++ b/executor/show_stats.go @@ -86,7 +86,7 @@ func (e *ShowExec) appendTableForStatsHistograms(dbName, tblName, partitionName if col.IsInvalid(nil, false) { continue } - e.histogramToRow(dbName, tblName, partitionName, col.Info.Name.O, 0, col.Histogram, col.AvgColSize(statsTbl.Count)) + e.histogramToRow(dbName, tblName, partitionName, col.Info.Name.O, 0, col.Histogram, col.AvgColSize(statsTbl.Count, false)) } for _, idx := range statsTbl.Indices { e.histogramToRow(dbName, tblName, partitionName, idx.Info.Name.O, 1, idx.Histogram, 0) diff --git a/expression/builtin_string.go b/expression/builtin_string.go index ed553476a74a1..38cd2c09faffc 100644 --- a/expression/builtin_string.go +++ b/expression/builtin_string.go @@ -272,17 +272,26 @@ func (c *concatFunctionClass) getFunction(ctx sessionctx.Context, args []Express if bf.tp.Flen >= mysql.MaxBlobWidth { bf.tp.Flen = mysql.MaxBlobWidth } - sig := &builtinConcatSig{bf} + + valStr, _ := ctx.GetSessionVars().GetSystemVar(variable.MaxAllowedPacket) + maxAllowedPacket, err := strconv.ParseUint(valStr, 10, 64) + if err != nil { + return nil, err + } + + sig := &builtinConcatSig{bf, maxAllowedPacket} return sig, nil } type builtinConcatSig struct { baseBuiltinFunc + maxAllowedPacket uint64 } func (b *builtinConcatSig) Clone() builtinFunc { newSig := &builtinConcatSig{} newSig.cloneFrom(&b.baseBuiltinFunc) + newSig.maxAllowedPacket = b.maxAllowedPacket return newSig } @@ -295,6 +304,10 @@ func (b *builtinConcatSig) evalString(row chunk.Row) (d string, isNull bool, err if isNull || err != nil { return d, isNull, err } + if uint64(len(s)+len(d)) > b.maxAllowedPacket { + b.ctx.GetSessionVars().StmtCtx.AppendWarning(errWarnAllowedPacketOverflowed.GenWithStackByArgs("concat", b.maxAllowedPacket)) + return "", true, nil + } s = append(s, []byte(d)...) } return string(s), false, nil @@ -337,17 +350,25 @@ func (c *concatWSFunctionClass) getFunction(ctx sessionctx.Context, args []Expre bf.tp.Flen = mysql.MaxBlobWidth } - sig := &builtinConcatWSSig{bf} + valStr, _ := ctx.GetSessionVars().GetSystemVar(variable.MaxAllowedPacket) + maxAllowedPacket, err := strconv.ParseUint(valStr, 10, 64) + if err != nil { + return nil, err + } + + sig := &builtinConcatWSSig{bf, maxAllowedPacket} return sig, nil } type builtinConcatWSSig struct { baseBuiltinFunc + maxAllowedPacket uint64 } func (b *builtinConcatWSSig) Clone() builtinFunc { newSig := &builtinConcatWSSig{} newSig.cloneFrom(&b.baseBuiltinFunc) + newSig.maxAllowedPacket = b.maxAllowedPacket return newSig } @@ -357,25 +378,35 @@ func (b *builtinConcatWSSig) evalString(row chunk.Row) (string, bool, error) { args := b.getArgs() strs := make([]string, 0, len(args)) var sep string - for i, arg := range args { - val, isNull, err := arg.EvalString(b.ctx, row) + var targetLength int + + N := len(args) + if N > 0 { + val, isNull, err := args[0].EvalString(b.ctx, row) + if err != nil || isNull { + // If the separator is NULL, the result is NULL. + return val, isNull, err + } + sep = val + } + for i := 1; i < N; i++ { + val, isNull, err := args[i].EvalString(b.ctx, row) if err != nil { return val, isNull, err } - if isNull { - // If the separator is NULL, the result is NULL. - if i == 0 { - return val, isNull, nil - } // CONCAT_WS() does not skip empty strings. However, // it does skip any NULL values after the separator argument. continue } - if i == 0 { - sep = val - continue + targetLength += len(val) + if i > 1 { + targetLength += len(sep) + } + if uint64(targetLength) > b.maxAllowedPacket { + b.ctx.GetSessionVars().StmtCtx.AppendWarning(errWarnAllowedPacketOverflowed.GenWithStackByArgs("concat_ws", b.maxAllowedPacket)) + return "", true, nil } strs = append(strs, val) } diff --git a/expression/builtin_string_test.go b/expression/builtin_string_test.go index 5c9b6debb62e4..b07ebd7d11f13 100644 --- a/expression/builtin_string_test.go +++ b/expression/builtin_string_test.go @@ -171,6 +171,50 @@ func (s *testEvaluatorSuite) TestConcat(c *C) { } } +func (s *testEvaluatorSuite) TestConcatSig(c *C) { + colTypes := []*types.FieldType{ + {Tp: mysql.TypeVarchar}, + {Tp: mysql.TypeVarchar}, + } + resultType := &types.FieldType{Tp: mysql.TypeVarchar, Flen: 1000} + args := []Expression{ + &Column{Index: 0, RetType: colTypes[0]}, + &Column{Index: 1, RetType: colTypes[1]}, + } + base := baseBuiltinFunc{args: args, ctx: s.ctx, tp: resultType} + concat := &builtinConcatSig{base, 5} + + cases := []struct { + args []interface{} + warnings int + res string + }{ + {[]interface{}{"a", "b"}, 0, "ab"}, + {[]interface{}{"aaa", "bbb"}, 1, ""}, + {[]interface{}{"中", "a"}, 0, "中a"}, + {[]interface{}{"中文", "a"}, 2, ""}, + } + + for _, t := range cases { + input := chunk.NewChunkWithCapacity(colTypes, 10) + input.AppendString(0, t.args[0].(string)) + input.AppendString(1, t.args[1].(string)) + + res, isNull, err := concat.evalString(input.GetRow(0)) + c.Assert(res, Equals, t.res) + c.Assert(err, IsNil) + if t.warnings == 0 { + c.Assert(isNull, IsFalse) + } else { + c.Assert(isNull, IsTrue) + warnings := s.ctx.GetSessionVars().StmtCtx.GetWarnings() + c.Assert(warnings, HasLen, t.warnings) + lastWarn := warnings[len(warnings)-1] + c.Assert(terror.ErrorEqual(errWarnAllowedPacketOverflowed, lastWarn.Err), IsTrue) + } + } +} + func (s *testEvaluatorSuite) TestConcatWS(c *C) { defer testleak.AfterTest(c)() cases := []struct { @@ -246,6 +290,53 @@ func (s *testEvaluatorSuite) TestConcatWS(c *C) { c.Assert(err, IsNil) } +func (s *testEvaluatorSuite) TestConcatWSSig(c *C) { + colTypes := []*types.FieldType{ + {Tp: mysql.TypeVarchar}, + {Tp: mysql.TypeVarchar}, + {Tp: mysql.TypeVarchar}, + } + resultType := &types.FieldType{Tp: mysql.TypeVarchar, Flen: 1000} + args := []Expression{ + &Column{Index: 0, RetType: colTypes[0]}, + &Column{Index: 1, RetType: colTypes[1]}, + &Column{Index: 2, RetType: colTypes[2]}, + } + base := baseBuiltinFunc{args: args, ctx: s.ctx, tp: resultType} + concat := &builtinConcatWSSig{base, 6} + + cases := []struct { + args []interface{} + warnings int + res string + }{ + {[]interface{}{",", "a", "b"}, 0, "a,b"}, + {[]interface{}{",", "aaa", "bbb"}, 1, ""}, + {[]interface{}{",", "中", "a"}, 0, "中,a"}, + {[]interface{}{",", "中文", "a"}, 2, ""}, + } + + for _, t := range cases { + input := chunk.NewChunkWithCapacity(colTypes, 10) + input.AppendString(0, t.args[0].(string)) + input.AppendString(1, t.args[1].(string)) + input.AppendString(2, t.args[2].(string)) + + res, isNull, err := concat.evalString(input.GetRow(0)) + c.Assert(res, Equals, t.res) + c.Assert(err, IsNil) + if t.warnings == 0 { + c.Assert(isNull, IsFalse) + } else { + c.Assert(isNull, IsTrue) + warnings := s.ctx.GetSessionVars().StmtCtx.GetWarnings() + c.Assert(warnings, HasLen, t.warnings) + lastWarn := warnings[len(warnings)-1] + c.Assert(terror.ErrorEqual(errWarnAllowedPacketOverflowed, lastWarn.Err), IsTrue) + } + } +} + func (s *testEvaluatorSuite) TestLeft(c *C) { defer testleak.AfterTest(c)() stmtCtx := s.ctx.GetSessionVars().StmtCtx diff --git a/expression/builtin_time.go b/expression/builtin_time.go index 75cfa8acacad1..50a8c4ac75129 100644 --- a/expression/builtin_time.go +++ b/expression/builtin_time.go @@ -2662,26 +2662,38 @@ func (du *baseDateArithmitical) getIntervalFromDecimal(ctx sessionctx.Context, a } switch strings.ToUpper(unit) { - case "HOUR_MINUTE", "MINUTE_SECOND": - interval = strings.Replace(interval, ".", ":", -1) - case "YEAR_MONTH": - interval = strings.Replace(interval, ".", "-", -1) - case "DAY_HOUR": - interval = strings.Replace(interval, ".", " ", -1) - case "DAY_MINUTE": - interval = "0 " + strings.Replace(interval, ".", ":", -1) - case "DAY_SECOND": - interval = "0 00:" + strings.Replace(interval, ".", ":", -1) - case "DAY_MICROSECOND": - interval = "0 00:00:" + interval - case "HOUR_MICROSECOND": - interval = "00:00:" + interval - case "HOUR_SECOND": - interval = "00:" + strings.Replace(interval, ".", ":", -1) - case "MINUTE_MICROSECOND": - interval = "00:" + interval - case "SECOND_MICROSECOND": - /* keep interval as original decimal */ + case "HOUR_MINUTE", "MINUTE_SECOND", "YEAR_MONTH", "DAY_HOUR", "DAY_MINUTE", + "DAY_SECOND", "DAY_MICROSECOND", "HOUR_MICROSECOND", "HOUR_SECOND", "MINUTE_MICROSECOND", "SECOND_MICROSECOND": + neg := false + if interval != "" && interval[0] == '-' { + neg = true + interval = interval[1:] + } + switch strings.ToUpper(unit) { + case "HOUR_MINUTE", "MINUTE_SECOND": + interval = strings.Replace(interval, ".", ":", -1) + case "YEAR_MONTH": + interval = strings.Replace(interval, ".", "-", -1) + case "DAY_HOUR": + interval = strings.Replace(interval, ".", " ", -1) + case "DAY_MINUTE": + interval = "0 " + strings.Replace(interval, ".", ":", -1) + case "DAY_SECOND": + interval = "0 00:" + strings.Replace(interval, ".", ":", -1) + case "DAY_MICROSECOND": + interval = "0 00:00:" + interval + case "HOUR_MICROSECOND": + interval = "00:00:" + interval + case "HOUR_SECOND": + interval = "00:" + strings.Replace(interval, ".", ":", -1) + case "MINUTE_MICROSECOND": + interval = "00:" + interval + case "SECOND_MICROSECOND": + /* keep interval as original decimal */ + } + if neg { + interval = "-" + interval + } case "SECOND": // Decimal's EvalString is like %f format. interval, isNull, err = args[1].EvalString(ctx, row) diff --git a/expression/evaluator_test.go b/expression/evaluator_test.go index db1bbd54d02b5..ef86d53f0b242 100644 --- a/expression/evaluator_test.go +++ b/expression/evaluator_test.go @@ -32,7 +32,7 @@ import ( "github.com/pingcap/tidb/util/testutil" ) -var _ = Suite(&testEvaluatorSuite{}) +var _ = SerialSuites(&testEvaluatorSuite{}) func TestT(t *testing.T) { CustomVerboseFlag = true diff --git a/expression/integration_test.go b/expression/integration_test.go index febcd22bace4b..0af2ca210c41f 100644 --- a/expression/integration_test.go +++ b/expression/integration_test.go @@ -2839,7 +2839,10 @@ func (s *testIntegrationSuite) TestArithmeticBuiltin(c *C) { c.Assert(err, NotNil) c.Assert(err.Error(), Equals, "[types:1690]BIGINT UNSIGNED value is out of range in '(18446744073709551615 - -1)'") c.Assert(rs.Close(), IsNil) + tk.MustQuery(`select cast(-3 as unsigned) - cast(-1 as signed);`).Check(testkit.Rows("18446744073709551614")) + tk.MustQuery("select 1.11 - 1.11;").Check(testkit.Rows("0.00")) + // for multiply tk.MustQuery("select 1234567890 * 1234567890").Check(testkit.Rows("1524157875019052100")) rs, err = tk.Exec("select 1234567890 * 12345671890") c.Assert(err, IsNil) @@ -2866,8 +2869,7 @@ func (s *testIntegrationSuite) TestArithmeticBuiltin(c *C) { _, err = session.GetRows4Test(ctx, tk.Se, rs) c.Assert(terror.ErrorEqual(err, types.ErrOverflow), IsTrue) c.Assert(rs.Close(), IsNil) - result = tk.MustQuery(`select cast(-3 as unsigned) - cast(-1 as signed);`) - result.Check(testkit.Rows("18446744073709551614")) + tk.MustQuery("select 0.0 * -1;").Check(testkit.Rows("0.0")) tk.MustExec("DROP TABLE IF EXISTS t;") tk.MustExec("CREATE TABLE t(a DECIMAL(4, 2), b DECIMAL(5, 3));") @@ -2937,6 +2939,7 @@ func (s *testIntegrationSuite) TestArithmeticBuiltin(c *C) { tk.MustExec("INSERT IGNORE INTO t VALUE(12 MOD 0);") tk.MustQuery("show warnings;").Check(testkit.Rows("Warning 1365 Division by 0")) tk.MustQuery("select v from t;").Check(testkit.Rows("")) + tk.MustQuery("select 0.000 % 0.11234500000000000000;").Check(testkit.Rows("0.00000000000000000000")) _, err = tk.Exec("INSERT INTO t VALUE(12 MOD 0);") c.Assert(terror.ErrorEqual(err, expression.ErrDivisionByZero), IsTrue) @@ -4130,6 +4133,10 @@ func (s *testIntegrationSuite) TestFuncNameConst(c *C) { r.Check(testkit.Rows("2")) r = tk.MustQuery("SELECT concat('hello', name_const('test_string', 'world')) FROM t;") r.Check(testkit.Rows("helloworld")) + r = tk.MustQuery("SELECT NAME_CONST('come', -1);") + r.Check(testkit.Rows("-1")) + r = tk.MustQuery("SELECT NAME_CONST('come', -1.0);") + r.Check(testkit.Rows("-1.0")) err := tk.ExecToErr(`select name_const(a,b) from t;`) c.Assert(err.Error(), Equals, "[planner:1210]Incorrect arguments to NAME_CONST") err = tk.ExecToErr(`select name_const(a,"hello") from t;`) @@ -4461,6 +4468,149 @@ func (s *testIntegrationSuite) TestIssue10675(c *C) { tk.MustQuery(`select * from t where a > 184467440737095516167.1;`).Check(testkit.Rows()) } +func (s *testIntegrationSuite) TestDatetimeMicrosecond(c *C) { + tk := testkit.NewTestKit(c, s.store) + // For int + tk.MustQuery(`select DATE_ADD('2007-03-28 22:08:28',INTERVAL -2 SECOND_MICROSECOND);`).Check( + testkit.Rows("2007-03-28 22:08:27.800000")) + tk.MustQuery(`select DATE_ADD('2007-03-28 22:08:28',INTERVAL -2 MINUTE_MICROSECOND);`).Check( + testkit.Rows("2007-03-28 22:08:27.800000")) + tk.MustQuery(`select DATE_ADD('2007-03-28 22:08:28',INTERVAL -2 HOUR_MICROSECOND);`).Check( + testkit.Rows("2007-03-28 22:08:27.800000")) + tk.MustQuery(`select DATE_ADD('2007-03-28 22:08:28',INTERVAL -2 DAY_MICROSECOND);`).Check( + testkit.Rows("2007-03-28 22:08:27.800000")) + + // For Decimal + tk.MustQuery(`select DATE_ADD('2007-03-28 22:08:28',INTERVAL 2.2 HOUR_MINUTE);`).Check( + testkit.Rows("2007-03-29 00:10:28")) + tk.MustQuery(`select DATE_ADD('2007-03-28 22:08:28',INTERVAL 2.2 MINUTE_SECOND);`).Check( + testkit.Rows("2007-03-28 22:10:30")) + tk.MustQuery(`select DATE_ADD('2007-03-28 22:08:28',INTERVAL 2.2 YEAR_MONTH);`).Check( + testkit.Rows("2009-05-28 22:08:28")) + tk.MustQuery(`select DATE_ADD('2007-03-28 22:08:28',INTERVAL 2.2 DAY_HOUR);`).Check( + testkit.Rows("2007-03-31 00:08:28")) + tk.MustQuery(`select DATE_ADD('2007-03-28 22:08:28',INTERVAL 2.2 DAY_MINUTE);`).Check( + testkit.Rows("2007-03-29 00:10:28")) + tk.MustQuery(`select DATE_ADD('2007-03-28 22:08:28',INTERVAL 2.2 DAY_SECOND);`).Check( + testkit.Rows("2007-03-28 22:10:30")) + tk.MustQuery(`select DATE_ADD('2007-03-28 22:08:28',INTERVAL 2.2 HOUR_SECOND);`).Check( + testkit.Rows("2007-03-28 22:10:30")) + tk.MustQuery(`select DATE_ADD('2007-03-28 22:08:28',INTERVAL 2.2 SECOND);`).Check( + testkit.Rows("2007-03-28 22:08:30.200000")) + tk.MustQuery(`select DATE_ADD('2007-03-28 22:08:28',INTERVAL 2.2 YEAR);`).Check( + testkit.Rows("2009-03-28 22:08:28")) + tk.MustQuery(`select DATE_ADD('2007-03-28 22:08:28',INTERVAL 2.2 QUARTER);`).Check( + testkit.Rows("2007-09-28 22:08:28")) + tk.MustQuery(`select DATE_ADD('2007-03-28 22:08:28',INTERVAL 2.2 MONTH);`).Check( + testkit.Rows("2007-05-28 22:08:28")) + tk.MustQuery(`select DATE_ADD('2007-03-28 22:08:28',INTERVAL 2.2 WEEK);`).Check( + testkit.Rows("2007-04-11 22:08:28")) + tk.MustQuery(`select DATE_ADD('2007-03-28 22:08:28',INTERVAL 2.2 DAY);`).Check( + testkit.Rows("2007-03-30 22:08:28")) + tk.MustQuery(`select DATE_ADD('2007-03-28 22:08:28',INTERVAL 2.2 HOUR);`).Check( + testkit.Rows("2007-03-29 00:08:28")) + tk.MustQuery(`select DATE_ADD('2007-03-28 22:08:28',INTERVAL 2.2 MINUTE);`).Check( + testkit.Rows("2007-03-28 22:10:28")) + tk.MustQuery(`select DATE_ADD('2007-03-28 22:08:28',INTERVAL 2.2 MICROSECOND);`).Check( + testkit.Rows("2007-03-28 22:08:28.000002")) + tk.MustQuery(`select DATE_ADD('2007-03-28 22:08:28',INTERVAL -2.2 HOUR_MINUTE);`).Check( + testkit.Rows("2007-03-28 20:06:28")) + tk.MustQuery(`select DATE_ADD('2007-03-28 22:08:28',INTERVAL -2.2 MINUTE_SECOND);`).Check( + testkit.Rows("2007-03-28 22:06:26")) + tk.MustQuery(`select DATE_ADD('2007-03-28 22:08:28',INTERVAL -2.2 YEAR_MONTH);`).Check( + testkit.Rows("2005-01-28 22:08:28")) + tk.MustQuery(`select DATE_ADD('2007-03-28 22:08:28',INTERVAL -2.2 DAY_HOUR);`).Check( + testkit.Rows("2007-03-26 20:08:28")) + tk.MustQuery(`select DATE_ADD('2007-03-28 22:08:28',INTERVAL -2.2 DAY_MINUTE);`).Check( + testkit.Rows("2007-03-28 20:06:28")) + tk.MustQuery(`select DATE_ADD('2007-03-28 22:08:28',INTERVAL -2.2 DAY_SECOND);`).Check( + testkit.Rows("2007-03-28 22:06:26")) + tk.MustQuery(`select DATE_ADD('2007-03-28 22:08:28',INTERVAL -2.2 HOUR_SECOND);`).Check( + testkit.Rows("2007-03-28 22:06:26")) + // tk.MustQuery(`select DATE_ADD('2007-03-28 22:08:28',INTERVAL -2.2 SECOND);`).Check( + // testkit.Rows("2007-03-28 22:08:25.800000")) + tk.MustQuery(`select DATE_ADD('2007-03-28 22:08:28',INTERVAL -2.2 YEAR);`).Check( + testkit.Rows("2005-03-28 22:08:28")) + tk.MustQuery(`select DATE_ADD('2007-03-28 22:08:28',INTERVAL -2.2 QUARTER);`).Check( + testkit.Rows("2006-09-28 22:08:28")) + tk.MustQuery(`select DATE_ADD('2007-03-28 22:08:28',INTERVAL -2.2 MONTH);`).Check( + testkit.Rows("2007-01-28 22:08:28")) + tk.MustQuery(`select DATE_ADD('2007-03-28 22:08:28',INTERVAL -2.2 WEEK);`).Check( + testkit.Rows("2007-03-14 22:08:28")) + tk.MustQuery(`select DATE_ADD('2007-03-28 22:08:28',INTERVAL -2.2 DAY);`).Check( + testkit.Rows("2007-03-26 22:08:28")) + tk.MustQuery(`select DATE_ADD('2007-03-28 22:08:28',INTERVAL -2.2 HOUR);`).Check( + testkit.Rows("2007-03-28 20:08:28")) + tk.MustQuery(`select DATE_ADD('2007-03-28 22:08:28',INTERVAL -2.2 MINUTE);`).Check( + testkit.Rows("2007-03-28 22:06:28")) + tk.MustQuery(`select DATE_ADD('2007-03-28 22:08:28',INTERVAL -2.2 MICROSECOND);`).Check( + testkit.Rows("2007-03-28 22:08:27.999998")) + tk.MustQuery(`select DATE_ADD('2007-03-28 22:08:28',INTERVAL "-2.2" HOUR_MINUTE);`).Check( + testkit.Rows("2007-03-28 20:06:28")) + tk.MustQuery(`select DATE_ADD('2007-03-28 22:08:28',INTERVAL "-2.2" MINUTE_SECOND);`).Check( + testkit.Rows("2007-03-28 22:06:26")) + tk.MustQuery(`select DATE_ADD('2007-03-28 22:08:28',INTERVAL "-2.2" YEAR_MONTH);`).Check( + testkit.Rows("2005-01-28 22:08:28")) + tk.MustQuery(`select DATE_ADD('2007-03-28 22:08:28',INTERVAL "-2.2" DAY_HOUR);`).Check( + testkit.Rows("2007-03-26 20:08:28")) + tk.MustQuery(`select DATE_ADD('2007-03-28 22:08:28',INTERVAL "-2.2" DAY_MINUTE);`).Check( + testkit.Rows("2007-03-28 20:06:28")) + tk.MustQuery(`select DATE_ADD('2007-03-28 22:08:28',INTERVAL "-2.2" DAY_SECOND);`).Check( + testkit.Rows("2007-03-28 22:06:26")) + tk.MustQuery(`select DATE_ADD('2007-03-28 22:08:28',INTERVAL "-2.2" HOUR_SECOND);`).Check( + testkit.Rows("2007-03-28 22:06:26")) + // tk.MustQuery(`select DATE_ADD('2007-03-28 22:08:28',INTERVAL "-2.2" SECOND);`).Check( + // testkit.Rows("2007-03-28 22:08:25.800000")) + tk.MustQuery(`select DATE_ADD('2007-03-28 22:08:28',INTERVAL "-2.2" YEAR);`).Check( + testkit.Rows("2005-03-28 22:08:28")) + tk.MustQuery(`select DATE_ADD('2007-03-28 22:08:28',INTERVAL "-2.2" QUARTER);`).Check( + testkit.Rows("2006-09-28 22:08:28")) + tk.MustQuery(`select DATE_ADD('2007-03-28 22:08:28',INTERVAL "-2.2" MONTH);`).Check( + testkit.Rows("2007-01-28 22:08:28")) + tk.MustQuery(`select DATE_ADD('2007-03-28 22:08:28',INTERVAL "-2.2" WEEK);`).Check( + testkit.Rows("2007-03-14 22:08:28")) + // tk.MustQuery(`select DATE_ADD('2007-03-28 22:08:28',INTERVAL "-2.2" DAY);`).Check( + // testkit.Rows("2007-03-26 22:08:28")) + // tk.MustQuery(`select DATE_ADD('2007-03-28 22:08:28',INTERVAL "-2.2" HOUR);`).Check( + // testkit.Rows("2007-03-28 20:08:28")) + tk.MustQuery(`select DATE_ADD('2007-03-28 22:08:28',INTERVAL "-2.2" MINUTE);`).Check( + testkit.Rows("2007-03-28 22:06:28")) + tk.MustQuery(`select DATE_ADD('2007-03-28 22:08:28',INTERVAL "-2.2" MICROSECOND);`).Check( + testkit.Rows("2007-03-28 22:08:27.999998")) + tk.MustQuery(`select DATE_ADD('2007-03-28 22:08:28',INTERVAL "-2.-2" HOUR_MINUTE);`).Check( + testkit.Rows("2007-03-28 20:06:28")) + tk.MustQuery(`select DATE_ADD('2007-03-28 22:08:28',INTERVAL "-2.-2" MINUTE_SECOND);`).Check( + testkit.Rows("2007-03-28 22:06:26")) + tk.MustQuery(`select DATE_ADD('2007-03-28 22:08:28',INTERVAL "-2.-2" YEAR_MONTH);`).Check( + testkit.Rows("2005-01-28 22:08:28")) + tk.MustQuery(`select DATE_ADD('2007-03-28 22:08:28',INTERVAL "-2.-2" DAY_HOUR);`).Check( + testkit.Rows("2007-03-26 20:08:28")) + tk.MustQuery(`select DATE_ADD('2007-03-28 22:08:28',INTERVAL "-2.-2" DAY_MINUTE);`).Check( + testkit.Rows("2007-03-28 20:06:28")) + tk.MustQuery(`select DATE_ADD('2007-03-28 22:08:28',INTERVAL "-2.-2" DAY_SECOND);`).Check( + testkit.Rows("2007-03-28 22:06:26")) + tk.MustQuery(`select DATE_ADD('2007-03-28 22:08:28',INTERVAL "-2.-2" HOUR_SECOND);`).Check( + testkit.Rows("2007-03-28 22:06:26")) + // tk.MustQuery(`select DATE_ADD('2007-03-28 22:08:28',INTERVAL "-2.-2" SECOND);`).Check( + // testkit.Rows("2007-03-28 22:08:26")) + tk.MustQuery(`select DATE_ADD('2007-03-28 22:08:28',INTERVAL "-2.-2" YEAR);`).Check( + testkit.Rows("2005-03-28 22:08:28")) + tk.MustQuery(`select DATE_ADD('2007-03-28 22:08:28',INTERVAL "-2.-2" QUARTER);`).Check( + testkit.Rows("2006-09-28 22:08:28")) + tk.MustQuery(`select DATE_ADD('2007-03-28 22:08:28',INTERVAL "-2.-2" MONTH);`).Check( + testkit.Rows("2007-01-28 22:08:28")) + tk.MustQuery(`select DATE_ADD('2007-03-28 22:08:28',INTERVAL "-2.-2" WEEK);`).Check( + testkit.Rows("2007-03-14 22:08:28")) + // tk.MustQuery(`select DATE_ADD('2007-03-28 22:08:28',INTERVAL "-2.-2" DAY);`).Check( + // testkit.Rows("2007-03-26 22:08:28")) + // tk.MustQuery(`select DATE_ADD('2007-03-28 22:08:28',INTERVAL "-2.-2" HOUR);`).Check( + // testkit.Rows("2007-03-28 20:08:28")) + tk.MustQuery(`select DATE_ADD('2007-03-28 22:08:28',INTERVAL "-2.-2" MINUTE);`).Check( + testkit.Rows("2007-03-28 22:06:28")) + tk.MustQuery(`select DATE_ADD('2007-03-28 22:08:28',INTERVAL "-2.-2" MICROSECOND);`).Check( + testkit.Rows("2007-03-28 22:08:27.999998")) +} + func (s *testIntegrationSuite) TestFuncCaseWithLeftJoin(c *C) { tk := testkit.NewTestKitWithInit(c, s.store) diff --git a/go.sum b/go.sum index 2bf2b6d4c859c..726d1d63d51be 100644 --- a/go.sum +++ b/go.sum @@ -12,6 +12,7 @@ github.com/blacktear23/go-proxyprotocol v0.0.0-20180807104634-af7a81e8dd0d/go.mo github.com/chzyer/logex v1.1.10/go.mod h1:+Ywpsq7O8HXn0nuIou7OrIPyXbp3wmkHB+jjWRnGsAI= github.com/chzyer/readline v0.0.0-20171208011716-f6d7a1f6fbf3/go.mod h1:nSuG5e5PlCu98SY8svDHJxuZscDgtXS6KTTbou5AhLI= github.com/chzyer/test v0.0.0-20180213035817-a1ea475d72b1/go.mod h1:Q3SI9o4m/ZMnBNeIyt5eFwwo7qiLfzFZmjNmxjkiQlU= +github.com/client9/misspell v0.3.4 h1:ta993UF76GwbvJcIo3Y68y/M3WxlpEHPWIGDkJYwzJI= github.com/client9/misspell v0.3.4/go.mod h1:qj6jICC3Q7zFZvVWo7KLAzC3yx5G7kyvSDkc90ppPyw= github.com/codahale/hdrhistogram v0.0.0-20161010025455-3a0bb77429bd h1:qMd81Ts1T2OTKmB4acZcyKaMtRnY5Y44NuXGX2GFJ1w= github.com/codahale/hdrhistogram v0.0.0-20161010025455-3a0bb77429bd/go.mod h1:sE/e/2PUdi/liOCUjSTXgM1o87ZssimdTWN964YiIeI= @@ -191,6 +192,7 @@ github.com/prometheus/procfs v0.0.0-20181005140218-185b4288413d h1:GoAlyOgbOEIFd github.com/prometheus/procfs v0.0.0-20181005140218-185b4288413d/go.mod h1:c3At6R/oaqEKCNdg8wHV1ftS6bRYblBhIjjI8uT2IGk= github.com/remyoudompheng/bigfft v0.0.0-20190512091148-babf20351dd7 h1:FUL3b97ZY2EPqg2NbXKuMHs5pXJB9hjj1fDHnF2vl28= github.com/remyoudompheng/bigfft v0.0.0-20190512091148-babf20351dd7/go.mod h1:qqbHyh8v60DhA7CoWK5oRCqLrMHRGoxYCSS9EjAz6Eo= +github.com/sergi/go-diff v1.0.1-0.20180205163309-da645544ed44 h1:tB9NOR21++IjLyVx3/PCPhWMwqGNCMQEH96A6dMZ/gc= github.com/sergi/go-diff v1.0.1-0.20180205163309-da645544ed44/go.mod h1:0CfEIISq7TuYL3j771MWULgwwjU+GofnZX9QAmXWZgo= github.com/shirou/gopsutil v2.18.10+incompatible h1:cy84jW6EVRPa5g9HAHrlbxMSIjBhDSX0OFYyMYminYs= github.com/shirou/gopsutil v2.18.10+incompatible/go.mod h1:5b4v6he4MtMOwMlS0TUMTu2PcXUg8+E1lC7eC3UO/RA= diff --git a/infoschema/tables_test.go b/infoschema/tables_test.go index 85033dd359d89..df0c95b2ab5a0 100644 --- a/infoschema/tables_test.go +++ b/infoschema/tables_test.go @@ -143,22 +143,22 @@ func (s *testTableSuite) TestDataForTableStatsField(c *C) { c.Assert(h.DumpStatsDeltaToKV(handle.DumpAll), IsNil) c.Assert(h.Update(is), IsNil) tk.MustQuery("select table_rows, avg_row_length, data_length, index_length from information_schema.tables where table_name='t'").Check( - testkit.Rows("3 17 51 3")) + testkit.Rows("3 18 54 6")) tk.MustExec(`insert into t(c, d, e) values(4, 5, "f")`) c.Assert(h.DumpStatsDeltaToKV(handle.DumpAll), IsNil) c.Assert(h.Update(is), IsNil) tk.MustQuery("select table_rows, avg_row_length, data_length, index_length from information_schema.tables where table_name='t'").Check( - testkit.Rows("4 17 68 4")) + testkit.Rows("4 18 72 8")) tk.MustExec("delete from t where c >= 3") c.Assert(h.DumpStatsDeltaToKV(handle.DumpAll), IsNil) c.Assert(h.Update(is), IsNil) tk.MustQuery("select table_rows, avg_row_length, data_length, index_length from information_schema.tables where table_name='t'").Check( - testkit.Rows("2 17 34 2")) + testkit.Rows("2 18 36 4")) tk.MustExec("delete from t where c=3") c.Assert(h.DumpStatsDeltaToKV(handle.DumpAll), IsNil) c.Assert(h.Update(is), IsNil) tk.MustQuery("select table_rows, avg_row_length, data_length, index_length from information_schema.tables where table_name='t'").Check( - testkit.Rows("2 17 34 2")) + testkit.Rows("2 18 36 4")) } func (s *testTableSuite) TestCharacterSetCollations(c *C) { diff --git a/metrics/metrics.go b/metrics/metrics.go index a21fd394c4d63..ba6a3ce73736d 100644 --- a/metrics/metrics.go +++ b/metrics/metrics.go @@ -115,6 +115,7 @@ func RegisterMetrics() { prometheus.MustRegister(TiKVSecondaryLockCleanupFailureCounter) prometheus.MustRegister(TiKVSendReqHistogram) prometheus.MustRegister(TiKVSnapshotCounter) + prometheus.MustRegister(TiKVTxnCmdCounter) prometheus.MustRegister(TiKVTxnCmdHistogram) prometheus.MustRegister(TiKVTxnCounter) prometheus.MustRegister(TiKVTxnRegionsNumHistogram) diff --git a/planner/core/cbo_test.go b/planner/core/cbo_test.go index 33b06eefa4881..cdce14d6e8f73 100644 --- a/planner/core/cbo_test.go +++ b/planner/core/cbo_test.go @@ -1090,3 +1090,18 @@ func (s *testAnalyzeSuite) TestLimitCrossEstimation(c *C) { " └─TableScan_19 6.00 cop table:t, keep order:false", )) } + +func (s *testAnalyzeSuite) TestUpdateProjEliminate(c *C) { + store, dom, err := newStoreWithBootstrap() + c.Assert(err, IsNil) + tk := testkit.NewTestKit(c, store) + defer func() { + dom.Close() + store.Close() + }() + + tk.MustExec("use test") + tk.MustExec("drop table if exists t") + tk.MustExec("create table t(a int, b int)") + tk.MustExec("explain update t t1, (select distinct b from t) t2 set t1.b = t2.b") +} diff --git a/planner/core/logical_plan_builder.go b/planner/core/logical_plan_builder.go index fb19ca7f6169f..d053c88bff08e 100644 --- a/planner/core/logical_plan_builder.go +++ b/planner/core/logical_plan_builder.go @@ -2576,7 +2576,9 @@ func (b *PlanBuilder) buildUpdate(update *ast.UpdateStmt) (Plan, error) { updt := Update{OrderedList: orderedList}.Init(b.ctx) updt.SetSchema(p.Schema()) - updt.SelectPlan, err = DoOptimize(b.optFlag, p) + // We cannot apply projection elimination when building the subplan, because + // columns in orderedList cannot be resolved. + updt.SelectPlan, err = DoOptimize(b.optFlag&^flagEliminateProjection, p) if err != nil { return nil, err } diff --git a/planner/core/plan_to_pb.go b/planner/core/plan_to_pb.go index 4694f7bcc5c14..789d1f0a132b3 100644 --- a/planner/core/plan_to_pb.go +++ b/planner/core/plan_to_pb.go @@ -152,7 +152,7 @@ func SetPBColumnsDefaultValue(ctx sessionctx.Context, pbColumns []*tipb.ColumnIn return err } - pbColumns[i].DefaultVal, err = tablecodec.EncodeValue(ctx.GetSessionVars().StmtCtx, d) + pbColumns[i].DefaultVal, err = tablecodec.EncodeValue(sessVars.StmtCtx, nil, d) if err != nil { return err } diff --git a/planner/core/preprocess.go b/planner/core/preprocess.go index 26d93eabaf9ab..770ba6f683e5a 100644 --- a/planner/core/preprocess.go +++ b/planner/core/preprocess.go @@ -178,7 +178,12 @@ func (p *preprocessor) Leave(in ast.Node) (out ast.Node, ok bool) { p.err = expression.ErrIncorrectParameterCount.GenWithStackByArgs(x.FnName.L) } else { _, isValueExpr1 := x.Args[0].(*driver.ValueExpr) - _, isValueExpr2 := x.Args[1].(*driver.ValueExpr) + isValueExpr2 := false + switch x.Args[1].(type) { + case *driver.ValueExpr, *ast.UnaryOperationExpr: + isValueExpr2 = true + } + if !isValueExpr1 || !isValueExpr2 { p.err = ErrWrongArguments.GenWithStackByArgs("NAME_CONST") } diff --git a/plugin/audit.go b/plugin/audit.go index 603b7e0f8982f..f1471562fc657 100644 --- a/plugin/audit.go +++ b/plugin/audit.go @@ -84,3 +84,8 @@ type AuditManifest struct { // OnParseEvent will be called around parse logic. OnParseEvent func(ctx context.Context, sctx *variable.SessionVars, event ParseEvent) error } + +const ( + // ExecStartTimeCtxKey indicates stmt start execution time. + ExecStartTimeCtxKey = "ExecStartTime" +) diff --git a/server/conn.go b/server/conn.go index 391cd779f4da7..11516087fc266 100644 --- a/server/conn.go +++ b/server/conn.go @@ -65,6 +65,7 @@ import ( "github.com/pingcap/tidb/util/hack" "github.com/pingcap/tidb/util/logutil" "github.com/pingcap/tidb/util/memory" + "github.com/pingcap/tidb/util/sqlexec" "go.uber.org/zap" ) @@ -945,6 +946,10 @@ func (cc *clientConn) flush() error { func (cc *clientConn) writeOK() error { msg := cc.ctx.LastMessage() + return cc.writeOkWith(msg, cc.ctx.AffectedRows(), cc.ctx.LastInsertID(), cc.ctx.Status(), cc.ctx.WarningCount()) +} + +func (cc *clientConn) writeOkWith(msg string, affectedRows, lastInsertID uint64, status, warnCnt uint16) error { enclen := 0 if len(msg) > 0 { enclen = lengthEncodedIntSize(uint64(len(msg))) + len(msg) @@ -952,11 +957,11 @@ func (cc *clientConn) writeOK() error { data := cc.alloc.AllocWithLen(4, 32+enclen) data = append(data, mysql.OKHeader) - data = dumpLengthEncodedInt(data, cc.ctx.AffectedRows()) - data = dumpLengthEncodedInt(data, cc.ctx.LastInsertID()) + data = dumpLengthEncodedInt(data, affectedRows) + data = dumpLengthEncodedInt(data, lastInsertID) if cc.capability&mysql.ClientProtocol41 > 0 { - data = dumpUint16(data, cc.ctx.Status()) - data = dumpUint16(data, cc.ctx.WarningCount()) + data = dumpUint16(data, status) + data = dumpUint16(data, warnCnt) } if enclen > 0 { // although MySQL manual says the info message is string(https://dev.mysql.com/doc/internals/en/packet-OK_Packet.html), @@ -1403,12 +1408,27 @@ func (cc *clientConn) writeChunksWithFetchSize(ctx context.Context, rs ResultSet } func (cc *clientConn) writeMultiResultset(ctx context.Context, rss []ResultSet, binary bool) error { - for _, rs := range rss { - if err := cc.writeResultset(ctx, rs, binary, mysql.ServerMoreResultsExists, 0); err != nil { + for i, rs := range rss { + lastRs := i == len(rss)-1 + if r, ok := rs.(*tidbResultSet).recordSet.(sqlexec.MultiQueryNoDelayResult); ok { + status := r.Status() + if !lastRs { + status |= mysql.ServerMoreResultsExists + } + if err := cc.writeOkWith(r.LastMessage(), r.AffectedRows(), r.LastInsertID(), status, r.WarnCount()); err != nil { + return err + } + continue + } + status := uint16(0) + if !lastRs { + status |= mysql.ServerMoreResultsExists + } + if err := cc.writeResultset(ctx, rs, binary, status, 0); err != nil { return err } } - return cc.writeOK() + return nil } func (cc *clientConn) setConn(conn net.Conn) { diff --git a/server/http_handler.go b/server/http_handler.go index 3e14f6e044685..c3507fed709e0 100644 --- a/server/http_handler.go +++ b/server/http_handler.go @@ -151,12 +151,9 @@ func (t *tikvHandlerTool) getMvccByStartTs(startTS uint64, startKey, endKey []by return nil, errors.Trace(err) } - tikvReq := &tikvrpc.Request{ - Type: tikvrpc.CmdMvccGetByStartTs, - MvccGetByStartTs: &kvrpcpb.MvccGetByStartTsRequest{ - StartTs: startTS, - }, - } + tikvReq := tikvrpc.NewRequest(tikvrpc.CmdMvccGetByStartTs, &kvrpcpb.MvccGetByStartTsRequest{ + StartTs: startTS, + }) tikvReq.Context.Priority = kvrpcpb.CommandPri_Low kvResp, err := t.Store.SendReq(bo, tikvReq, curRegion.Region, time.Hour) if err != nil { @@ -170,7 +167,7 @@ func (t *tikvHandlerTool) getMvccByStartTs(startTS uint64, startKey, endKey []by zap.Error(err)) return nil, errors.Trace(err) } - data := kvResp.MvccGetByStartTS + data := kvResp.Resp.(*kvrpcpb.MvccGetByStartTsResponse) if err := data.GetRegionError(); err != nil { logutil.BgLogger().Warn("get MVCC by startTS failed", zap.Uint64("txnStartTS", startTS), @@ -826,8 +823,8 @@ func (h tableHandler) addScatterSchedule(startKey, endKey []byte, name string) e } input := map[string]string{ "name": "scatter-range", - "start_key": string(startKey), - "end_key": string(endKey), + "start_key": url.QueryEscape(string(startKey)), + "end_key": url.QueryEscape(string(endKey)), "range_name": name, } v, err := json.Marshal(input) diff --git a/session/session.go b/session/session.go index 37b5fab958b94..07290a295b01f 100644 --- a/session/session.go +++ b/session/session.go @@ -991,7 +991,7 @@ func (s *session) SetProcessInfo(sql string, t time.Time, command byte, maxExecu s.processInfo.Store(&pi) } -func (s *session) executeStatement(ctx context.Context, connID uint64, stmtNode ast.StmtNode, stmt sqlexec.Statement, recordSets []sqlexec.RecordSet) ([]sqlexec.RecordSet, error) { +func (s *session) executeStatement(ctx context.Context, connID uint64, stmtNode ast.StmtNode, stmt sqlexec.Statement, recordSets []sqlexec.RecordSet, inMulitQuery bool) ([]sqlexec.RecordSet, error) { s.SetValue(sessionctx.QueryString, stmt.OriginText()) if _, ok := stmtNode.(ast.DDLNode); ok { s.SetValue(sessionctx.LastExecuteDDL, true) @@ -1016,6 +1016,16 @@ func (s *session) executeStatement(ctx context.Context, connID uint64, stmtNode sessionExecuteRunDurationGeneral.Observe(time.Since(startTime).Seconds()) } + if inMulitQuery && recordSet == nil { + recordSet = &multiQueryNoDelayRecordSet{ + affectedRows: s.AffectedRows(), + lastMessage: s.LastMessage(), + warnCount: s.sessionVars.StmtCtx.WarningCount(), + lastInsertID: s.sessionVars.StmtCtx.LastInsertID, + status: s.sessionVars.Status, + } + } + if recordSet != nil { recordSets = append(recordSets, recordSet) } @@ -1062,6 +1072,7 @@ func (s *session) execute(ctx context.Context, sql string) (recordSets []sqlexec var tempStmtNodes []ast.StmtNode compiler := executor.Compiler{Ctx: s} + multiQuery := len(stmtNodes) > 1 for idx, stmtNode := range stmtNodes { s.PrepareTxnCtx(ctx) @@ -1098,7 +1109,7 @@ func (s *session) execute(ctx context.Context, sql string) (recordSets []sqlexec s.currentPlan = stmt.Plan // Step3: Execute the physical plan. - if recordSets, err = s.executeStatement(ctx, connID, stmtNode, stmt, recordSets); err != nil { + if recordSets, err = s.executeStatement(ctx, connID, stmtNode, stmt, recordSets, multiQuery); err != nil { return nil, err } } @@ -1952,3 +1963,47 @@ func (s *session) recordTransactionCounter(err error) { } } } + +type multiQueryNoDelayRecordSet struct { + affectedRows uint64 + lastMessage string + status uint16 + warnCount uint16 + lastInsertID uint64 +} + +func (c *multiQueryNoDelayRecordSet) Fields() []*ast.ResultField { + panic("unsupported method") +} + +func (c *multiQueryNoDelayRecordSet) Next(ctx context.Context, chk *chunk.Chunk) error { + panic("unsupported method") +} + +func (c *multiQueryNoDelayRecordSet) NewChunk() *chunk.Chunk { + panic("unsupported method") +} + +func (c *multiQueryNoDelayRecordSet) Close() error { + return nil +} + +func (c *multiQueryNoDelayRecordSet) AffectedRows() uint64 { + return c.affectedRows +} + +func (c *multiQueryNoDelayRecordSet) LastMessage() string { + return c.lastMessage +} + +func (c *multiQueryNoDelayRecordSet) WarnCount() uint16 { + return c.warnCount +} + +func (c *multiQueryNoDelayRecordSet) Status() uint16 { + return c.status +} + +func (c *multiQueryNoDelayRecordSet) LastInsertID() uint64 { + return c.lastInsertID +} diff --git a/sessionctx/variable/varsutil.go b/sessionctx/variable/varsutil.go index 65f69300ae61d..47a60f1a7b7ed 100644 --- a/sessionctx/variable/varsutil.go +++ b/sessionctx/variable/varsutil.go @@ -374,9 +374,16 @@ func ValidateSetSystemVar(vars *SessionVars, name string, value string) (string, return "1", nil } return value, ErrWrongValueForVar.GenWithStackByArgs(name, value) - case GeneralLog, TiDBGeneralLog, AvoidTemporalUpgrade, BigTables, CheckProxyUsers, LogBin, + case TiDBSkipUTF8Check, TiDBOptAggPushDown, + TiDBOptInSubqToJoinAndAgg, TiDBEnableFastAnalyze, + TiDBBatchInsert, TiDBDisableTxnAutoRetry, TiDBEnableStreaming, + TiDBBatchDelete, TiDBBatchCommit, TiDBEnableCascadesPlanner, TiDBEnableWindowFunction, + TiDBCheckMb4ValueInUTF8, TiDBLowResolutionTSO, TiDBEnableIndexMerge, TiDBEnableNoopFuncs, + TiDBScatterRegion, TiDBGeneralLog, TiDBConstraintCheckInPlace: + fallthrough + case GeneralLog, AvoidTemporalUpgrade, BigTables, CheckProxyUsers, LogBin, CoreFile, EndMakersInJSON, SQLLogBin, OfflineMode, PseudoSlaveMode, LowPriorityUpdates, - SkipNameResolve, SQLSafeUpdates, TiDBConstraintCheckInPlace, serverReadOnly, SlaveAllowBatching, + SkipNameResolve, SQLSafeUpdates, serverReadOnly, SlaveAllowBatching, Flush, PerformanceSchema, LocalInFile, ShowOldTemporals, KeepFilesOnCreate, AutoCommit, SQLWarnings, UniqueChecks, OldAlterTable, LogBinTrustFunctionCreators, SQLBigSelects, BinlogDirectNonTransactionalUpdates, SQLQuoteShowCreate, AutomaticSpPrivileges, @@ -415,16 +422,6 @@ func ValidateSetSystemVar(vars *SessionVars, name string, value string) (string, } } return value, ErrWrongValueForVar.GenWithStackByArgs(name, value) - case TiDBSkipUTF8Check, TiDBOptAggPushDown, - TiDBOptInSubqToJoinAndAgg, TiDBEnableFastAnalyze, - TiDBBatchInsert, TiDBDisableTxnAutoRetry, TiDBEnableStreaming, - TiDBBatchDelete, TiDBBatchCommit, TiDBEnableCascadesPlanner, TiDBEnableWindowFunction, - TiDBCheckMb4ValueInUTF8, TiDBLowResolutionTSO, TiDBEnableIndexMerge, TiDBEnableNoopFuncs, - TiDBScatterRegion: - if strings.EqualFold(value, "ON") || value == "1" || strings.EqualFold(value, "OFF") || value == "0" { - return value, nil - } - return value, ErrWrongValueForVar.GenWithStackByArgs(name, value) case MaxExecutionTime: return checkUInt64SystemVar(name, value, 0, math.MaxUint64, vars) case TiDBEnableTablePartition: diff --git a/statistics/handle/ddl_test.go b/statistics/handle/ddl_test.go index a85d40e154bef..4baab71c6c87c 100644 --- a/statistics/handle/ddl_test.go +++ b/statistics/handle/ddl_test.go @@ -160,7 +160,7 @@ func (s *testStatsSuite) TestDDLHistogram(c *C) { tableInfo = tbl.Meta() statsTbl = do.StatsHandle().GetTableStats(tableInfo) c.Assert(statsTbl.Pseudo, IsFalse) - c.Check(statsTbl.Columns[tableInfo.Columns[5].ID].AvgColSize(statsTbl.Count), Equals, 3.0) + c.Check(statsTbl.Columns[tableInfo.Columns[5].ID].AvgColSize(statsTbl.Count, false), Equals, 3.0) testKit.MustExec("create index i on t(c2, c1)") testKit.MustExec("analyze table t") @@ -212,6 +212,6 @@ PARTITION BY RANGE ( a ) ( for _, def := range pi.Definitions { statsTbl := h.GetPartitionStats(tableInfo, def.ID) c.Assert(statsTbl.Pseudo, IsFalse) - c.Check(statsTbl.Columns[tableInfo.Columns[2].ID].AvgColSize(statsTbl.Count), Equals, 3.0) + c.Check(statsTbl.Columns[tableInfo.Columns[2].ID].AvgColSize(statsTbl.Count, false), Equals, 3.0) } } diff --git a/statistics/handle/handle_test.go b/statistics/handle/handle_test.go index 1d53e53ed2e16..7a23de43a73f1 100644 --- a/statistics/handle/handle_test.go +++ b/statistics/handle/handle_test.go @@ -208,19 +208,19 @@ func (s *testStatsSuite) TestAvgColLen(c *C) { c.Assert(err, IsNil) tableInfo := tbl.Meta() statsTbl := do.StatsHandle().GetTableStats(tableInfo) - c.Assert(statsTbl.Columns[tableInfo.Columns[0].ID].AvgColSize(statsTbl.Count), Equals, 8.0) + c.Assert(statsTbl.Columns[tableInfo.Columns[0].ID].AvgColSize(statsTbl.Count, false), Equals, 1.0) // The size of varchar type is LEN + BYTE, here is 1 + 7 = 8 - c.Assert(statsTbl.Columns[tableInfo.Columns[1].ID].AvgColSize(statsTbl.Count), Equals, 8.0) - c.Assert(statsTbl.Columns[tableInfo.Columns[2].ID].AvgColSize(statsTbl.Count), Equals, 4.0) - c.Assert(statsTbl.Columns[tableInfo.Columns[3].ID].AvgColSize(statsTbl.Count), Equals, 16.0) + c.Assert(statsTbl.Columns[tableInfo.Columns[1].ID].AvgColSize(statsTbl.Count, false), Equals, 8.0) + c.Assert(statsTbl.Columns[tableInfo.Columns[2].ID].AvgColSize(statsTbl.Count, false), Equals, 8.0) + c.Assert(statsTbl.Columns[tableInfo.Columns[3].ID].AvgColSize(statsTbl.Count, false), Equals, 8.0) testKit.MustExec("insert into t values(132, '123456789112', 1232.3, '2018-03-07 19:17:29')") testKit.MustExec("analyze table t") statsTbl = do.StatsHandle().GetTableStats(tableInfo) - c.Assert(statsTbl.Columns[tableInfo.Columns[0].ID].AvgColSize(statsTbl.Count), Equals, 8.0) - c.Assert(statsTbl.Columns[tableInfo.Columns[1].ID].AvgColSize(statsTbl.Count), Equals, 10.5) - c.Assert(statsTbl.Columns[tableInfo.Columns[2].ID].AvgColSize(statsTbl.Count), Equals, 4.0) - c.Assert(statsTbl.Columns[tableInfo.Columns[3].ID].AvgColSize(statsTbl.Count), Equals, 16.0) + c.Assert(statsTbl.Columns[tableInfo.Columns[0].ID].AvgColSize(statsTbl.Count, false), Equals, 1.5) + c.Assert(statsTbl.Columns[tableInfo.Columns[1].ID].AvgColSize(statsTbl.Count, false), Equals, 10.5) + c.Assert(statsTbl.Columns[tableInfo.Columns[2].ID].AvgColSize(statsTbl.Count, false), Equals, 8.0) + c.Assert(statsTbl.Columns[tableInfo.Columns[3].ID].AvgColSize(statsTbl.Count, false), Equals, 8.0) } func (s *testStatsSuite) TestDurationToTS(c *C) { diff --git a/statistics/handle/update_test.go b/statistics/handle/update_test.go index b15b8f47a5788..84e428ee22f5a 100644 --- a/statistics/handle/update_test.go +++ b/statistics/handle/update_test.go @@ -160,8 +160,8 @@ func (s *testStatsSuite) TestSingleSessionInsert(c *C) { rs := testKit.MustQuery("select modify_count from mysql.stats_meta") rs.Check(testkit.Rows("40", "70")) - rs = testKit.MustQuery("select tot_col_size from mysql.stats_histograms") - rs.Check(testkit.Rows("0", "0", "10", "10")) + rs = testKit.MustQuery("select tot_col_size from mysql.stats_histograms").Sort() + rs.Check(testkit.Rows("0", "0", "20", "20")) // test dump delta only when `modify count / count` is greater than the ratio. originValue := handle.DumpStatsDeltaRatio @@ -343,7 +343,7 @@ func (s *testStatsSuite) TestUpdatePartition(c *C) { statsTbl := h.GetPartitionStats(tableInfo, def.ID) c.Assert(statsTbl.ModifyCount, Equals, int64(1)) c.Assert(statsTbl.Count, Equals, int64(1)) - c.Assert(statsTbl.Columns[bColID].TotColSize, Equals, int64(1)) + c.Assert(statsTbl.Columns[bColID].TotColSize, Equals, int64(2)) } testKit.MustExec(`update t set a = a + 1, b = "aa"`) @@ -353,7 +353,7 @@ func (s *testStatsSuite) TestUpdatePartition(c *C) { statsTbl := h.GetPartitionStats(tableInfo, def.ID) c.Assert(statsTbl.ModifyCount, Equals, int64(2)) c.Assert(statsTbl.Count, Equals, int64(1)) - c.Assert(statsTbl.Columns[bColID].TotColSize, Equals, int64(2)) + c.Assert(statsTbl.Columns[bColID].TotColSize, Equals, int64(3)) } testKit.MustExec("delete from t") @@ -442,7 +442,7 @@ func (s *testStatsSuite) TestAutoUpdate(c *C) { c.Assert(stats.ModifyCount, Equals, int64(1)) for _, item := range stats.Columns { // TotColSize = 6, because the table has not been analyzed, and insert statement will add 3(length of 'eee') to TotColSize. - c.Assert(item.TotColSize, Equals, int64(14)) + c.Assert(item.TotColSize, Equals, int64(15)) break } @@ -1307,7 +1307,7 @@ func (s *testStatsSuite) TestIndexQueryFeedback(c *C) { }, { sql: "select * from t use index(idx_ac) where a = 1 and c < 21", - hist: "column:3 ndv:20 totColSize:20\n" + + hist: "column:3 ndv:20 totColSize:40\n" + "num: 13 lower_bound: -9223372036854775808 upper_bound: 6 repeats: 0\n" + "num: 13 lower_bound: 7 upper_bound: 13 repeats: 0\n" + "num: 12 lower_bound: 14 upper_bound: 21 repeats: 0", @@ -1318,7 +1318,7 @@ func (s *testStatsSuite) TestIndexQueryFeedback(c *C) { }, { sql: "select * from t use index(idx_ad) where a = 1 and d < 21", - hist: "column:4 ndv:20 totColSize:160\n" + + hist: "column:4 ndv:20 totColSize:320\n" + "num: 13 lower_bound: -10000000000000 upper_bound: 6 repeats: 0\n" + "num: 12 lower_bound: 7 upper_bound: 13 repeats: 0\n" + "num: 10 lower_bound: 14 upper_bound: 21 repeats: 0", @@ -1329,7 +1329,7 @@ func (s *testStatsSuite) TestIndexQueryFeedback(c *C) { }, { sql: "select * from t use index(idx_ae) where a = 1 and e < 21", - hist: "column:5 ndv:20 totColSize:160\n" + + hist: "column:5 ndv:20 totColSize:320\n" + "num: 13 lower_bound: -100000000000000000000000 upper_bound: 6 repeats: 0\n" + "num: 12 lower_bound: 7 upper_bound: 13 repeats: 0\n" + "num: 10 lower_bound: 14 upper_bound: 21 repeats: 0", @@ -1340,7 +1340,7 @@ func (s *testStatsSuite) TestIndexQueryFeedback(c *C) { }, { sql: "select * from t use index(idx_af) where a = 1 and f < 21", - hist: "column:6 ndv:20 totColSize:200\n" + + hist: "column:6 ndv:20 totColSize:400\n" + "num: 13 lower_bound: -999999999999999.99 upper_bound: 6.00 repeats: 0\n" + "num: 12 lower_bound: 7.00 upper_bound: 13.00 repeats: 0\n" + "num: 10 lower_bound: 14.00 upper_bound: 21.00 repeats: 0", @@ -1351,7 +1351,7 @@ func (s *testStatsSuite) TestIndexQueryFeedback(c *C) { }, { sql: "select * from t use index(idx_ag) where a = 1 and g < 21", - hist: "column:7 ndv:20 totColSize:98\n" + + hist: "column:7 ndv:20 totColSize:196\n" + "num: 13 lower_bound: -838:59:59 upper_bound: 00:00:06 repeats: 0\n" + "num: 11 lower_bound: 00:00:07 upper_bound: 00:00:13 repeats: 0\n" + "num: 10 lower_bound: 00:00:14 upper_bound: 00:00:21 repeats: 0", @@ -1362,7 +1362,7 @@ func (s *testStatsSuite) TestIndexQueryFeedback(c *C) { }, { sql: `select * from t use index(idx_ah) where a = 1 and h < "1000-01-21"`, - hist: "column:8 ndv:20 totColSize:180\n" + + hist: "column:8 ndv:20 totColSize:360\n" + "num: 13 lower_bound: 1000-01-01 upper_bound: 1000-01-07 repeats: 0\n" + "num: 11 lower_bound: 1000-01-08 upper_bound: 1000-01-14 repeats: 0\n" + "num: 10 lower_bound: 1000-01-15 upper_bound: 1000-01-21 repeats: 0", @@ -1504,7 +1504,7 @@ func (s *testStatsSuite) TestFeedbackRanges(c *C) { }, { sql: "select * from t use index(idx) where a = 1 and (b <= 50 or (b > 130 and b < 140))", - hist: "column:2 ndv:20 totColSize:20\n" + + hist: "column:2 ndv:20 totColSize:30\n" + "num: 7 lower_bound: -128 upper_bound: 6 repeats: 0\n" + "num: 7 lower_bound: 7 upper_bound: 13 repeats: 1\n" + "num: 6 lower_bound: 14 upper_bound: 19 repeats: 1", @@ -1561,7 +1561,7 @@ func (s *testStatsSuite) TestUnsignedFeedbackRanges(c *C) { }{ { sql: "select * from t where a <= 50", - hist: "column:1 ndv:30 totColSize:0\n" + + hist: "column:1 ndv:30 totColSize:10\n" + "num: 8 lower_bound: 0 upper_bound: 7 repeats: 0\n" + "num: 8 lower_bound: 8 upper_bound: 15 repeats: 0\n" + "num: 14 lower_bound: 16 upper_bound: 50 repeats: 0", @@ -1569,7 +1569,7 @@ func (s *testStatsSuite) TestUnsignedFeedbackRanges(c *C) { }, { sql: "select count(*) from t", - hist: "column:1 ndv:30 totColSize:0\n" + + hist: "column:1 ndv:30 totColSize:10\n" + "num: 8 lower_bound: 0 upper_bound: 7 repeats: 0\n" + "num: 8 lower_bound: 8 upper_bound: 15 repeats: 0\n" + "num: 14 lower_bound: 16 upper_bound: 255 repeats: 0", @@ -1577,7 +1577,7 @@ func (s *testStatsSuite) TestUnsignedFeedbackRanges(c *C) { }, { sql: "select * from t1 where a <= 50", - hist: "column:1 ndv:30 totColSize:0\n" + + hist: "column:1 ndv:30 totColSize:10\n" + "num: 8 lower_bound: 0 upper_bound: 7 repeats: 0\n" + "num: 8 lower_bound: 8 upper_bound: 15 repeats: 0\n" + "num: 14 lower_bound: 16 upper_bound: 50 repeats: 0", @@ -1585,7 +1585,7 @@ func (s *testStatsSuite) TestUnsignedFeedbackRanges(c *C) { }, { sql: "select count(*) from t1", - hist: "column:1 ndv:30 totColSize:0\n" + + hist: "column:1 ndv:30 totColSize:10\n" + "num: 8 lower_bound: 0 upper_bound: 7 repeats: 0\n" + "num: 8 lower_bound: 8 upper_bound: 15 repeats: 0\n" + "num: 14 lower_bound: 16 upper_bound: 18446744073709551615 repeats: 0", @@ -1593,6 +1593,7 @@ func (s *testStatsSuite) TestUnsignedFeedbackRanges(c *C) { }, } is := s.do.InfoSchema() + c.Assert(h.Update(is), IsNil) for i, t := range tests { table, err := is.TableByName(model.NewCIStr("test"), model.NewCIStr(t.tblName)) c.Assert(err, IsNil) diff --git a/statistics/histogram.go b/statistics/histogram.go index f00873ce3a1dc..2ffbd933909a0 100644 --- a/statistics/histogram.go +++ b/statistics/histogram.go @@ -108,25 +108,34 @@ func (hg *Histogram) GetUpper(idx int) *types.Datum { return &d } -// AvgColSize is the average column size of the histogram. -func (c *Column) AvgColSize(count int64) float64 { +// AvgColSize is the average column size of the histogram. These sizes are derived from function `encode` +// and `Datum::ConvertTo`, so we need to update them if those 2 functions are changed. +func (c *Column) AvgColSize(count int64, isKey bool) float64 { if count == 0 { return 0 } - switch c.Histogram.Tp.Tp { - case mysql.TypeFloat: - return 4 - case mysql.TypeTiny, mysql.TypeShort, mysql.TypeInt24, mysql.TypeLong, mysql.TypeLonglong, - mysql.TypeDouble, mysql.TypeYear: + // Note that, if the handle column is encoded as value, instead of key, i.e, + // when the handle column is in a unique index, the real column size may be + // smaller than 8 because it is encoded using `EncodeVarint`. Since we don't + // know the exact value size now, use 8 as approximation. + if c.IsHandle { return 8 - case mysql.TypeDuration, mysql.TypeDate, mysql.TypeDatetime, mysql.TypeTimestamp: - return 16 - case mysql.TypeNewDecimal: - return types.MyDecimalStructSize - default: - // Keep two decimal place. - return math.Round(float64(c.TotColSize)/float64(count)*100) / 100 } + histCount := c.TotalRowCount() + notNullRatio := 1.0 + if histCount > 0 { + notNullRatio = 1.0 - float64(c.NullCount)/histCount + } + switch c.Histogram.Tp.Tp { + case mysql.TypeFloat, mysql.TypeDouble, mysql.TypeDuration, mysql.TypeDate, mysql.TypeDatetime, mysql.TypeTimestamp: + return 8 * notNullRatio + case mysql.TypeTiny, mysql.TypeShort, mysql.TypeInt24, mysql.TypeLong, mysql.TypeLonglong, mysql.TypeYear, mysql.TypeEnum, mysql.TypeBit, mysql.TypeSet: + if isKey { + return 8 * notNullRatio + } + } + // Keep two decimal place. + return math.Round(float64(c.TotColSize)/float64(count)*100) / 100 } // AppendBucket appends a bucket into `hg`. diff --git a/store/helper/helper.go b/store/helper/helper.go index f36616eade9d6..ce202f9bb2ef1 100644 --- a/store/helper/helper.go +++ b/store/helper/helper.go @@ -50,12 +50,7 @@ func (h *Helper) GetMvccByEncodedKey(encodedKey kv.Key) (*kvrpcpb.MvccGetByKeyRe return nil, errors.Trace(err) } - tikvReq := &tikvrpc.Request{ - Type: tikvrpc.CmdMvccGetByKey, - MvccGetByKey: &kvrpcpb.MvccGetByKeyRequest{ - Key: encodedKey, - }, - } + tikvReq := tikvrpc.NewRequest(tikvrpc.CmdMvccGetByKey, &kvrpcpb.MvccGetByKeyRequest{Key: encodedKey}) kvResp, err := h.Store.SendReq(tikv.NewBackoffer(context.Background(), 500), tikvReq, keyLocation.Region, time.Minute) if err != nil { logutil.BgLogger().Info("get MVCC by encoded key failed", @@ -67,7 +62,7 @@ func (h *Helper) GetMvccByEncodedKey(encodedKey kv.Key) (*kvrpcpb.MvccGetByKeyRe zap.Error(err)) return nil, errors.Trace(err) } - return kvResp.MvccGetByKey, nil + return kvResp.Resp.(*kvrpcpb.MvccGetByKeyResponse), nil } // StoreHotRegionInfos records all hog region stores. diff --git a/store/mockstore/mocktikv/mock_tikv_test.go b/store/mockstore/mocktikv/mock_tikv_test.go index 250f81ab81809..d36700f8efd75 100644 --- a/store/mockstore/mocktikv/mock_tikv_test.go +++ b/store/mockstore/mocktikv/mock_tikv_test.go @@ -201,6 +201,10 @@ func (s *testMockTiKVSuite) mustBatchResolveLock(c *C, txnInfos map[uint64]uint6 c.Assert(s.store.BatchResolveLock(nil, nil, txnInfos), IsNil) } +func (s *testMockTiKVSuite) mustGC(c *C, safePoint uint64) { + c.Assert(s.store.GC(nil, nil, safePoint), IsNil) +} + func (s *testMockTiKVSuite) mustDeleteRange(c *C, startKey, endKey string) { err := s.store.DeleteRange([]byte(startKey), []byte(endKey)) c.Assert(err, IsNil) @@ -488,6 +492,50 @@ func (s *testMockTiKVSuite) TestBatchResolveLock(c *C) { s.mustScanLock(c, 30, nil) } +func (s *testMockTiKVSuite) TestGC(c *C) { + var safePoint uint64 = 100 + + // Prepare data + s.mustPutOK(c, "k1", "v1", 1, 2) + s.mustPutOK(c, "k1", "v2", 11, 12) + + s.mustPutOK(c, "k2", "v1", 1, 2) + s.mustPutOK(c, "k2", "v2", 11, 12) + s.mustPutOK(c, "k2", "v3", 101, 102) + + s.mustPutOK(c, "k3", "v1", 1, 2) + s.mustPutOK(c, "k3", "v2", 11, 12) + s.mustDeleteOK(c, "k3", 101, 102) + + s.mustPutOK(c, "k4", "v1", 1, 2) + s.mustDeleteOK(c, "k4", 11, 12) + + // Check prepared data + s.mustGetOK(c, "k1", 5, "v1") + s.mustGetOK(c, "k1", 15, "v2") + s.mustGetOK(c, "k2", 5, "v1") + s.mustGetOK(c, "k2", 15, "v2") + s.mustGetOK(c, "k2", 105, "v3") + s.mustGetOK(c, "k3", 5, "v1") + s.mustGetOK(c, "k3", 15, "v2") + s.mustGetNone(c, "k3", 105) + s.mustGetOK(c, "k4", 5, "v1") + s.mustGetNone(c, "k4", 105) + + s.mustGC(c, safePoint) + + s.mustGetNone(c, "k1", 5) + s.mustGetOK(c, "k1", 15, "v2") + s.mustGetNone(c, "k2", 5) + s.mustGetOK(c, "k2", 15, "v2") + s.mustGetOK(c, "k2", 105, "v3") + s.mustGetNone(c, "k3", 5) + s.mustGetOK(c, "k3", 15, "v2") + s.mustGetNone(c, "k3", 105) + s.mustGetNone(c, "k4", 5) + s.mustGetNone(c, "k4", 105) +} + func (s *testMockTiKVSuite) TestRollbackAndWriteConflict(c *C) { s.mustPutOK(c, "test", "test", 1, 3) req := &kvrpcpb.PrewriteRequest{ diff --git a/store/mockstore/mocktikv/mvcc.go b/store/mockstore/mocktikv/mvcc.go index e5825de671494..091373a1030b2 100644 --- a/store/mockstore/mocktikv/mvcc.go +++ b/store/mockstore/mocktikv/mvcc.go @@ -257,6 +257,7 @@ type MVCCStore interface { ScanLock(startKey, endKey []byte, maxTS uint64) ([]*kvrpcpb.LockInfo, error) ResolveLock(startKey, endKey []byte, startTS, commitTS uint64) error BatchResolveLock(startKey, endKey []byte, txnInfos map[uint64]uint64) error + GC(startKey, endKey []byte, safePoint uint64) error DeleteRange(startKey, endKey []byte) error Close() error } diff --git a/store/mockstore/mocktikv/mvcc_leveldb.go b/store/mockstore/mocktikv/mvcc_leveldb.go index d4413de96ae53..ad059f1a4cbd4 100644 --- a/store/mockstore/mocktikv/mvcc_leveldb.go +++ b/store/mockstore/mocktikv/mvcc_leveldb.go @@ -1047,6 +1047,72 @@ func (mvcc *MVCCLevelDB) BatchResolveLock(startKey, endKey []byte, txnInfos map[ return mvcc.db.Write(batch, nil) } +// GC implements the MVCCStore interface +func (mvcc *MVCCLevelDB) GC(startKey, endKey []byte, safePoint uint64) error { + mvcc.mu.Lock() + defer mvcc.mu.Unlock() + + iter, currKey, err := newScanIterator(mvcc.db, startKey, endKey) + defer iter.Release() + if err != nil { + return errors.Trace(err) + } + + // Mock TiKV usually doesn't need to process large amount of data. So write it in a single batch. + batch := &leveldb.Batch{} + + for iter.Valid() { + lockDec := lockDecoder{expectKey: currKey} + ok, err := lockDec.Decode(iter) + if err != nil { + return errors.Trace(err) + } + if ok && lockDec.lock.startTS <= safePoint { + return errors.Errorf( + "key %+q has lock with startTs %v which is under safePoint %v", + currKey, + lockDec.lock.startTS, + safePoint) + } + + keepNext := true + dec := valueDecoder{expectKey: currKey} + + for iter.Valid() { + ok, err := dec.Decode(iter) + if err != nil { + return errors.Trace(err) + } + + if !ok { + // Go to the next key + currKey, _, err = mvccDecode(iter.Key()) + if err != nil { + return errors.Trace(err) + } + break + } + + if dec.value.commitTS > safePoint { + continue + } + + if dec.value.valueType == typePut || dec.value.valueType == typeDelete { + // Keep the latest version if it's `typePut` + if !keepNext || dec.value.valueType == typeDelete { + batch.Delete(mvccEncode(currKey, dec.value.commitTS)) + } + keepNext = false + } else { + // Delete all other types + batch.Delete(mvccEncode(currKey, dec.value.commitTS)) + } + } + } + + return mvcc.db.Write(batch, nil) +} + // DeleteRange implements the MVCCStore interface. func (mvcc *MVCCLevelDB) DeleteRange(startKey, endKey []byte) error { return mvcc.doRawDeleteRange(codec.EncodeBytes(nil, startKey), codec.EncodeBytes(nil, endKey)) diff --git a/store/mockstore/mocktikv/pd.go b/store/mockstore/mocktikv/pd.go index b259357a51281..e1c0dda9633de 100644 --- a/store/mockstore/mocktikv/pd.go +++ b/store/mockstore/mocktikv/pd.go @@ -32,6 +32,9 @@ var tsMu = struct { type pdClient struct { cluster *Cluster + // SafePoint set by `UpdateGCSafePoint`. Not to be confused with SafePointKV. + gcSafePoint uint64 + gcSafePointMu sync.Mutex } // NewPDClient creates a mock pd.Client that uses local timestamp and meta data @@ -108,7 +111,13 @@ func (c *pdClient) GetAllStores(ctx context.Context, opts ...pd.GetStoreOption) } func (c *pdClient) UpdateGCSafePoint(ctx context.Context, safePoint uint64) (uint64, error) { - return 0, nil + c.gcSafePointMu.Lock() + defer c.gcSafePointMu.Unlock() + + if safePoint > c.gcSafePoint { + c.gcSafePoint = safePoint + } + return c.gcSafePoint, nil } func (c *pdClient) Close() { diff --git a/store/mockstore/mocktikv/rpc.go b/store/mockstore/mocktikv/rpc.go index 5e9d253cc4ffe..99866195b17a9 100644 --- a/store/mockstore/mocktikv/rpc.go +++ b/store/mockstore/mocktikv/rpc.go @@ -443,6 +443,18 @@ func (h *rpcHandler) handleKvResolveLock(req *kvrpcpb.ResolveLockRequest) *kvrpc return &kvrpcpb.ResolveLockResponse{} } +func (h *rpcHandler) handleKvGC(req *kvrpcpb.GCRequest) *kvrpcpb.GCResponse { + startKey := MvccKey(h.startKey).Raw() + endKey := MvccKey(h.endKey).Raw() + err := h.mvccStore.GC(startKey, endKey, req.GetSafePoint()) + if err != nil { + return &kvrpcpb.GCResponse{ + Error: convertToKeyError(err), + } + } + return &kvrpcpb.GCResponse{} +} + func (h *rpcHandler) handleKvDeleteRange(req *kvrpcpb.DeleteRangeRequest) *kvrpcpb.DeleteRangeResponse { if !h.checkKeyInRegion(req.StartKey) { panic("KvDeleteRange: key not in region") @@ -668,44 +680,43 @@ func (c *RPCClient) SendRequest(ctx context.Context, addr string, req *tikvrpc.R } reqCtx := &req.Context resp := &tikvrpc.Response{} - resp.Type = req.Type switch req.Type { case tikvrpc.CmdGet: - r := req.Get + r := req.Get() if err := handler.checkRequest(reqCtx, r.Size()); err != nil { - resp.Get = &kvrpcpb.GetResponse{RegionError: err} + resp.Resp = &kvrpcpb.GetResponse{RegionError: err} return resp, nil } - resp.Get = handler.handleKvGet(r) + resp.Resp = handler.handleKvGet(r) case tikvrpc.CmdScan: - r := req.Scan + r := req.Scan() if err := handler.checkRequest(reqCtx, r.Size()); err != nil { - resp.Scan = &kvrpcpb.ScanResponse{RegionError: err} + resp.Resp = &kvrpcpb.ScanResponse{RegionError: err} return resp, nil } - resp.Scan = handler.handleKvScan(r) + resp.Resp = handler.handleKvScan(r) case tikvrpc.CmdPrewrite: - r := req.Prewrite + r := req.Prewrite() if err := handler.checkRequest(reqCtx, r.Size()); err != nil { - resp.Prewrite = &kvrpcpb.PrewriteResponse{RegionError: err} + resp.Resp = &kvrpcpb.PrewriteResponse{RegionError: err} return resp, nil } - resp.Prewrite = handler.handleKvPrewrite(r) + resp.Resp = handler.handleKvPrewrite(r) case tikvrpc.CmdPessimisticLock: - r := req.PessimisticLock + r := req.PessimisticLock() if err := handler.checkRequest(reqCtx, r.Size()); err != nil { - resp.PessimisticLock = &kvrpcpb.PessimisticLockResponse{RegionError: err} + resp.Resp = &kvrpcpb.PessimisticLockResponse{RegionError: err} return resp, nil } - resp.PessimisticLock = handler.handleKvPessimisticLock(r) + resp.Resp = handler.handleKvPessimisticLock(r) case tikvrpc.CmdPessimisticRollback: - r := req.PessimisticRollback + r := req.PessimisticRollback() if err := handler.checkRequest(reqCtx, r.Size()); err != nil { - resp.PessimisticRollback = &kvrpcpb.PessimisticRollbackResponse{RegionError: err} + resp.Resp = &kvrpcpb.PessimisticRollbackResponse{RegionError: err} return resp, nil } - resp.PessimisticRollback = handler.handleKvPessimisticRollback(r) + resp.Resp = handler.handleKvPessimisticRollback(r) case tikvrpc.CmdCommit: failpoint.Inject("rpcCommitResult", func(val failpoint.Value) { switch val.(string) { @@ -713,138 +724,136 @@ func (c *RPCClient) SendRequest(ctx context.Context, addr string, req *tikvrpc.R failpoint.Return(nil, errors.New("timeout")) case "notLeader": failpoint.Return(&tikvrpc.Response{ - Type: tikvrpc.CmdCommit, - Commit: &kvrpcpb.CommitResponse{RegionError: &errorpb.Error{NotLeader: &errorpb.NotLeader{}}}, + Resp: &kvrpcpb.CommitResponse{RegionError: &errorpb.Error{NotLeader: &errorpb.NotLeader{}}}, }, nil) case "keyError": failpoint.Return(&tikvrpc.Response{ - Type: tikvrpc.CmdCommit, - Commit: &kvrpcpb.CommitResponse{Error: &kvrpcpb.KeyError{}}, + Resp: &kvrpcpb.CommitResponse{Error: &kvrpcpb.KeyError{}}, }, nil) } }) - r := req.Commit + r := req.Commit() if err := handler.checkRequest(reqCtx, r.Size()); err != nil { - resp.Commit = &kvrpcpb.CommitResponse{RegionError: err} + resp.Resp = &kvrpcpb.CommitResponse{RegionError: err} return resp, nil } - resp.Commit = handler.handleKvCommit(r) + resp.Resp = handler.handleKvCommit(r) failpoint.Inject("rpcCommitTimeout", func(val failpoint.Value) { if val.(bool) { failpoint.Return(nil, undeterminedErr) } }) case tikvrpc.CmdCleanup: - r := req.Cleanup + r := req.Cleanup() if err := handler.checkRequest(reqCtx, r.Size()); err != nil { - resp.Cleanup = &kvrpcpb.CleanupResponse{RegionError: err} + resp.Resp = &kvrpcpb.CleanupResponse{RegionError: err} return resp, nil } - resp.Cleanup = handler.handleKvCleanup(r) + resp.Resp = handler.handleKvCleanup(r) case tikvrpc.CmdBatchGet: - r := req.BatchGet + r := req.BatchGet() if err := handler.checkRequest(reqCtx, r.Size()); err != nil { - resp.BatchGet = &kvrpcpb.BatchGetResponse{RegionError: err} + resp.Resp = &kvrpcpb.BatchGetResponse{RegionError: err} return resp, nil } - resp.BatchGet = handler.handleKvBatchGet(r) + resp.Resp = handler.handleKvBatchGet(r) case tikvrpc.CmdBatchRollback: - r := req.BatchRollback + r := req.BatchRollback() if err := handler.checkRequest(reqCtx, r.Size()); err != nil { - resp.BatchRollback = &kvrpcpb.BatchRollbackResponse{RegionError: err} + resp.Resp = &kvrpcpb.BatchRollbackResponse{RegionError: err} return resp, nil } - resp.BatchRollback = handler.handleKvBatchRollback(r) + resp.Resp = handler.handleKvBatchRollback(r) case tikvrpc.CmdScanLock: - r := req.ScanLock + r := req.ScanLock() if err := handler.checkRequest(reqCtx, r.Size()); err != nil { - resp.ScanLock = &kvrpcpb.ScanLockResponse{RegionError: err} + resp.Resp = &kvrpcpb.ScanLockResponse{RegionError: err} return resp, nil } - resp.ScanLock = handler.handleKvScanLock(r) + resp.Resp = handler.handleKvScanLock(r) case tikvrpc.CmdResolveLock: - r := req.ResolveLock + r := req.ResolveLock() if err := handler.checkRequest(reqCtx, r.Size()); err != nil { - resp.ResolveLock = &kvrpcpb.ResolveLockResponse{RegionError: err} + resp.Resp = &kvrpcpb.ResolveLockResponse{RegionError: err} return resp, nil } - resp.ResolveLock = handler.handleKvResolveLock(r) + resp.Resp = handler.handleKvResolveLock(r) case tikvrpc.CmdGC: - r := req.GC + r := req.GC() if err := handler.checkRequest(reqCtx, r.Size()); err != nil { - resp.GC = &kvrpcpb.GCResponse{RegionError: err} + resp.Resp = &kvrpcpb.GCResponse{RegionError: err} return resp, nil } - resp.GC = &kvrpcpb.GCResponse{} + resp.Resp = handler.handleKvGC(r) case tikvrpc.CmdDeleteRange: - r := req.DeleteRange + r := req.DeleteRange() if err := handler.checkRequest(reqCtx, r.Size()); err != nil { - resp.DeleteRange = &kvrpcpb.DeleteRangeResponse{RegionError: err} + resp.Resp = &kvrpcpb.DeleteRangeResponse{RegionError: err} return resp, nil } - resp.DeleteRange = handler.handleKvDeleteRange(r) + resp.Resp = handler.handleKvDeleteRange(r) case tikvrpc.CmdRawGet: - r := req.RawGet + r := req.RawGet() if err := handler.checkRequest(reqCtx, r.Size()); err != nil { - resp.RawGet = &kvrpcpb.RawGetResponse{RegionError: err} + resp.Resp = &kvrpcpb.RawGetResponse{RegionError: err} return resp, nil } - resp.RawGet = handler.handleKvRawGet(r) + resp.Resp = handler.handleKvRawGet(r) case tikvrpc.CmdRawBatchGet: - r := req.RawBatchGet + r := req.RawBatchGet() if err := handler.checkRequest(reqCtx, r.Size()); err != nil { - resp.RawBatchGet = &kvrpcpb.RawBatchGetResponse{RegionError: err} + resp.Resp = &kvrpcpb.RawBatchGetResponse{RegionError: err} return resp, nil } - resp.RawBatchGet = handler.handleKvRawBatchGet(r) + resp.Resp = handler.handleKvRawBatchGet(r) case tikvrpc.CmdRawPut: - r := req.RawPut + r := req.RawPut() if err := handler.checkRequest(reqCtx, r.Size()); err != nil { - resp.RawPut = &kvrpcpb.RawPutResponse{RegionError: err} + resp.Resp = &kvrpcpb.RawPutResponse{RegionError: err} return resp, nil } - resp.RawPut = handler.handleKvRawPut(r) + resp.Resp = handler.handleKvRawPut(r) case tikvrpc.CmdRawBatchPut: - r := req.RawBatchPut + r := req.RawBatchPut() if err := handler.checkRequest(reqCtx, r.Size()); err != nil { - resp.RawBatchPut = &kvrpcpb.RawBatchPutResponse{RegionError: err} + resp.Resp = &kvrpcpb.RawBatchPutResponse{RegionError: err} return resp, nil } - resp.RawBatchPut = handler.handleKvRawBatchPut(r) + resp.Resp = handler.handleKvRawBatchPut(r) case tikvrpc.CmdRawDelete: - r := req.RawDelete + r := req.RawDelete() if err := handler.checkRequest(reqCtx, r.Size()); err != nil { - resp.RawDelete = &kvrpcpb.RawDeleteResponse{RegionError: err} + resp.Resp = &kvrpcpb.RawDeleteResponse{RegionError: err} return resp, nil } - resp.RawDelete = handler.handleKvRawDelete(r) + resp.Resp = handler.handleKvRawDelete(r) case tikvrpc.CmdRawBatchDelete: - r := req.RawBatchDelete + r := req.RawBatchDelete() if err := handler.checkRequest(reqCtx, r.Size()); err != nil { - resp.RawBatchDelete = &kvrpcpb.RawBatchDeleteResponse{RegionError: err} + resp.Resp = &kvrpcpb.RawBatchDeleteResponse{RegionError: err} } - resp.RawBatchDelete = handler.handleKvRawBatchDelete(r) + resp.Resp = handler.handleKvRawBatchDelete(r) case tikvrpc.CmdRawDeleteRange: - r := req.RawDeleteRange + r := req.RawDeleteRange() if err := handler.checkRequest(reqCtx, r.Size()); err != nil { - resp.RawDeleteRange = &kvrpcpb.RawDeleteRangeResponse{RegionError: err} + resp.Resp = &kvrpcpb.RawDeleteRangeResponse{RegionError: err} return resp, nil } - resp.RawDeleteRange = handler.handleKvRawDeleteRange(r) + resp.Resp = handler.handleKvRawDeleteRange(r) case tikvrpc.CmdRawScan: - r := req.RawScan + r := req.RawScan() if err := handler.checkRequest(reqCtx, r.Size()); err != nil { - resp.RawScan = &kvrpcpb.RawScanResponse{RegionError: err} + resp.Resp = &kvrpcpb.RawScanResponse{RegionError: err} return resp, nil } - resp.RawScan = handler.handleKvRawScan(r) + resp.Resp = handler.handleKvRawScan(r) case tikvrpc.CmdUnsafeDestroyRange: panic("unimplemented") case tikvrpc.CmdCop: - r := req.Cop + r := req.Cop() if err := handler.checkRequestContext(reqCtx); err != nil { - resp.Cop = &coprocessor.Response{RegionError: err} + resp.Resp = &coprocessor.Response{RegionError: err} return resp, nil } handler.rawStartKey = MvccKey(handler.startKey).Raw() @@ -860,11 +869,11 @@ func (c *RPCClient) SendRequest(ctx context.Context, addr string, req *tikvrpc.R default: panic(fmt.Sprintf("unknown coprocessor request type: %v", r.GetTp())) } - resp.Cop = res + resp.Resp = res case tikvrpc.CmdCopStream: - r := req.Cop + r := req.Cop() if err := handler.checkRequestContext(reqCtx); err != nil { - resp.CopStream = &tikvrpc.CopStreamResponse{ + resp.Resp = &tikvrpc.CopStreamResponse{ Tikv_CoprocessorStreamClient: &mockCopStreamErrClient{Error: err}, Response: &coprocessor.Response{ RegionError: err, @@ -893,34 +902,34 @@ func (c *RPCClient) SendRequest(ctx context.Context, addr string, req *tikvrpc.R return nil, errors.Trace(err) } streamResp.Response = first - resp.CopStream = streamResp + resp.Resp = streamResp case tikvrpc.CmdMvccGetByKey: - r := req.MvccGetByKey + r := req.MvccGetByKey() if err := handler.checkRequest(reqCtx, r.Size()); err != nil { - resp.MvccGetByKey = &kvrpcpb.MvccGetByKeyResponse{RegionError: err} + resp.Resp = &kvrpcpb.MvccGetByKeyResponse{RegionError: err} return resp, nil } - resp.MvccGetByKey = handler.handleMvccGetByKey(r) + resp.Resp = handler.handleMvccGetByKey(r) case tikvrpc.CmdMvccGetByStartTs: - r := req.MvccGetByStartTs + r := req.MvccGetByStartTs() if err := handler.checkRequest(reqCtx, r.Size()); err != nil { - resp.MvccGetByStartTS = &kvrpcpb.MvccGetByStartTsResponse{RegionError: err} + resp.Resp = &kvrpcpb.MvccGetByStartTsResponse{RegionError: err} return resp, nil } - resp.MvccGetByStartTS = handler.handleMvccGetByStartTS(r) + resp.Resp = handler.handleMvccGetByStartTS(r) case tikvrpc.CmdSplitRegion: - r := req.SplitRegion + r := req.SplitRegion() if err := handler.checkRequest(reqCtx, r.Size()); err != nil { - resp.SplitRegion = &kvrpcpb.SplitRegionResponse{RegionError: err} + resp.Resp = &kvrpcpb.SplitRegionResponse{RegionError: err} return resp, nil } - resp.SplitRegion = handler.handleSplitRegion(r) + resp.Resp = handler.handleSplitRegion(r) // DebugGetRegionProperties is for fast analyze in mock tikv. case tikvrpc.CmdDebugGetRegionProperties: - r := req.DebugGetRegionProperties + r := req.DebugGetRegionProperties() region, _ := c.Cluster.GetRegion(r.RegionId) scanResp := handler.handleKvScan(&kvrpcpb.ScanRequest{StartKey: MvccKey(region.StartKey).Raw(), EndKey: MvccKey(region.EndKey).Raw(), Version: math.MaxUint64, Limit: math.MaxUint32}) - resp.DebugGetRegionProperties = &debugpb.GetRegionPropertiesResponse{ + resp.Resp = &debugpb.GetRegionPropertiesResponse{ Props: []*debugpb.Property{{ Name: "mvcc.num_rows", Value: strconv.Itoa(len(scanResp.Pairs)), diff --git a/store/tikv/2pc.go b/store/tikv/2pc.go index 04b2e7f2b2c55..1e11c3f891185 100644 --- a/store/tikv/2pc.go +++ b/store/tikv/2pc.go @@ -473,22 +473,16 @@ func (c *twoPhaseCommitter) buildPrewriteRequest(batch batchKeys) *tikvrpc.Reque isPessimisticLock[i] = true } } - return &tikvrpc.Request{ - Type: tikvrpc.CmdPrewrite, - Prewrite: &pb.PrewriteRequest{ - Mutations: mutations, - PrimaryLock: c.primary(), - StartVersion: c.startTS, - LockTtl: c.lockTTL, - IsPessimisticLock: isPessimisticLock, - ForUpdateTs: c.forUpdateTS, - TxnSize: uint64(len(batch.keys)), - }, - Context: pb.Context{ - Priority: c.priority, - SyncLog: c.syncLog, - }, + req := &pb.PrewriteRequest{ + Mutations: mutations, + PrimaryLock: c.primary(), + StartVersion: c.startTS, + LockTtl: c.lockTTL, + IsPessimisticLock: isPessimisticLock, + ForUpdateTs: c.forUpdateTS, + TxnSize: uint64(len(batch.keys)), } + return tikvrpc.NewRequest(tikvrpc.CmdPrewrite, req, pb.Context{Priority: c.priority, SyncLog: c.syncLog}) } func (c *twoPhaseCommitter) prewriteSingleBatch(bo *Backoffer, batch batchKeys) error { @@ -510,10 +504,10 @@ func (c *twoPhaseCommitter) prewriteSingleBatch(bo *Backoffer, batch batchKeys) err = c.prewriteKeys(bo, batch.keys) return errors.Trace(err) } - prewriteResp := resp.Prewrite - if prewriteResp == nil { + if resp.Resp == nil { return errors.Trace(ErrBodyMissing) } + prewriteResp := resp.Resp.(*pb.PrewriteResponse) keyErrs := prewriteResp.GetErrors() if len(keyErrs) == 0 { return nil @@ -572,21 +566,14 @@ func (c *twoPhaseCommitter) pessimisticLockSingleBatch(bo *Backoffer, batch batc mutations[i] = mut } - req := &tikvrpc.Request{ - Type: tikvrpc.CmdPessimisticLock, - PessimisticLock: &pb.PessimisticLockRequest{ - Mutations: mutations, - PrimaryLock: c.primary(), - StartVersion: c.startTS, - ForUpdateTs: c.forUpdateTS, - LockTtl: PessimisticLockTTL, - IsFirstLock: c.isFirstLock, - }, - Context: pb.Context{ - Priority: c.priority, - SyncLog: c.syncLog, - }, - } + req := tikvrpc.NewRequest(tikvrpc.CmdPessimisticLock, &pb.PessimisticLockRequest{ + Mutations: mutations, + PrimaryLock: c.primary(), + StartVersion: c.startTS, + ForUpdateTs: c.forUpdateTS, + LockTtl: PessimisticLockTTL, + IsFirstLock: c.isFirstLock, + }, pb.Context{Priority: c.priority, SyncLog: c.syncLog}) for { resp, err := c.store.SendReq(bo, req, batch.region, readTimeoutShort) if err != nil { @@ -604,10 +591,10 @@ func (c *twoPhaseCommitter) pessimisticLockSingleBatch(bo *Backoffer, batch batc err = c.pessimisticLockKeys(bo, batch.keys) return errors.Trace(err) } - lockResp := resp.PessimisticLock - if lockResp == nil { + if resp.Resp == nil { return errors.Trace(ErrBodyMissing) } + lockResp := resp.Resp.(*pb.PessimisticLockResponse) keyErrs := lockResp.GetErrors() if len(keyErrs) == 0 { return nil @@ -643,14 +630,11 @@ func (c *twoPhaseCommitter) pessimisticLockSingleBatch(bo *Backoffer, batch batc } func (c *twoPhaseCommitter) pessimisticRollbackSingleBatch(bo *Backoffer, batch batchKeys) error { - req := &tikvrpc.Request{ - Type: tikvrpc.CmdPessimisticRollback, - PessimisticRollback: &pb.PessimisticRollbackRequest{ - StartVersion: c.startTS, - ForUpdateTs: c.forUpdateTS, - Keys: batch.keys, - }, - } + req := tikvrpc.NewRequest(tikvrpc.CmdPessimisticRollback, &pb.PessimisticRollbackRequest{ + StartVersion: c.startTS, + ForUpdateTs: c.forUpdateTS, + Keys: batch.keys, + }) for { resp, err := c.store.SendReq(bo, req, batch.region, readTimeoutShort) if err != nil { @@ -709,19 +693,11 @@ func (c *twoPhaseCommitter) getUndeterminedErr() error { } func (c *twoPhaseCommitter) commitSingleBatch(bo *Backoffer, batch batchKeys) error { - req := &tikvrpc.Request{ - Type: tikvrpc.CmdCommit, - Commit: &pb.CommitRequest{ - StartVersion: c.startTS, - Keys: batch.keys, - CommitVersion: c.commitTS, - }, - Context: pb.Context{ - Priority: c.priority, - SyncLog: c.syncLog, - }, - } - req.Context.Priority = c.priority + req := tikvrpc.NewRequest(tikvrpc.CmdCommit, &pb.CommitRequest{ + StartVersion: c.startTS, + Keys: batch.keys, + CommitVersion: c.commitTS, + }, pb.Context{Priority: c.priority, SyncLog: c.syncLog}) sender := NewRegionRequestSender(c.store.regionCache, c.store.client) resp, err := sender.SendReq(bo, req, batch.region, readTimeoutShort) @@ -752,10 +728,10 @@ func (c *twoPhaseCommitter) commitSingleBatch(bo *Backoffer, batch batchKeys) er err = c.commitKeys(bo, batch.keys) return errors.Trace(err) } - commitResp := resp.Commit - if commitResp == nil { + if resp.Resp == nil { return errors.Trace(ErrBodyMissing) } + commitResp := resp.Resp.(*pb.CommitResponse) // Here we can make sure tikv has processed the commit primary key request. So // we can clean undetermined error. if isPrimary { @@ -789,17 +765,10 @@ func (c *twoPhaseCommitter) commitSingleBatch(bo *Backoffer, batch batchKeys) er } func (c *twoPhaseCommitter) cleanupSingleBatch(bo *Backoffer, batch batchKeys) error { - req := &tikvrpc.Request{ - Type: tikvrpc.CmdBatchRollback, - BatchRollback: &pb.BatchRollbackRequest{ - Keys: batch.keys, - StartVersion: c.startTS, - }, - Context: pb.Context{ - Priority: c.priority, - SyncLog: c.syncLog, - }, - } + req := tikvrpc.NewRequest(tikvrpc.CmdBatchRollback, &pb.BatchRollbackRequest{ + Keys: batch.keys, + StartVersion: c.startTS, + }, pb.Context{Priority: c.priority, SyncLog: c.syncLog}) resp, err := c.store.SendReq(bo, req, batch.region, readTimeoutShort) if err != nil { return errors.Trace(err) @@ -816,7 +785,7 @@ func (c *twoPhaseCommitter) cleanupSingleBatch(bo *Backoffer, batch batchKeys) e err = c.cleanupKeys(bo, batch.keys) return errors.Trace(err) } - if keyErr := resp.BatchRollback.GetError(); keyErr != nil { + if keyErr := resp.Resp.(*pb.BatchRollbackResponse).GetError(); keyErr != nil { err = errors.Errorf("conn%d 2PC cleanup failed: %s", c.connID, keyErr) logutil.BgLogger().Debug("2PC failed cleanup key", zap.Error(err), diff --git a/store/tikv/2pc_test.go b/store/tikv/2pc_test.go index b1fc5d69a9e5e..63426bd9fb0e8 100644 --- a/store/tikv/2pc_test.go +++ b/store/tikv/2pc_test.go @@ -237,19 +237,16 @@ func (s *testCommitterSuite) isKeyLocked(c *C, key []byte) bool { ver, err := s.store.CurrentVersion() c.Assert(err, IsNil) bo := NewBackoffer(context.Background(), getMaxBackoff) - req := &tikvrpc.Request{ - Type: tikvrpc.CmdGet, - Get: &kvrpcpb.GetRequest{ - Key: key, - Version: ver.Ver, - }, - } + req := tikvrpc.NewRequest(tikvrpc.CmdGet, &kvrpcpb.GetRequest{ + Key: key, + Version: ver.Ver, + }) loc, err := s.store.regionCache.LocateKey(bo, key) c.Assert(err, IsNil) resp, err := s.store.SendReq(bo, req, loc.Region, readTimeoutShort) c.Assert(err, IsNil) - c.Assert(resp.Get, NotNil) - keyErr := resp.Get.GetError() + c.Assert(resp.Resp, NotNil) + keyErr := (resp.Resp.(*kvrpcpb.GetResponse)).GetError() return keyErr.GetLocked() != nil } @@ -455,8 +452,8 @@ func (s *testCommitterSuite) TestPessimisticPrewriteRequest(c *C) { batch.keys = append(batch.keys, []byte("t1")) batch.region = RegionVerID{1, 1, 1} req := commiter.buildPrewriteRequest(batch) - c.Assert(len(req.Prewrite.IsPessimisticLock), Greater, 0) - c.Assert(req.Prewrite.ForUpdateTs, Equals, uint64(100)) + c.Assert(len(req.Prewrite().IsPessimisticLock), Greater, 0) + c.Assert(req.Prewrite().ForUpdateTs, Equals, uint64(100)) } func (s *testCommitterSuite) TestUnsetPrimaryKey(c *C) { diff --git a/store/tikv/client.go b/store/tikv/client.go index 4944830a3bd46..3278a2ebb9fd8 100644 --- a/store/tikv/client.go +++ b/store/tikv/client.go @@ -314,7 +314,7 @@ func (c *rpcClient) SendRequest(ctx context.Context, addr string, req *tikvrpc.R } // Put the lease object to the timeout channel, so it would be checked periodically. - copStream := resp.CopStream + copStream := resp.Resp.(*tikvrpc.CopStreamResponse) copStream.Timeout = timeout copStream.Lease.Cancel = cancel connArray.streamTimeout <- &copStream.Lease diff --git a/store/tikv/client_fail_test.go b/store/tikv/client_fail_test.go index 5fc73b6ff1152..74c67ab25b0a5 100644 --- a/store/tikv/client_fail_test.go +++ b/store/tikv/client_fail_test.go @@ -52,10 +52,7 @@ func (s *testClientSuite) TestPanicInRecvLoop(c *C) { c.Assert(failpoint.Disable("github.com/pingcap/tidb/store/tikv/gotErrorInRecvLoop"), IsNil) time.Sleep(time.Second) - req := &tikvrpc.Request{ - Type: tikvrpc.CmdEmpty, - Empty: &tikvpb.BatchCommandsEmptyRequest{}, - } + req := tikvrpc.NewRequest(tikvrpc.CmdEmpty, &tikvpb.BatchCommandsEmptyRequest{}) _, err = rpcClient.SendRequest(context.Background(), addr, req, time.Second) c.Assert(err, IsNil) server.Stop() diff --git a/store/tikv/coprocessor.go b/store/tikv/coprocessor.go index e0b74f113b8c8..83f9395c4f96b 100644 --- a/store/tikv/coprocessor.go +++ b/store/tikv/coprocessor.go @@ -612,21 +612,17 @@ func (worker *copIteratorWorker) handleTaskOnce(bo *Backoffer, task *copTask, ch }) sender := NewRegionRequestSender(worker.store.regionCache, worker.store.client) - req := &tikvrpc.Request{ - Type: task.cmdType, - Cop: &coprocessor.Request{ - Tp: worker.req.Tp, - Data: worker.req.Data, - Ranges: task.ranges.toPBRanges(), - }, - Context: kvrpcpb.Context{ - IsolationLevel: pbIsolationLevel(worker.req.IsolationLevel), - Priority: kvPriorityToCommandPri(worker.req.Priority), - NotFillCache: worker.req.NotFillCache, - HandleTime: true, - ScanDetail: true, - }, - } + req := tikvrpc.NewRequest(task.cmdType, &coprocessor.Request{ + Tp: worker.req.Tp, + Data: worker.req.Data, + Ranges: task.ranges.toPBRanges(), + }, kvrpcpb.Context{ + IsolationLevel: pbIsolationLevel(worker.req.IsolationLevel), + Priority: kvPriorityToCommandPri(worker.req.Priority), + NotFillCache: worker.req.NotFillCache, + HandleTime: true, + ScanDetail: true, + }) startTime := time.Now() resp, rpcCtx, err := sender.SendReqCtx(bo, req, task.region, ReadTimeoutMedium) if err != nil { @@ -641,11 +637,11 @@ func (worker *copIteratorWorker) handleTaskOnce(bo *Backoffer, task *copTask, ch metrics.TiKVCoprocessorHistogram.Observe(costTime.Seconds()) if task.cmdType == tikvrpc.CmdCopStream { - return worker.handleCopStreamResult(bo, rpcCtx, resp.CopStream, task, ch) + return worker.handleCopStreamResult(bo, rpcCtx, resp.Resp.(*tikvrpc.CopStreamResponse), task, ch) } // Handles the response for non-streaming copTask. - return worker.handleCopResponse(bo, rpcCtx, &copResponse{pbResp: resp.Cop}, task, ch, nil) + return worker.handleCopResponse(bo, rpcCtx, &copResponse{pbResp: resp.Resp.(*coprocessor.Response)}, task, ch, nil) } const ( @@ -661,11 +657,11 @@ func (worker *copIteratorWorker) logTimeCopTask(costTime time.Duration, task *co logStr += fmt.Sprintf(" backoff_ms:%d backoff_types:%s", bo.totalSleep, backoffTypes) } var detail *kvrpcpb.ExecDetails - if resp.Cop != nil { - detail = resp.Cop.ExecDetails - } else if resp.CopStream != nil && resp.CopStream.Response != nil { + if resp.Resp != nil { + detail = resp.Resp.(*coprocessor.Response).ExecDetails + } else if resp.Resp != nil && resp.Resp.(*tikvrpc.CopStreamResponse).Response != nil { // streaming request returns io.EOF, so the first resp.CopStream.Response maybe nil. - detail = resp.CopStream.ExecDetails + detail = (resp.Resp.(*tikvrpc.CopStreamResponse)).ExecDetails } if detail != nil && detail.HandleTime != nil { diff --git a/store/tikv/delete_range.go b/store/tikv/delete_range.go index f721ae6305030..81a8eaab7bf75 100644 --- a/store/tikv/delete_range.go +++ b/store/tikv/delete_range.go @@ -104,14 +104,11 @@ func (t *DeleteRangeTask) sendReqOnRange(ctx context.Context, r kv.KeyRange) (Ra endKey = rangeEndKey } - req := &tikvrpc.Request{ - Type: tikvrpc.CmdDeleteRange, - DeleteRange: &kvrpcpb.DeleteRangeRequest{ - StartKey: startKey, - EndKey: endKey, - NotifyOnly: t.notifyOnly, - }, - } + req := tikvrpc.NewRequest(tikvrpc.CmdDeleteRange, &kvrpcpb.DeleteRangeRequest{ + StartKey: startKey, + EndKey: endKey, + NotifyOnly: t.notifyOnly, + }) resp, err := t.store.SendReq(bo, req, loc.Region, ReadTimeoutMedium) if err != nil { @@ -128,10 +125,10 @@ func (t *DeleteRangeTask) sendReqOnRange(ctx context.Context, r kv.KeyRange) (Ra } continue } - deleteRangeResp := resp.DeleteRange - if deleteRangeResp == nil { + if resp.Resp == nil { return stat, errors.Trace(ErrBodyMissing) } + deleteRangeResp := resp.Resp.(*kvrpcpb.DeleteRangeResponse) if err := deleteRangeResp.GetError(); err != "" { return stat, errors.Errorf("unexpected delete range err: %v", err) } diff --git a/store/tikv/gcworker/gc_worker.go b/store/tikv/gcworker/gc_worker.go index e26cf40b5c19e..382c0c474ca5f 100644 --- a/store/tikv/gcworker/gc_worker.go +++ b/store/tikv/gcworker/gc_worker.go @@ -233,13 +233,11 @@ func (w *GCWorker) leaderTick(ctx context.Context) error { if err != nil { metrics.GCJobFailureCounter.WithLabelValues("prepare").Inc() } - w.gcIsRunning = false return errors.Trace(err) } // When the worker is just started, or an old GC job has just finished, // wait a while before starting a new job. if time.Since(w.lastFinish) < gcWaitTime { - w.gcIsRunning = false logutil.Logger(ctx).Info("[gc worker] another gc job has just finished, skipped.", zap.String("leaderTick on ", w.uuid)) return nil @@ -258,7 +256,9 @@ func (w *GCWorker) leaderTick(ctx context.Context) error { zap.String("uuid", w.uuid), zap.Uint64("safePoint", safePoint), zap.Int("concurrency", concurrency)) - go w.runGCJob(ctx, safePoint, concurrency) + go func() { + w.done <- w.runGCJob(ctx, safePoint, concurrency) + }() return nil } @@ -466,7 +466,7 @@ func (w *GCWorker) calculateNewSafePoint(now time.Time) (*time.Time, error) { return &safePoint, nil } -func (w *GCWorker) runGCJob(ctx context.Context, safePoint uint64, concurrency int) { +func (w *GCWorker) runGCJob(ctx context.Context, safePoint uint64, concurrency int) error { metrics.GCWorkerCounter.WithLabelValues("run_job").Inc() err := w.resolveLocks(ctx, safePoint, concurrency) if err != nil { @@ -474,8 +474,7 @@ func (w *GCWorker) runGCJob(ctx context.Context, safePoint uint64, concurrency i zap.String("uuid", w.uuid), zap.Error(err)) metrics.GCJobFailureCounter.WithLabelValues("resolve_lock").Inc() - w.done <- errors.Trace(err) - return + return errors.Trace(err) } // Save safe point to pd. err = w.saveSafePoint(w.store.GetSafePointKV(), tikv.GcSavedSafePoint, safePoint) @@ -483,10 +482,8 @@ func (w *GCWorker) runGCJob(ctx context.Context, safePoint uint64, concurrency i logutil.Logger(ctx).Error("[gc worker] failed to save safe point to PD", zap.String("uuid", w.uuid), zap.Error(err)) - w.gcIsRunning = false metrics.GCJobFailureCounter.WithLabelValues("save_safe_point").Inc() - w.done <- errors.Trace(err) - return + return errors.Trace(err) } // Sleep to wait for all other tidb instances update their safepoint cache. time.Sleep(gcSafePointCacheInterval) @@ -497,8 +494,7 @@ func (w *GCWorker) runGCJob(ctx context.Context, safePoint uint64, concurrency i zap.String("uuid", w.uuid), zap.Error(err)) metrics.GCJobFailureCounter.WithLabelValues("delete_range").Inc() - w.done <- errors.Trace(err) - return + return errors.Trace(err) } err = w.redoDeleteRanges(ctx, safePoint, concurrency) if err != nil { @@ -506,8 +502,7 @@ func (w *GCWorker) runGCJob(ctx context.Context, safePoint uint64, concurrency i zap.String("uuid", w.uuid), zap.Error(err)) metrics.GCJobFailureCounter.WithLabelValues("redo_delete_range").Inc() - w.done <- errors.Trace(err) - return + return errors.Trace(err) } useDistributedGC, err := w.checkUseDistributedGC() @@ -525,10 +520,8 @@ func (w *GCWorker) runGCJob(ctx context.Context, safePoint uint64, concurrency i logutil.Logger(ctx).Error("[gc worker] failed to upload safe point to PD", zap.String("uuid", w.uuid), zap.Error(err)) - w.gcIsRunning = false metrics.GCJobFailureCounter.WithLabelValues("upload_safe_point").Inc() - w.done <- errors.Trace(err) - return + return errors.Trace(err) } } else { err = w.doGC(ctx, safePoint, concurrency) @@ -536,14 +529,12 @@ func (w *GCWorker) runGCJob(ctx context.Context, safePoint uint64, concurrency i logutil.Logger(ctx).Error("[gc worker] do GC returns an error", zap.String("uuid", w.uuid), zap.Error(err)) - w.gcIsRunning = false metrics.GCJobFailureCounter.WithLabelValues("gc").Inc() - w.done <- errors.Trace(err) - return + return errors.Trace(err) } } - w.done <- nil + return nil } // deleteRanges processes all delete range records whose ts < safePoint in table `gc_delete_range` @@ -658,13 +649,10 @@ func (w *GCWorker) doUnsafeDestroyRangeRequest(ctx context.Context, startKey []b return errors.Trace(err) } - req := &tikvrpc.Request{ - Type: tikvrpc.CmdUnsafeDestroyRange, - UnsafeDestroyRange: &kvrpcpb.UnsafeDestroyRangeRequest{ - StartKey: startKey, - EndKey: endKey, - }, - } + req := tikvrpc.NewRequest(tikvrpc.CmdUnsafeDestroyRange, &kvrpcpb.UnsafeDestroyRangeRequest{ + StartKey: startKey, + EndKey: endKey, + }) var wg sync.WaitGroup errChan := make(chan error, len(stores)) @@ -678,10 +666,10 @@ func (w *GCWorker) doUnsafeDestroyRangeRequest(ctx context.Context, startKey []b resp, err1 := w.store.GetTiKVClient().SendRequest(ctx, address, req, tikv.UnsafeDestroyRangeTimeout) if err1 == nil { - if resp == nil || resp.UnsafeDestroyRange == nil { + if resp == nil || resp.Resp == nil { err1 = errors.Errorf("unsafe destroy range returns nil response from store %v", storeID) } else { - errStr := resp.UnsafeDestroyRange.Error + errStr := (resp.Resp.(*kvrpcpb.UnsafeDestroyRangeResponse)).Error if len(errStr) > 0 { err1 = errors.Errorf("unsafe destroy range failed on store %v: %s", storeID, errStr) } @@ -821,13 +809,10 @@ func (w *GCWorker) resolveLocksForRange(ctx context.Context, safePoint uint64, s // for scan lock request, we must return all locks even if they are generated // by the same transaction. because gc worker need to make sure all locks have been // cleaned. - req := &tikvrpc.Request{ - Type: tikvrpc.CmdScanLock, - ScanLock: &kvrpcpb.ScanLockRequest{ - MaxVersion: safePoint, - Limit: gcScanLockLimit, - }, - } + req := tikvrpc.NewRequest(tikvrpc.CmdScanLock, &kvrpcpb.ScanLockRequest{ + MaxVersion: safePoint, + Limit: gcScanLockLimit, + }) var stat tikv.RangeTaskStat key := startKey @@ -840,7 +825,7 @@ func (w *GCWorker) resolveLocksForRange(ctx context.Context, safePoint uint64, s bo := tikv.NewBackoffer(ctx, tikv.GcResolveLockMaxBackoff) - req.ScanLock.StartKey = key + req.ScanLock().StartKey = key loc, err := w.store.GetRegionCache().LocateKey(bo, key) if err != nil { return stat, errors.Trace(err) @@ -860,10 +845,10 @@ func (w *GCWorker) resolveLocksForRange(ctx context.Context, safePoint uint64, s } continue } - locksResp := resp.ScanLock - if locksResp == nil { + if resp.Resp == nil { return stat, errors.Trace(tikv.ErrBodyMissing) } + locksResp := resp.Resp.(*kvrpcpb.ScanLockResponse) if locksResp.GetError() != nil { return stat, errors.Errorf("unexpected scanlock error: %s", locksResp) } @@ -986,12 +971,9 @@ func (w *GCWorker) doGCForRange(ctx context.Context, startKey []byte, endKey []b // doGCForRegion used for gc for region. // these two errors should not return together, for more, see the func 'doGC' func (w *GCWorker) doGCForRegion(bo *tikv.Backoffer, safePoint uint64, region tikv.RegionVerID) (*errorpb.Error, error) { - req := &tikvrpc.Request{ - Type: tikvrpc.CmdGC, - GC: &kvrpcpb.GCRequest{ - SafePoint: safePoint, - }, - } + req := tikvrpc.NewRequest(tikvrpc.CmdGC, &kvrpcpb.GCRequest{ + SafePoint: safePoint, + }) resp, err := w.store.SendReq(bo, req, region, tikv.GCTimeout) if err != nil { @@ -1005,10 +987,10 @@ func (w *GCWorker) doGCForRegion(bo *tikv.Backoffer, safePoint uint64, region ti return regionErr, nil } - gcResp := resp.GC - if gcResp == nil { + if resp.Resp == nil { return nil, errors.Trace(tikv.ErrBodyMissing) } + gcResp := resp.Resp.(*kvrpcpb.GCResponse) if gcResp.GetError() != nil { return nil, errors.Errorf("unexpected gc error: %s", gcResp.GetError()) } diff --git a/store/tikv/gcworker/gc_worker_test.go b/store/tikv/gcworker/gc_worker_test.go index fc2f14c9bca54..4f762242dcc94 100644 --- a/store/tikv/gcworker/gc_worker_test.go +++ b/store/tikv/gcworker/gc_worker_test.go @@ -16,7 +16,8 @@ package gcworker import ( "bytes" "context" - "errors" + "github.com/pingcap/tidb/kv" + "github.com/pingcap/tidb/store/tikv/oracle" "math" "sort" "strconv" @@ -24,11 +25,12 @@ import ( "time" . "github.com/pingcap/check" + "github.com/pingcap/errors" "github.com/pingcap/failpoint" "github.com/pingcap/kvproto/pkg/errorpb" "github.com/pingcap/kvproto/pkg/kvrpcpb" "github.com/pingcap/kvproto/pkg/metapb" - pd "github.com/pingcap/pd/client" + "github.com/pingcap/pd/client" "github.com/pingcap/tidb/config" "github.com/pingcap/tidb/ddl/util" "github.com/pingcap/tidb/domain" @@ -98,6 +100,83 @@ func (s *testGCWorkerSuite) timeEqual(c *C, t1, t2 time.Time, epsilon time.Durat c.Assert(math.Abs(float64(t1.Sub(t2))), Less, float64(epsilon)) } +func (s *testGCWorkerSuite) mustPut(c *C, key, value string) { + txn, err := s.store.Begin() + c.Assert(err, IsNil) + err = txn.Set([]byte(key), []byte(value)) + c.Assert(err, IsNil) + err = txn.Commit(context.Background()) + c.Assert(err, IsNil) +} + +func (s *testGCWorkerSuite) mustGet(c *C, key string, ts uint64) string { + snap, err := s.store.GetSnapshot(kv.Version{Ver: ts}) + c.Assert(err, IsNil) + value, err := snap.Get([]byte(key)) + c.Assert(err, IsNil) + return string(value) +} + +func (s *testGCWorkerSuite) mustGetNone(c *C, key string, ts uint64) { + snap, err := s.store.GetSnapshot(kv.Version{Ver: ts}) + c.Assert(err, IsNil) + _, err = snap.Get([]byte(key)) + c.Assert(err, Equals, kv.ErrNotExist) +} + +func (s *testGCWorkerSuite) mustAllocTs(c *C) uint64 { + ts, err := s.oracle.GetTimestamp(context.Background()) + c.Assert(err, IsNil) + return ts +} + +func (s *testGCWorkerSuite) mustGetSafePointFromPd(c *C) uint64 { + // UpdateGCSafePoint returns the newest safePoint after the updating, which can be used to check whether the + // safePoint is successfully uploaded. + safePoint, err := s.pdClient.UpdateGCSafePoint(context.Background(), 0) + c.Assert(err, IsNil) + return safePoint +} + +// gcProbe represents a key that contains multiple versions, one of which should be collected. Execution of GC with +// greater ts will be detected, but it may not work properly if there are newer versions of the key. +// This is not used to check the correctness of GC algorithm, but only for checking whether GC has been executed on the +// specified key. Create this using `s.createGCProbe`. +type gcProbe struct { + key string + // The ts that can see the version that should be deleted. + v1Ts uint64 + // The ts that can see the version that should be kept. + v2Ts uint64 +} + +// createGCProbe creates gcProbe on specified key. +func (s *testGCWorkerSuite) createGCProbe(c *C, key string) *gcProbe { + s.mustPut(c, key, "v1") + ts1 := s.mustAllocTs(c) + s.mustPut(c, key, "v2") + ts2 := s.mustAllocTs(c) + p := &gcProbe{ + key: key, + v1Ts: ts1, + v2Ts: ts2, + } + s.checkNotCollected(c, p) + return p +} + +// checkCollected asserts the gcProbe has been correctly collected. +func (s *testGCWorkerSuite) checkCollected(c *C, p *gcProbe) { + s.mustGetNone(c, p.key, p.v1Ts) + c.Assert(s.mustGet(c, p.key, p.v2Ts), Equals, "v2") +} + +// checkNotCollected asserts the gcProbe has not been collected. +func (s *testGCWorkerSuite) checkNotCollected(c *C, p *gcProbe) { + c.Assert(s.mustGet(c, p.key, p.v1Ts), Equals, "v1") + c.Assert(s.mustGet(c, p.key, p.v2Ts), Equals, "v2") +} + func (s *testGCWorkerSuite) TestGetOracleTime(c *C) { t1, err := s.gcWorker.getOracleTime() c.Assert(err, IsNil) @@ -241,24 +320,27 @@ func (s *testGCWorkerSuite) TestDoGCForOneRegion(c *C) { loc, err := s.store.GetRegionCache().LocateKey(bo, []byte("")) c.Assert(err, IsNil) var regionErr *errorpb.Error - regionErr, err = s.gcWorker.doGCForRegion(bo, 20, loc.Region) + + p := s.createGCProbe(c, "k1") + regionErr, err = s.gcWorker.doGCForRegion(bo, s.mustAllocTs(c), loc.Region) c.Assert(regionErr, IsNil) c.Assert(err, IsNil) + s.checkCollected(c, p) c.Assert(failpoint.Enable("github.com/pingcap/tidb/store/tikv/tikvStoreSendReqResult", `return("timeout")`), IsNil) - regionErr, err = s.gcWorker.doGCForRegion(bo, 20, loc.Region) + regionErr, err = s.gcWorker.doGCForRegion(bo, s.mustAllocTs(c), loc.Region) c.Assert(regionErr, IsNil) c.Assert(err, NotNil) c.Assert(failpoint.Disable("github.com/pingcap/tidb/store/tikv/tikvStoreSendReqResult"), IsNil) c.Assert(failpoint.Enable("github.com/pingcap/tidb/store/tikv/tikvStoreSendReqResult", `return("GCNotLeader")`), IsNil) - regionErr, err = s.gcWorker.doGCForRegion(bo, 20, loc.Region) + regionErr, err = s.gcWorker.doGCForRegion(bo, s.mustAllocTs(c), loc.Region) c.Assert(regionErr.GetNotLeader(), NotNil) c.Assert(err, IsNil) c.Assert(failpoint.Disable("github.com/pingcap/tidb/store/tikv/tikvStoreSendReqResult"), IsNil) c.Assert(failpoint.Enable("github.com/pingcap/tidb/store/tikv/tikvStoreSendReqResult", `return("GCServerIsBusy")`), IsNil) - regionErr, err = s.gcWorker.doGCForRegion(bo, 20, loc.Region) + regionErr, err = s.gcWorker.doGCForRegion(bo, s.mustAllocTs(c), loc.Region) c.Assert(regionErr.GetServerIsBusy(), NotNil) c.Assert(err, IsNil) c.Assert(failpoint.Disable("github.com/pingcap/tidb/store/tikv/tikvStoreSendReqResult"), IsNil) @@ -292,26 +374,20 @@ func (s *testGCWorkerSuite) TestDoGC(c *C) { gcSafePointCacheInterval = 1 - err = s.gcWorker.saveValueToSysTable(gcConcurrencyKey, strconv.Itoa(gcDefaultConcurrency)) - c.Assert(err, IsNil) - concurrency, err := s.gcWorker.loadGCConcurrencyWithDefault() - c.Assert(err, IsNil) - err = s.gcWorker.doGC(ctx, 20, concurrency) + p := s.createGCProbe(c, "k1") + err = s.gcWorker.doGC(ctx, s.mustAllocTs(c), gcDefaultConcurrency) c.Assert(err, IsNil) + s.checkCollected(c, p) - err = s.gcWorker.saveValueToSysTable(gcConcurrencyKey, strconv.Itoa(gcMinConcurrency)) - c.Assert(err, IsNil) - concurrency, err = s.gcWorker.loadGCConcurrencyWithDefault() - c.Assert(err, IsNil) - err = s.gcWorker.doGC(ctx, 20, concurrency) + p = s.createGCProbe(c, "k1") + err = s.gcWorker.doGC(ctx, s.mustAllocTs(c), gcMinConcurrency) c.Assert(err, IsNil) + s.checkCollected(c, p) - err = s.gcWorker.saveValueToSysTable(gcConcurrencyKey, strconv.Itoa(gcMaxConcurrency)) - c.Assert(err, IsNil) - concurrency, err = s.gcWorker.loadGCConcurrencyWithDefault() - c.Assert(err, IsNil) - err = s.gcWorker.doGC(ctx, 20, concurrency) + p = s.createGCProbe(c, "k1") + err = s.gcWorker.doGC(ctx, s.mustAllocTs(c), gcMaxConcurrency) c.Assert(err, IsNil) + s.checkCollected(c, p) } func (s *testGCWorkerSuite) TestCheckGCMode(c *C) { @@ -399,16 +475,15 @@ func (s *testGCWorkerSuite) testDeleteRangesFailureImpl(c *C, failType int) { s.client.unsafeDestroyRangeHandler = func(addr string, req *tikvrpc.Request) (*tikvrpc.Response, error) { sendReqCh <- SentReq{req, addr} resp := &tikvrpc.Response{ - Type: tikvrpc.CmdUnsafeDestroyRange, - UnsafeDestroyRange: &kvrpcpb.UnsafeDestroyRangeResponse{}, + Resp: &kvrpcpb.UnsafeDestroyRangeResponse{}, } - if bytes.Equal(req.UnsafeDestroyRange.GetStartKey(), failKey) && addr == failStore.GetAddress() { + if bytes.Equal(req.UnsafeDestroyRange().GetStartKey(), failKey) && addr == failStore.GetAddress() { if failType == failRPCErr { return nil, errors.New("error") } else if failType == failNilResp { - resp.UnsafeDestroyRange = nil + resp.Resp = nil } else if failType == failErrResp { - resp.UnsafeDestroyRange.Error = "error" + (resp.Resp.(*kvrpcpb.UnsafeDestroyRangeResponse)).Error = "error" } else { panic("unreachable") } @@ -488,7 +563,7 @@ Loop: } sort.Slice(sentReq, func(i, j int) bool { - cmp := bytes.Compare(sentReq[i].req.UnsafeDestroyRange.StartKey, sentReq[j].req.UnsafeDestroyRange.StartKey) + cmp := bytes.Compare(sentReq[i].req.UnsafeDestroyRange().StartKey, sentReq[j].req.UnsafeDestroyRange().StartKey) return cmp < 0 || (cmp == 0 && sentReq[i].addr < sentReq[j].addr) }) @@ -501,8 +576,8 @@ Loop: for storeIndex := range expectedStores { i := rangeIndex*len(expectedStores) + storeIndex c.Assert(sentReq[i].addr, Equals, expectedStores[storeIndex].Address) - c.Assert(sentReq[i].req.UnsafeDestroyRange.GetStartKey(), DeepEquals, sortedRanges[rangeIndex].StartKey) - c.Assert(sentReq[i].req.UnsafeDestroyRange.GetEndKey(), DeepEquals, sortedRanges[rangeIndex].EndKey) + c.Assert(sentReq[i].req.UnsafeDestroyRange().GetStartKey(), DeepEquals, sortedRanges[rangeIndex].StartKey) + c.Assert(sentReq[i].req.UnsafeDestroyRange().GetEndKey(), DeepEquals, sortedRanges[rangeIndex].EndKey) } } } @@ -522,33 +597,163 @@ func (c *testGCWorkerClient) SendRequest(ctx context.Context, addr string, req * return c.Client.SendRequest(ctx, addr, req, timeout) } -func (s *testGCWorkerSuite) TestRunGCJob(c *C) { +func (s *testGCWorkerSuite) TestLeaderTick(c *C) { gcSafePointCacheInterval = 0 - err := RunGCJob(context.Background(), s.store, 0, "mock", 1) + + veryLong := gcDefaultLifeTime * 10 + // Avoid failing at interval check. `lastFinish` is checked by os time. + s.gcWorker.lastFinish = time.Now().Add(-veryLong) + // Use central mode to do this test. + err := s.gcWorker.saveValueToSysTable(gcModeKey, gcModeCentral) c.Assert(err, IsNil) - gcWorker, err := NewGCWorker(s.store, s.pdClient) + p := s.createGCProbe(c, "k1") + s.oracle.AddOffset(gcDefaultLifeTime * 2) + + // Skip if GC is running. + s.gcWorker.gcIsRunning = true + err = s.gcWorker.leaderTick(context.Background()) + c.Assert(err, IsNil) + s.checkNotCollected(c, p) + s.gcWorker.gcIsRunning = false + // Reset GC last run time + err = s.gcWorker.saveTime(gcLastRunTimeKey, oracle.GetTimeFromTS(s.mustAllocTs(c)).Add(-veryLong)) + c.Assert(err, IsNil) + + // Skip if prepare failed (disabling GC will make prepare returns ok = false). + err = s.gcWorker.saveValueToSysTable(gcEnableKey, booleanFalse) + c.Assert(err, IsNil) + err = s.gcWorker.leaderTick(context.Background()) + c.Assert(err, IsNil) + s.checkNotCollected(c, p) + err = s.gcWorker.saveValueToSysTable(gcEnableKey, booleanTrue) + c.Assert(err, IsNil) + // Reset GC last run time + err = s.gcWorker.saveTime(gcLastRunTimeKey, oracle.GetTimeFromTS(s.mustAllocTs(c)).Add(-veryLong)) + c.Assert(err, IsNil) + + // Skip if gcWaitTime not exceeded. + s.gcWorker.lastFinish = time.Now() + err = s.gcWorker.leaderTick(context.Background()) + c.Assert(err, IsNil) + s.checkNotCollected(c, p) + s.gcWorker.lastFinish = time.Now().Add(-veryLong) + // Reset GC last run time + err = s.gcWorker.saveTime(gcLastRunTimeKey, oracle.GetTimeFromTS(s.mustAllocTs(c)).Add(-veryLong)) + c.Assert(err, IsNil) + + // Continue GC if all those checks passed. + err = s.gcWorker.leaderTick(context.Background()) + c.Assert(err, IsNil) + // Wait for GC finish + select { + case err = <-s.gcWorker.done: + s.gcWorker.gcIsRunning = false + break + case <-time.After(time.Second * 10): + err = errors.New("receive from s.gcWorker.done timeout") + } + c.Assert(err, IsNil) + s.checkCollected(c, p) + + // Test again to ensure the synchronization between goroutines is correct. + err = s.gcWorker.saveTime(gcLastRunTimeKey, oracle.GetTimeFromTS(s.mustAllocTs(c)).Add(-veryLong)) + c.Assert(err, IsNil) + s.gcWorker.lastFinish = time.Now().Add(-veryLong) + p = s.createGCProbe(c, "k1") + s.oracle.AddOffset(gcDefaultLifeTime * 2) + + err = s.gcWorker.leaderTick(context.Background()) + c.Assert(err, IsNil) + // Wait for GC finish + select { + case err = <-s.gcWorker.done: + s.gcWorker.gcIsRunning = false + break + case <-time.After(time.Second * 10): + err = errors.New("receive from s.gcWorker.done timeout") + } + c.Assert(err, IsNil) + s.checkCollected(c, p) + + // No more signals in the channel + select { + case err = <-s.gcWorker.done: + err = errors.Errorf("received signal s.gcWorker.done which shouldn't exist: %v", err) + break + case <-time.After(time.Second): + break + } + c.Assert(err, IsNil) +} + +func (s *testGCWorkerSuite) TestRunGCJob(c *C) { + gcSafePointCacheInterval = 0 + + // Test distributed mode + useDistributedGC, err := s.gcWorker.checkUseDistributedGC() c.Assert(err, IsNil) - gcWorker.Start() - useDistributedGC, err := gcWorker.(*GCWorker).checkUseDistributedGC() c.Assert(useDistributedGC, IsTrue) + safePoint := s.mustAllocTs(c) + err = s.gcWorker.runGCJob(context.Background(), safePoint, 1) c.Assert(err, IsNil) - safePoint := uint64(time.Now().Unix()) - gcWorker.(*GCWorker).runGCJob(context.Background(), safePoint, 1) - getSafePoint, err := loadSafePoint(gcWorker.(*GCWorker).store.GetSafePointKV()) + + pdSafePoint := s.mustGetSafePointFromPd(c) + c.Assert(pdSafePoint, Equals, safePoint) + + etcdSafePoint := s.loadEtcdSafePoint(c) + c.Assert(etcdSafePoint, Equals, safePoint) + + // Test distributed mode with safePoint regressing (although this is impossible) + err = s.gcWorker.runGCJob(context.Background(), safePoint-1, 1) + c.Assert(err, NotNil) + + // Test central mode + err = s.gcWorker.saveValueToSysTable(gcModeKey, gcModeCentral) c.Assert(err, IsNil) - c.Assert(getSafePoint, Equals, safePoint) - gcWorker.Close() + useDistributedGC, err = s.gcWorker.checkUseDistributedGC() + c.Assert(err, IsNil) + c.Assert(useDistributedGC, IsFalse) + + p := s.createGCProbe(c, "k1") + safePoint = s.mustAllocTs(c) + err = s.gcWorker.runGCJob(context.Background(), safePoint, 1) + c.Assert(err, IsNil) + s.checkCollected(c, p) + + etcdSafePoint = s.loadEtcdSafePoint(c) + c.Assert(etcdSafePoint, Equals, safePoint) } -func loadSafePoint(kv tikv.SafePointKV) (uint64, error) { - val, err := kv.Get(tikv.GcSavedSafePoint) - if err != nil { - return 0, err - } - return strconv.ParseUint(val, 10, 64) +func (s *testGCWorkerSuite) TestRunGCJobAPI(c *C) { + gcSafePointCacheInterval = 0 + + p := s.createGCProbe(c, "k1") + safePoint := s.mustAllocTs(c) + err := RunGCJob(context.Background(), s.store, safePoint, "mock", 1) + c.Assert(err, IsNil) + s.checkCollected(c, p) + etcdSafePoint := s.loadEtcdSafePoint(c) + c.Assert(err, IsNil) + c.Assert(etcdSafePoint, Equals, safePoint) +} + +func (s *testGCWorkerSuite) TestRunDistGCJobAPI(c *C) { + gcSafePointCacheInterval = 0 + + safePoint := s.mustAllocTs(c) + err := RunDistributedGCJob(context.Background(), s.store, s.pdClient, safePoint, "mock", 1) + c.Assert(err, IsNil) + pdSafePoint := s.mustGetSafePointFromPd(c) + c.Assert(pdSafePoint, Equals, safePoint) + etcdSafePoint := s.loadEtcdSafePoint(c) + c.Assert(err, IsNil) + c.Assert(etcdSafePoint, Equals, safePoint) } -func (s *testGCWorkerSuite) TestRunDistGCJob(c *C) { - err := RunDistributedGCJob(context.Background(), s.store, s.pdClient, 0, "mock", 1) +func (s *testGCWorkerSuite) loadEtcdSafePoint(c *C) uint64 { + val, err := s.gcWorker.store.GetSafePointKV().Get(tikv.GcSavedSafePoint) + c.Assert(err, IsNil) + res, err := strconv.ParseUint(val, 10, 64) c.Assert(err, IsNil) + return res } diff --git a/store/tikv/lock_resolver.go b/store/tikv/lock_resolver.go index ebabeb90621fc..f20116c26af55 100644 --- a/store/tikv/lock_resolver.go +++ b/store/tikv/lock_resolver.go @@ -220,12 +220,7 @@ func (lr *LockResolver) BatchResolveLocks(bo *Backoffer, locks []*Lock, loc Regi }) } - req := &tikvrpc.Request{ - Type: tikvrpc.CmdResolveLock, - ResolveLock: &kvrpcpb.ResolveLockRequest{ - TxnInfos: listTxnInfos, - }, - } + req := tikvrpc.NewRequest(tikvrpc.CmdResolveLock, &kvrpcpb.ResolveLockRequest{TxnInfos: listTxnInfos}) startTime = time.Now() resp, err := lr.store.SendReq(bo, req, loc, readTimeoutShort) if err != nil { @@ -245,10 +240,10 @@ func (lr *LockResolver) BatchResolveLocks(bo *Backoffer, locks []*Lock, loc Regi return false, nil } - cmdResp := resp.ResolveLock - if cmdResp == nil { + if resp.Resp == nil { return false, errors.Trace(ErrBodyMissing) } + cmdResp := resp.Resp.(*kvrpcpb.ResolveLockResponse) if keyErr := cmdResp.GetError(); keyErr != nil { return false, errors.Errorf("unexpected resolve err: %s", keyErr) } @@ -341,13 +336,10 @@ func (lr *LockResolver) getTxnStatus(bo *Backoffer, txnID uint64, primary []byte tikvLockResolverCountWithQueryTxnStatus.Inc() var status TxnStatus - req := &tikvrpc.Request{ - Type: tikvrpc.CmdCleanup, - Cleanup: &kvrpcpb.CleanupRequest{ - Key: primary, - StartVersion: txnID, - }, - } + req := tikvrpc.NewRequest(tikvrpc.CmdCleanup, &kvrpcpb.CleanupRequest{ + Key: primary, + StartVersion: txnID, + }) for { loc, err := lr.store.GetRegionCache().LocateKey(bo, primary) if err != nil { @@ -368,10 +360,10 @@ func (lr *LockResolver) getTxnStatus(bo *Backoffer, txnID uint64, primary []byte } continue } - cmdResp := resp.Cleanup - if cmdResp == nil { + if resp.Resp == nil { return status, errors.Trace(ErrBodyMissing) } + cmdResp := resp.Resp.(*kvrpcpb.CleanupResponse) if keyErr := cmdResp.GetError(); keyErr != nil { err = errors.Errorf("unexpected cleanup err: %s, tid: %v", keyErr, txnID) logutil.BgLogger().Error("getTxnStatus error", zap.Error(err)) @@ -399,21 +391,19 @@ func (lr *LockResolver) resolveLock(bo *Backoffer, l *Lock, status TxnStatus, cl if _, ok := cleanRegions[loc.Region]; ok { return nil } - req := &tikvrpc.Request{ - Type: tikvrpc.CmdResolveLock, - ResolveLock: &kvrpcpb.ResolveLockRequest{ - StartVersion: l.TxnID, - }, + lreq := &kvrpcpb.ResolveLockRequest{ + StartVersion: l.TxnID, } if status.IsCommitted() { - req.ResolveLock.CommitVersion = status.CommitTS() + lreq.CommitVersion = status.CommitTS() } if l.TxnSize < bigTxnThreshold { // Only resolve specified keys when it is a small transaction, // prevent from scanning the whole region in this case. tikvLockResolverCountWithResolveLockLite.Inc() - req.ResolveLock.Keys = [][]byte{l.Key} + lreq.Keys = [][]byte{l.Key} } + req := tikvrpc.NewRequest(tikvrpc.CmdResolveLock, lreq) resp, err := lr.store.SendReq(bo, req, loc.Region, readTimeoutShort) if err != nil { return errors.Trace(err) @@ -429,10 +419,10 @@ func (lr *LockResolver) resolveLock(bo *Backoffer, l *Lock, status TxnStatus, cl } continue } - cmdResp := resp.ResolveLock - if cmdResp == nil { + if resp.Resp == nil { return errors.Trace(ErrBodyMissing) } + cmdResp := resp.Resp.(*kvrpcpb.ResolveLockResponse) if keyErr := cmdResp.GetError(); keyErr != nil { err = errors.Errorf("unexpected resolve err: %s, lock: %v", keyErr, l) logutil.BgLogger().Error("resolveLock error", zap.Error(err)) diff --git a/store/tikv/lock_test.go b/store/tikv/lock_test.go index 3b9ed28f3fded..f2379cd142f87 100644 --- a/store/tikv/lock_test.go +++ b/store/tikv/lock_test.go @@ -213,20 +213,16 @@ func (s *testLockSuite) mustGetLock(c *C, key []byte) *Lock { ver, err := s.store.CurrentVersion() c.Assert(err, IsNil) bo := NewBackoffer(context.Background(), getMaxBackoff) - req := &tikvrpc.Request{ - Type: tikvrpc.CmdGet, - Get: &kvrpcpb.GetRequest{ - Key: key, - Version: ver.Ver, - }, - } + req := tikvrpc.NewRequest(tikvrpc.CmdGet, &kvrpcpb.GetRequest{ + Key: key, + Version: ver.Ver, + }) loc, err := s.store.regionCache.LocateKey(bo, key) c.Assert(err, IsNil) resp, err := s.store.SendReq(bo, req, loc.Region, readTimeoutShort) c.Assert(err, IsNil) - cmdGetResp := resp.Get - c.Assert(cmdGetResp, NotNil) - keyErr := cmdGetResp.GetError() + c.Assert(resp.Resp, NotNil) + keyErr := resp.Resp.(*kvrpcpb.GetResponse).GetError() c.Assert(keyErr, NotNil) lock, err := extractLockFromKeyErr(keyErr) c.Assert(err, IsNil) diff --git a/store/tikv/rawkv.go b/store/tikv/rawkv.go index aea8c9abf52dd..c510dddba9d9f 100644 --- a/store/tikv/rawkv.go +++ b/store/tikv/rawkv.go @@ -104,20 +104,15 @@ func (c *RawKVClient) Get(key []byte) ([]byte, error) { start := time.Now() defer func() { tikvRawkvCmdHistogramWithGet.Observe(time.Since(start).Seconds()) }() - req := &tikvrpc.Request{ - Type: tikvrpc.CmdRawGet, - RawGet: &kvrpcpb.RawGetRequest{ - Key: key, - }, - } + req := tikvrpc.NewRequest(tikvrpc.CmdRawGet, &kvrpcpb.RawGetRequest{Key: key}) resp, _, err := c.sendReq(key, req, false) if err != nil { return nil, errors.Trace(err) } - cmdResp := resp.RawGet - if cmdResp == nil { + if resp.Resp == nil { return nil, errors.Trace(ErrBodyMissing) } + cmdResp := resp.Resp.(*kvrpcpb.RawGetResponse) if cmdResp.GetError() != "" { return nil, errors.New(cmdResp.GetError()) } @@ -140,10 +135,10 @@ func (c *RawKVClient) BatchGet(keys [][]byte) ([][]byte, error) { return nil, errors.Trace(err) } - cmdResp := resp.RawBatchGet - if cmdResp == nil { + if resp.Resp == nil { return nil, errors.Trace(ErrBodyMissing) } + cmdResp := resp.Resp.(*kvrpcpb.RawBatchGetResponse) keyToValue := make(map[string][]byte, len(keys)) for _, pair := range cmdResp.Pairs { @@ -168,21 +163,18 @@ func (c *RawKVClient) Put(key, value []byte) error { return errors.New("empty value is not supported") } - req := &tikvrpc.Request{ - Type: tikvrpc.CmdRawPut, - RawPut: &kvrpcpb.RawPutRequest{ - Key: key, - Value: value, - }, - } + req := tikvrpc.NewRequest(tikvrpc.CmdRawPut, &kvrpcpb.RawPutRequest{ + Key: key, + Value: value, + }) resp, _, err := c.sendReq(key, req, false) if err != nil { return errors.Trace(err) } - cmdResp := resp.RawPut - if cmdResp == nil { + if resp.Resp == nil { return errors.Trace(ErrBodyMissing) } + cmdResp := resp.Resp.(*kvrpcpb.RawPutResponse) if cmdResp.GetError() != "" { return errors.New(cmdResp.GetError()) } @@ -214,20 +206,17 @@ func (c *RawKVClient) Delete(key []byte) error { start := time.Now() defer func() { tikvRawkvCmdHistogramWithDelete.Observe(time.Since(start).Seconds()) }() - req := &tikvrpc.Request{ - Type: tikvrpc.CmdRawDelete, - RawDelete: &kvrpcpb.RawDeleteRequest{ - Key: key, - }, - } + req := tikvrpc.NewRequest(tikvrpc.CmdRawDelete, &kvrpcpb.RawDeleteRequest{ + Key: key, + }) resp, _, err := c.sendReq(key, req, false) if err != nil { return errors.Trace(err) } - cmdResp := resp.RawDelete - if cmdResp == nil { + if resp.Resp == nil { return errors.Trace(ErrBodyMissing) } + cmdResp := resp.Resp.(*kvrpcpb.RawDeleteResponse) if cmdResp.GetError() != "" { return errors.New(cmdResp.GetError()) } @@ -246,10 +235,10 @@ func (c *RawKVClient) BatchDelete(keys [][]byte) error { if err != nil { return errors.Trace(err) } - cmdResp := resp.RawBatchDelete - if cmdResp == nil { + if resp.Resp == nil { return errors.Trace(ErrBodyMissing) } + cmdResp := resp.Resp.(*kvrpcpb.RawBatchDeleteResponse) if cmdResp.GetError() != "" { return errors.New(cmdResp.GetError()) } @@ -276,10 +265,10 @@ func (c *RawKVClient) DeleteRange(startKey []byte, endKey []byte) error { if err != nil { return errors.Trace(err) } - cmdResp := resp.RawDeleteRange - if cmdResp == nil { + if resp.Resp == nil { return errors.Trace(ErrBodyMissing) } + cmdResp := resp.Resp.(*kvrpcpb.RawDeleteRangeResponse) if cmdResp.GetError() != "" { return errors.New(cmdResp.GetError()) } @@ -303,22 +292,19 @@ func (c *RawKVClient) Scan(startKey, endKey []byte, limit int) (keys [][]byte, v } for len(keys) < limit { - req := &tikvrpc.Request{ - Type: tikvrpc.CmdRawScan, - RawScan: &kvrpcpb.RawScanRequest{ - StartKey: startKey, - EndKey: endKey, - Limit: uint32(limit - len(keys)), - }, - } + req := tikvrpc.NewRequest(tikvrpc.CmdRawScan, &kvrpcpb.RawScanRequest{ + StartKey: startKey, + EndKey: endKey, + Limit: uint32(limit - len(keys)), + }) resp, loc, err := c.sendReq(startKey, req, false) if err != nil { return nil, nil, errors.Trace(err) } - cmdResp := resp.RawScan - if cmdResp == nil { + if resp.Resp == nil { return nil, nil, errors.Trace(ErrBodyMissing) } + cmdResp := resp.Resp.(*kvrpcpb.RawScanResponse) for _, pair := range cmdResp.Kvs { keys = append(keys, pair.Key) values = append(values, pair.Value) @@ -349,23 +335,20 @@ func (c *RawKVClient) ReverseScan(startKey, endKey []byte, limit int) (keys [][] } for len(keys) < limit { - req := &tikvrpc.Request{ - Type: tikvrpc.CmdRawScan, - RawScan: &kvrpcpb.RawScanRequest{ - StartKey: startKey, - EndKey: endKey, - Limit: uint32(limit - len(keys)), - Reverse: true, - }, - } + req := tikvrpc.NewRequest(tikvrpc.CmdRawScan, &kvrpcpb.RawScanRequest{ + StartKey: startKey, + EndKey: endKey, + Limit: uint32(limit - len(keys)), + Reverse: true, + }) resp, loc, err := c.sendReq(startKey, req, true) if err != nil { return nil, nil, errors.Trace(err) } - cmdResp := resp.RawScan - if cmdResp == nil { + if resp.Resp == nil { return nil, nil, errors.Trace(ErrBodyMissing) } + cmdResp := resp.Resp.(*kvrpcpb.RawScanResponse) for _, pair := range cmdResp.Kvs { keys = append(keys, pair.Key) values = append(values, pair.Value) @@ -436,9 +419,9 @@ func (c *RawKVClient) sendBatchReq(bo *Backoffer, keys [][]byte, cmdType tikvrpc var resp *tikvrpc.Response switch cmdType { case tikvrpc.CmdRawBatchGet: - resp = &tikvrpc.Response{Type: tikvrpc.CmdRawBatchGet, RawBatchGet: &kvrpcpb.RawBatchGetResponse{}} + resp = &tikvrpc.Response{Resp: &kvrpcpb.RawBatchGetResponse{}} case tikvrpc.CmdRawBatchDelete: - resp = &tikvrpc.Response{Type: tikvrpc.CmdRawBatchDelete, RawBatchDelete: &kvrpcpb.RawBatchDeleteResponse{}} + resp = &tikvrpc.Response{Resp: &kvrpcpb.RawBatchDeleteResponse{}} } for i := 0; i < len(batches); i++ { singleResp, ok := <-ches @@ -449,8 +432,8 @@ func (c *RawKVClient) sendBatchReq(bo *Backoffer, keys [][]byte, cmdType tikvrpc firstError = singleResp.err } } else if cmdType == tikvrpc.CmdRawBatchGet { - cmdResp := singleResp.resp.RawBatchGet - resp.RawBatchGet.Pairs = append(resp.RawBatchGet.Pairs, cmdResp.Pairs...) + cmdResp := singleResp.resp.Resp.(*kvrpcpb.RawBatchGetResponse) + resp.Resp.(*kvrpcpb.RawBatchGetResponse).Pairs = append(resp.Resp.(*kvrpcpb.RawBatchGetResponse).Pairs, cmdResp.Pairs...) } } } @@ -462,19 +445,13 @@ func (c *RawKVClient) doBatchReq(bo *Backoffer, batch batch, cmdType tikvrpc.Cmd var req *tikvrpc.Request switch cmdType { case tikvrpc.CmdRawBatchGet: - req = &tikvrpc.Request{ - Type: cmdType, - RawBatchGet: &kvrpcpb.RawBatchGetRequest{ - Keys: batch.keys, - }, - } + req = tikvrpc.NewRequest(cmdType, &kvrpcpb.RawBatchGetRequest{ + Keys: batch.keys, + }) case tikvrpc.CmdRawBatchDelete: - req = &tikvrpc.Request{ - Type: cmdType, - RawBatchDelete: &kvrpcpb.RawBatchDeleteRequest{ - Keys: batch.keys, - }, - } + req = tikvrpc.NewRequest(cmdType, &kvrpcpb.RawBatchDeleteRequest{ + Keys: batch.keys, + }) } sender := NewRegionRequestSender(c.regionCache, c.rpcClient) @@ -506,11 +483,11 @@ func (c *RawKVClient) doBatchReq(bo *Backoffer, batch batch, cmdType tikvrpc.Cmd case tikvrpc.CmdRawBatchGet: batchResp.resp = resp case tikvrpc.CmdRawBatchDelete: - cmdResp := resp.RawBatchDelete - if cmdResp == nil { + if resp.Resp == nil { batchResp.err = errors.Trace(ErrBodyMissing) return batchResp } + cmdResp := resp.Resp.(*kvrpcpb.RawBatchDeleteResponse) if cmdResp.GetError() != "" { batchResp.err = errors.New(cmdResp.GetError()) return batchResp @@ -538,13 +515,10 @@ func (c *RawKVClient) sendDeleteRangeReq(startKey []byte, endKey []byte) (*tikvr actualEndKey = loc.EndKey } - req := &tikvrpc.Request{ - Type: tikvrpc.CmdRawDeleteRange, - RawDeleteRange: &kvrpcpb.RawDeleteRangeRequest{ - StartKey: startKey, - EndKey: actualEndKey, - }, - } + req := tikvrpc.NewRequest(tikvrpc.CmdRawDeleteRange, &kvrpcpb.RawDeleteRangeRequest{ + StartKey: startKey, + EndKey: actualEndKey, + }) resp, err := sender.SendReq(bo, req, loc.Region, readTimeoutShort) if err != nil { @@ -648,12 +622,7 @@ func (c *RawKVClient) doBatchPut(bo *Backoffer, batch batch) error { kvPair = append(kvPair, &kvrpcpb.KvPair{Key: key, Value: batch.values[i]}) } - req := &tikvrpc.Request{ - Type: tikvrpc.CmdRawBatchPut, - RawBatchPut: &kvrpcpb.RawBatchPutRequest{ - Pairs: kvPair, - }, - } + req := tikvrpc.NewRequest(tikvrpc.CmdRawBatchPut, &kvrpcpb.RawBatchPutRequest{Pairs: kvPair}) sender := NewRegionRequestSender(c.regionCache, c.rpcClient) resp, err := sender.SendReq(bo, req, batch.regionID, readTimeoutShort) @@ -673,10 +642,10 @@ func (c *RawKVClient) doBatchPut(bo *Backoffer, batch batch) error { return c.sendBatchPut(bo, batch.keys, batch.values) } - cmdResp := resp.RawBatchPut - if cmdResp == nil { + if resp.Resp == nil { return errors.Trace(ErrBodyMissing) } + cmdResp := resp.Resp.(*kvrpcpb.RawBatchPutResponse) if cmdResp.GetError() != "" { return errors.New(cmdResp.GetError()) } diff --git a/store/tikv/region_request.go b/store/tikv/region_request.go index 3fcd149707f2b..45c0d53667e44 100644 --- a/store/tikv/region_request.go +++ b/store/tikv/region_request.go @@ -15,6 +15,9 @@ package tikv import ( "context" + "go.uber.org/zap" + "google.golang.org/grpc" + "google.golang.org/grpc/codes" "sync/atomic" "time" @@ -26,9 +29,6 @@ import ( "github.com/pingcap/tidb/metrics" "github.com/pingcap/tidb/store/tikv/tikvrpc" "github.com/pingcap/tidb/util/logutil" - "go.uber.org/zap" - "google.golang.org/grpc" - "google.golang.org/grpc/codes" ) // ShuttingDown is a flag to indicate tidb-server is exiting (Ctrl+C signal @@ -82,15 +82,13 @@ func (s *RegionRequestSender) SendReqCtx(bo *Backoffer, req *tikvrpc.Request, re case "GCNotLeader": if req.Type == tikvrpc.CmdGC { failpoint.Return(&tikvrpc.Response{ - Type: tikvrpc.CmdGC, - GC: &kvrpcpb.GCResponse{RegionError: &errorpb.Error{NotLeader: &errorpb.NotLeader{}}}, + Resp: &kvrpcpb.GCResponse{RegionError: &errorpb.Error{NotLeader: &errorpb.NotLeader{}}}, }, nil, nil) } case "GCServerIsBusy": if req.Type == tikvrpc.CmdGC { failpoint.Return(&tikvrpc.Response{ - Type: tikvrpc.CmdGC, - GC: &kvrpcpb.GCResponse{RegionError: &errorpb.Error{ServerIsBusy: &errorpb.ServerIsBusy{}}}, + Resp: &kvrpcpb.GCResponse{RegionError: &errorpb.Error{ServerIsBusy: &errorpb.ServerIsBusy{}}}, }, nil, nil) } } diff --git a/store/tikv/region_request_test.go b/store/tikv/region_request_test.go index 52cc1636bef7e..9d80af4c67efa 100644 --- a/store/tikv/region_request_test.go +++ b/store/tikv/region_request_test.go @@ -60,19 +60,16 @@ func (s *testRegionRequestSuite) TearDownTest(c *C) { } func (s *testRegionRequestSuite) TestOnSendFailedWithStoreRestart(c *C) { - req := &tikvrpc.Request{ - Type: tikvrpc.CmdRawPut, - RawPut: &kvrpcpb.RawPutRequest{ - Key: []byte("key"), - Value: []byte("value"), - }, - } + req := tikvrpc.NewRequest(tikvrpc.CmdRawPut, &kvrpcpb.RawPutRequest{ + Key: []byte("key"), + Value: []byte("value"), + }) region, err := s.cache.LocateRegionByID(s.bo, s.region) c.Assert(err, IsNil) c.Assert(region, NotNil) resp, err := s.regionRequestSender.SendReq(s.bo, req, region.Region, time.Second) c.Assert(err, IsNil) - c.Assert(resp.RawPut, NotNil) + c.Assert(resp.Resp, NotNil) // stop store. s.cluster.StopStore(s.store) @@ -89,23 +86,20 @@ func (s *testRegionRequestSuite) TestOnSendFailedWithStoreRestart(c *C) { c.Assert(region, NotNil) resp, err = s.regionRequestSender.SendReq(s.bo, req, region.Region, time.Second) c.Assert(err, IsNil) - c.Assert(resp.RawPut, NotNil) + c.Assert(resp.Resp, NotNil) } func (s *testRegionRequestSuite) TestOnSendFailedWithCloseKnownStoreThenUseNewOne(c *C) { - req := &tikvrpc.Request{ - Type: tikvrpc.CmdRawPut, - RawPut: &kvrpcpb.RawPutRequest{ - Key: []byte("key"), - Value: []byte("value"), - }, - } + req := tikvrpc.NewRequest(tikvrpc.CmdRawPut, &kvrpcpb.RawPutRequest{ + Key: []byte("key"), + Value: []byte("value"), + }) region, err := s.cache.LocateRegionByID(s.bo, s.region) c.Assert(err, IsNil) c.Assert(region, NotNil) resp, err := s.regionRequestSender.SendReq(s.bo, req, region.Region, time.Second) c.Assert(err, IsNil) - c.Assert(resp.RawPut, NotNil) + c.Assert(resp.Resp, NotNil) // add new unknown region store2 := s.cluster.AllocID() @@ -131,36 +125,30 @@ func (s *testRegionRequestSuite) TestOnSendFailedWithCloseKnownStoreThenUseNewOn } func (s *testRegionRequestSuite) TestSendReqCtx(c *C) { - req := &tikvrpc.Request{ - Type: tikvrpc.CmdRawPut, - RawPut: &kvrpcpb.RawPutRequest{ - Key: []byte("key"), - Value: []byte("value"), - }, - } + req := tikvrpc.NewRequest(tikvrpc.CmdRawPut, &kvrpcpb.RawPutRequest{ + Key: []byte("key"), + Value: []byte("value"), + }) region, err := s.cache.LocateRegionByID(s.bo, s.region) c.Assert(err, IsNil) c.Assert(region, NotNil) resp, ctx, err := s.regionRequestSender.SendReqCtx(s.bo, req, region.Region, time.Second) c.Assert(err, IsNil) - c.Assert(resp.RawPut, NotNil) + c.Assert(resp.Resp, NotNil) c.Assert(ctx, NotNil) } func (s *testRegionRequestSuite) TestOnSendFailedWithCancelled(c *C) { - req := &tikvrpc.Request{ - Type: tikvrpc.CmdRawPut, - RawPut: &kvrpcpb.RawPutRequest{ - Key: []byte("key"), - Value: []byte("value"), - }, - } + req := tikvrpc.NewRequest(tikvrpc.CmdRawPut, &kvrpcpb.RawPutRequest{ + Key: []byte("key"), + Value: []byte("value"), + }) region, err := s.cache.LocateRegionByID(s.bo, s.region) c.Assert(err, IsNil) c.Assert(region, NotNil) resp, err := s.regionRequestSender.SendReq(s.bo, req, region.Region, time.Second) c.Assert(err, IsNil) - c.Assert(resp.RawPut, NotNil) + c.Assert(resp.Resp, NotNil) // set store to cancel state. s.cluster.CancelStore(s.store) @@ -177,17 +165,14 @@ func (s *testRegionRequestSuite) TestOnSendFailedWithCancelled(c *C) { c.Assert(region, NotNil) resp, err = s.regionRequestSender.SendReq(s.bo, req, region.Region, time.Second) c.Assert(err, IsNil) - c.Assert(resp.RawPut, NotNil) + c.Assert(resp.Resp, NotNil) } func (s *testRegionRequestSuite) TestNoReloadRegionWhenCtxCanceled(c *C) { - req := &tikvrpc.Request{ - Type: tikvrpc.CmdRawPut, - RawPut: &kvrpcpb.RawPutRequest{ - Key: []byte("key"), - Value: []byte("value"), - }, - } + req := tikvrpc.NewRequest(tikvrpc.CmdRawPut, &kvrpcpb.RawPutRequest{ + Key: []byte("key"), + Value: []byte("value"), + }) region, err := s.cache.LocateRegionByID(s.bo, s.region) c.Assert(err, IsNil) c.Assert(region, NotNil) @@ -340,13 +325,10 @@ func (s *testRegionRequestSuite) TestNoReloadRegionForGrpcWhenCtxCanceled(c *C) client := newRPCClient(config.Security{}) sender := NewRegionRequestSender(s.cache, client) - req := &tikvrpc.Request{ - Type: tikvrpc.CmdRawPut, - RawPut: &kvrpcpb.RawPutRequest{ - Key: []byte("key"), - Value: []byte("value"), - }, - } + req := tikvrpc.NewRequest(tikvrpc.CmdRawPut, &kvrpcpb.RawPutRequest{ + Key: []byte("key"), + Value: []byte("value"), + }) region, err := s.cache.LocateRegionByID(s.bo, s.region) c.Assert(err, IsNil) diff --git a/store/tikv/scan.go b/store/tikv/scan.go index e0220f5e63563..5804c5bf6d3aa 100644 --- a/store/tikv/scan.go +++ b/store/tikv/scan.go @@ -185,25 +185,22 @@ func (s *Scanner) getData(bo *Backoffer) error { } } - req := &tikvrpc.Request{ - Type: tikvrpc.CmdScan, - Scan: &pb.ScanRequest{ - StartKey: s.nextStartKey, - EndKey: reqEndKey, - Limit: uint32(s.batchSize), - Version: s.startTS(), - KeyOnly: s.snapshot.keyOnly, - }, - Context: pb.Context{ - Priority: s.snapshot.priority, - NotFillCache: s.snapshot.notFillCache, - }, + sreq := &pb.ScanRequest{ + StartKey: s.nextStartKey, + EndKey: reqEndKey, + Limit: uint32(s.batchSize), + Version: s.startTS(), + KeyOnly: s.snapshot.keyOnly, } if s.reverse { - req.Scan.StartKey = s.nextEndKey - req.Scan.EndKey = reqStartKey - req.Scan.Reverse = true + sreq.StartKey = s.nextEndKey + sreq.EndKey = reqStartKey + sreq.Reverse = true } + req := tikvrpc.NewRequest(tikvrpc.CmdScan, sreq, pb.Context{ + Priority: s.snapshot.priority, + NotFillCache: s.snapshot.notFillCache, + }) resp, err := sender.SendReq(bo, req, loc.Region, ReadTimeoutMedium) if err != nil { return errors.Trace(err) @@ -221,10 +218,10 @@ func (s *Scanner) getData(bo *Backoffer) error { } continue } - cmdScanResp := resp.Scan - if cmdScanResp == nil { + if resp.Resp == nil { return errors.Trace(ErrBodyMissing) } + cmdScanResp := resp.Resp.(*pb.ScanResponse) err = s.snapshot.store.CheckVisibility(s.startTS()) if err != nil { diff --git a/store/tikv/snapshot.go b/store/tikv/snapshot.go index f940448a6caa0..258856fb6eb3b 100644 --- a/store/tikv/snapshot.go +++ b/store/tikv/snapshot.go @@ -154,17 +154,13 @@ func (s *tikvSnapshot) batchGetSingleRegion(bo *Backoffer, batch batchKeys, coll pending := batch.keys for { - req := &tikvrpc.Request{ - Type: tikvrpc.CmdBatchGet, - BatchGet: &pb.BatchGetRequest{ - Keys: pending, - Version: s.version.Ver, - }, - Context: pb.Context{ - Priority: s.priority, - NotFillCache: s.notFillCache, - }, - } + req := tikvrpc.NewRequest(tikvrpc.CmdBatchGet, &pb.BatchGetRequest{ + Keys: pending, + Version: s.version.Ver, + }, pb.Context{ + Priority: s.priority, + NotFillCache: s.notFillCache, + }) resp, err := sender.SendReq(bo, req, batch.region, ReadTimeoutMedium) if err != nil { return errors.Trace(err) @@ -181,10 +177,10 @@ func (s *tikvSnapshot) batchGetSingleRegion(bo *Backoffer, batch batchKeys, coll err = s.batchGetKeysByRegions(bo, pending, collectF) return errors.Trace(err) } - batchGetResp := resp.BatchGet - if batchGetResp == nil { + if resp.Resp == nil { return errors.Trace(ErrBodyMissing) } + batchGetResp := resp.Resp.(*pb.BatchGetResponse) var ( lockedKeys [][]byte locks []*Lock @@ -236,17 +232,14 @@ func (s *tikvSnapshot) Get(k kv.Key) ([]byte, error) { func (s *tikvSnapshot) get(bo *Backoffer, k kv.Key) ([]byte, error) { sender := NewRegionRequestSender(s.store.regionCache, s.store.client) - req := &tikvrpc.Request{ - Type: tikvrpc.CmdGet, - Get: &pb.GetRequest{ + req := tikvrpc.NewRequest(tikvrpc.CmdGet, + &pb.GetRequest{ Key: k, Version: s.version.Ver, - }, - Context: pb.Context{ + }, pb.Context{ Priority: s.priority, NotFillCache: s.notFillCache, - }, - } + }) for { loc, err := s.store.regionCache.LocateKey(bo, k) if err != nil { @@ -267,10 +260,10 @@ func (s *tikvSnapshot) get(bo *Backoffer, k kv.Key) ([]byte, error) { } continue } - cmdGetResp := resp.Get - if cmdGetResp == nil { + if resp.Resp == nil { return nil, errors.Trace(ErrBodyMissing) } + cmdGetResp := resp.Resp.(*pb.GetResponse) val := cmdGetResp.GetValue() if keyErr := cmdGetResp.GetError(); keyErr != nil { lock, err := extractLockFromKeyErr(keyErr) diff --git a/store/tikv/split_region.go b/store/tikv/split_region.go index e949153ae5846..a232573575682 100644 --- a/store/tikv/split_region.go +++ b/store/tikv/split_region.go @@ -33,13 +33,11 @@ func (s *tikvStore) SplitRegion(splitKey kv.Key, scatter bool) (regionID uint64, zap.Binary("at", splitKey)) bo := NewBackoffer(context.Background(), splitRegionBackoff) sender := NewRegionRequestSender(s.regionCache, s.client) - req := &tikvrpc.Request{ - Type: tikvrpc.CmdSplitRegion, - SplitRegion: &kvrpcpb.SplitRegionRequest{ - SplitKey: splitKey, - }, - } - req.Context.Priority = kvrpcpb.CommandPri_Normal + req := tikvrpc.NewRequest(tikvrpc.CmdSplitRegion, &kvrpcpb.SplitRegionRequest{ + SplitKey: splitKey, + }, kvrpcpb.Context{ + Priority: kvrpcpb.CommandPri_Normal, + }) for { loc, err := s.regionCache.LocateKey(bo, splitKey) if err != nil { @@ -65,11 +63,12 @@ func (s *tikvStore) SplitRegion(splitKey kv.Key, scatter bool) (regionID uint64, } continue } + splitRegion := res.Resp.(*kvrpcpb.SplitRegionResponse) logutil.BgLogger().Info("split region complete", zap.Binary("at", splitKey), - zap.Stringer("new region left", res.SplitRegion.GetLeft()), - zap.Stringer("new region right", res.SplitRegion.GetRight())) - left := res.SplitRegion.GetLeft() + zap.Stringer("new region left", splitRegion.GetLeft()), + zap.Stringer("new region right", splitRegion.GetRight())) + left := splitRegion.GetLeft() if left == nil { return 0, nil } diff --git a/store/tikv/store_test.go b/store/tikv/store_test.go index 40fb48a2b004b..5fff7dd47ccc5 100644 --- a/store/tikv/store_test.go +++ b/store/tikv/store_test.go @@ -220,8 +220,8 @@ type checkRequestClient struct { func (c *checkRequestClient) SendRequest(ctx context.Context, addr string, req *tikvrpc.Request, timeout time.Duration) (*tikvrpc.Response, error) { resp, err := c.Client.SendRequest(ctx, addr, req, timeout) if c.priority != req.Priority { - if resp.Get != nil { - resp.Get.Error = &pb.KeyError{ + if resp.Resp != nil { + (resp.Resp.(*pb.GetResponse)).Error = &pb.KeyError{ Abort: "request check error", } } diff --git a/store/tikv/tikvrpc/tikvrpc.go b/store/tikv/tikvrpc/tikvrpc.go index b3cf78ee5e82e..8d2e813e6d9d6 100644 --- a/store/tikv/tikvrpc/tikvrpc.go +++ b/store/tikv/tikvrpc/tikvrpc.go @@ -134,90 +134,215 @@ func (t CmdType) String() string { // Request wraps all kv/coprocessor requests. type Request struct { + Type CmdType + req interface{} kvrpcpb.Context - Type CmdType - Get *kvrpcpb.GetRequest - Scan *kvrpcpb.ScanRequest - Prewrite *kvrpcpb.PrewriteRequest - Commit *kvrpcpb.CommitRequest - Cleanup *kvrpcpb.CleanupRequest - BatchGet *kvrpcpb.BatchGetRequest - BatchRollback *kvrpcpb.BatchRollbackRequest - ScanLock *kvrpcpb.ScanLockRequest - ResolveLock *kvrpcpb.ResolveLockRequest - GC *kvrpcpb.GCRequest - DeleteRange *kvrpcpb.DeleteRangeRequest - RawGet *kvrpcpb.RawGetRequest - RawBatchGet *kvrpcpb.RawBatchGetRequest - RawPut *kvrpcpb.RawPutRequest - RawBatchPut *kvrpcpb.RawBatchPutRequest - RawDelete *kvrpcpb.RawDeleteRequest - RawBatchDelete *kvrpcpb.RawBatchDeleteRequest - RawDeleteRange *kvrpcpb.RawDeleteRangeRequest - RawScan *kvrpcpb.RawScanRequest - UnsafeDestroyRange *kvrpcpb.UnsafeDestroyRangeRequest - Cop *coprocessor.Request - MvccGetByKey *kvrpcpb.MvccGetByKeyRequest - MvccGetByStartTs *kvrpcpb.MvccGetByStartTsRequest - SplitRegion *kvrpcpb.SplitRegionRequest - - PessimisticLock *kvrpcpb.PessimisticLockRequest - PessimisticRollback *kvrpcpb.PessimisticRollbackRequest - - DebugGetRegionProperties *debugpb.GetRegionPropertiesRequest - - Empty *tikvpb.BatchCommandsEmptyRequest +} + +// NewRequest returns new kv rpc request. +func NewRequest(typ CmdType, pointer interface{}, ctxs ...kvrpcpb.Context) *Request { + if len(ctxs) > 0 { + return &Request{ + Type: typ, + req: pointer, + Context: ctxs[0], + } + } + return &Request{ + Type: typ, + req: pointer, + } +} + +// Get returns GetRequest in request. +func (req *Request) Get() *kvrpcpb.GetRequest { + return req.req.(*kvrpcpb.GetRequest) +} + +// Scan returns ScanRequest in request. +func (req *Request) Scan() *kvrpcpb.ScanRequest { + return req.req.(*kvrpcpb.ScanRequest) +} + +// Prewrite returns PrewriteRequest in request. +func (req *Request) Prewrite() *kvrpcpb.PrewriteRequest { + return req.req.(*kvrpcpb.PrewriteRequest) +} + +// Commit returns CommitRequest in request. +func (req *Request) Commit() *kvrpcpb.CommitRequest { + return req.req.(*kvrpcpb.CommitRequest) +} + +// Cleanup returns CleanupRequest in request. +func (req *Request) Cleanup() *kvrpcpb.CleanupRequest { + return req.req.(*kvrpcpb.CleanupRequest) +} + +// BatchGet returns BatchGetRequest in request. +func (req *Request) BatchGet() *kvrpcpb.BatchGetRequest { + return req.req.(*kvrpcpb.BatchGetRequest) +} + +// BatchRollback returns BatchRollbackRequest in request. +func (req *Request) BatchRollback() *kvrpcpb.BatchRollbackRequest { + return req.req.(*kvrpcpb.BatchRollbackRequest) +} + +// ScanLock returns ScanLockRequest in request. +func (req *Request) ScanLock() *kvrpcpb.ScanLockRequest { + return req.req.(*kvrpcpb.ScanLockRequest) +} + +// ResolveLock returns ResolveLockRequest in request. +func (req *Request) ResolveLock() *kvrpcpb.ResolveLockRequest { + return req.req.(*kvrpcpb.ResolveLockRequest) +} + +// GC returns GCRequest in request. +func (req *Request) GC() *kvrpcpb.GCRequest { + return req.req.(*kvrpcpb.GCRequest) +} + +// DeleteRange returns DeleteRangeRequest in request. +func (req *Request) DeleteRange() *kvrpcpb.DeleteRangeRequest { + return req.req.(*kvrpcpb.DeleteRangeRequest) +} + +// RawGet returns RawGetRequest in request. +func (req *Request) RawGet() *kvrpcpb.RawGetRequest { + return req.req.(*kvrpcpb.RawGetRequest) +} + +// RawBatchGet returns RawBatchGetRequest in request. +func (req *Request) RawBatchGet() *kvrpcpb.RawBatchGetRequest { + return req.req.(*kvrpcpb.RawBatchGetRequest) +} + +// RawPut returns RawPutRequest in request. +func (req *Request) RawPut() *kvrpcpb.RawPutRequest { + return req.req.(*kvrpcpb.RawPutRequest) +} + +// RawBatchPut returns RawBatchPutRequest in request. +func (req *Request) RawBatchPut() *kvrpcpb.RawBatchPutRequest { + return req.req.(*kvrpcpb.RawBatchPutRequest) +} + +// RawDelete returns PrewriteRequest in request. +func (req *Request) RawDelete() *kvrpcpb.RawDeleteRequest { + return req.req.(*kvrpcpb.RawDeleteRequest) +} + +// RawBatchDelete returns RawBatchDeleteRequest in request. +func (req *Request) RawBatchDelete() *kvrpcpb.RawBatchDeleteRequest { + return req.req.(*kvrpcpb.RawBatchDeleteRequest) +} + +// RawDeleteRange returns RawDeleteRangeRequest in request. +func (req *Request) RawDeleteRange() *kvrpcpb.RawDeleteRangeRequest { + return req.req.(*kvrpcpb.RawDeleteRangeRequest) +} + +// RawScan returns RawScanRequest in request. +func (req *Request) RawScan() *kvrpcpb.RawScanRequest { + return req.req.(*kvrpcpb.RawScanRequest) +} + +// UnsafeDestroyRange returns UnsafeDestroyRangeRequest in request. +func (req *Request) UnsafeDestroyRange() *kvrpcpb.UnsafeDestroyRangeRequest { + return req.req.(*kvrpcpb.UnsafeDestroyRangeRequest) +} + +// Cop returns coprocessor request in request. +func (req *Request) Cop() *coprocessor.Request { + return req.req.(*coprocessor.Request) +} + +// MvccGetByKey returns MvccGetByKeyRequest in request. +func (req *Request) MvccGetByKey() *kvrpcpb.MvccGetByKeyRequest { + return req.req.(*kvrpcpb.MvccGetByKeyRequest) +} + +// MvccGetByStartTs returns MvccGetByStartTsRequest in request. +func (req *Request) MvccGetByStartTs() *kvrpcpb.MvccGetByStartTsRequest { + return req.req.(*kvrpcpb.MvccGetByStartTsRequest) +} + +// SplitRegion returns SplitRegionRequest in request. +func (req *Request) SplitRegion() *kvrpcpb.SplitRegionRequest { + return req.req.(*kvrpcpb.SplitRegionRequest) +} + +// PessimisticLock returns PessimisticLockRequest in request. +func (req *Request) PessimisticLock() *kvrpcpb.PessimisticLockRequest { + return req.req.(*kvrpcpb.PessimisticLockRequest) +} + +// PessimisticRollback returns PessimisticRollbackRequest in request. +func (req *Request) PessimisticRollback() *kvrpcpb.PessimisticRollbackRequest { + return req.req.(*kvrpcpb.PessimisticRollbackRequest) +} + +// DebugGetRegionProperties returns GetRegionPropertiesRequest in request. +func (req *Request) DebugGetRegionProperties() *debugpb.GetRegionPropertiesRequest { + return req.req.(*debugpb.GetRegionPropertiesRequest) +} + +// Empty returns BatchCommandsEmptyRequest in request +func (req *Request) Empty() *tikvpb.BatchCommandsEmptyRequest { + return req.req.(*tikvpb.BatchCommandsEmptyRequest) } // ToBatchCommandsRequest converts the request to an entry in BatchCommands request. func (req *Request) ToBatchCommandsRequest() *tikvpb.BatchCommandsRequest_Request { switch req.Type { case CmdGet: - return &tikvpb.BatchCommandsRequest_Request{Cmd: &tikvpb.BatchCommandsRequest_Request_Get{Get: req.Get}} + return &tikvpb.BatchCommandsRequest_Request{Cmd: &tikvpb.BatchCommandsRequest_Request_Get{Get: req.Get()}} case CmdScan: - return &tikvpb.BatchCommandsRequest_Request{Cmd: &tikvpb.BatchCommandsRequest_Request_Scan{Scan: req.Scan}} + return &tikvpb.BatchCommandsRequest_Request{Cmd: &tikvpb.BatchCommandsRequest_Request_Scan{Scan: req.Scan()}} case CmdPrewrite: - return &tikvpb.BatchCommandsRequest_Request{Cmd: &tikvpb.BatchCommandsRequest_Request_Prewrite{Prewrite: req.Prewrite}} + return &tikvpb.BatchCommandsRequest_Request{Cmd: &tikvpb.BatchCommandsRequest_Request_Prewrite{Prewrite: req.Prewrite()}} case CmdCommit: - return &tikvpb.BatchCommandsRequest_Request{Cmd: &tikvpb.BatchCommandsRequest_Request_Commit{Commit: req.Commit}} + return &tikvpb.BatchCommandsRequest_Request{Cmd: &tikvpb.BatchCommandsRequest_Request_Commit{Commit: req.Commit()}} case CmdCleanup: - return &tikvpb.BatchCommandsRequest_Request{Cmd: &tikvpb.BatchCommandsRequest_Request_Cleanup{Cleanup: req.Cleanup}} + return &tikvpb.BatchCommandsRequest_Request{Cmd: &tikvpb.BatchCommandsRequest_Request_Cleanup{Cleanup: req.Cleanup()}} case CmdBatchGet: - return &tikvpb.BatchCommandsRequest_Request{Cmd: &tikvpb.BatchCommandsRequest_Request_BatchGet{BatchGet: req.BatchGet}} + return &tikvpb.BatchCommandsRequest_Request{Cmd: &tikvpb.BatchCommandsRequest_Request_BatchGet{BatchGet: req.BatchGet()}} case CmdBatchRollback: - return &tikvpb.BatchCommandsRequest_Request{Cmd: &tikvpb.BatchCommandsRequest_Request_BatchRollback{BatchRollback: req.BatchRollback}} + return &tikvpb.BatchCommandsRequest_Request{Cmd: &tikvpb.BatchCommandsRequest_Request_BatchRollback{BatchRollback: req.BatchRollback()}} case CmdScanLock: - return &tikvpb.BatchCommandsRequest_Request{Cmd: &tikvpb.BatchCommandsRequest_Request_ScanLock{ScanLock: req.ScanLock}} + return &tikvpb.BatchCommandsRequest_Request{Cmd: &tikvpb.BatchCommandsRequest_Request_ScanLock{ScanLock: req.ScanLock()}} case CmdResolveLock: - return &tikvpb.BatchCommandsRequest_Request{Cmd: &tikvpb.BatchCommandsRequest_Request_ResolveLock{ResolveLock: req.ResolveLock}} + return &tikvpb.BatchCommandsRequest_Request{Cmd: &tikvpb.BatchCommandsRequest_Request_ResolveLock{ResolveLock: req.ResolveLock()}} case CmdGC: - return &tikvpb.BatchCommandsRequest_Request{Cmd: &tikvpb.BatchCommandsRequest_Request_GC{GC: req.GC}} + return &tikvpb.BatchCommandsRequest_Request{Cmd: &tikvpb.BatchCommandsRequest_Request_GC{GC: req.GC()}} case CmdDeleteRange: - return &tikvpb.BatchCommandsRequest_Request{Cmd: &tikvpb.BatchCommandsRequest_Request_DeleteRange{DeleteRange: req.DeleteRange}} + return &tikvpb.BatchCommandsRequest_Request{Cmd: &tikvpb.BatchCommandsRequest_Request_DeleteRange{DeleteRange: req.DeleteRange()}} case CmdRawGet: - return &tikvpb.BatchCommandsRequest_Request{Cmd: &tikvpb.BatchCommandsRequest_Request_RawGet{RawGet: req.RawGet}} + return &tikvpb.BatchCommandsRequest_Request{Cmd: &tikvpb.BatchCommandsRequest_Request_RawGet{RawGet: req.RawGet()}} case CmdRawBatchGet: - return &tikvpb.BatchCommandsRequest_Request{Cmd: &tikvpb.BatchCommandsRequest_Request_RawBatchGet{RawBatchGet: req.RawBatchGet}} + return &tikvpb.BatchCommandsRequest_Request{Cmd: &tikvpb.BatchCommandsRequest_Request_RawBatchGet{RawBatchGet: req.RawBatchGet()}} case CmdRawPut: - return &tikvpb.BatchCommandsRequest_Request{Cmd: &tikvpb.BatchCommandsRequest_Request_RawPut{RawPut: req.RawPut}} + return &tikvpb.BatchCommandsRequest_Request{Cmd: &tikvpb.BatchCommandsRequest_Request_RawPut{RawPut: req.RawPut()}} case CmdRawBatchPut: - return &tikvpb.BatchCommandsRequest_Request{Cmd: &tikvpb.BatchCommandsRequest_Request_RawBatchPut{RawBatchPut: req.RawBatchPut}} + return &tikvpb.BatchCommandsRequest_Request{Cmd: &tikvpb.BatchCommandsRequest_Request_RawBatchPut{RawBatchPut: req.RawBatchPut()}} case CmdRawDelete: - return &tikvpb.BatchCommandsRequest_Request{Cmd: &tikvpb.BatchCommandsRequest_Request_RawDelete{RawDelete: req.RawDelete}} + return &tikvpb.BatchCommandsRequest_Request{Cmd: &tikvpb.BatchCommandsRequest_Request_RawDelete{RawDelete: req.RawDelete()}} case CmdRawBatchDelete: - return &tikvpb.BatchCommandsRequest_Request{Cmd: &tikvpb.BatchCommandsRequest_Request_RawBatchDelete{RawBatchDelete: req.RawBatchDelete}} + return &tikvpb.BatchCommandsRequest_Request{Cmd: &tikvpb.BatchCommandsRequest_Request_RawBatchDelete{RawBatchDelete: req.RawBatchDelete()}} case CmdRawDeleteRange: - return &tikvpb.BatchCommandsRequest_Request{Cmd: &tikvpb.BatchCommandsRequest_Request_RawDeleteRange{RawDeleteRange: req.RawDeleteRange}} + return &tikvpb.BatchCommandsRequest_Request{Cmd: &tikvpb.BatchCommandsRequest_Request_RawDeleteRange{RawDeleteRange: req.RawDeleteRange()}} case CmdRawScan: - return &tikvpb.BatchCommandsRequest_Request{Cmd: &tikvpb.BatchCommandsRequest_Request_RawScan{RawScan: req.RawScan}} + return &tikvpb.BatchCommandsRequest_Request{Cmd: &tikvpb.BatchCommandsRequest_Request_RawScan{RawScan: req.RawScan()}} case CmdCop: - return &tikvpb.BatchCommandsRequest_Request{Cmd: &tikvpb.BatchCommandsRequest_Request_Coprocessor{Coprocessor: req.Cop}} + return &tikvpb.BatchCommandsRequest_Request{Cmd: &tikvpb.BatchCommandsRequest_Request_Coprocessor{Coprocessor: req.Cop()}} case CmdPessimisticLock: - return &tikvpb.BatchCommandsRequest_Request{Cmd: &tikvpb.BatchCommandsRequest_Request_PessimisticLock{PessimisticLock: req.PessimisticLock}} + return &tikvpb.BatchCommandsRequest_Request{Cmd: &tikvpb.BatchCommandsRequest_Request_PessimisticLock{PessimisticLock: req.PessimisticLock()}} case CmdPessimisticRollback: - return &tikvpb.BatchCommandsRequest_Request{Cmd: &tikvpb.BatchCommandsRequest_Request_PessimisticRollback{PessimisticRollback: req.PessimisticRollback}} + return &tikvpb.BatchCommandsRequest_Request{Cmd: &tikvpb.BatchCommandsRequest_Request_PessimisticRollback{PessimisticRollback: req.PessimisticRollback()}} case CmdEmpty: - return &tikvpb.BatchCommandsRequest_Request{Cmd: &tikvpb.BatchCommandsRequest_Request_Empty{Empty: req.Empty}} + return &tikvpb.BatchCommandsRequest_Request{Cmd: &tikvpb.BatchCommandsRequest_Request_Empty{Empty: req.Empty()}} } return nil } @@ -233,90 +358,58 @@ func (req *Request) IsDebugReq() bool { // Response wraps all kv/coprocessor responses. type Response struct { - Type CmdType - Get *kvrpcpb.GetResponse - Scan *kvrpcpb.ScanResponse - Prewrite *kvrpcpb.PrewriteResponse - Commit *kvrpcpb.CommitResponse - Cleanup *kvrpcpb.CleanupResponse - BatchGet *kvrpcpb.BatchGetResponse - BatchRollback *kvrpcpb.BatchRollbackResponse - ScanLock *kvrpcpb.ScanLockResponse - ResolveLock *kvrpcpb.ResolveLockResponse - GC *kvrpcpb.GCResponse - DeleteRange *kvrpcpb.DeleteRangeResponse - RawGet *kvrpcpb.RawGetResponse - RawBatchGet *kvrpcpb.RawBatchGetResponse - RawPut *kvrpcpb.RawPutResponse - RawBatchPut *kvrpcpb.RawBatchPutResponse - RawDelete *kvrpcpb.RawDeleteResponse - RawBatchDelete *kvrpcpb.RawBatchDeleteResponse - RawDeleteRange *kvrpcpb.RawDeleteRangeResponse - RawScan *kvrpcpb.RawScanResponse - UnsafeDestroyRange *kvrpcpb.UnsafeDestroyRangeResponse - Cop *coprocessor.Response - CopStream *CopStreamResponse - MvccGetByKey *kvrpcpb.MvccGetByKeyResponse - MvccGetByStartTS *kvrpcpb.MvccGetByStartTsResponse - SplitRegion *kvrpcpb.SplitRegionResponse - - PessimisticLock *kvrpcpb.PessimisticLockResponse - PessimisticRollback *kvrpcpb.PessimisticRollbackResponse - - DebugGetRegionProperties *debugpb.GetRegionPropertiesResponse - - Empty *tikvpb.BatchCommandsEmptyResponse + Resp interface{} } // FromBatchCommandsResponse converts a BatchCommands response to Response. func FromBatchCommandsResponse(res *tikvpb.BatchCommandsResponse_Response) *Response { switch res := res.GetCmd().(type) { case *tikvpb.BatchCommandsResponse_Response_Get: - return &Response{Type: CmdGet, Get: res.Get} + return &Response{Resp: res.Get} case *tikvpb.BatchCommandsResponse_Response_Scan: - return &Response{Type: CmdScan, Scan: res.Scan} + return &Response{Resp: res.Scan} case *tikvpb.BatchCommandsResponse_Response_Prewrite: - return &Response{Type: CmdPrewrite, Prewrite: res.Prewrite} + return &Response{Resp: res.Prewrite} case *tikvpb.BatchCommandsResponse_Response_Commit: - return &Response{Type: CmdCommit, Commit: res.Commit} + return &Response{Resp: res.Commit} case *tikvpb.BatchCommandsResponse_Response_Cleanup: - return &Response{Type: CmdCleanup, Cleanup: res.Cleanup} + return &Response{Resp: res.Cleanup} case *tikvpb.BatchCommandsResponse_Response_BatchGet: - return &Response{Type: CmdBatchGet, BatchGet: res.BatchGet} + return &Response{Resp: res.BatchGet} case *tikvpb.BatchCommandsResponse_Response_BatchRollback: - return &Response{Type: CmdBatchRollback, BatchRollback: res.BatchRollback} + return &Response{Resp: res.BatchRollback} case *tikvpb.BatchCommandsResponse_Response_ScanLock: - return &Response{Type: CmdScanLock, ScanLock: res.ScanLock} + return &Response{Resp: res.ScanLock} case *tikvpb.BatchCommandsResponse_Response_ResolveLock: - return &Response{Type: CmdResolveLock, ResolveLock: res.ResolveLock} + return &Response{Resp: res.ResolveLock} case *tikvpb.BatchCommandsResponse_Response_GC: - return &Response{Type: CmdGC, GC: res.GC} + return &Response{Resp: res.GC} case *tikvpb.BatchCommandsResponse_Response_DeleteRange: - return &Response{Type: CmdDeleteRange, DeleteRange: res.DeleteRange} + return &Response{Resp: res.DeleteRange} case *tikvpb.BatchCommandsResponse_Response_RawGet: - return &Response{Type: CmdRawGet, RawGet: res.RawGet} + return &Response{Resp: res.RawGet} case *tikvpb.BatchCommandsResponse_Response_RawBatchGet: - return &Response{Type: CmdRawBatchGet, RawBatchGet: res.RawBatchGet} + return &Response{Resp: res.RawBatchGet} case *tikvpb.BatchCommandsResponse_Response_RawPut: - return &Response{Type: CmdRawPut, RawPut: res.RawPut} + return &Response{Resp: res.RawPut} case *tikvpb.BatchCommandsResponse_Response_RawBatchPut: - return &Response{Type: CmdRawBatchPut, RawBatchPut: res.RawBatchPut} + return &Response{Resp: res.RawBatchPut} case *tikvpb.BatchCommandsResponse_Response_RawDelete: - return &Response{Type: CmdRawDelete, RawDelete: res.RawDelete} + return &Response{Resp: res.RawDelete} case *tikvpb.BatchCommandsResponse_Response_RawBatchDelete: - return &Response{Type: CmdRawBatchDelete, RawBatchDelete: res.RawBatchDelete} + return &Response{Resp: res.RawBatchDelete} case *tikvpb.BatchCommandsResponse_Response_RawDeleteRange: - return &Response{Type: CmdRawDeleteRange, RawDeleteRange: res.RawDeleteRange} + return &Response{Resp: res.RawDeleteRange} case *tikvpb.BatchCommandsResponse_Response_RawScan: - return &Response{Type: CmdRawScan, RawScan: res.RawScan} + return &Response{Resp: res.RawScan} case *tikvpb.BatchCommandsResponse_Response_Coprocessor: - return &Response{Type: CmdCop, Cop: res.Coprocessor} + return &Response{Resp: res.Coprocessor} case *tikvpb.BatchCommandsResponse_Response_PessimisticLock: - return &Response{Type: CmdPessimisticLock, PessimisticLock: res.PessimisticLock} + return &Response{Resp: res.PessimisticLock} case *tikvpb.BatchCommandsResponse_Response_PessimisticRollback: - return &Response{Type: CmdPessimisticRollback, PessimisticRollback: res.PessimisticRollback} + return &Response{Resp: res.PessimisticRollback} case *tikvpb.BatchCommandsResponse_Response_Empty: - return &Response{Type: CmdEmpty, Empty: res.Empty} + return &Response{Resp: res.Empty} } return nil } @@ -340,60 +433,61 @@ func SetContext(req *Request, region *metapb.Region, peer *metapb.Peer) error { switch req.Type { case CmdGet: - req.Get.Context = ctx + req.Get().Context = ctx case CmdScan: - req.Scan.Context = ctx + req.Scan().Context = ctx case CmdPrewrite: - req.Prewrite.Context = ctx + req.Prewrite().Context = ctx case CmdPessimisticLock: - req.PessimisticLock.Context = ctx + req.PessimisticLock().Context = ctx case CmdPessimisticRollback: - req.PessimisticRollback.Context = ctx + req.PessimisticRollback().Context = ctx case CmdCommit: - req.Commit.Context = ctx + req.Commit().Context = ctx case CmdCleanup: - req.Cleanup.Context = ctx + req.Cleanup().Context = ctx case CmdBatchGet: - req.BatchGet.Context = ctx + req.BatchGet().Context = ctx case CmdBatchRollback: - req.BatchRollback.Context = ctx + req.BatchRollback().Context = ctx case CmdScanLock: - req.ScanLock.Context = ctx + req.ScanLock().Context = ctx case CmdResolveLock: - req.ResolveLock.Context = ctx + req.ResolveLock().Context = ctx case CmdGC: - req.GC.Context = ctx + req.GC().Context = ctx case CmdDeleteRange: - req.DeleteRange.Context = ctx + req.DeleteRange().Context = ctx case CmdRawGet: - req.RawGet.Context = ctx + req.RawGet().Context = ctx case CmdRawBatchGet: - req.RawBatchGet.Context = ctx + req.RawBatchGet().Context = ctx case CmdRawPut: - req.RawPut.Context = ctx + req.RawPut().Context = ctx case CmdRawBatchPut: - req.RawBatchPut.Context = ctx + req.RawBatchPut().Context = ctx case CmdRawDelete: - req.RawDelete.Context = ctx + req.RawDelete().Context = ctx case CmdRawBatchDelete: - req.RawBatchDelete.Context = ctx + req.RawBatchDelete().Context = ctx case CmdRawDeleteRange: - req.RawDeleteRange.Context = ctx + req.RawDeleteRange().Context = ctx case CmdRawScan: - req.RawScan.Context = ctx + req.RawScan().Context = ctx case CmdUnsafeDestroyRange: - req.UnsafeDestroyRange.Context = ctx + req.UnsafeDestroyRange().Context = ctx case CmdCop: - req.Cop.Context = ctx + req.Cop().Context = ctx case CmdCopStream: - req.Cop.Context = ctx + req.Cop().Context = ctx case CmdMvccGetByKey: - req.MvccGetByKey.Context = ctx + req.MvccGetByKey().Context = ctx case CmdMvccGetByStartTs: - req.MvccGetByStartTs.Context = ctx + req.MvccGetByStartTs().Context = ctx case CmdSplitRegion: - req.SplitRegion.Context = ctx + req.SplitRegion().Context = ctx case CmdEmpty: + req.SplitRegion().Context = ctx default: return fmt.Errorf("invalid request type %v", req.Type) } @@ -403,189 +497,144 @@ func SetContext(req *Request, region *metapb.Region, peer *metapb.Peer) error { // GenRegionErrorResp returns corresponding Response with specified RegionError // according to the given req. func GenRegionErrorResp(req *Request, e *errorpb.Error) (*Response, error) { + var p interface{} resp := &Response{} - resp.Type = req.Type switch req.Type { case CmdGet: - resp.Get = &kvrpcpb.GetResponse{ + p = &kvrpcpb.GetResponse{ RegionError: e, } case CmdScan: - resp.Scan = &kvrpcpb.ScanResponse{ + p = &kvrpcpb.ScanResponse{ RegionError: e, } case CmdPrewrite: - resp.Prewrite = &kvrpcpb.PrewriteResponse{ + p = &kvrpcpb.PrewriteResponse{ RegionError: e, } case CmdPessimisticLock: - resp.PessimisticLock = &kvrpcpb.PessimisticLockResponse{ + p = &kvrpcpb.PessimisticLockResponse{ RegionError: e, } case CmdPessimisticRollback: - resp.PessimisticRollback = &kvrpcpb.PessimisticRollbackResponse{ + p = &kvrpcpb.PessimisticRollbackResponse{ RegionError: e, } case CmdCommit: - resp.Commit = &kvrpcpb.CommitResponse{ + p = &kvrpcpb.CommitResponse{ RegionError: e, } case CmdCleanup: - resp.Cleanup = &kvrpcpb.CleanupResponse{ + p = &kvrpcpb.CleanupResponse{ RegionError: e, } case CmdBatchGet: - resp.BatchGet = &kvrpcpb.BatchGetResponse{ + p = &kvrpcpb.BatchGetResponse{ RegionError: e, } case CmdBatchRollback: - resp.BatchRollback = &kvrpcpb.BatchRollbackResponse{ + p = &kvrpcpb.BatchRollbackResponse{ RegionError: e, } case CmdScanLock: - resp.ScanLock = &kvrpcpb.ScanLockResponse{ + p = &kvrpcpb.ScanLockResponse{ RegionError: e, } case CmdResolveLock: - resp.ResolveLock = &kvrpcpb.ResolveLockResponse{ + p = &kvrpcpb.ResolveLockResponse{ RegionError: e, } case CmdGC: - resp.GC = &kvrpcpb.GCResponse{ + p = &kvrpcpb.GCResponse{ RegionError: e, } case CmdDeleteRange: - resp.DeleteRange = &kvrpcpb.DeleteRangeResponse{ + p = &kvrpcpb.DeleteRangeResponse{ RegionError: e, } case CmdRawGet: - resp.RawGet = &kvrpcpb.RawGetResponse{ + p = &kvrpcpb.RawGetResponse{ RegionError: e, } case CmdRawBatchGet: - resp.RawBatchGet = &kvrpcpb.RawBatchGetResponse{ + p = &kvrpcpb.RawBatchGetResponse{ RegionError: e, } case CmdRawPut: - resp.RawPut = &kvrpcpb.RawPutResponse{ + p = &kvrpcpb.RawPutResponse{ RegionError: e, } case CmdRawBatchPut: - resp.RawBatchPut = &kvrpcpb.RawBatchPutResponse{ + p = &kvrpcpb.RawBatchPutResponse{ RegionError: e, } case CmdRawDelete: - resp.RawDelete = &kvrpcpb.RawDeleteResponse{ + p = &kvrpcpb.RawDeleteResponse{ RegionError: e, } case CmdRawBatchDelete: - resp.RawBatchDelete = &kvrpcpb.RawBatchDeleteResponse{ + p = &kvrpcpb.RawBatchDeleteResponse{ RegionError: e, } case CmdRawDeleteRange: - resp.RawDeleteRange = &kvrpcpb.RawDeleteRangeResponse{ + p = &kvrpcpb.RawDeleteRangeResponse{ RegionError: e, } case CmdRawScan: - resp.RawScan = &kvrpcpb.RawScanResponse{ + p = &kvrpcpb.RawScanResponse{ RegionError: e, } case CmdUnsafeDestroyRange: - resp.UnsafeDestroyRange = &kvrpcpb.UnsafeDestroyRangeResponse{ + p = &kvrpcpb.UnsafeDestroyRangeResponse{ RegionError: e, } case CmdCop: - resp.Cop = &coprocessor.Response{ + p = &coprocessor.Response{ RegionError: e, } case CmdCopStream: - resp.CopStream = &CopStreamResponse{ + p = &CopStreamResponse{ Response: &coprocessor.Response{ RegionError: e, }, } case CmdMvccGetByKey: - resp.MvccGetByKey = &kvrpcpb.MvccGetByKeyResponse{ + p = &kvrpcpb.MvccGetByKeyResponse{ RegionError: e, } case CmdMvccGetByStartTs: - resp.MvccGetByStartTS = &kvrpcpb.MvccGetByStartTsResponse{ + p = &kvrpcpb.MvccGetByStartTsResponse{ RegionError: e, } case CmdSplitRegion: - resp.SplitRegion = &kvrpcpb.SplitRegionResponse{ + p = &kvrpcpb.SplitRegionResponse{ RegionError: e, } case CmdEmpty: default: return nil, fmt.Errorf("invalid request type %v", req.Type) } + resp.Resp = p return resp, nil } +type getRegionError interface { + GetRegionError() *errorpb.Error +} + // GetRegionError returns the RegionError of the underlying concrete response. func (resp *Response) GetRegionError() (*errorpb.Error, error) { - var e *errorpb.Error - switch resp.Type { - case CmdGet: - e = resp.Get.GetRegionError() - case CmdScan: - e = resp.Scan.GetRegionError() - case CmdPessimisticLock: - e = resp.PessimisticLock.GetRegionError() - case CmdPessimisticRollback: - e = resp.PessimisticRollback.GetRegionError() - case CmdPrewrite: - e = resp.Prewrite.GetRegionError() - case CmdCommit: - e = resp.Commit.GetRegionError() - case CmdCleanup: - e = resp.Cleanup.GetRegionError() - case CmdBatchGet: - e = resp.BatchGet.GetRegionError() - case CmdBatchRollback: - e = resp.BatchRollback.GetRegionError() - case CmdScanLock: - e = resp.ScanLock.GetRegionError() - case CmdResolveLock: - e = resp.ResolveLock.GetRegionError() - case CmdGC: - e = resp.GC.GetRegionError() - case CmdDeleteRange: - e = resp.DeleteRange.GetRegionError() - case CmdRawGet: - e = resp.RawGet.GetRegionError() - case CmdRawBatchGet: - e = resp.RawBatchGet.GetRegionError() - case CmdRawPut: - e = resp.RawPut.GetRegionError() - case CmdRawBatchPut: - e = resp.RawBatchPut.GetRegionError() - case CmdRawDelete: - e = resp.RawDelete.GetRegionError() - case CmdRawBatchDelete: - e = resp.RawBatchDelete.GetRegionError() - case CmdRawDeleteRange: - e = resp.RawDeleteRange.GetRegionError() - case CmdRawScan: - e = resp.RawScan.GetRegionError() - case CmdUnsafeDestroyRange: - e = resp.UnsafeDestroyRange.GetRegionError() - case CmdCop: - e = resp.Cop.GetRegionError() - case CmdCopStream: - e = resp.CopStream.Response.GetRegionError() - case CmdMvccGetByKey: - e = resp.MvccGetByKey.GetRegionError() - case CmdMvccGetByStartTs: - e = resp.MvccGetByStartTS.GetRegionError() - case CmdSplitRegion: - e = resp.SplitRegion.GetRegionError() - case CmdEmpty: - default: - return nil, fmt.Errorf("invalid response type %v", resp.Type) + if resp.Resp == nil { + return nil, nil } - return e, nil + err, ok := resp.Resp.(getRegionError) + if !ok { + if _, isEmpty := resp.Resp.(*tikvpb.BatchCommandsEmptyResponse); isEmpty { + return nil, nil + } + return nil, fmt.Errorf("invalid response type %v", resp) + } + return err.GetRegionError(), nil } // CallRPC launches a rpc call. @@ -593,69 +642,68 @@ func (resp *Response) GetRegionError() (*errorpb.Error, error) { // cancel function will be sent to the channel, together with a lease checked by a background goroutine. func CallRPC(ctx context.Context, client tikvpb.TikvClient, req *Request) (*Response, error) { resp := &Response{} - resp.Type = req.Type var err error switch req.Type { case CmdGet: - resp.Get, err = client.KvGet(ctx, req.Get) + resp.Resp, err = client.KvGet(ctx, req.Get()) case CmdScan: - resp.Scan, err = client.KvScan(ctx, req.Scan) + resp.Resp, err = client.KvScan(ctx, req.Scan()) case CmdPrewrite: - resp.Prewrite, err = client.KvPrewrite(ctx, req.Prewrite) + resp.Resp, err = client.KvPrewrite(ctx, req.Prewrite()) case CmdPessimisticLock: - resp.PessimisticLock, err = client.KvPessimisticLock(ctx, req.PessimisticLock) + resp.Resp, err = client.KvPessimisticLock(ctx, req.PessimisticLock()) case CmdPessimisticRollback: - resp.PessimisticRollback, err = client.KVPessimisticRollback(ctx, req.PessimisticRollback) + resp.Resp, err = client.KVPessimisticRollback(ctx, req.PessimisticRollback()) case CmdCommit: - resp.Commit, err = client.KvCommit(ctx, req.Commit) + resp.Resp, err = client.KvCommit(ctx, req.Commit()) case CmdCleanup: - resp.Cleanup, err = client.KvCleanup(ctx, req.Cleanup) + resp.Resp, err = client.KvCleanup(ctx, req.Cleanup()) case CmdBatchGet: - resp.BatchGet, err = client.KvBatchGet(ctx, req.BatchGet) + resp.Resp, err = client.KvBatchGet(ctx, req.BatchGet()) case CmdBatchRollback: - resp.BatchRollback, err = client.KvBatchRollback(ctx, req.BatchRollback) + resp.Resp, err = client.KvBatchRollback(ctx, req.BatchRollback()) case CmdScanLock: - resp.ScanLock, err = client.KvScanLock(ctx, req.ScanLock) + resp.Resp, err = client.KvScanLock(ctx, req.ScanLock()) case CmdResolveLock: - resp.ResolveLock, err = client.KvResolveLock(ctx, req.ResolveLock) + resp.Resp, err = client.KvResolveLock(ctx, req.ResolveLock()) case CmdGC: - resp.GC, err = client.KvGC(ctx, req.GC) + resp.Resp, err = client.KvGC(ctx, req.GC()) case CmdDeleteRange: - resp.DeleteRange, err = client.KvDeleteRange(ctx, req.DeleteRange) + resp.Resp, err = client.KvDeleteRange(ctx, req.DeleteRange()) case CmdRawGet: - resp.RawGet, err = client.RawGet(ctx, req.RawGet) + resp.Resp, err = client.RawGet(ctx, req.RawGet()) case CmdRawBatchGet: - resp.RawBatchGet, err = client.RawBatchGet(ctx, req.RawBatchGet) + resp.Resp, err = client.RawBatchGet(ctx, req.RawBatchGet()) case CmdRawPut: - resp.RawPut, err = client.RawPut(ctx, req.RawPut) + resp.Resp, err = client.RawPut(ctx, req.RawPut()) case CmdRawBatchPut: - resp.RawBatchPut, err = client.RawBatchPut(ctx, req.RawBatchPut) + resp.Resp, err = client.RawBatchPut(ctx, req.RawBatchPut()) case CmdRawDelete: - resp.RawDelete, err = client.RawDelete(ctx, req.RawDelete) + resp.Resp, err = client.RawDelete(ctx, req.RawDelete()) case CmdRawBatchDelete: - resp.RawBatchDelete, err = client.RawBatchDelete(ctx, req.RawBatchDelete) + resp.Resp, err = client.RawBatchDelete(ctx, req.RawBatchDelete()) case CmdRawDeleteRange: - resp.RawDeleteRange, err = client.RawDeleteRange(ctx, req.RawDeleteRange) + resp.Resp, err = client.RawDeleteRange(ctx, req.RawDeleteRange()) case CmdRawScan: - resp.RawScan, err = client.RawScan(ctx, req.RawScan) + resp.Resp, err = client.RawScan(ctx, req.RawScan()) case CmdUnsafeDestroyRange: - resp.UnsafeDestroyRange, err = client.UnsafeDestroyRange(ctx, req.UnsafeDestroyRange) + resp.Resp, err = client.UnsafeDestroyRange(ctx, req.UnsafeDestroyRange()) case CmdCop: - resp.Cop, err = client.Coprocessor(ctx, req.Cop) + resp.Resp, err = client.Coprocessor(ctx, req.Cop()) case CmdCopStream: var streamClient tikvpb.Tikv_CoprocessorStreamClient - streamClient, err = client.CoprocessorStream(ctx, req.Cop) - resp.CopStream = &CopStreamResponse{ + streamClient, err = client.CoprocessorStream(ctx, req.Cop()) + resp.Resp = &CopStreamResponse{ Tikv_CoprocessorStreamClient: streamClient, } case CmdMvccGetByKey: - resp.MvccGetByKey, err = client.MvccGetByKey(ctx, req.MvccGetByKey) + resp.Resp, err = client.MvccGetByKey(ctx, req.MvccGetByKey()) case CmdMvccGetByStartTs: - resp.MvccGetByStartTS, err = client.MvccGetByStartTs(ctx, req.MvccGetByStartTs) + resp.Resp, err = client.MvccGetByStartTs(ctx, req.MvccGetByStartTs()) case CmdSplitRegion: - resp.SplitRegion, err = client.SplitRegion(ctx, req.SplitRegion) + resp.Resp, err = client.SplitRegion(ctx, req.SplitRegion()) case CmdEmpty: - resp.Empty, err = &tikvpb.BatchCommandsEmptyResponse{}, nil + resp.Resp, err = &tikvpb.BatchCommandsEmptyResponse{}, nil default: return nil, errors.Errorf("invalid request type: %v", req.Type) } @@ -667,12 +715,11 @@ func CallRPC(ctx context.Context, client tikvpb.TikvClient, req *Request) (*Resp // CallDebugRPC launches a debug rpc call. func CallDebugRPC(ctx context.Context, client debugpb.DebugClient, req *Request) (*Response, error) { - resp := &Response{Type: req.Type} - resp.Type = req.Type + resp := &Response{} var err error switch req.Type { case CmdDebugGetRegionProperties: - resp.DebugGetRegionProperties, err = client.GetRegionProperties(ctx, req.DebugGetRegionProperties) + resp.Resp, err = client.GetRegionProperties(ctx, req.DebugGetRegionProperties()) default: return nil, errors.Errorf("invalid request type: %v", req.Type) } diff --git a/table/tables/tables.go b/table/tables/tables.go index 5a1aa172482bd..e6b04136e0e3c 100644 --- a/table/tables/tables.go +++ b/table/tables/tables.go @@ -314,7 +314,9 @@ func (t *tableCommon) UpdateRecord(ctx sessionctx.Context, h int64, oldData, new } key := t.RecordKey(h) - value, err := tablecodec.EncodeRow(ctx.GetSessionVars().StmtCtx, row, colIDs, nil, nil) + sessVars := ctx.GetSessionVars() + sc := sessVars.StmtCtx + value, err := tablecodec.EncodeRow(sc, row, colIDs, nil, nil) if err != nil { return err } @@ -338,13 +340,21 @@ func (t *tableCommon) UpdateRecord(ctx sessionctx.Context, h int64, oldData, new } } colSize := make(map[int64]int64) + encodedCol := make([]byte, 0, 16) for id, col := range t.Cols() { - val := int64(len(newData[id].GetBytes()) - len(oldData[id].GetBytes())) - if val != 0 { - colSize[col.ID] = val + encodedCol, err = tablecodec.EncodeValue(sc, encodedCol[:0], newData[id]) + if err != nil { + continue } + newLen := len(encodedCol) - 1 + encodedCol, err = tablecodec.EncodeValue(sc, encodedCol[:0], oldData[id]) + if err != nil { + continue + } + oldLen := len(encodedCol) - 1 + colSize[col.ID] = int64(newLen - oldLen) } - ctx.GetSessionVars().TxnCtx.UpdateDeltaForTable(t.physicalTableID, 0, 1, colSize) + sessVars.TxnCtx.UpdateDeltaForTable(t.physicalTableID, 0, 1, colSize) return nil } @@ -504,7 +514,8 @@ func (t *tableCommon) AddRecord(ctx sessionctx.Context, r []types.Datum, opts .. writeBufs := sessVars.GetWriteStmtBufs() adjustRowValuesBuf(writeBufs, len(row)) key := t.RecordKey(recordID) - writeBufs.RowValBuf, err = tablecodec.EncodeRow(ctx.GetSessionVars().StmtCtx, row, colIDs, writeBufs.RowValBuf, writeBufs.AddRowValues) + sc := sessVars.StmtCtx + writeBufs.RowValBuf, err = tablecodec.EncodeRow(sc, row, colIDs, writeBufs.RowValBuf, writeBufs.AddRowValues) if err != nil { return 0, err } @@ -532,13 +543,15 @@ func (t *tableCommon) AddRecord(ctx sessionctx.Context, r []types.Datum, opts .. return 0, err } } - sessVars.StmtCtx.AddAffectedRows(1) + sc.AddAffectedRows(1) colSize := make(map[int64]int64) + encodedCol := make([]byte, 0, 16) for id, col := range t.Cols() { - val := int64(len(r[id].GetBytes())) - if val != 0 { - colSize[col.ID] = val + encodedCol, err = tablecodec.EncodeValue(sc, encodedCol[:0], r[id]) + if err != nil { + continue } + colSize[col.ID] = int64(len(encodedCol) - 1) } sessVars.TxnCtx.UpdateDeltaForTable(t.physicalTableID, 1, 1, colSize) return recordID, nil @@ -714,11 +727,14 @@ func (t *tableCommon) RemoveRecord(ctx sessionctx.Context, h int64, r []types.Da err = t.addDeleteBinlog(ctx, binlogRow, colIDs) } colSize := make(map[int64]int64) + encodedCol := make([]byte, 0, 16) + sc := ctx.GetSessionVars().StmtCtx for id, col := range t.Cols() { - val := -int64(len(r[id].GetBytes())) - if val != 0 { - colSize[col.ID] = val + encodedCol, err = tablecodec.EncodeValue(sc, encodedCol[:0], r[id]) + if err != nil { + continue } + colSize[col.ID] = -int64(len(encodedCol) - 1) } ctx.GetSessionVars().TxnCtx.UpdateDeltaForTable(t.physicalTableID, -1, 1, colSize) return err diff --git a/tablecodec/tablecodec.go b/tablecodec/tablecodec.go index 62adacc54a9e8..0666427666fe0 100644 --- a/tablecodec/tablecodec.go +++ b/tablecodec/tablecodec.go @@ -209,14 +209,13 @@ func DecodeRowKey(key kv.Key) (int64, error) { } // EncodeValue encodes a go value to bytes. -func EncodeValue(sc *stmtctx.StatementContext, raw types.Datum) ([]byte, error) { +func EncodeValue(sc *stmtctx.StatementContext, b []byte, raw types.Datum) ([]byte, error) { var v types.Datum err := flatten(sc, raw, &v) if err != nil { - return nil, errors.Trace(err) + return nil, err } - b, err := codec.EncodeValue(sc, nil, v) - return b, errors.Trace(err) + return codec.EncodeValue(sc, b, v) } // EncodeRow encode row data and column ids into a slice of byte. diff --git a/tablecodec/tablecodec_test.go b/tablecodec/tablecodec_test.go index 7c339d8b7ef14..4d4175282ee28 100644 --- a/tablecodec/tablecodec_test.go +++ b/tablecodec/tablecodec_test.go @@ -243,11 +243,11 @@ func (s *testTableCodecSuite) TestCutRow(c *C) { sc := &stmtctx.StatementContext{TimeZone: time.UTC} data := make([][]byte, 3) - data[0], err = EncodeValue(sc, row[0]) + data[0], err = EncodeValue(sc, nil, row[0]) c.Assert(err, IsNil) - data[1], err = EncodeValue(sc, row[1]) + data[1], err = EncodeValue(sc, nil, row[1]) c.Assert(err, IsNil) - data[2], err = EncodeValue(sc, row[2]) + data[2], err = EncodeValue(sc, nil, row[2]) c.Assert(err, IsNil) // Encode colIDs := make([]int64, 0, 3) @@ -490,3 +490,24 @@ func BenchmarkHasTablePrefixBuiltin(b *testing.B) { k.HasPrefix(tablePrefix) } } + +// Bench result: +// BenchmarkEncodeValue 5000000 368 ns/op +func BenchmarkEncodeValue(b *testing.B) { + row := make([]types.Datum, 7) + row[0] = types.NewIntDatum(100) + row[1] = types.NewBytesDatum([]byte("abc")) + row[2] = types.NewDecimalDatum(types.NewDecFromInt(1)) + row[3] = types.NewMysqlEnumDatum(types.Enum{Name: "a", Value: 0}) + row[4] = types.NewDatum(types.Set{Name: "a", Value: 0}) + row[5] = types.NewDatum(types.BinaryLiteral{100}) + row[6] = types.NewFloat32Datum(1.5) + b.ResetTimer() + encodedCol := make([]byte, 0, 16) + for i := 0; i < b.N; i++ { + for _, d := range row { + encodedCol = encodedCol[:0] + EncodeValue(nil, encodedCol, d) + } + } +} diff --git a/types/convert.go b/types/convert.go index 5e16ad43ac2d1..7ddfa5e1dd2b1 100644 --- a/types/convert.go +++ b/types/convert.go @@ -394,6 +394,9 @@ func roundIntStr(numNextDot byte, intStr string) string { // strconv.ParseInt, we can't parse float first then convert it to string because precision will // be lost. For example, the string value "18446744073709551615" which is the max number of unsigned // int will cause some precision to lose. intStr[0] may be a positive and negative sign like '+' or '-'. +// +// This func will find serious overflow such as the len of intStr > 20 (without prefix `+/-`) +// however, it will not check whether the intStr overflow BIGINT. func floatStrToIntStr(sc *stmtctx.StatementContext, validFloat string, oriStr string) (intStr string, _ error) { var dotIdx = -1 var eIdx = -1 @@ -443,12 +446,15 @@ func floatStrToIntStr(sc *stmtctx.StatementContext, validFloat string, oriStr st if err != nil { return validFloat, errors.Trace(err) } - if exp > 0 && int64(intCnt) > (math.MaxInt64-int64(exp)) { - // (exp + incCnt) overflows MaxInt64. + intCnt += exp + if exp >= 0 && (intCnt > 21 || intCnt < 0) { + // MaxInt64 has 19 decimal digits. + // MaxUint64 has 20 decimal digits. + // And the intCnt may contain the len of `+/-`, + // so I use 21 here as the early detection. sc.AppendWarning(ErrOverflow.GenWithStackByArgs("BIGINT", oriStr)) return validFloat[:eIdx], nil } - intCnt += exp if intCnt <= 0 { intStr = "0" if intCnt == 0 && len(digits) > 0 { @@ -474,11 +480,6 @@ func floatStrToIntStr(sc *stmtctx.StatementContext, validFloat string, oriStr st } else { // convert scientific notation decimal number extraZeroCount := intCnt - len(digits) - if extraZeroCount > 20 { - // Append overflow warning and return to avoid allocating too much memory. - sc.AppendWarning(ErrOverflow.GenWithStackByArgs("BIGINT", oriStr)) - return validFloat[:eIdx], nil - } intStr = string(digits) + strings.Repeat("0", extraZeroCount) } return intStr, nil @@ -580,6 +581,10 @@ func ConvertJSONToDecimal(sc *stmtctx.StatementContext, j json.BinaryJSON) (*MyD // getValidFloatPrefix gets prefix of string which can be successfully parsed as float. func getValidFloatPrefix(sc *stmtctx.StatementContext, s string) (valid string, err error) { + if sc.InDeleteStmt && s == "" { + return "0", nil + } + var ( sawDot bool sawDigit bool diff --git a/types/convert_test.go b/types/convert_test.go index c591650952056..52e1b58832683 100644 --- a/types/convert_test.go +++ b/types/convert_test.go @@ -461,6 +461,29 @@ func (s *testTypeConvertSuite) TestStrToNum(c *C) { testStrToFloat(c, "1e649", math.MaxFloat64, false, nil) testStrToFloat(c, "-1e649", -math.MaxFloat64, true, ErrTruncatedWrongVal) testStrToFloat(c, "-1e649", -math.MaxFloat64, false, nil) + + // for issue #10806 + testDeleteEmptyStringError(c) +} + +func testDeleteEmptyStringError(c *C) { + sc := new(stmtctx.StatementContext) + sc.InDeleteStmt = true + + str := "" + expect := 0 + + val, err := StrToInt(sc, str) + c.Assert(err, IsNil) + c.Assert(val, Equals, int64(expect)) + + val1, err := StrToUint(sc, str) + c.Assert(err, IsNil) + c.Assert(val1, Equals, uint64(expect)) + + val2, err := StrToFloat(sc, str) + c.Assert(err, IsNil) + c.Assert(val2, Equals, float64(expect)) } func (s *testTypeConvertSuite) TestFieldTypeToStr(c *C) { @@ -708,6 +731,7 @@ func (s *testTypeConvertSuite) TestGetValidFloat(c *C) { {".5e0", "1"}, {"+.5e0", "+1"}, {"-.5e0", "-1"}, + {".5", "1"}, {"123.456789e5", "12345679"}, {"123.456784e5", "12345678"}, } diff --git a/types/fsp.go b/types/fsp.go index fe5a656cd87bb..c9709822c6453 100644 --- a/types/fsp.go +++ b/types/fsp.go @@ -86,9 +86,12 @@ func ParseFrac(s string, fsp int) (v int, overflow bool, err error) { return } -// alignFrac is used to generate alignment frac, like `100` -> `100000` +// alignFrac is used to generate alignment frac, like `100` -> `100000` ,`-100` -> `-100000` func alignFrac(s string, fsp int) string { sl := len(s) + if sl > 0 && s[0] == '-' { + sl = sl - 1 + } if sl < fsp { return s + strings.Repeat("0", fsp-sl) } diff --git a/types/fsp_test.go b/types/fsp_test.go index 8802e87d5b3e4..b8f29cd4077d7 100644 --- a/types/fsp_test.go +++ b/types/fsp_test.go @@ -115,4 +115,8 @@ func (s *FspTest) TestAlignFrac(c *C) { c.Assert(obtained, Equals, "100000") obtained = alignFrac("10000000000", 6) c.Assert(obtained, Equals, "10000000000") + obtained = alignFrac("-100", 6) + c.Assert(obtained, Equals, "-100000") + obtained = alignFrac("-10000000000", 6) + c.Assert(obtained, Equals, "-10000000000") } diff --git a/types/mydecimal.go b/types/mydecimal.go index 06ca50c4204fb..996323fa8706f 100644 --- a/types/mydecimal.go +++ b/types/mydecimal.go @@ -107,6 +107,14 @@ var ( zeroMyDecimal = MyDecimal{} ) +// get the zero of MyDecimal with the specified result fraction digits +func zeroMyDecimalWithFrac(frac int8) MyDecimal { + zero := MyDecimal{} + zero.digitsFrac = frac + zero.resultFrac = frac + return zero +} + // add adds a and b and carry, returns the sum and new carry. func add(a, b, carry int32) (int32, int32) { sum := a + b + carry @@ -1556,7 +1564,7 @@ func doSub(from1, from2, to *MyDecimal) (cmp int, err error) { if to == nil { return 0, nil } - *to = zeroMyDecimal + *to = zeroMyDecimalWithFrac(to.resultFrac) return 0, nil } } @@ -1911,7 +1919,7 @@ func DecimalMul(from1, from2, to *MyDecimal) error { idx++ /* We got decimal zero */ if idx == end { - *to = zeroMyDecimal + *to = zeroMyDecimalWithFrac(to.resultFrac) break } } @@ -2010,7 +2018,7 @@ func doDivMod(from1, from2, to, mod *MyDecimal, fracIncr int) error { } if prec1 <= 0 { /* short-circuit everything: from1 == 0 */ - *to = zeroMyDecimal + *to = zeroMyDecimalWithFrac(to.resultFrac) return nil } prec1 -= countLeadingZeroes((prec1-1)%digitsPerWord, from1.wordBuf[idx1]) diff --git a/types/mydecimal_test.go b/types/mydecimal_test.go index e799692231c6a..551105987b3d0 100644 --- a/types/mydecimal_test.go +++ b/types/mydecimal_test.go @@ -694,6 +694,7 @@ func (s *testMyDecimalSuite) TestAdd(c *C) { {"-123.45", "12345", "12221.55", nil}, {"5", "-6.0", "-1.0", nil}, {"2" + strings.Repeat("1", 71), strings.Repeat("8", 81), "8888888890" + strings.Repeat("9", 71), nil}, + {"-1234.1234", "1234.1234", "0.0000", nil}, } for _, tt := range tests { a := NewDecFromStringForTest(tt.a) @@ -718,7 +719,7 @@ func (s *testMyDecimalSuite) TestSub(c *C) { {"1234500009876.5", ".00012345000098765", "1234500009876.49987654999901235", nil}, {"9999900000000.5", ".555", "9999899999999.945", nil}, {"1111.5551", "1111.555", "0.0001", nil}, - {".555", ".555", "0", nil}, + {".555", ".555", "0.000", nil}, {"10000000", "1", "9999999", nil}, {"1000001000", ".1", "1000000999.9", nil}, {"1000000000", ".1", "999999999.9", nil}, @@ -728,6 +729,7 @@ func (s *testMyDecimalSuite) TestSub(c *C) { {"-123.45", "-12345", "12221.55", nil}, {"-12345", "123.45", "-12468.45", nil}, {"12345", "-123.45", "12468.45", nil}, + {"12.12", "12.12", "0.00", nil}, } for _, tt := range tests { var a, b, sum MyDecimal @@ -759,6 +761,7 @@ func (s *testMyDecimalSuite) TestMul(c *C) { {"1" + strings.Repeat("0", 60), "1" + strings.Repeat("0", 60), "0", ErrOverflow}, {"0.5999991229316", "0.918755041726043", "0.5512522192246113614062276588", nil}, {"0.5999991229317", "0.918755041726042", "0.5512522192247026369112773314", nil}, + {"0.000", "-1", "0.000", nil}, } for _, tt := range tests { var a, b, product MyDecimal @@ -786,7 +789,7 @@ func (s *testMyDecimalSuite) TestDivMod(c *C) { {"0", "0", "", ErrDivByZero}, {"-12193185.1853376", "98765.4321", "-123.456000000000000000", nil}, {"121931851853376", "987654321", "123456.000000000", nil}, - {"0", "987", "0", nil}, + {"0", "987", "0.00000", nil}, {"1", "3", "0.333333333", nil}, {"1.000000000000", "3", "0.333333333333333333", nil}, {"1", "1", "1.000000000", nil}, @@ -799,7 +802,7 @@ func (s *testMyDecimalSuite) TestDivMod(c *C) { var a, b, to MyDecimal a.FromString([]byte(tt.a)) b.FromString([]byte(tt.b)) - err := doDivMod(&a, &b, &to, nil, 5) + err := DecimalDiv(&a, &b, &to, 5) c.Check(err, Equals, tt.err) if tt.err == ErrDivByZero { continue @@ -816,12 +819,13 @@ func (s *testMyDecimalSuite) TestDivMod(c *C) { {"99999999999999999999999999999999999999", "3", "0", nil}, {"51", "0.003430", "0.002760", nil}, {"0.0000000001", "1.0", "0.0000000001", nil}, + {"0.000", "0.1", "0.000", nil}, } for _, tt := range tests { var a, b, to MyDecimal a.FromString([]byte(tt.a)) b.FromString([]byte(tt.b)) - ec := doDivMod(&a, &b, nil, &to, 0) + ec := DecimalMod(&a, &b, &to) c.Check(ec, Equals, tt.err) if tt.err == ErrDivByZero { continue @@ -836,6 +840,7 @@ func (s *testMyDecimalSuite) TestDivMod(c *C) { {"1", "1.000", "1.0000", nil}, {"2", "3", "0.6667", nil}, {"51", "0.003430", "14868.8047", nil}, + {"0.000", "0.1", "0.0000000", nil}, } for _, tt := range tests { var a, b, to MyDecimal diff --git a/util/mock/context.go b/util/mock/context.go index befd00241b08f..31ddddb5fe209 100644 --- a/util/mock/context.go +++ b/util/mock/context.go @@ -264,6 +264,9 @@ func NewContext() *Context { sctx.sessionVars.MaxChunkSize = 32 sctx.sessionVars.StmtCtx.TimeZone = time.UTC sctx.sessionVars.GlobalVarsAccessor = variable.NewMockGlobalAccessor() + if err := sctx.GetSessionVars().SetSystemVar(variable.MaxAllowedPacket, "67108864"); err != nil { + panic(err) + } return sctx } diff --git a/util/sqlexec/restricted_sql_executor.go b/util/sqlexec/restricted_sql_executor.go index 5c5c0d31711f0..ec42336087752 100644 --- a/util/sqlexec/restricted_sql_executor.go +++ b/util/sqlexec/restricted_sql_executor.go @@ -96,3 +96,17 @@ type RecordSet interface { // restart the iteration. Close() error } + +// MultiQueryNoDelayResult is an interface for one no-delay result for one statement in multi-queries. +type MultiQueryNoDelayResult interface { + // AffectedRows return affected row for one statement in multi-queries. + AffectedRows() uint64 + // LastMessage return last message for one statement in multi-queries. + LastMessage() string + // WarnCount return warn count for one statement in multi-queries. + WarnCount() uint16 + // Status return status when executing one statement in multi-queries. + Status() uint16 + // LastInsertID return last insert id for one statement in multi-queries. + LastInsertID() uint64 +}