diff --git a/ddl/index.go b/ddl/index.go index 2a521542f0909..9887eea569ffa 100644 --- a/ddl/index.go +++ b/ddl/index.go @@ -1189,7 +1189,7 @@ func (w *addIndexWorker) BackfillDataInTxn(handleRange reorgBackfillTask) (taskC } // Create the index. - handle, err := w.index.Create(w.sessCtx, txn.GetUnionStore(), idxRecord.vals, idxRecord.handle, idxRecord.rsData) + handle, err := w.index.Create(w.sessCtx, txn, idxRecord.vals, idxRecord.handle, idxRecord.rsData) if err != nil { if kv.ErrKeyExists.Equal(err) && idxRecord.handle.Equal(handle) { // Index already exists, skip it. diff --git a/executor/admin.go b/executor/admin.go index 4769fb760d744..1b81dcf342c43 100644 --- a/executor/admin.go +++ b/executor/admin.go @@ -465,7 +465,7 @@ func (e *RecoverIndexExec) backfillIndexInTxn(ctx context.Context, txn kv.Transa return result, err } - _, err = e.index.Create(e.ctx, txn.GetUnionStore(), row.idxVals, row.handle, row.rsData) + _, err = e.index.Create(e.ctx, txn, row.idxVals, row.handle, row.rsData) if err != nil { return result, err } diff --git a/executor/admin_test.go b/executor/admin_test.go index 427fa7ab54678..ba44194944bcc 100644 --- a/executor/admin_test.go +++ b/executor/admin_test.go @@ -386,21 +386,21 @@ func (s *testSuite5) TestAdminCleanupIndex(c *C) { txn, err := s.store.Begin() c.Assert(err, IsNil) - _, err = indexOpr2.Create(s.ctx, txn.GetUnionStore(), types.MakeDatums(1), kv.IntHandle(-100), nil) + _, err = indexOpr2.Create(s.ctx, txn, types.MakeDatums(1), kv.IntHandle(-100), nil) c.Assert(err, IsNil) - _, err = indexOpr2.Create(s.ctx, txn.GetUnionStore(), types.MakeDatums(6), kv.IntHandle(100), nil) + _, err = indexOpr2.Create(s.ctx, txn, types.MakeDatums(6), kv.IntHandle(100), nil) c.Assert(err, IsNil) - _, err = indexOpr2.Create(s.ctx, txn.GetUnionStore(), types.MakeDatums(8), kv.IntHandle(100), nil) + _, err = indexOpr2.Create(s.ctx, txn, types.MakeDatums(8), kv.IntHandle(100), nil) c.Assert(err, IsNil) - _, err = indexOpr2.Create(s.ctx, txn.GetUnionStore(), types.MakeDatums(nil), kv.IntHandle(101), nil) + _, err = indexOpr2.Create(s.ctx, txn, types.MakeDatums(nil), kv.IntHandle(101), nil) c.Assert(err, IsNil) - _, err = indexOpr2.Create(s.ctx, txn.GetUnionStore(), types.MakeDatums(nil), kv.IntHandle(102), nil) + _, err = indexOpr2.Create(s.ctx, txn, types.MakeDatums(nil), kv.IntHandle(102), nil) c.Assert(err, IsNil) - _, err = indexOpr3.Create(s.ctx, txn.GetUnionStore(), types.MakeDatums(6), kv.IntHandle(200), nil) + _, err = indexOpr3.Create(s.ctx, txn, types.MakeDatums(6), kv.IntHandle(200), nil) c.Assert(err, IsNil) - _, err = indexOpr3.Create(s.ctx, txn.GetUnionStore(), types.MakeDatums(6), kv.IntHandle(-200), nil) + _, err = indexOpr3.Create(s.ctx, txn, types.MakeDatums(6), kv.IntHandle(-200), nil) c.Assert(err, IsNil) - _, err = indexOpr3.Create(s.ctx, txn.GetUnionStore(), types.MakeDatums(8), kv.IntHandle(-200), nil) + _, err = indexOpr3.Create(s.ctx, txn, types.MakeDatums(8), kv.IntHandle(-200), nil) c.Assert(err, IsNil) err = txn.Commit(context.Background()) c.Assert(err, IsNil) @@ -455,9 +455,9 @@ func (s *testSuite5) TestAdminCleanupIndexForPartitionTable(c *C) { txn, err := s.store.Begin() c.Assert(err, IsNil) - _, err = indexOpr2.Create(s.ctx, txn.GetUnionStore(), types.MakeDatums(idxValue), kv.IntHandle(handle), nil) + _, err = indexOpr2.Create(s.ctx, txn, types.MakeDatums(idxValue), kv.IntHandle(handle), nil) c.Assert(err, IsNil) - _, err = indexOpr3.Create(s.ctx, txn.GetUnionStore(), types.MakeDatums(idxValue), kv.IntHandle(handle), nil) + _, err = indexOpr3.Create(s.ctx, txn, types.MakeDatums(idxValue), kv.IntHandle(handle), nil) c.Assert(err, IsNil) err = txn.Commit(context.Background()) c.Assert(err, IsNil) @@ -537,11 +537,11 @@ func (s *testSuite5) TestAdminCleanupIndexPKNotHandle(c *C) { txn, err := s.store.Begin() c.Assert(err, IsNil) - _, err = indexOpr.Create(s.ctx, txn.GetUnionStore(), types.MakeDatums(7, 10), kv.IntHandle(-100), nil) + _, err = indexOpr.Create(s.ctx, txn, types.MakeDatums(7, 10), kv.IntHandle(-100), nil) c.Assert(err, IsNil) - _, err = indexOpr.Create(s.ctx, txn.GetUnionStore(), types.MakeDatums(4, 6), kv.IntHandle(100), nil) + _, err = indexOpr.Create(s.ctx, txn, types.MakeDatums(4, 6), kv.IntHandle(100), nil) c.Assert(err, IsNil) - _, err = indexOpr.Create(s.ctx, txn.GetUnionStore(), types.MakeDatums(-7, 4), kv.IntHandle(101), nil) + _, err = indexOpr.Create(s.ctx, txn, types.MakeDatums(-7, 4), kv.IntHandle(101), nil) c.Assert(err, IsNil) err = txn.Commit(context.Background()) c.Assert(err, IsNil) @@ -590,9 +590,9 @@ func (s *testSuite5) TestAdminCleanupIndexMore(c *C) { for i := 0; i < 2000; i++ { c1 := int64(2*i + 7) c2 := int64(2*i + 8) - _, err = indexOpr1.Create(s.ctx, txn.GetUnionStore(), types.MakeDatums(c1, c2), kv.IntHandle(c1), nil) + _, err = indexOpr1.Create(s.ctx, txn, types.MakeDatums(c1, c2), kv.IntHandle(c1), nil) c.Assert(err, IsNil, Commentf(errors.ErrorStack(err))) - _, err = indexOpr2.Create(s.ctx, txn.GetUnionStore(), types.MakeDatums(c2), kv.IntHandle(c1), nil) + _, err = indexOpr2.Create(s.ctx, txn, types.MakeDatums(c2), kv.IntHandle(c1), nil) c.Assert(err, IsNil) } err = txn.Commit(context.Background()) @@ -669,11 +669,11 @@ func (s *testSuite5) TestClusteredAdminCleanupIndex(c *C) { txn, err := s.store.Begin() c.Assert(err, IsNil) for _, di := range c2DanglingIdx { - _, err := indexOpr2.Create(s.ctx, txn.GetUnionStore(), di.idxVal, di.handle, nil) + _, err := indexOpr2.Create(s.ctx, txn, di.idxVal, di.handle, nil) c.Assert(err, IsNil) } for _, di := range c3DanglingIdx { - _, err := indexOpr3.Create(s.ctx, txn.GetUnionStore(), di.idxVal, di.handle, nil) + _, err := indexOpr3.Create(s.ctx, txn, di.idxVal, di.handle, nil) c.Assert(err, IsNil) } err = txn.Commit(context.Background()) @@ -742,7 +742,7 @@ func (s *testSuite3) TestAdminCheckPartitionTableFailed(c *C) { // Manual recover index. txn, err = s.store.Begin() c.Assert(err, IsNil) - _, err = indexOpr.Create(s.ctx, txn.GetUnionStore(), types.MakeDatums(i), kv.IntHandle(i), nil) + _, err = indexOpr.Create(s.ctx, txn, types.MakeDatums(i), kv.IntHandle(i), nil) c.Assert(err, IsNil) err = txn.Commit(context.Background()) c.Assert(err, IsNil) @@ -756,7 +756,7 @@ func (s *testSuite3) TestAdminCheckPartitionTableFailed(c *C) { indexOpr := tables.NewIndex(tblInfo.GetPartitionInfo().Definitions[partitionIdx].ID, tblInfo, idxInfo) txn, err := s.store.Begin() c.Assert(err, IsNil) - _, err = indexOpr.Create(s.ctx, txn.GetUnionStore(), types.MakeDatums(i+8), kv.IntHandle(i+8), nil) + _, err = indexOpr.Create(s.ctx, txn, types.MakeDatums(i+8), kv.IntHandle(i+8), nil) c.Assert(err, IsNil) err = txn.Commit(context.Background()) c.Assert(err, IsNil) @@ -779,7 +779,7 @@ func (s *testSuite3) TestAdminCheckPartitionTableFailed(c *C) { indexOpr := tables.NewIndex(tblInfo.GetPartitionInfo().Definitions[partitionIdx].ID, tblInfo, idxInfo) txn, err := s.store.Begin() c.Assert(err, IsNil) - _, err = indexOpr.Create(s.ctx, txn.GetUnionStore(), types.MakeDatums(i+8), kv.IntHandle(i), nil) + _, err = indexOpr.Create(s.ctx, txn, types.MakeDatums(i+8), kv.IntHandle(i), nil) c.Assert(err, IsNil) err = txn.Commit(context.Background()) c.Assert(err, IsNil) @@ -842,7 +842,7 @@ func (s *testSuite5) TestAdminCheckTableFailed(c *C) { // Index c2 has one more values than table data: 0, and the handle 0 hasn't correlative record. txn, err = s.store.Begin() c.Assert(err, IsNil) - _, err = indexOpr.Create(s.ctx, txn.GetUnionStore(), types.MakeDatums(0), kv.IntHandle(0), nil) + _, err = indexOpr.Create(s.ctx, txn, types.MakeDatums(0), kv.IntHandle(0), nil) c.Assert(err, IsNil) err = txn.Commit(context.Background()) c.Assert(err, IsNil) @@ -858,9 +858,9 @@ func (s *testSuite5) TestAdminCheckTableFailed(c *C) { err = indexOpr.Delete(sc, txn.GetUnionStore(), types.MakeDatums(0), kv.IntHandle(0)) c.Assert(err, IsNil) // Make sure the index value "19" is smaller "21". Then we scan to "19" before "21". - _, err = indexOpr.Create(s.ctx, txn.GetUnionStore(), types.MakeDatums(19), kv.IntHandle(10), nil) + _, err = indexOpr.Create(s.ctx, txn, types.MakeDatums(19), kv.IntHandle(10), nil) c.Assert(err, IsNil) - _, err = indexOpr.Create(s.ctx, txn.GetUnionStore(), types.MakeDatums(13), kv.IntHandle(2), nil) + _, err = indexOpr.Create(s.ctx, txn, types.MakeDatums(13), kv.IntHandle(2), nil) c.Assert(err, IsNil) err = txn.Commit(context.Background()) c.Assert(err, IsNil) @@ -886,7 +886,7 @@ func (s *testSuite5) TestAdminCheckTableFailed(c *C) { // Index c2 has one line of data is 19, the corresponding table data is 20. txn, err = s.store.Begin() c.Assert(err, IsNil) - _, err = indexOpr.Create(s.ctx, txn.GetUnionStore(), types.MakeDatums(12), kv.IntHandle(2), nil) + _, err = indexOpr.Create(s.ctx, txn, types.MakeDatums(12), kv.IntHandle(2), nil) c.Assert(err, IsNil) err = indexOpr.Delete(sc, txn.GetUnionStore(), types.MakeDatums(20), kv.IntHandle(10)) c.Assert(err, IsNil) @@ -901,7 +901,7 @@ func (s *testSuite5) TestAdminCheckTableFailed(c *C) { c.Assert(err, IsNil) err = indexOpr.Delete(sc, txn.GetUnionStore(), types.MakeDatums(19), kv.IntHandle(10)) c.Assert(err, IsNil) - _, err = indexOpr.Create(s.ctx, txn.GetUnionStore(), types.MakeDatums(20), kv.IntHandle(10), nil) + _, err = indexOpr.Create(s.ctx, txn, types.MakeDatums(20), kv.IntHandle(10), nil) c.Assert(err, IsNil) err = txn.Commit(context.Background()) c.Assert(err, IsNil) @@ -1058,7 +1058,7 @@ func (s *testSuite5) TestAdminCheckWithSnapshot(c *C) { idxOpr := tables.NewIndex(tblInfo.ID, tblInfo, idxInfo) txn, err := s.store.Begin() c.Assert(err, IsNil) - _, err = idxOpr.Create(s.ctx, txn.GetUnionStore(), types.MakeDatums(2), kv.IntHandle(100), nil) + _, err = idxOpr.Create(s.ctx, txn, types.MakeDatums(2), kv.IntHandle(100), nil) c.Assert(err, IsNil) err = txn.Commit(context.Background()) c.Assert(err, IsNil) diff --git a/executor/distsql_test.go b/executor/distsql_test.go index e5f079963b285..d027534021e73 100644 --- a/executor/distsql_test.go +++ b/executor/distsql_test.go @@ -221,7 +221,7 @@ func (s *testSuite3) TestInconsistentIndex(c *C) { for i := 0; i < 10; i++ { txn, err := s.store.Begin() c.Assert(err, IsNil) - _, err = idxOp.Create(ctx, txn.GetUnionStore(), types.MakeDatums(i+10), kv.IntHandle(100+i), nil) + _, err = idxOp.Create(ctx, txn, types.MakeDatums(i+10), kv.IntHandle(100+i), nil) c.Assert(err, IsNil) err = txn.Commit(context.Background()) c.Assert(err, IsNil) diff --git a/executor/executor_test.go b/executor/executor_test.go index 9b35a4a288a57..43d96b1a06730 100644 --- a/executor/executor_test.go +++ b/executor/executor_test.go @@ -456,7 +456,7 @@ func (s *testSuite3) TestAdmin(c *C) { tb, err := is.TableByName(model.NewCIStr("test"), model.NewCIStr("admin_test")) c.Assert(err, IsNil) c.Assert(tb.Indices(), HasLen, 1) - _, err = tb.Indices()[0].Create(mock.NewContext(), txn.GetUnionStore(), types.MakeDatums(int64(10)), kv.IntHandle(1), nil) + _, err = tb.Indices()[0].Create(mock.NewContext(), txn, types.MakeDatums(int64(10)), kv.IntHandle(1), nil) c.Assert(err, IsNil) err = txn.Commit(context.Background()) c.Assert(err, IsNil) @@ -3716,7 +3716,7 @@ func (s *testSuite) TestCheckIndex(c *C) { // table data (handle, data): (1, 10), (2, 20), (4, 40) txn, err = s.store.Begin() c.Assert(err, IsNil) - _, err = idx.Create(mockCtx, txn.GetUnionStore(), types.MakeDatums(int64(30)), kv.IntHandle(3), nil) + _, err = idx.Create(mockCtx, txn, types.MakeDatums(int64(30)), kv.IntHandle(3), nil) c.Assert(err, IsNil) key := tablecodec.EncodeRowKey(tb.Meta().ID, kv.IntHandle(4).Encoded()) setColValue(c, txn, key, types.NewDatum(int64(40))) @@ -3731,7 +3731,7 @@ func (s *testSuite) TestCheckIndex(c *C) { // table data (handle, data): (1, 10), (2, 20), (4, 40) txn, err = s.store.Begin() c.Assert(err, IsNil) - _, err = idx.Create(mockCtx, txn.GetUnionStore(), types.MakeDatums(int64(40)), kv.IntHandle(4), nil) + _, err = idx.Create(mockCtx, txn, types.MakeDatums(int64(40)), kv.IntHandle(4), nil) c.Assert(err, IsNil) err = txn.Commit(context.Background()) c.Assert(err, IsNil) diff --git a/executor/write_test.go b/executor/write_test.go index 65118ec012ba3..9bcaa68199b4b 100644 --- a/executor/write_test.go +++ b/executor/write_test.go @@ -2689,7 +2689,7 @@ func (s *testSuite7) TestReplaceLog(c *C) { txn, err := s.store.Begin() c.Assert(err, IsNil) - _, err = indexOpr.Create(s.ctx, txn.GetUnionStore(), types.MakeDatums(1), kv.IntHandle(1), nil) + _, err = indexOpr.Create(s.ctx, txn, types.MakeDatums(1), kv.IntHandle(1), nil) c.Assert(err, IsNil) err = txn.Commit(context.Background()) c.Assert(err, IsNil) diff --git a/go.mod b/go.mod index a04c6589d348e..fe7143d1340dd 100644 --- a/go.mod +++ b/go.mod @@ -38,7 +38,7 @@ require ( github.com/phayes/freeport v0.0.0-20180830031419-95f893ade6f2 github.com/pierrec/lz4 v2.5.2+incompatible // indirect github.com/pingcap/badger v1.5.1-0.20200908111422-2e78ee155d19 - github.com/pingcap/br v4.0.0-beta.2.0.20210220133344-578be7fb5165+incompatible + github.com/pingcap/br v4.0.0-beta.2.0.20210302095941-59e4efeaeb47+incompatible github.com/pingcap/check v0.0.0-20200212061837-5e12011dc712 github.com/pingcap/errors v0.11.5-0.20201126102027-b0a155152ca3 github.com/pingcap/failpoint v0.0.0-20200702092429-9f69995143ce diff --git a/go.sum b/go.sum index b4c4d54c3b3de..2cdd1744e9040 100644 --- a/go.sum +++ b/go.sum @@ -385,8 +385,8 @@ github.com/pierrec/lz4 v2.5.2+incompatible/go.mod h1:pdkljMzZIN41W+lC3N2tnIh5sFi github.com/pingcap-incubator/tidb-dashboard v0.0.0-20210104140916-41a0a3a87e75/go.mod h1:EONGys2gM5n14pII2vjmU/5VG3Dtj6kpqUT1GUZ4ysw= github.com/pingcap/badger v1.5.1-0.20200908111422-2e78ee155d19 h1:IXpGy7y9HyoShAFmzW2OPF0xCA5EOoSTyZHwsgYk9Ro= github.com/pingcap/badger v1.5.1-0.20200908111422-2e78ee155d19/go.mod h1:LyrqUOHZrUDf9oGi1yoz1+qw9ckSIhQb5eMa1acOLNQ= -github.com/pingcap/br v4.0.0-beta.2.0.20210220133344-578be7fb5165+incompatible h1:Zd4LjoIYVmGF9KW484B0F+XvFHlcp9hraI5FAB9h1/I= -github.com/pingcap/br v4.0.0-beta.2.0.20210220133344-578be7fb5165+incompatible/go.mod h1:ymVmo50lQydxib0tmK5hHk4oteB7hZ0IMCArunwy3UQ= +github.com/pingcap/br v4.0.0-beta.2.0.20210302095941-59e4efeaeb47+incompatible h1:0B1CQlmaky9VEa1STBH/WM81wLOuFJ2Rmb5APHzPefU= +github.com/pingcap/br v4.0.0-beta.2.0.20210302095941-59e4efeaeb47+incompatible/go.mod h1:ymVmo50lQydxib0tmK5hHk4oteB7hZ0IMCArunwy3UQ= github.com/pingcap/check v0.0.0-20190102082844-67f458068fc8/go.mod h1:B1+S9LNcuMyLH/4HMTViQOJevkGiik3wW2AN9zb2fNQ= github.com/pingcap/check v0.0.0-20191107115940-caf2b9e6ccf4/go.mod h1:PYMCGwN0JHjoqGr3HrZoD+b8Tgx8bKnArhSq8YVzUMc= github.com/pingcap/check v0.0.0-20191216031241-8a5a85928f12/go.mod h1:PYMCGwN0JHjoqGr3HrZoD+b8Tgx8bKnArhSq8YVzUMc= diff --git a/kv/interface_mock_test.go b/kv/interface_mock_test.go index c2114bd1e722a..461250901233a 100644 --- a/kv/interface_mock_test.go +++ b/kv/interface_mock_test.go @@ -16,6 +16,7 @@ package kv import ( "context" + "github.com/pingcap/parser/model" "github.com/pingcap/tidb/store/tikv/oracle" ) @@ -133,6 +134,14 @@ func (t *mockTxn) GetVars() *Variables { return nil } +func (t *mockTxn) CacheTableInfo(id int64, info *model.TableInfo) { + +} + +func (t *mockTxn) GetTableInfo(id int64) *model.TableInfo { + return nil +} + // newMockTxn new a mockTxn. func newMockTxn() Transaction { return &mockTxn{ diff --git a/kv/kv.go b/kv/kv.go index 950afddba212a..de2aad488cfb0 100644 --- a/kv/kv.go +++ b/kv/kv.go @@ -19,6 +19,7 @@ import ( "sync" "time" + "github.com/pingcap/parser/model" "github.com/pingcap/tidb/config" "github.com/pingcap/tidb/store/tikv/oracle" "github.com/pingcap/tidb/util/execdetails" @@ -292,6 +293,12 @@ type Transaction interface { // If a key doesn't exist, there shouldn't be any corresponding entry in the result map. BatchGet(ctx context.Context, keys []Key) (map[string][]byte, error) IsPessimistic() bool + // CacheIndexName caches the index name. + // PresumeKeyNotExists will use this to help decode error message. + CacheTableInfo(id int64, info *model.TableInfo) + // GetIndexName returns the cached index name. + // If there is no such index already inserted through CacheIndexName, it will return UNKNOWN. + GetTableInfo(id int64) *model.TableInfo } // LockCtx contains information for LockKeys method. diff --git a/kv/union_store.go b/kv/union_store.go index 669048a9d0bf7..96ff6c8965a4b 100644 --- a/kv/union_store.go +++ b/kv/union_store.go @@ -15,8 +15,6 @@ package kv import ( "context" - - "github.com/pingcap/parser/model" ) // UnionStore is a store that wraps a snapshot for read and a MemBuffer for buffered write. @@ -28,12 +26,6 @@ type UnionStore interface { HasPresumeKeyNotExists(k Key) bool // UnmarkPresumeKeyNotExists deletes the key presume key not exists error flag for the lazy check. UnmarkPresumeKeyNotExists(k Key) - // CacheIndexName caches the index name. - // PresumeKeyNotExists will use this to help decode error message. - CacheTableInfo(id int64, info *model.TableInfo) - // GetIndexName returns the cached index name. - // If there is no such index already inserted through CacheIndexName, it will return UNKNOWN. - GetTableInfo(id int64) *model.TableInfo // SetOption sets an option with a value, when val is nil, uses the default // value of this option. @@ -68,19 +60,17 @@ type Options interface { // unionStore is an in-memory Store which contains a buffer for write and a // snapshot for read. type unionStore struct { - memBuffer *memdb - snapshot Snapshot - idxNameCache map[int64]*model.TableInfo - opts options + memBuffer *memdb + snapshot Snapshot + opts options } // NewUnionStore builds a new unionStore. func NewUnionStore(snapshot Snapshot) UnionStore { return &unionStore{ - snapshot: snapshot, - memBuffer: newMemDB(), - idxNameCache: make(map[int64]*model.TableInfo), - opts: make(map[Option]interface{}), + snapshot: snapshot, + memBuffer: newMemDB(), + opts: make(map[Option]interface{}), } } @@ -144,14 +134,6 @@ func (us *unionStore) UnmarkPresumeKeyNotExists(k Key) { us.memBuffer.UpdateFlags(k, DelPresumeKeyNotExists) } -func (us *unionStore) GetTableInfo(id int64) *model.TableInfo { - return us.idxNameCache[id] -} - -func (us *unionStore) CacheTableInfo(id int64, info *model.TableInfo) { - us.idxNameCache[id] = info -} - // SetOption implements the unionStore SetOption interface. func (us *unionStore) SetOption(opt Option, val interface{}) { us.opts[opt] = val diff --git a/session/txn.go b/session/txn.go index 7c32829bc30fc..4f7175c789477 100644 --- a/session/txn.go +++ b/session/txn.go @@ -24,6 +24,7 @@ import ( "github.com/opentracing/opentracing-go" "github.com/pingcap/errors" "github.com/pingcap/failpoint" + "github.com/pingcap/parser/model" "github.com/pingcap/parser/terror" "github.com/pingcap/tidb/config" "github.com/pingcap/tidb/kv" @@ -53,87 +54,97 @@ type TxnState struct { mutations map[int64]*binlog.TableMutation } -func (st *TxnState) init() { - st.mutations = make(map[int64]*binlog.TableMutation) +// GetTableInfo returns the cached index name. +func (txn *TxnState) GetTableInfo(id int64) *model.TableInfo { + return txn.Transaction.GetTableInfo(id) } -func (st *TxnState) initStmtBuf() { - if st.Transaction == nil { +// CacheTableInfo caches the index name. +func (txn *TxnState) CacheTableInfo(id int64, info *model.TableInfo) { + txn.Transaction.CacheTableInfo(id, info) +} + +func (txn *TxnState) init() { + txn.mutations = make(map[int64]*binlog.TableMutation) +} + +func (txn *TxnState) initStmtBuf() { + if txn.Transaction == nil { return } - buf := st.Transaction.GetMemBuffer() - st.initCnt = buf.Len() - st.stagingHandle = buf.Staging() + buf := txn.Transaction.GetMemBuffer() + txn.initCnt = buf.Len() + txn.stagingHandle = buf.Staging() } // countHint is estimated count of mutations. -func (st *TxnState) countHint() int { - if st.stagingHandle == kv.InvalidStagingHandle { +func (txn *TxnState) countHint() int { + if txn.stagingHandle == kv.InvalidStagingHandle { return 0 } - return st.Transaction.GetMemBuffer().Len() - st.initCnt + return txn.Transaction.GetMemBuffer().Len() - txn.initCnt } -func (st *TxnState) flushStmtBuf() { - if st.stagingHandle == kv.InvalidStagingHandle { +func (txn *TxnState) flushStmtBuf() { + if txn.stagingHandle == kv.InvalidStagingHandle { return } - buf := st.Transaction.GetMemBuffer() - buf.Release(st.stagingHandle) - st.initCnt = buf.Len() + buf := txn.Transaction.GetMemBuffer() + buf.Release(txn.stagingHandle) + txn.initCnt = buf.Len() } -func (st *TxnState) cleanupStmtBuf() { - if st.stagingHandle == kv.InvalidStagingHandle { +func (txn *TxnState) cleanupStmtBuf() { + if txn.stagingHandle == kv.InvalidStagingHandle { return } - buf := st.Transaction.GetMemBuffer() - buf.Cleanup(st.stagingHandle) - st.initCnt = buf.Len() + buf := txn.Transaction.GetMemBuffer() + buf.Cleanup(txn.stagingHandle) + txn.initCnt = buf.Len() } // Size implements the MemBuffer interface. -func (st *TxnState) Size() int { - if st.Transaction == nil { +func (txn *TxnState) Size() int { + if txn.Transaction == nil { return 0 } - return st.Transaction.Size() + return txn.Transaction.Size() } // Valid implements the kv.Transaction interface. -func (st *TxnState) Valid() bool { - return st.Transaction != nil && st.Transaction.Valid() +func (txn *TxnState) Valid() bool { + return txn.Transaction != nil && txn.Transaction.Valid() } -func (st *TxnState) pending() bool { - return st.Transaction == nil && st.txnFuture != nil +func (txn *TxnState) pending() bool { + return txn.Transaction == nil && txn.txnFuture != nil } -func (st *TxnState) validOrPending() bool { - return st.txnFuture != nil || st.Valid() +func (txn *TxnState) validOrPending() bool { + return txn.txnFuture != nil || txn.Valid() } -func (st *TxnState) String() string { - if st.Transaction != nil { - return st.Transaction.String() +func (txn *TxnState) String() string { + if txn.Transaction != nil { + return txn.Transaction.String() } - if st.txnFuture != nil { + if txn.txnFuture != nil { return "txnFuture" } return "invalid transaction" } // GoString implements the "%#v" format for fmt.Printf. -func (st *TxnState) GoString() string { +func (txn *TxnState) GoString() string { var s strings.Builder s.WriteString("Txn{") - if st.pending() { + if txn.pending() { s.WriteString("state=pending") - } else if st.Valid() { + } else if txn.Valid() { s.WriteString("state=valid") - fmt.Fprintf(&s, ", txnStartTS=%d", st.Transaction.StartTS()) - if len(st.mutations) > 0 { - fmt.Fprintf(&s, ", len(mutations)=%d, %#v", len(st.mutations), st.mutations) + fmt.Fprintf(&s, ", txnStartTS=%d", txn.Transaction.StartTS()) + if len(txn.mutations) > 0 { + fmt.Fprintf(&s, ", len(mutations)=%d, %#v", len(txn.mutations), txn.mutations) } } else { s.WriteString("state=invalid") @@ -143,43 +154,43 @@ func (st *TxnState) GoString() string { return s.String() } -func (st *TxnState) changeInvalidToValid(txn kv.Transaction) { - st.Transaction = txn - st.initStmtBuf() - st.txnFuture = nil +func (txn *TxnState) changeInvalidToValid(kvTxn kv.Transaction) { + txn.Transaction = kvTxn + txn.initStmtBuf() + txn.txnFuture = nil } -func (st *TxnState) changeInvalidToPending(future *txnFuture) { - st.Transaction = nil - st.txnFuture = future +func (txn *TxnState) changeInvalidToPending(future *txnFuture) { + txn.Transaction = nil + txn.txnFuture = future } -func (st *TxnState) changePendingToValid(ctx context.Context) error { - if st.txnFuture == nil { +func (txn *TxnState) changePendingToValid(ctx context.Context) error { + if txn.txnFuture == nil { return errors.New("transaction future is not set") } - future := st.txnFuture - st.txnFuture = nil + future := txn.txnFuture + txn.txnFuture = nil defer trace.StartRegion(ctx, "WaitTsoFuture").End() - txn, err := future.wait() + t, err := future.wait() if err != nil { - st.Transaction = nil + txn.Transaction = nil return err } - st.Transaction = txn - st.initStmtBuf() + txn.Transaction = t + txn.initStmtBuf() return nil } -func (st *TxnState) changeToInvalid() { - if st.stagingHandle != kv.InvalidStagingHandle { - st.Transaction.GetMemBuffer().Cleanup(st.stagingHandle) +func (txn *TxnState) changeToInvalid() { + if txn.stagingHandle != kv.InvalidStagingHandle { + txn.Transaction.GetMemBuffer().Cleanup(txn.stagingHandle) } - st.stagingHandle = kv.InvalidStagingHandle - st.Transaction = nil - st.txnFuture = nil + txn.stagingHandle = kv.InvalidStagingHandle + txn.Transaction = nil + txn.txnFuture = nil } var hasMockAutoIncIDRetry = int64(0) @@ -209,12 +220,12 @@ func ResetMockAutoRandIDRetryCount(failTimes int64) { } // Commit overrides the Transaction interface. -func (st *TxnState) Commit(ctx context.Context) error { - defer st.reset() - if len(st.mutations) != 0 || st.countHint() != 0 { +func (txn *TxnState) Commit(ctx context.Context) error { + defer txn.reset() + if len(txn.mutations) != 0 || txn.countHint() != 0 { logutil.BgLogger().Error("the code should never run here", - zap.String("TxnState", st.GoString()), - zap.Int("staging handler", int(st.stagingHandle)), + zap.String("TxnState", txn.GoString()), + zap.Int("staging handler", int(txn.stagingHandle)), zap.Stack("something must be wrong")) return errors.Trace(kv.ErrInvalidTxn) } @@ -241,36 +252,36 @@ func (st *TxnState) Commit(ctx context.Context) error { } }) - return st.Transaction.Commit(ctx) + return txn.Transaction.Commit(ctx) } // Rollback overrides the Transaction interface. -func (st *TxnState) Rollback() error { - defer st.reset() - return st.Transaction.Rollback() +func (txn *TxnState) Rollback() error { + defer txn.reset() + return txn.Transaction.Rollback() } -func (st *TxnState) reset() { - st.cleanup() - st.changeToInvalid() +func (txn *TxnState) reset() { + txn.cleanup() + txn.changeToInvalid() } -func (st *TxnState) cleanup() { - st.cleanupStmtBuf() - st.initStmtBuf() - for key := range st.mutations { - delete(st.mutations, key) +func (txn *TxnState) cleanup() { + txn.cleanupStmtBuf() + txn.initStmtBuf() + for key := range txn.mutations { + delete(txn.mutations, key) } } // KeysNeedToLock returns the keys need to be locked. -func (st *TxnState) KeysNeedToLock() ([]kv.Key, error) { - if st.stagingHandle == kv.InvalidStagingHandle { +func (txn *TxnState) KeysNeedToLock() ([]kv.Key, error) { + if txn.stagingHandle == kv.InvalidStagingHandle { return nil, nil } - keys := make([]kv.Key, 0, st.countHint()) - buf := st.Transaction.GetMemBuffer() - buf.InspectStage(st.stagingHandle, func(k kv.Key, flags kv.KeyFlags, v []byte) { + keys := make([]kv.Key, 0, txn.countHint()) + buf := txn.Transaction.GetMemBuffer() + buf.InspectStage(txn.stagingHandle, func(k kv.Key, flags kv.KeyFlags, v []byte) { if !keyNeedToLock(k, v, flags) { return } diff --git a/store/driver/tikv_driver.go b/store/driver/tikv_driver.go index 3305c3b6b4c43..128de806261b3 100644 --- a/store/driver/tikv_driver.go +++ b/store/driver/tikv_driver.go @@ -299,7 +299,7 @@ func (s *tikvStore) Begin() (kv.Transaction, error) { func (s *tikvStore) BeginWithTxnScope(txnScope string) (kv.Transaction, error) { txn, err := s.KVStore.BeginWithTxnScope(txnScope) if err != nil { - return txn, errors.Trace(err) + return nil, errors.Trace(err) } return txn_driver.NewTiKVTxn(txn), err } @@ -308,7 +308,7 @@ func (s *tikvStore) BeginWithTxnScope(txnScope string) (kv.Transaction, error) { func (s *tikvStore) BeginWithStartTS(txnScope string, startTS uint64) (kv.Transaction, error) { txn, err := s.KVStore.BeginWithStartTS(txnScope, startTS) if err != nil { - return txn, errors.Trace(err) + return nil, errors.Trace(err) } return txn_driver.NewTiKVTxn(txn), err } @@ -317,7 +317,7 @@ func (s *tikvStore) BeginWithStartTS(txnScope string, startTS uint64) (kv.Transa func (s *tikvStore) BeginWithExactStaleness(txnScope string, prevSec uint64) (kv.Transaction, error) { txn, err := s.KVStore.BeginWithExactStaleness(txnScope, prevSec) if err != nil { - return txn, errors.Trace(err) + return nil, errors.Trace(err) } return txn_driver.NewTiKVTxn(txn), err } diff --git a/store/driver/txn/txn_driver.go b/store/driver/txn/txn_driver.go index 9bfe19623bdbc..90d1e0daf9b93 100644 --- a/store/driver/txn/txn_driver.go +++ b/store/driver/txn/txn_driver.go @@ -14,15 +14,187 @@ package txn import ( + "context" + "fmt" + "strings" + "time" + + "github.com/pingcap/errors" + "github.com/pingcap/parser/model" + "github.com/pingcap/parser/mysql" "github.com/pingcap/tidb/kv" "github.com/pingcap/tidb/store/tikv" + "github.com/pingcap/tidb/store/tikv/logutil" + "github.com/pingcap/tidb/table/tables" + "github.com/pingcap/tidb/tablecodec" + "github.com/pingcap/tidb/types" + "github.com/pingcap/tidb/util/rowcodec" + "go.uber.org/zap" ) type tikvTxn struct { *tikv.KVTxn + idxNameCache map[int64]*model.TableInfo } // NewTiKVTxn returns a new Transaction. func NewTiKVTxn(txn *tikv.KVTxn) kv.Transaction { - return &tikvTxn{txn} + return &tikvTxn{txn, make(map[int64]*model.TableInfo)} +} + +func (txn *tikvTxn) GetTableInfo(id int64) *model.TableInfo { + return txn.idxNameCache[id] +} + +func (txn *tikvTxn) CacheTableInfo(id int64, info *model.TableInfo) { + txn.idxNameCache[id] = info +} + +// lockWaitTime in ms, except that kv.LockAlwaysWait(0) means always wait lock, kv.LockNowait(-1) means nowait lock +func (txn *tikvTxn) LockKeys(ctx context.Context, lockCtx *kv.LockCtx, keysInput ...kv.Key) error { + err := txn.KVTxn.LockKeys(ctx, lockCtx, keysInput...) + return txn.extractKeyErr(err) +} + +func (txn *tikvTxn) Commit(ctx context.Context) error { + err := txn.KVTxn.Commit(ctx) + return txn.extractKeyErr(err) +} + +func (txn *tikvTxn) extractKeyErr(err error) error { + if e, ok := errors.Cause(err).(*tikv.ErrKeyExist); ok { + return txn.extractKeyExistsErr(e.GetKey()) + } + return errors.Trace(err) +} + +func (txn *tikvTxn) extractKeyExistsErr(key kv.Key) error { + tableID, indexID, isRecord, err := tablecodec.DecodeKeyHead(key) + if err != nil { + return genKeyExistsError("UNKNOWN", key.String(), err) + } + + tblInfo := txn.GetTableInfo(tableID) + if tblInfo == nil { + return genKeyExistsError("UNKNOWN", key.String(), errors.New("cannot find table info")) + } + + value, err := txn.GetUnionStore().GetMemBuffer().SelectValueHistory(key, func(value []byte) bool { return len(value) != 0 }) + if err != nil { + return genKeyExistsError("UNKNOWN", key.String(), err) + } + + if isRecord { + return extractKeyExistsErrFromHandle(key, value, tblInfo) + } + return extractKeyExistsErrFromIndex(key, value, tblInfo, indexID) +} + +func genKeyExistsError(name string, value string, err error) error { + if err != nil { + logutil.BgLogger().Info("extractKeyExistsErr meets error", zap.Error(err)) + } + return kv.ErrKeyExists.FastGenByArgs(value, name) +} + +func extractKeyExistsErrFromHandle(key kv.Key, value []byte, tblInfo *model.TableInfo) error { + const name = "PRIMARY" + _, handle, err := tablecodec.DecodeRecordKey(key) + if err != nil { + return genKeyExistsError(name, key.String(), err) + } + + if handle.IsInt() { + if pkInfo := tblInfo.GetPkColInfo(); pkInfo != nil { + if mysql.HasUnsignedFlag(pkInfo.Flag) { + handleStr := fmt.Sprintf("%d", uint64(handle.IntValue())) + return genKeyExistsError(name, handleStr, nil) + } + } + return genKeyExistsError(name, handle.String(), nil) + } + + if len(value) == 0 { + return genKeyExistsError(name, handle.String(), errors.New("missing value")) + } + + idxInfo := tables.FindPrimaryIndex(tblInfo) + if idxInfo == nil { + return genKeyExistsError(name, handle.String(), errors.New("cannot find index info")) + } + + cols := make(map[int64]*types.FieldType, len(tblInfo.Columns)) + for _, col := range tblInfo.Columns { + cols[col.ID] = &col.FieldType + } + handleColIDs := make([]int64, 0, len(idxInfo.Columns)) + for _, col := range idxInfo.Columns { + handleColIDs = append(handleColIDs, tblInfo.Columns[col.Offset].ID) + } + + row, err := tablecodec.DecodeRowToDatumMap(value, cols, time.Local) + if err != nil { + return genKeyExistsError(name, handle.String(), err) + } + + data, err := tablecodec.DecodeHandleToDatumMap(handle, handleColIDs, cols, time.Local, row) + if err != nil { + return genKeyExistsError(name, handle.String(), err) + } + + valueStr := make([]string, 0, len(data)) + for _, col := range idxInfo.Columns { + d := data[tblInfo.Columns[col.Offset].ID] + str, err := d.ToString() + if err != nil { + return genKeyExistsError(name, key.String(), err) + } + valueStr = append(valueStr, str) + } + return genKeyExistsError(name, strings.Join(valueStr, "-"), nil) +} + +func extractKeyExistsErrFromIndex(key kv.Key, value []byte, tblInfo *model.TableInfo, indexID int64) error { + var idxInfo *model.IndexInfo + for _, index := range tblInfo.Indices { + if index.ID == indexID { + idxInfo = index + } + } + if idxInfo == nil { + return genKeyExistsError("UNKNOWN", key.String(), errors.New("cannot find index info")) + } + name := idxInfo.Name.String() + + if len(value) == 0 { + return genKeyExistsError(name, key.String(), errors.New("missing value")) + } + + colInfo := make([]rowcodec.ColInfo, 0, len(idxInfo.Columns)) + for _, idxCol := range idxInfo.Columns { + col := tblInfo.Columns[idxCol.Offset] + colInfo = append(colInfo, rowcodec.ColInfo{ + ID: col.ID, + IsPKHandle: tblInfo.PKIsHandle && mysql.HasPriKeyFlag(col.Flag), + Ft: rowcodec.FieldTypeFromModelColumn(col), + }) + } + + values, err := tablecodec.DecodeIndexKV(key, value, len(idxInfo.Columns), tablecodec.HandleNotNeeded, colInfo) + if err != nil { + return genKeyExistsError(name, key.String(), err) + } + valueStr := make([]string, 0, len(values)) + for i, val := range values { + d, err := tablecodec.DecodeColumnValue(val, colInfo[i].Ft, time.Local) + if err != nil { + return genKeyExistsError(name, key.String(), err) + } + str, err := d.ToString() + if err != nil { + return genKeyExistsError(name, key.String(), err) + } + valueStr = append(valueStr, str) + } + return genKeyExistsError(name, strings.Join(valueStr, "-"), nil) } diff --git a/store/tikv/2pc.go b/store/tikv/2pc.go index e32a992de03ef..20435d2776a4b 100644 --- a/store/tikv/2pc.go +++ b/store/tikv/2pc.go @@ -17,7 +17,6 @@ import ( "bytes" "context" "encoding/hex" - "fmt" "math" "math/rand" "strings" @@ -29,8 +28,6 @@ import ( "github.com/pingcap/errors" "github.com/pingcap/failpoint" pb "github.com/pingcap/kvproto/pkg/kvrpcpb" - "github.com/pingcap/parser/model" - "github.com/pingcap/parser/mysql" "github.com/pingcap/parser/terror" "github.com/pingcap/tidb/kv" "github.com/pingcap/tidb/store/tikv/config" @@ -39,11 +36,8 @@ import ( "github.com/pingcap/tidb/store/tikv/oracle" "github.com/pingcap/tidb/store/tikv/tikvrpc" "github.com/pingcap/tidb/store/tikv/util" - "github.com/pingcap/tidb/table/tables" "github.com/pingcap/tidb/tablecodec" - "github.com/pingcap/tidb/types" "github.com/pingcap/tidb/util/execdetails" - "github.com/pingcap/tidb/util/rowcodec" "github.com/prometheus/client_golang/prometheus" zap "go.uber.org/zap" ) @@ -311,142 +305,11 @@ func newTwoPhaseCommitter(txn *KVTxn, sessionID uint64) (*twoPhaseCommitter, err }, nil } -func (c *twoPhaseCommitter) extractKeyExistsErr(key kv.Key) error { - if !c.txn.us.HasPresumeKeyNotExists(key) { - return errors.Errorf("session %d, existErr for key:%s should not be nil", c.sessionID, key) +func (c *twoPhaseCommitter) extractKeyExistsErr(err *ErrKeyExist) error { + if !c.txn.us.HasPresumeKeyNotExists(err.GetKey()) { + return errors.Errorf("session %d, existErr for key:%s should not be nil", c.sessionID, err.GetKey()) } - - tableID, indexID, isRecord, err := tablecodec.DecodeKeyHead(key) - if err != nil { - return c.genKeyExistsError("UNKNOWN", key.String(), err) - } - - tblInfo := c.txn.us.GetTableInfo(tableID) - if tblInfo == nil { - return c.genKeyExistsError("UNKNOWN", key.String(), errors.New("cannot find table info")) - } - - value, err := c.txn.us.GetMemBuffer().SelectValueHistory(key, func(value []byte) bool { return len(value) != 0 }) - if err != nil { - return c.genKeyExistsError("UNKNOWN", key.String(), err) - } - - if isRecord { - return c.extractKeyExistsErrFromHandle(key, value, tblInfo) - } - return c.extractKeyExistsErrFromIndex(key, value, tblInfo, indexID) -} - -func (c *twoPhaseCommitter) extractKeyExistsErrFromIndex(key kv.Key, value []byte, tblInfo *model.TableInfo, indexID int64) error { - var idxInfo *model.IndexInfo - for _, index := range tblInfo.Indices { - if index.ID == indexID { - idxInfo = index - } - } - if idxInfo == nil { - return c.genKeyExistsError("UNKNOWN", key.String(), errors.New("cannot find index info")) - } - name := idxInfo.Name.String() - - if len(value) == 0 { - return c.genKeyExistsError(name, key.String(), errors.New("missing value")) - } - - colInfo := make([]rowcodec.ColInfo, 0, len(idxInfo.Columns)) - for _, idxCol := range idxInfo.Columns { - col := tblInfo.Columns[idxCol.Offset] - colInfo = append(colInfo, rowcodec.ColInfo{ - ID: col.ID, - IsPKHandle: tblInfo.PKIsHandle && mysql.HasPriKeyFlag(col.Flag), - Ft: rowcodec.FieldTypeFromModelColumn(col), - }) - } - - values, err := tablecodec.DecodeIndexKV(key, value, len(idxInfo.Columns), tablecodec.HandleNotNeeded, colInfo) - if err != nil { - return c.genKeyExistsError(name, key.String(), err) - } - valueStr := make([]string, 0, len(values)) - for i, val := range values { - d, err := tablecodec.DecodeColumnValue(val, colInfo[i].Ft, time.Local) - if err != nil { - return c.genKeyExistsError(name, key.String(), err) - } - str, err := d.ToString() - if err != nil { - return c.genKeyExistsError(name, key.String(), err) - } - valueStr = append(valueStr, str) - } - return c.genKeyExistsError(name, strings.Join(valueStr, "-"), nil) -} - -func (c *twoPhaseCommitter) extractKeyExistsErrFromHandle(key kv.Key, value []byte, tblInfo *model.TableInfo) error { - const name = "PRIMARY" - _, handle, err := tablecodec.DecodeRecordKey(key) - if err != nil { - return c.genKeyExistsError(name, key.String(), err) - } - - if handle.IsInt() { - if pkInfo := tblInfo.GetPkColInfo(); pkInfo != nil { - if mysql.HasUnsignedFlag(pkInfo.Flag) { - handleStr := fmt.Sprintf("%d", uint64(handle.IntValue())) - return c.genKeyExistsError(name, handleStr, nil) - } - } - return c.genKeyExistsError(name, handle.String(), nil) - } - - if len(value) == 0 { - return c.genKeyExistsError(name, handle.String(), errors.New("missing value")) - } - - idxInfo := tables.FindPrimaryIndex(tblInfo) - if idxInfo == nil { - return c.genKeyExistsError(name, handle.String(), errors.New("cannot find index info")) - } - - cols := make(map[int64]*types.FieldType, len(tblInfo.Columns)) - for _, col := range tblInfo.Columns { - cols[col.ID] = &col.FieldType - } - handleColIDs := make([]int64, 0, len(idxInfo.Columns)) - for _, col := range idxInfo.Columns { - handleColIDs = append(handleColIDs, tblInfo.Columns[col.Offset].ID) - } - - row, err := tablecodec.DecodeRowToDatumMap(value, cols, time.Local) - if err != nil { - return c.genKeyExistsError(name, handle.String(), err) - } - - data, err := tablecodec.DecodeHandleToDatumMap(handle, handleColIDs, cols, time.Local, row) - if err != nil { - return c.genKeyExistsError(name, handle.String(), err) - } - - valueStr := make([]string, 0, len(data)) - for _, col := range idxInfo.Columns { - d := data[tblInfo.Columns[col.Offset].ID] - str, err := d.ToString() - if err != nil { - return c.genKeyExistsError(name, key.String(), err) - } - if col.Length > 0 && len([]rune(str)) > col.Length { - str = string([]rune(str)[:col.Length]) - } - valueStr = append(valueStr, str) - } - return c.genKeyExistsError(name, strings.Join(valueStr, "-"), nil) -} - -func (c *twoPhaseCommitter) genKeyExistsError(name string, value string, err error) error { - if err != nil { - logutil.BgLogger().Info("extractKeyExistsErr meets error", zap.Error(err)) - } - return kv.ErrKeyExists.FastGenByArgs(value, name) + return errors.Trace(err) } func (c *twoPhaseCommitter) initKeysAndMutations() error { diff --git a/store/tikv/async_commit_test.go b/store/tikv/async_commit_test.go index 498023680d14e..e63e356f6ea9f 100644 --- a/store/tikv/async_commit_test.go +++ b/store/tikv/async_commit_test.go @@ -69,7 +69,7 @@ func (s *testAsyncCommitCommon) putKV(c *C, key, value []byte, enableAsyncCommit return txn.StartTS(), txn.commitTS } -func (s *testAsyncCommitCommon) mustGetFromTxn(c *C, txn kv.Transaction, key, expectedValue []byte) { +func (s *testAsyncCommitCommon) mustGetFromTxn(c *C, txn *KVTxn, key, expectedValue []byte) { v, err := txn.Get(context.Background(), key) c.Assert(err, IsNil) c.Assert(v, BytesEquals, expectedValue) diff --git a/store/tikv/error.go b/store/tikv/error.go index a93adc16ea6ca..33521d62c15bf 100644 --- a/store/tikv/error.go +++ b/store/tikv/error.go @@ -77,3 +77,12 @@ type PDError struct { func (d *PDError) Error() string { return d.Err.String() } + +// ErrKeyExist wraps *pdpb.AlreadyExist to implement the error interface. +type ErrKeyExist struct { + *kvrpcpb.AlreadyExist +} + +func (k *ErrKeyExist) Error() string { + return k.AlreadyExist.String() +} diff --git a/store/tikv/pessimistic.go b/store/tikv/pessimistic.go index ef4ac925babc2..2f85cc723a7c8 100644 --- a/store/tikv/pessimistic.go +++ b/store/tikv/pessimistic.go @@ -147,8 +147,8 @@ func (action actionPessimisticLock) handleSingleBatch(c *twoPhaseCommitter, bo * for _, keyErr := range keyErrs { // Check already exists error if alreadyExist := keyErr.GetAlreadyExist(); alreadyExist != nil { - key := alreadyExist.GetKey() - return c.extractKeyExistsErr(key) + e := &ErrKeyExist{AlreadyExist: alreadyExist} + return c.extractKeyExistsErr(e) } if deadlock := keyErr.Deadlock; deadlock != nil { return &ErrDeadlock{Deadlock: deadlock} diff --git a/store/tikv/prewrite.go b/store/tikv/prewrite.go index dcb2fb522332f..5583a1e8525e2 100644 --- a/store/tikv/prewrite.go +++ b/store/tikv/prewrite.go @@ -248,8 +248,8 @@ func (action actionPrewrite) handleSingleBatch(c *twoPhaseCommitter, bo *Backoff for _, keyErr := range keyErrs { // Check already exists error if alreadyExist := keyErr.GetAlreadyExist(); alreadyExist != nil { - key := alreadyExist.GetKey() - return c.extractKeyExistsErr(key) + e := &ErrKeyExist{AlreadyExist: alreadyExist} + return c.extractKeyExistsErr(e) } // Extract lock from key error diff --git a/store/tikv/txn.go b/store/tikv/txn.go index df63dda6bcd23..77b610247a7f6 100644 --- a/store/tikv/txn.go +++ b/store/tikv/txn.go @@ -28,6 +28,7 @@ import ( "github.com/opentracing/opentracing-go" "github.com/pingcap/errors" "github.com/pingcap/failpoint" + "github.com/pingcap/kvproto/pkg/kvrpcpb" "github.com/pingcap/parser/terror" "github.com/pingcap/tidb/kv" "github.com/pingcap/tidb/store/tikv/logutil" @@ -412,7 +413,9 @@ func (txn *KVTxn) LockKeys(ctx context.Context, lockCtx *kv.LockCtx, keysInput . keys = append(keys, key) } else if txn.IsPessimistic() { if checkKeyExists && valueExist { - return txn.committer.extractKeyExistsErr(key) + alreadyExist := kvrpcpb.AlreadyExist{Key: key} + e := &ErrKeyExist{AlreadyExist: &alreadyExist} + return txn.committer.extractKeyExistsErr(e) } } if lockCtx.ReturnValues && locked { diff --git a/table/index.go b/table/index.go index af823e80b998b..5a9f32fbbfd3f 100644 --- a/table/index.go +++ b/table/index.go @@ -64,7 +64,7 @@ type Index interface { // Meta returns IndexInfo. Meta() *model.IndexInfo // Create supports insert into statement. - Create(ctx sessionctx.Context, us kv.UnionStore, indexedValues []types.Datum, h kv.Handle, handleRestoreData []types.Datum, opts ...CreateIdxOptFunc) (kv.Handle, error) + Create(ctx sessionctx.Context, txn kv.Transaction, indexedValues []types.Datum, h kv.Handle, handleRestoreData []types.Datum, opts ...CreateIdxOptFunc) (kv.Handle, error) // Delete supports delete from statement. Delete(sc *stmtctx.StatementContext, us kv.UnionStore, indexedValues []types.Datum, h kv.Handle) error // Drop supports drop table, drop index statements. diff --git a/table/tables/index.go b/table/tables/index.go index 95169fdad1bab..25f57dc170e70 100644 --- a/table/tables/index.go +++ b/table/tables/index.go @@ -140,9 +140,9 @@ func (c *index) GenIndexKey(sc *stmtctx.StatementContext, indexedValues []types. // Create creates a new entry in the kvIndex data. // If the index is unique and there is an existing entry with the same key, // Create will return the existing entry's handle as the first return value, ErrKeyExists as the second return value. -func (c *index) Create(sctx sessionctx.Context, us kv.UnionStore, indexedValues []types.Datum, h kv.Handle, handleRestoreData []types.Datum, opts ...table.CreateIdxOptFunc) (kv.Handle, error) { +func (c *index) Create(sctx sessionctx.Context, txn kv.Transaction, indexedValues []types.Datum, h kv.Handle, handleRestoreData []types.Datum, opts ...table.CreateIdxOptFunc) (kv.Handle, error) { if c.Meta().Unique { - us.CacheTableInfo(c.phyTblID, c.tblInfo) + txn.CacheTableInfo(c.phyTblID, c.tblInfo) } var opt table.CreateIdxOpt for _, fn := range opts { @@ -178,6 +178,7 @@ func (c *index) Create(sctx sessionctx.Context, us kv.UnionStore, indexedValues return nil, err } + us := txn.GetUnionStore() if !distinct || skipCheck || opt.Untouched { err = us.GetMemBuffer().Set(key, idxVal) return nil, err diff --git a/table/tables/index_test.go b/table/tables/index_test.go index 6004533c8a8ba..3648118088879 100644 --- a/table/tables/index_test.go +++ b/table/tables/index_test.go @@ -90,7 +90,7 @@ func (s *testIndexSuite) TestIndex(c *C) { values := types.MakeDatums(1, 2) mockCtx := mock.NewContext() - _, err = index.Create(mockCtx, txn.GetUnionStore(), values, kv.IntHandle(1), nil) + _, err = index.Create(mockCtx, txn, values, kv.IntHandle(1), nil) c.Assert(err, IsNil) it, err := index.SeekFirst(txn) @@ -122,7 +122,7 @@ func (s *testIndexSuite) TestIndex(c *C) { c.Assert(terror.ErrorEqual(err, io.EOF), IsTrue, Commentf("err %v", err)) it.Close() - _, err = index.Create(mockCtx, txn.GetUnionStore(), values, kv.IntHandle(0), nil) + _, err = index.Create(mockCtx, txn, values, kv.IntHandle(0), nil) c.Assert(err, IsNil) _, err = index.SeekFirst(txn) @@ -177,10 +177,10 @@ func (s *testIndexSuite) TestIndex(c *C) { txn, err = s.s.Begin() c.Assert(err, IsNil) - _, err = index.Create(mockCtx, txn.GetUnionStore(), values, kv.IntHandle(1), nil) + _, err = index.Create(mockCtx, txn, values, kv.IntHandle(1), nil) c.Assert(err, IsNil) - _, err = index.Create(mockCtx, txn.GetUnionStore(), values, kv.IntHandle(2), nil) + _, err = index.Create(mockCtx, txn, values, kv.IntHandle(2), nil) c.Assert(err, NotNil) it, err = index.SeekFirst(txn) @@ -215,7 +215,7 @@ func (s *testIndexSuite) TestIndex(c *C) { // Test the function of Next when the value of unique key is nil. values2 := types.MakeDatums(nil, nil) - _, err = index.Create(mockCtx, txn.GetUnionStore(), values2, kv.IntHandle(2), nil) + _, err = index.Create(mockCtx, txn, values2, kv.IntHandle(2), nil) c.Assert(err, IsNil) it, err = index.SeekFirst(txn) c.Assert(err, IsNil) @@ -257,7 +257,7 @@ func (s *testIndexSuite) TestCombineIndexSeek(c *C) { mockCtx := mock.NewContext() values := types.MakeDatums("abc", "def") - _, err = index.Create(mockCtx, txn.GetUnionStore(), values, kv.IntHandle(1), nil) + _, err = index.Create(mockCtx, txn, values, kv.IntHandle(1), nil) c.Assert(err, IsNil) index2 := tables.NewIndex(tblInfo.ID, tblInfo, tblInfo.Indices[0]) @@ -298,7 +298,7 @@ func (s *testIndexSuite) TestSingleColumnCommonHandle(c *C) { for _, idx := range []table.Index{idxUnique, idxNonUnique} { key, _, err := idx.GenIndexKey(sc, idxColVals, commonHandle, nil) c.Assert(err, IsNil) - _, err = idx.Create(mockCtx, txn.GetUnionStore(), idxColVals, commonHandle, nil) + _, err = idx.Create(mockCtx, txn, idxColVals, commonHandle, nil) c.Assert(err, IsNil) val, err := txn.Get(context.Background(), key) c.Assert(err, IsNil) @@ -362,7 +362,7 @@ func (s *testIndexSuite) TestMultiColumnCommonHandle(c *C) { for _, idx := range []table.Index{idxUnique, idxNonUnique} { key, _, err := idx.GenIndexKey(sc, idxColVals, commonHandle, nil) c.Assert(err, IsNil) - _, err = idx.Create(mockCtx, txn.GetUnionStore(), idxColVals, commonHandle, nil) + _, err = idx.Create(mockCtx, txn, idxColVals, commonHandle, nil) c.Assert(err, IsNil) val, err := txn.Get(context.Background(), key) c.Assert(err, IsNil) diff --git a/table/tables/tables.go b/table/tables/tables.go index f22e5a2127a5d..0a29dd7f4cf28 100644 --- a/table/tables/tables.go +++ b/table/tables/tables.go @@ -617,7 +617,7 @@ func (t *TableCommon) AddRecord(sctx sessionctx.Context, r []types.Datum, opts . hasRecordID = true } else { tblInfo := t.Meta() - txn.GetUnionStore().CacheTableInfo(t.physicalTableID, tblInfo) + txn.CacheTableInfo(t.physicalTableID, tblInfo) if tblInfo.PKIsHandle { recordID = kv.IntHandle(r[tblInfo.GetPkColInfo().Offset].GetInt64()) hasRecordID = true @@ -847,7 +847,7 @@ func (t *TableCommon) addIndices(sctx sessionctx.Context, recordID kv.Handle, r dupErr = kv.ErrKeyExists.FastGenByArgs(entryKey, idxMeta.Name.String()) } rsData := TryGetHandleRestoredDataWrapper(t, r, nil) - if dupHandle, err := v.Create(sctx, txn.GetUnionStore(), indexVals, recordID, rsData, opts...); err != nil { + if dupHandle, err := v.Create(sctx, txn, indexVals, recordID, rsData, opts...); err != nil { if kv.ErrKeyExists.Equal(err) { return dupHandle, dupErr } @@ -1146,7 +1146,7 @@ func (t *TableCommon) buildIndexForRow(ctx sessionctx.Context, h kv.Handle, vals opts = append(opts, table.IndexIsUntouched) } rsData := TryGetHandleRestoredDataWrapper(t, newData, nil) - if _, err := idx.Create(ctx, txn.GetUnionStore(), vals, h, rsData, opts...); err != nil { + if _, err := idx.Create(ctx, txn, vals, h, rsData, opts...); err != nil { if kv.ErrKeyExists.Equal(err) { // Make error message consistent with MySQL. entryKey, err1 := t.genIndexKeyStr(vals) diff --git a/util/mock/context.go b/util/mock/context.go index 3077669c36ec3..4350de1d81529 100644 --- a/util/mock/context.go +++ b/util/mock/context.go @@ -66,6 +66,20 @@ func (txn *wrapTxn) GetUnionStore() kv.UnionStore { return txn.Transaction.GetUnionStore() } +func (txn *wrapTxn) CacheTableInfo(id int64, info *model.TableInfo) { + if txn.Transaction == nil { + return + } + txn.Transaction.CacheTableInfo(id, info) +} + +func (txn *wrapTxn) GetTableInfo(id int64) *model.TableInfo { + if txn.Transaction == nil { + return nil + } + return txn.Transaction.GetTableInfo(id) +} + // Execute implements sqlexec.SQLExecutor Execute interface. func (c *Context) Execute(ctx context.Context, sql string) ([]sqlexec.RecordSet, error) { return nil, errors.Errorf("Not Supported.")