From f53e3c72cc3c83619d70418ff2c8707fe2168932 Mon Sep 17 00:00:00 2001 From: lance6716 Date: Mon, 27 Jun 2022 10:54:38 +0800 Subject: [PATCH 1/4] ddl: for schema-level DDL method parameter is now XXXStmt (#35722) ref pingcap/tidb#35665, close pingcap/tidb#35734 --- ddl/ddl.go | 4 +- ddl/ddl_api.go | 98 ++++++++++++++++--- domain/domain_test.go | 19 +++- executor/ddl.go | 75 +------------- .../realtikvtest/sessiontest/session_test.go | 3 +- 5 files changed, 104 insertions(+), 95 deletions(-) diff --git a/ddl/ddl.go b/ddl/ddl.go index e77a9da01e4a9..0ea01e1222eb5 100644 --- a/ddl/ddl.go +++ b/ddl/ddl.go @@ -101,9 +101,9 @@ var ( // DDL is responsible for updating schema in data store and maintaining in-memory InfoSchema cache. type DDL interface { - CreateSchema(ctx sessionctx.Context, schema model.CIStr, charsetInfo *ast.CharsetOpt, placementPolicyRef *model.PolicyRefInfo) error + CreateSchema(ctx sessionctx.Context, stmt *ast.CreateDatabaseStmt) error AlterSchema(sctx sessionctx.Context, stmt *ast.AlterDatabaseStmt) error - DropSchema(ctx sessionctx.Context, schema model.CIStr) error + DropSchema(ctx sessionctx.Context, stmt *ast.DropDatabaseStmt) error CreateTable(ctx sessionctx.Context, stmt *ast.CreateTableStmt) error CreateView(ctx sessionctx.Context, stmt *ast.CreateViewStmt) error DropTable(ctx sessionctx.Context, tableIdent ast.Ident) (err error) diff --git a/ddl/ddl_api.go b/ddl/ddl_api.go index 5a7758a36ba88..01f57c3ce1d18 100644 --- a/ddl/ddl_api.go +++ b/ddl/ddl_api.go @@ -82,21 +82,76 @@ const ( tiflashCheckPendingTablesRetry = 7 ) -func (d *ddl) CreateSchema(ctx sessionctx.Context, schema model.CIStr, charsetInfo *ast.CharsetOpt, placementPolicyRef *model.PolicyRefInfo) (err error) { - dbInfo := &model.DBInfo{Name: schema} - if charsetInfo != nil { - chs, coll, err := ResolveCharsetCollation(ast.CharsetOpt{Chs: charsetInfo.Chs, Col: charsetInfo.Col}) +func (d *ddl) CreateSchema(ctx sessionctx.Context, stmt *ast.CreateDatabaseStmt) (err error) { + var placementPolicyRef *model.PolicyRefInfo + sessionVars := ctx.GetSessionVars() + + // If no charset and/or collation is specified use collation_server and character_set_server + charsetOpt := &ast.CharsetOpt{} + if sessionVars.GlobalVarsAccessor != nil { + charsetOpt.Col, err = variable.GetSessionOrGlobalSystemVar(sessionVars, variable.CollationServer) if err != nil { - return errors.Trace(err) + return err + } + charsetOpt.Chs, err = variable.GetSessionOrGlobalSystemVar(sessionVars, variable.CharacterSetServer) + if err != nil { + return err + } + } + + explicitCharset := false + explicitCollation := false + if len(stmt.Options) != 0 { + for _, val := range stmt.Options { + switch val.Tp { + case ast.DatabaseOptionCharset: + charsetOpt.Chs = val.Value + explicitCharset = true + case ast.DatabaseOptionCollate: + charsetOpt.Col = val.Value + explicitCollation = true + case ast.DatabaseOptionPlacementPolicy: + placementPolicyRef = &model.PolicyRefInfo{ + Name: model.NewCIStr(val.Value), + } + } } - dbInfo.Charset = chs - dbInfo.Collate = coll - } else { - dbInfo.Charset, dbInfo.Collate = charset.GetDefaultCharsetAndCollate() } + if charsetOpt.Col != "" { + coll, err := collate.GetCollationByName(charsetOpt.Col) + if err != nil { + return err + } + + // The collation is not valid for the specified character set. + // Try to remove any of them, but not if they are explicitly defined. + if coll.CharsetName != charsetOpt.Chs { + if explicitCollation && !explicitCharset { + // Use the explicitly set collation, not the implicit charset. + charsetOpt.Chs = "" + } + if !explicitCollation && explicitCharset { + // Use the explicitly set charset, not the (session) collation. + charsetOpt.Col = "" + } + } + + } + dbInfo := &model.DBInfo{Name: stmt.Name} + chs, coll, err := ResolveCharsetCollation(ast.CharsetOpt{Chs: charsetOpt.Chs, Col: charsetOpt.Col}) + if err != nil { + return errors.Trace(err) + } + dbInfo.Charset = chs + dbInfo.Collate = coll dbInfo.PlacementPolicyRef = placementPolicyRef - return d.CreateSchemaWithInfo(ctx, dbInfo, OnExistError) + + onExist := OnExistError + if stmt.IfNotExists { + onExist = OnExistIgnore + } + return d.CreateSchemaWithInfo(ctx, dbInfo, onExist) } func (d *ddl) CreateSchemaWithInfo( @@ -147,6 +202,12 @@ func (d *ddl) CreateSchemaWithInfo( err = d.DoDDLJob(ctx, job) err = d.callHookOnChanged(job, err) + + if infoschema.ErrDatabaseExists.Equal(err) && onExist == OnExistIgnore { + ctx.GetSessionVars().StmtCtx.AppendNote(err) + return nil + } + return errors.Trace(err) } @@ -520,11 +581,14 @@ func (d *ddl) AlterSchema(sctx sessionctx.Context, stmt *ast.AlterDatabaseStmt) return nil } -func (d *ddl) DropSchema(ctx sessionctx.Context, schema model.CIStr) (err error) { +func (d *ddl) DropSchema(ctx sessionctx.Context, stmt *ast.DropDatabaseStmt) (err error) { is := d.GetInfoSchemaWithInterceptor(ctx) - old, ok := is.SchemaByName(schema) + old, ok := is.SchemaByName(stmt.Name) if !ok { - return errors.Trace(infoschema.ErrDatabaseNotExists) + if stmt.IfExists { + return nil + } + return infoschema.ErrDatabaseDropExists.GenWithStackByArgs(stmt.Name) } job := &model.Job{ SchemaID: old.ID, @@ -537,13 +601,19 @@ func (d *ddl) DropSchema(ctx sessionctx.Context, schema model.CIStr) (err error) err = d.DoDDLJob(ctx, job) err = d.callHookOnChanged(job, err) if err != nil { + if infoschema.ErrDatabaseNotExists.Equal(err) { + if stmt.IfExists { + return nil + } + return infoschema.ErrDatabaseDropExists.GenWithStackByArgs(stmt.Name) + } return errors.Trace(err) } if !config.TableLockEnabled() { return nil } // Clear table locks hold by the session. - tbs := is.SchemaTables(schema) + tbs := is.SchemaTables(stmt.Name) lockTableIDs := make([]int64, 0) for _, tb := range tbs { if ok, _ := ctx.CheckTableLocked(tb.Meta().ID); ok { diff --git a/domain/domain_test.go b/domain/domain_test.go index b5783a2b97013..98776381d1ae8 100644 --- a/domain/domain_test.go +++ b/domain/domain_test.go @@ -122,13 +122,22 @@ func TestInfo(t *testing.T) { } require.True(t, syncerStarted) - // Make sure loading schema is normal. - cs := &ast.CharsetOpt{ - Chs: "utf8", - Col: "utf8_bin", + stmt := &ast.CreateDatabaseStmt{ + Name: model.NewCIStr("aaa"), + // Make sure loading schema is normal. + Options: []*ast.DatabaseOption{ + { + Tp: ast.DatabaseOptionCharset, + Value: "utf8", + }, + { + Tp: ast.DatabaseOptionCollate, + Value: "utf8_bin", + }, + }, } ctx := mock.NewContext() - require.NoError(t, dom.ddl.CreateSchema(ctx, model.NewCIStr("aaa"), cs, nil)) + require.NoError(t, dom.ddl.CreateSchema(ctx, stmt)) require.NoError(t, dom.Reload()) require.Equal(t, int64(1), dom.InfoSchema().SchemaMetaVersion()) diff --git a/executor/ddl.go b/executor/ddl.go index af8742fb1f901..1553be2299fd7 100644 --- a/executor/ddl.go +++ b/executor/ddl.go @@ -34,7 +34,6 @@ import ( "github.com/pingcap/tidb/table" "github.com/pingcap/tidb/table/temptable" "github.com/pingcap/tidb/util/chunk" - "github.com/pingcap/tidb/util/collate" "github.com/pingcap/tidb/util/dbterror" "github.com/pingcap/tidb/util/gcutil" "github.com/pingcap/tidb/util/logutil" @@ -248,70 +247,7 @@ func (e *DDLExec) executeRenameTable(s *ast.RenameTableStmt) error { } func (e *DDLExec) executeCreateDatabase(s *ast.CreateDatabaseStmt) error { - var opt *ast.CharsetOpt - var placementPolicyRef *model.PolicyRefInfo - var err error - sessionVars := e.ctx.GetSessionVars() - - // If no charset and/or collation is specified use collation_server and character_set_server - opt = &ast.CharsetOpt{} - if sessionVars.GlobalVarsAccessor != nil { - opt.Col, err = variable.GetSessionOrGlobalSystemVar(sessionVars, variable.CollationServer) - if err != nil { - return err - } - opt.Chs, err = variable.GetSessionOrGlobalSystemVar(sessionVars, variable.CharacterSetServer) - if err != nil { - return err - } - } - - explicitCharset := false - explicitCollation := false - if len(s.Options) != 0 { - for _, val := range s.Options { - switch val.Tp { - case ast.DatabaseOptionCharset: - opt.Chs = val.Value - explicitCharset = true - case ast.DatabaseOptionCollate: - opt.Col = val.Value - explicitCollation = true - case ast.DatabaseOptionPlacementPolicy: - placementPolicyRef = &model.PolicyRefInfo{ - Name: model.NewCIStr(val.Value), - } - } - } - } - - if opt.Col != "" { - coll, err := collate.GetCollationByName(opt.Col) - if err != nil { - return err - } - - // The collation is not valid for the specified character set. - // Try to remove any of them, but not if they are explicitly defined. - if coll.CharsetName != opt.Chs { - if explicitCollation && !explicitCharset { - // Use the explicitly set collation, not the implicit charset. - opt.Chs = "" - } - if !explicitCollation && explicitCharset { - // Use the explicitly set charset, not the (session) collation. - opt.Col = "" - } - } - - } - - err = domain.GetDomain(e.ctx).DDL().CreateSchema(e.ctx, s.Name, opt, placementPolicyRef) - if err != nil { - if infoschema.ErrDatabaseExists.Equal(err) && s.IfNotExists { - err = nil - } - } + err := domain.GetDomain(e.ctx).DDL().CreateSchema(e.ctx, s) return err } @@ -383,14 +319,7 @@ func (e *DDLExec) executeDropDatabase(s *ast.DropDatabaseStmt) error { return errors.New("Drop 'mysql' database is forbidden") } - err := domain.GetDomain(e.ctx).DDL().DropSchema(e.ctx, dbName) - if infoschema.ErrDatabaseNotExists.Equal(err) { - if s.IfExists { - err = nil - } else { - err = infoschema.ErrDatabaseDropExists.GenWithStackByArgs(s.Name) - } - } + err := domain.GetDomain(e.ctx).DDL().DropSchema(e.ctx, s) sessionVars := e.ctx.GetSessionVars() if err == nil && strings.ToLower(sessionVars.CurrentDB) == dbName.L { sessionVars.CurrentDB = "" diff --git a/tests/realtikvtest/sessiontest/session_test.go b/tests/realtikvtest/sessiontest/session_test.go index 4bc983d94b40f..9262ebe196498 100644 --- a/tests/realtikvtest/sessiontest/session_test.go +++ b/tests/realtikvtest/sessiontest/session_test.go @@ -30,6 +30,7 @@ import ( "github.com/pingcap/tidb/domain" "github.com/pingcap/tidb/errno" "github.com/pingcap/tidb/kv" + "github.com/pingcap/tidb/parser/ast" "github.com/pingcap/tidb/parser/auth" "github.com/pingcap/tidb/parser/format" "github.com/pingcap/tidb/parser/model" @@ -1363,7 +1364,7 @@ func TestDoDDLJobQuit(t *testing.T) { defer func() { require.NoError(t, failpoint.Disable("github.com/pingcap/tidb/ddl/storeCloseInLoop")) }() // this DDL call will enter deadloop before this fix - err = dom.DDL().CreateSchema(se, model.NewCIStr("testschema"), nil, nil) + err = dom.DDL().CreateSchema(se, &ast.CreateDatabaseStmt{Name: model.NewCIStr("testschema")}) require.Equal(t, "context canceled", err.Error()) } From 0998cba23d1d7e13f8095ec579d744a4d943ec8d Mon Sep 17 00:00:00 2001 From: Spade A <71589810+SpadeA-Tang@users.noreply.github.com> Date: Mon, 27 Jun 2022 11:58:38 +0800 Subject: [PATCH 2/4] txn: refactor ts acquisition within build and execute phases (#35376) close pingcap/tidb#35377 --- executor/adapter.go | 58 ++-- executor/batch_point_get.go | 20 +- executor/builder.go | 72 ++--- executor/executor.go | 19 +- executor/point_get.go | 35 ++- executor/trace_test.go | 2 +- sessiontxn/failpoint.go | 18 ++ sessiontxn/isolation/readcommitted.go | 100 +++--- sessiontxn/isolation/readcommitted_test.go | 92 +++++- sessiontxn/isolation/repeatable_read.go | 71 +++-- sessiontxn/isolation/repeatable_read_test.go | 304 +++++++++++++++---- sessiontxn/txn_context_test.go | 2 +- 12 files changed, 566 insertions(+), 227 deletions(-) diff --git a/executor/adapter.go b/executor/adapter.go index faf894816c702..31a8c8d20150f 100644 --- a/executor/adapter.go +++ b/executor/adapter.go @@ -351,7 +351,7 @@ func IsFastPlan(p plannercore.Plan) bool { } // Exec builds an Executor from a plan. If the Executor doesn't return result, -// like the INSERT, UPDATE statements, it executes in this function, if the Executor returns +// like the INSERT, UPDATE statements, it executes in this function. If the Executor returns // result, execution is done after this function returns, in the returned sqlexec.RecordSet Next method. func (a *ExecStmt) Exec(ctx context.Context) (_ sqlexec.RecordSet, err error) { defer func() { @@ -708,7 +708,10 @@ func (a *ExecStmt) handlePessimisticDML(ctx context.Context, e Executor) error { keys = filterTemporaryTableKeys(sctx.GetSessionVars(), keys) seVars := sctx.GetSessionVars() keys = filterLockTableKeys(seVars.StmtCtx, keys) - lockCtx := newLockCtx(seVars, seVars.LockWaitTimeout, len(keys)) + lockCtx, err := newLockCtx(sctx, seVars.LockWaitTimeout, len(keys)) + if err != nil { + return err + } var lockKeyStats *util.LockKeysDetails ctx = context.WithValue(ctx, util.LockKeysDetailCtxKey, &lockKeyStats) startLocking := time.Now() @@ -730,43 +733,18 @@ func (a *ExecStmt) handlePessimisticDML(ctx context.Context, e Executor) error { } } -// UpdateForUpdateTS updates the ForUpdateTS, if newForUpdateTS is 0, it obtain a new TS from PD. -func UpdateForUpdateTS(seCtx sessionctx.Context, newForUpdateTS uint64) error { - txn, err := seCtx.Txn(false) - if err != nil { - return err - } - if !txn.Valid() { - return errors.Trace(kv.ErrInvalidTxn) - } - - // The Oracle serializable isolation is actually SI in pessimistic mode. - // Do not update ForUpdateTS when the user is using the Serializable isolation level. - // It can be used temporarily on the few occasions when an Oracle-like isolation level is needed. - // Support for this does not mean that TiDB supports serializable isolation of MySQL. - // tidb_skip_isolation_level_check should still be disabled by default. - if seCtx.GetSessionVars().IsIsolation(ast.Serializable) { - return nil - } - if newForUpdateTS == 0 { - // Because the ForUpdateTS is used for the snapshot for reading data in DML. - // We can avoid allocating a global TSO here to speed it up by using the local TSO. - version, err := seCtx.GetStore().CurrentVersion(seCtx.GetSessionVars().TxnCtx.TxnScope) - if err != nil { - return err - } - newForUpdateTS = version.Ver - } - seCtx.GetSessionVars().TxnCtx.SetForUpdateTS(newForUpdateTS) - txn.SetOption(kv.SnapshotTS, seCtx.GetSessionVars().TxnCtx.GetForUpdateTS()) - return nil -} - // handlePessimisticLockError updates TS and rebuild executor if the err is write conflict. func (a *ExecStmt) handlePessimisticLockError(ctx context.Context, lockErr error) (_ Executor, err error) { if lockErr == nil { return nil, nil } + failpoint.Inject("assertPessimisticLockErr", func() { + if terror.ErrorEqual(kv.ErrWriteConflict, lockErr) { + sessiontxn.AddAssertEntranceForLockError(a.Ctx, "errWriteConflict") + } else if terror.ErrorEqual(kv.ErrKeyExists, lockErr) { + sessiontxn.AddAssertEntranceForLockError(a.Ctx, "errDuplicateKey") + } + }) defer func() { if _, ok := errors.Cause(err).(*tikverr.ErrDeadlock); ok { @@ -774,7 +752,8 @@ func (a *ExecStmt) handlePessimisticLockError(ctx context.Context, lockErr error } }() - action, err := sessiontxn.GetTxnManager(a.Ctx).OnStmtErrorForNextAction(sessiontxn.StmtErrAfterPessimisticLock, lockErr) + txnManager := sessiontxn.GetTxnManager(a.Ctx) + action, err := txnManager.OnStmtErrorForNextAction(sessiontxn.StmtErrAfterPessimisticLock, lockErr) if err != nil { return nil, err } @@ -789,10 +768,17 @@ func (a *ExecStmt) handlePessimisticLockError(ctx context.Context, lockErr error a.retryCount++ a.retryStartTime = time.Now() - err = sessiontxn.GetTxnManager(a.Ctx).OnStmtRetry(ctx) + err = txnManager.OnStmtRetry(ctx) if err != nil { return nil, err } + + // Without this line of code, the result will still be correct. But it can ensure that the update time of for update read + // is determined which is beneficial for testing. + if _, err = txnManager.GetStmtForUpdateTS(); err != nil { + return nil, err + } + breakpoint.Inject(a.Ctx, sessiontxn.BreakPointOnStmtRetryAfterLockError) e, err := a.buildExecutor() diff --git a/executor/batch_point_get.go b/executor/batch_point_get.go index 0ce745d172e4a..b5eb68a8b12de 100644 --- a/executor/batch_point_get.go +++ b/executor/batch_point_get.go @@ -56,7 +56,6 @@ type BatchPointGetExec struct { singlePart bool partTblID int64 idxVals [][]types.Datum - startTS uint64 readReplicaScope string isStaleness bool snapshotTS uint64 @@ -97,13 +96,9 @@ func (e *BatchPointGetExec) buildVirtualColumnInfo() { // Open implements the Executor interface. func (e *BatchPointGetExec) Open(context.Context) error { - e.snapshotTS = e.startTS sessVars := e.ctx.GetSessionVars() txnCtx := sessVars.TxnCtx stmtCtx := sessVars.StmtCtx - if e.lock { - e.snapshotTS = txnCtx.GetForUpdateTS() - } txn, err := e.ctx.Txn(false) if err != nil { return err @@ -111,8 +106,8 @@ func (e *BatchPointGetExec) Open(context.Context) error { e.txn = txn var snapshot kv.Snapshot if txn.Valid() && txnCtx.StartTS == txnCtx.GetForUpdateTS() && txnCtx.StartTS == e.snapshotTS { - // We can safely reuse the transaction snapshot if startTS is equal to forUpdateTS. - // The snapshot may contains cache that can reduce RPC call. + // We can safely reuse the transaction snapshot if snapshotTS is equal to forUpdateTS. + // The snapshot may contain cache that can reduce RPC call. snapshot = txn.GetSnapshot() } else { snapshot = e.ctx.GetSnapshotWithTS(e.snapshotTS) @@ -540,13 +535,16 @@ func (e *BatchPointGetExec) initialize(ctx context.Context) error { } // LockKeys locks the keys for pessimistic transaction. -func LockKeys(ctx context.Context, seCtx sessionctx.Context, lockWaitTime int64, keys ...kv.Key) error { - txnCtx := seCtx.GetSessionVars().TxnCtx - lctx := newLockCtx(seCtx.GetSessionVars(), lockWaitTime, len(keys)) +func LockKeys(ctx context.Context, sctx sessionctx.Context, lockWaitTime int64, keys ...kv.Key) error { + txnCtx := sctx.GetSessionVars().TxnCtx + lctx, err := newLockCtx(sctx, lockWaitTime, len(keys)) + if err != nil { + return err + } if txnCtx.IsPessimistic { lctx.InitReturnValues(len(keys)) } - err := doLockKeys(ctx, seCtx, lctx, keys...) + err = doLockKeys(ctx, sctx, lctx, keys...) if err != nil { return err } diff --git a/executor/builder.go b/executor/builder.go index 8a44c09aaf033..73a4995696de0 100644 --- a/executor/builder.go +++ b/executor/builder.go @@ -657,11 +657,11 @@ func (b *executorBuilder) buildSelectLock(v *plannercore.PhysicalLock) Executor defer func() { b.inSelectLockStmt = false }() } b.hasLock = true - if b.err = b.updateForUpdateTSIfNeeded(v.Children()[0]); b.err != nil { + + // Build 'select for update' using the 'for update' ts. + if b.forUpdateTS, b.err = b.getSnapshotTS(); b.err != nil { return nil } - // Build 'select for update' using the 'for update' ts. - b.forUpdateTS = b.ctx.GetSessionVars().TxnCtx.GetForUpdateTS() src := b.build(v.Children()[0]) if b.err != nil { @@ -865,14 +865,11 @@ func (b *executorBuilder) buildSetConfig(v *plannercore.SetConfig) Executor { func (b *executorBuilder) buildInsert(v *plannercore.Insert) Executor { b.inInsertStmt = true - if v.SelectPlan != nil { - // Try to update the forUpdateTS for insert/replace into select statements. - // Set the selectPlan parameter to nil to make it always update the forUpdateTS. - if b.err = b.updateForUpdateTSIfNeeded(nil); b.err != nil { - return nil - } + + if b.forUpdateTS, b.err = b.getSnapshotTS(); b.err != nil { + return nil } - b.forUpdateTS = b.ctx.GetSessionVars().TxnCtx.GetForUpdateTS() + selectExec := b.build(v.SelectPlan) if b.err != nil { return nil @@ -2116,10 +2113,11 @@ func (b *executorBuilder) buildUpdate(v *plannercore.Update) Executor { } } } - if b.err = b.updateForUpdateTSIfNeeded(v.SelectPlan); b.err != nil { + + if b.forUpdateTS, b.err = b.getSnapshotTS(); b.err != nil { return nil } - b.forUpdateTS = b.ctx.GetSessionVars().TxnCtx.GetForUpdateTS() + selExec := b.build(v.SelectPlan) if b.err != nil { return nil @@ -2173,10 +2171,11 @@ func (b *executorBuilder) buildDelete(v *plannercore.Delete) Executor { for _, info := range v.TblColPosInfos { tblID2table[info.TblID], _ = b.is.TableByID(info.TblID) } - if b.err = b.updateForUpdateTSIfNeeded(v.SelectPlan); b.err != nil { + + if b.forUpdateTS, b.err = b.getSnapshotTS(); b.err != nil { return nil } - b.forUpdateTS = b.ctx.GetSessionVars().TxnCtx.GetForUpdateTS() + selExec := b.build(v.SelectPlan) if b.err != nil { return nil @@ -2192,34 +2191,6 @@ func (b *executorBuilder) buildDelete(v *plannercore.Delete) Executor { return deleteExec } -// updateForUpdateTSIfNeeded updates the ForUpdateTS for a pessimistic transaction if needed. -// PointGet executor will get conflict error if the ForUpdateTS is older than the latest commitTS, -// so we don't need to update now for better latency. -func (b *executorBuilder) updateForUpdateTSIfNeeded(selectPlan plannercore.PhysicalPlan) error { - txnCtx := b.ctx.GetSessionVars().TxnCtx - if !txnCtx.IsPessimistic { - return nil - } - if _, ok := selectPlan.(*plannercore.PointGetPlan); ok { - return nil - } - // Activate the invalid txn, use the txn startTS as newForUpdateTS - txn, err := b.ctx.Txn(false) - if err != nil { - return err - } - if !txn.Valid() { - _, err := b.ctx.Txn(true) - if err != nil { - return err - } - return nil - } - // GetStmtForUpdateTS will auto update the for update ts if necessary - _, err = sessiontxn.GetTxnManager(b.ctx).GetStmtForUpdateTS() - return err -} - func (b *executorBuilder) buildAnalyzeIndexPushdown(task plannercore.AnalyzeIndexTask, opts map[ast.AnalyzeOptionType]uint64, autoAnalyze string) *analyzeTask { job := &statistics.AnalyzeJob{DBName: task.DBName, TableName: task.TableName, PartitionName: task.PartitionName, JobInfo: autoAnalyze + "analyze index " + task.IndexInfo.Name.O} _, offset := timeutil.Zone(b.ctx.GetSessionVars().Location()) @@ -4663,18 +4634,26 @@ func (b *executorBuilder) buildBatchPointGet(plan *plannercore.BatchPointGetPlan return nil } - startTS, err := b.getSnapshotTS() + if plan.Lock && !b.inSelectLockStmt { + b.inSelectLockStmt = true + defer func() { + b.inSelectLockStmt = false + }() + } + + snapshotTS, err := b.getSnapshotTS() if err != nil { b.err = err return nil } + decoder := NewRowDecoder(b.ctx, plan.Schema(), plan.TblInfo) e := &BatchPointGetExec{ baseExecutor: newBaseExecutor(b.ctx, plan.Schema(), plan.ID()), tblInfo: plan.TblInfo, idxInfo: plan.IndexInfo, rowDecoder: decoder, - startTS: startTS, + snapshotTS: snapshotTS, readReplicaScope: b.readReplicaScope, isStaleness: b.isStaleness, keepOrder: plan.KeepOrder, @@ -4687,9 +4666,11 @@ func (b *executorBuilder) buildBatchPointGet(plan *plannercore.BatchPointGetPlan partTblID: plan.PartTblID, columns: plan.Columns, } + if plan.TblInfo.TableCacheStatusType == model.TableCacheStatusEnable { - e.cacheTable = b.getCacheTable(plan.TblInfo, startTS) + e.cacheTable = b.getCacheTable(plan.TblInfo, snapshotTS) } + if plan.TblInfo.TempTableType != model.TempTableNone { // Temporary table should not do any lock operations e.lock = false @@ -4699,6 +4680,7 @@ func (b *executorBuilder) buildBatchPointGet(plan *plannercore.BatchPointGetPlan if e.lock { b.hasLock = true } + var capacity int if plan.IndexInfo != nil && !isCommonHandleRead(plan.TblInfo, plan.IndexInfo) { e.idxVals = plan.IndexValues diff --git a/executor/executor.go b/executor/executor.go index aa71929859898..d2b726f24adbb 100644 --- a/executor/executor.go +++ b/executor/executor.go @@ -49,6 +49,7 @@ import ( "github.com/pingcap/tidb/sessionctx" "github.com/pingcap/tidb/sessionctx/stmtctx" "github.com/pingcap/tidb/sessionctx/variable" + "github.com/pingcap/tidb/sessiontxn" "github.com/pingcap/tidb/table" "github.com/pingcap/tidb/table/tables" "github.com/pingcap/tidb/tablecodec" @@ -1042,12 +1043,20 @@ func (e *SelectLockExec) Next(ctx context.Context, req *chunk.Chunk) error { for id := range e.tblID2Handle { e.updateDeltaForTableID(id) } - - return doLockKeys(ctx, e.ctx, newLockCtx(e.ctx.GetSessionVars(), lockWaitTime, len(e.keys)), e.keys...) + lockCtx, err := newLockCtx(e.ctx, lockWaitTime, len(e.keys)) + if err != nil { + return err + } + return doLockKeys(ctx, e.ctx, lockCtx, e.keys...) } -func newLockCtx(seVars *variable.SessionVars, lockWaitTime int64, numKeys int) *tikvstore.LockCtx { - lockCtx := tikvstore.NewLockCtx(seVars.TxnCtx.GetForUpdateTS(), lockWaitTime, seVars.StmtCtx.GetLockWaitStartTime()) +func newLockCtx(sctx sessionctx.Context, lockWaitTime int64, numKeys int) (*tikvstore.LockCtx, error) { + seVars := sctx.GetSessionVars() + forUpdateTS, err := sessiontxn.GetTxnManager(sctx).GetStmtForUpdateTS() + if err != nil { + return nil, err + } + lockCtx := tikvstore.NewLockCtx(forUpdateTS, lockWaitTime, seVars.StmtCtx.GetLockWaitStartTime()) lockCtx.Killed = &seVars.Killed lockCtx.PessimisticLockWaited = &seVars.StmtCtx.PessimisticLockWaited lockCtx.LockKeysDuration = &seVars.StmtCtx.LockKeysDuration @@ -1082,7 +1091,7 @@ func newLockCtx(seVars *variable.SessionVars, lockWaitTime int64, numKeys int) * if lockCtx.ForUpdateTS > 0 && seVars.AssertionLevel != variable.AssertionLevelOff { lockCtx.InitCheckExistence(numKeys) } - return lockCtx + return lockCtx, nil } // doLockKeys is the main entry for pessimistic lock keys diff --git a/executor/point_get.go b/executor/point_get.go index f33ba20b5dd5a..1b4d6666663b5 100644 --- a/executor/point_get.go +++ b/executor/point_get.go @@ -49,11 +49,19 @@ func (b *executorBuilder) buildPointGet(p *plannercore.PointGetPlan) Executor { return nil } - startTS, err := b.getSnapshotTS() + if p.Lock && !b.inSelectLockStmt { + b.inSelectLockStmt = true + defer func() { + b.inSelectLockStmt = false + }() + } + + snapshotTS, err := b.getSnapshotTS() if err != nil { b.err = err return nil } + e := &PointGetExecutor{ baseExecutor: newBaseExecutor(b.ctx, p.Schema(), p.ID()), readReplicaScope: b.readReplicaScope, @@ -61,14 +69,17 @@ func (b *executorBuilder) buildPointGet(p *plannercore.PointGetPlan) Executor { } if p.TblInfo.TableCacheStatusType == model.TableCacheStatusEnable { - e.cacheTable = b.getCacheTable(p.TblInfo, startTS) + e.cacheTable = b.getCacheTable(p.TblInfo, snapshotTS) } + e.base().initCap = 1 e.base().maxChunkSize = 1 - e.Init(p, startTS) + e.Init(p, snapshotTS) + if e.lock { b.hasLock = true } + return e } @@ -83,7 +94,7 @@ type PointGetExecutor struct { idxKey kv.Key handleVal []byte idxVals []types.Datum - startTS uint64 + snapshotTS uint64 readReplicaScope string isStaleness bool txn kv.Transaction @@ -106,13 +117,13 @@ type PointGetExecutor struct { } // Init set fields needed for PointGetExecutor reuse, this does NOT change baseExecutor field -func (e *PointGetExecutor) Init(p *plannercore.PointGetPlan, startTs uint64) { +func (e *PointGetExecutor) Init(p *plannercore.PointGetPlan, snapshotTS uint64) { decoder := NewRowDecoder(e.ctx, p.Schema(), p.TblInfo) e.tblInfo = p.TblInfo e.handle = p.Handle e.idxInfo = p.IndexInfo e.idxVals = p.IndexValues - e.startTS = startTs + e.snapshotTS = snapshotTS e.done = false if e.tblInfo.TempTableType == model.TempTableNone { e.lock = p.Lock @@ -142,10 +153,7 @@ func (e *PointGetExecutor) buildVirtualColumnInfo() { // Open implements the Executor interface. func (e *PointGetExecutor) Open(context.Context) error { txnCtx := e.ctx.GetSessionVars().TxnCtx - snapshotTS := e.startTS - if e.lock { - snapshotTS = txnCtx.GetForUpdateTS() - } + snapshotTS := e.snapshotTS var err error e.txn, err = e.ctx.Txn(false) if err != nil { @@ -381,9 +389,12 @@ func (e *PointGetExecutor) lockKeyIfNeeded(ctx context.Context, key []byte) erro } if e.lock { seVars := e.ctx.GetSessionVars() - lockCtx := newLockCtx(seVars, e.lockWaitTime, 1) + lockCtx, err := newLockCtx(e.ctx, e.lockWaitTime, 1) + if err != nil { + return err + } lockCtx.InitReturnValues(1) - err := doLockKeys(ctx, e.ctx, lockCtx, key) + err = doLockKeys(ctx, e.ctx, lockCtx, key) if err != nil { return err } diff --git a/executor/trace_test.go b/executor/trace_test.go index f8e8e91ddebd7..9b448670cc39a 100644 --- a/executor/trace_test.go +++ b/executor/trace_test.go @@ -33,7 +33,7 @@ func TestTraceExec(t *testing.T) { require.GreaterOrEqual(t, len(rows), 1) // +---------------------------+-----------------+------------+ - // | operation | startTS | duration | + // | operation | snapshotTS | duration | // +---------------------------+-----------------+------------+ // | session.getTxnFuture | 22:08:38.247834 | 78.909µs | // | ├─session.Execute | 22:08:38.247829 | 1.478487ms | diff --git a/sessiontxn/failpoint.go b/sessiontxn/failpoint.go index d33984649b371..b41be21165908 100644 --- a/sessiontxn/failpoint.go +++ b/sessiontxn/failpoint.go @@ -43,6 +43,10 @@ var BreakPointBeforeExecutorFirstRun = "beforeExecutorFirstRun" // Only for test var BreakPointOnStmtRetryAfterLockError = "lockErrorAndThenOnStmtRetryCalled" +// AssertLockErr is used to record the lock errors we encountered +// Only for test +var AssertLockErr stringutil.StringerStr = "assertLockError" + // RecordAssert is used only for test func RecordAssert(sctx sessionctx.Context, name string, value interface{}) { records, ok := sctx.Value(AssertRecordsKey).(map[string]interface{}) @@ -94,6 +98,20 @@ func AssertTxnManagerReadTS(sctx sessionctx.Context, expected uint64) { } } +// AddAssertEntranceForLockError is used only for test +func AddAssertEntranceForLockError(sctx sessionctx.Context, name string) { + records, ok := sctx.Value(AssertLockErr).(map[string]int) + if !ok { + records = make(map[string]int) + sctx.SetValue(AssertLockErr, records) + } + if v, ok := records[name]; ok { + records[name] = v + 1 + } else { + records[name] = 1 + } +} + // ExecTestHook is used only for test. It consumes hookKey in session wait do what it gets from it. func ExecTestHook(sctx sessionctx.Context, hookKey fmt.Stringer) { c := sctx.Value(hookKey) diff --git a/sessiontxn/isolation/readcommitted.go b/sessiontxn/isolation/readcommitted.go index 8a409000d7049..5fb316b59f8bf 100644 --- a/sessiontxn/isolation/readcommitted.go +++ b/sessiontxn/isolation/readcommitted.go @@ -21,6 +21,7 @@ import ( "github.com/pingcap/tidb/kv" "github.com/pingcap/tidb/parser/ast" "github.com/pingcap/tidb/parser/terror" + plannercore "github.com/pingcap/tidb/planner/core" "github.com/pingcap/tidb/sessionctx" "github.com/pingcap/tidb/sessionctx/variable" "github.com/pingcap/tidb/sessiontxn" @@ -31,20 +32,15 @@ import ( ) type stmtState struct { - stmtTS uint64 - stmtTSFuture oracle.Future - stmtUseStartTS bool - onNextRetryOrStmt func() error + stmtTS uint64 + stmtTSFuture oracle.Future + stmtUseStartTS bool } func (s *stmtState) prepareStmt(useStartTS bool) error { - onNextStmt := s.onNextRetryOrStmt *s = stmtState{ stmtUseStartTS: useStartTS, } - if onNextStmt != nil { - return onNextStmt() - } return nil } @@ -52,7 +48,9 @@ func (s *stmtState) prepareStmt(useStartTS bool) error { type PessimisticRCTxnContextProvider struct { baseTxnContextProvider stmtState - availableRCCheckTS uint64 + latestOracleTS uint64 + // latestOracleTSValid shows whether we have already fetched a ts from pd and whether the ts we fetched is still valid. + latestOracleTSValid bool } // NewPessimisticRCTxnContextProvider returns a new PessimisticRCTxnContextProvider @@ -65,12 +63,14 @@ func NewPessimisticRCTxnContextProvider(sctx sessionctx.Context, causalConsisten txnCtx.IsPessimistic = true txnCtx.Isolation = ast.ReadCommitted }, - onTxnActive: func(txn kv.Transaction) { - txn.SetOption(kv.Pessimistic, true) - }, }, } + provider.onTxnActive = func(txn kv.Transaction) { + txn.SetOption(kv.Pessimistic, true) + provider.latestOracleTS = txn.StartTS() + provider.latestOracleTSValid = true + } provider.getStmtReadTSFunc = provider.getStmtTS provider.getStmtForUpdateTSFunc = provider.getStmtTS return provider @@ -86,9 +86,6 @@ func (p *PessimisticRCTxnContextProvider) OnStmtStart(ctx context.Context) error // OnStmtErrorForNextAction is the hook that should be called when a new statement get an error func (p *PessimisticRCTxnContextProvider) OnStmtErrorForNextAction(point sessiontxn.StmtErrorHandlePoint, err error) (sessiontxn.StmtErrorAction, error) { - // Invalid rc check for next statement or retry when error occurs - p.availableRCCheckTS = 0 - switch point { case sessiontxn.StmtErrAfterQuery: return p.handleAfterQueryError(err) @@ -117,15 +114,30 @@ func (p *PessimisticRCTxnContextProvider) prepareStmtTS() { switch { case p.stmtUseStartTS: stmtTSFuture = sessiontxn.FuncFuture(p.getTxnStartTS) - case p.availableRCCheckTS != 0 && sessVars.StmtCtx.RCCheckTS: - stmtTSFuture = sessiontxn.ConstantFuture(p.availableRCCheckTS) + case p.latestOracleTSValid && sessVars.StmtCtx.RCCheckTS: + stmtTSFuture = sessiontxn.ConstantFuture(p.latestOracleTS) default: - stmtTSFuture = sessiontxn.NewOracleFuture(p.ctx, p.sctx, sessVars.TxnCtx.TxnScope) + stmtTSFuture = p.getOracleFuture() } p.stmtTSFuture = stmtTSFuture } +func (p *PessimisticRCTxnContextProvider) getOracleFuture() sessiontxn.FuncFuture { + txnCtx := p.sctx.GetSessionVars().TxnCtx + future := sessiontxn.NewOracleFuture(p.ctx, p.sctx, txnCtx.TxnScope) + return func() (ts uint64, err error) { + if ts, err = future.Wait(); err != nil { + return + } + txnCtx.SetForUpdateTS(ts) + ts = txnCtx.GetForUpdateTS() + p.latestOracleTS = ts + p.latestOracleTSValid = true + return + } +} + func (p *PessimisticRCTxnContextProvider) getStmtTS() (ts uint64, err error) { if p.stmtTS != 0 { return p.stmtTS, nil @@ -141,13 +153,8 @@ func (p *PessimisticRCTxnContextProvider) getStmtTS() (ts uint64, err error) { return 0, err } - // forUpdateTS should exactly equal to the read ts - txnCtx := p.sctx.GetSessionVars().TxnCtx - txnCtx.SetForUpdateTS(ts) txn.SetOption(kv.SnapshotTS, ts) - p.stmtTS = ts - p.availableRCCheckTS = ts return } @@ -155,16 +162,18 @@ func (p *PessimisticRCTxnContextProvider) getStmtTS() (ts uint64, err error) { // At this point the query will be retried from the beginning. func (p *PessimisticRCTxnContextProvider) handleAfterQueryError(queryErr error) (sessiontxn.StmtErrorAction, error) { sessVars := p.sctx.GetSessionVars() - if sessVars.StmtCtx.RCCheckTS && errors.ErrorEqual(queryErr, kv.ErrWriteConflict) { - logutil.Logger(p.ctx).Info("RC read with ts checking has failed, retry RC read", - zap.String("sql", sessVars.StmtCtx.OriginalSQL)) - return sessiontxn.RetryReady() + if !errors.ErrorEqual(queryErr, kv.ErrWriteConflict) || !sessVars.StmtCtx.RCCheckTS { + return sessiontxn.NoIdea() } - return sessiontxn.NoIdea() + p.latestOracleTSValid = false + logutil.Logger(p.ctx).Info("RC read with ts checking has failed, retry RC read", + zap.String("sql", sessVars.StmtCtx.OriginalSQL)) + return sessiontxn.RetryReady() } func (p *PessimisticRCTxnContextProvider) handleAfterPessimisticLockError(lockErr error) (sessiontxn.StmtErrorAction, error) { + p.latestOracleTSValid = false txnCtx := p.sctx.GetSessionVars().TxnCtx retryable := false if deadlock, ok := errors.Cause(lockErr).(*tikverr.ErrDeadlock); ok && deadlock.IsRetryable { @@ -182,16 +191,9 @@ func (p *PessimisticRCTxnContextProvider) handleAfterPessimisticLockError(lockEr retryable = true } - // force refresh ts in next retry or statement when lock error occurs - p.onNextRetryOrStmt = func() error { - _, err := p.getStmtTS() - return err - } - if retryable { return sessiontxn.RetryReady() } - return sessiontxn.ErrorAction(lockErr) } @@ -207,3 +209,31 @@ func (p *PessimisticRCTxnContextProvider) AdviseWarmup() error { p.prepareStmtTS() return nil } + +// AdviseOptimizeWithPlan in RC covers much fewer cases compared with pessimistic repeatable read. +// We only optimize with insert operator with no selection in that we do not fetch latest ts immediately. +// We only update ts if write conflict is incurred. +func (p *PessimisticRCTxnContextProvider) AdviseOptimizeWithPlan(val interface{}) (err error) { + if p.isTidbSnapshotEnabled() || p.isBeginStmtWithStaleRead() { + return nil + } + + if p.stmtUseStartTS || !p.latestOracleTSValid { + return nil + } + + plan, ok := val.(plannercore.Plan) + if !ok { + return nil + } + + if execute, ok := plan.(*plannercore.Execute); ok { + plan = execute.Plan + } + + if v, ok := plan.(*plannercore.Insert); ok && v.SelectPlan == nil { + p.stmtTSFuture = sessiontxn.ConstantFuture(p.latestOracleTS) + } + + return nil +} diff --git a/sessiontxn/isolation/readcommitted_test.go b/sessiontxn/isolation/readcommitted_test.go index d867cc689ca28..5c747eba4fa2c 100644 --- a/sessiontxn/isolation/readcommitted_test.go +++ b/sessiontxn/isolation/readcommitted_test.go @@ -21,6 +21,7 @@ import ( "time" "github.com/pingcap/errors" + "github.com/pingcap/failpoint" "github.com/pingcap/kvproto/pkg/kvrpcpb" "github.com/pingcap/tidb/config" "github.com/pingcap/tidb/executor" @@ -53,13 +54,13 @@ func TestPessimisticRCTxnContextProviderRCCheck(t *testing.T) { require.NoError(t, err) forUpdateStmt := stmts[0] - compareTS := getOracleTS(t, se) + compareTS := se.GetSessionVars().TxnCtx.StartTS // first ts should request from tso require.NoError(t, executor.ResetContextOfStmt(se, readOnlyStmt)) require.NoError(t, provider.OnStmtStart(context.TODO())) ts, err := provider.GetStmtReadTS() require.NoError(t, err) - require.Greater(t, ts, compareTS) + require.Equal(t, ts, compareTS) rcCheckTS := ts // second ts should reuse first ts @@ -101,14 +102,11 @@ func TestPessimisticRCTxnContextProviderRCCheck(t *testing.T) { nextAction, err = provider.OnStmtErrorForNextAction(sessiontxn.StmtErrAfterQuery, errors.New("err")) require.NoError(t, err) require.Equal(t, sessiontxn.StmtActionNoIdea, nextAction) - compareTS = getOracleTS(t, se) - require.Greater(t, compareTS, rcCheckTS) require.NoError(t, executor.ResetContextOfStmt(se, readOnlyStmt)) require.NoError(t, provider.OnStmtStart(context.TODO())) ts, err = provider.GetStmtReadTS() require.NoError(t, err) - require.Greater(t, ts, compareTS) - rcCheckTS = ts + require.Equal(t, rcCheckTS, ts) // `StmtErrAfterPessimisticLock` will still disable rc check require.NoError(t, executor.ResetContextOfStmt(se, readOnlyStmt)) @@ -381,6 +379,88 @@ func TestTidbSnapshotVarInRC(t *testing.T) { } } +func TestConflictErrorsInRC(t *testing.T) { + require.NoError(t, failpoint.Enable("github.com/pingcap/tidb/executor/assertPessimisticLockErr", "return")) + store, _, clean := testkit.CreateMockStoreAndDomain(t) + defer clean() + + tk := testkit.NewTestKit(t, store) + se := tk.Session() + tk2 := testkit.NewTestKit(t, store) + + tk.MustExec("use test") + tk2.MustExec("use test") + tk.MustExec("create table t (id int primary key, v int)") + + tk.MustExec("set tx_isolation='READ-COMMITTED'") + + // Test for insert + tk.MustExec("begin pessimistic") + tk2.MustExec("insert into t values (1, 2)") + se.SetValue(sessiontxn.AssertLockErr, nil) + _, err := tk.Exec("insert into t values (1, 1), (2, 2)") + require.Error(t, err) + records, ok := se.Value(sessiontxn.AssertLockErr).(map[string]int) + require.True(t, ok) + for _, name := range errorsInInsert { + require.Equal(t, records[name], 1) + } + + se.SetValue(sessiontxn.AssertLockErr, nil) + tk.MustExec("rollback") + + // Test for delete + tk.MustExec("truncate t") + tk.MustExec("insert into t values (1, 1), (2, 2)") + + tk.MustExec("begin pessimistic") + tk2.MustExec("insert into t values (3, 1)") + + se.SetValue(sessiontxn.AssertLockErr, nil) + tk.MustExec("delete from t where v = 1") + _, ok = se.Value(sessiontxn.AssertLockErr).(map[string]int) + require.False(t, ok) + tk.MustQuery("select * from t").Check(testkit.Rows("2 2")) + tk.MustExec("commit") + + // Unlike RR, in RC, we will always fetch the latest ts. So write conflict will not be happened + tk.MustExec("begin pessimistic") + tk2.MustExec("update t set id = 1 where id = 2") + se.SetValue(sessiontxn.AssertLockErr, nil) + tk.MustExec("delete from t where id = 1") + _, ok = se.Value(sessiontxn.AssertLockErr).(map[string]int) + require.False(t, ok) + tk.MustQuery("select * from t for update").Check(testkit.Rows()) + + tk.MustExec("rollback") + + // Test for update + tk.MustExec("truncate t") + tk.MustExec("insert into t values (1, 1), (2, 2)") + + tk.MustExec("begin pessimistic") + tk2.MustExec("update t set v = v + 10") + + se.SetValue(sessiontxn.AssertLockErr, nil) + tk.MustExec("update t set v = v + 10") + _, ok = se.Value(sessiontxn.AssertLockErr).(map[string]int) + require.False(t, ok) + tk.MustQuery("select * from t").Check(testkit.Rows("1 21", "2 22")) + tk.MustExec("commit") + + // Unlike RR, in RC, we will always fetch the latest ts. So write conflict will not be happened + tk.MustExec("begin pessimistic") + tk2.MustExec("update t set v = v + 10 where id = 1") + tk.MustExec("update t set v = v + 10 where id = 1") + _, ok = se.Value(sessiontxn.AssertLockErr).(map[string]int) + require.False(t, ok) + tk.MustQuery("select * from t for update").Check(testkit.Rows("1 41", "2 22")) + + tk.MustExec("rollback") + + require.NoError(t, failpoint.Disable("github.com/pingcap/tidb/executor/assertPessimisticLockErr")) +} + func activeRCTxnAssert(t *testing.T, sctx sessionctx.Context, inTxn bool) *txnAssert[*isolation.PessimisticRCTxnContextProvider] { return &txnAssert[*isolation.PessimisticRCTxnContextProvider]{ sctx: sctx, diff --git a/sessiontxn/isolation/repeatable_read.go b/sessiontxn/isolation/repeatable_read.go index ef678f40c614c..571d2754be9a3 100644 --- a/sessiontxn/isolation/repeatable_read.go +++ b/sessiontxn/isolation/repeatable_read.go @@ -35,10 +35,11 @@ type PessimisticRRTxnContextProvider struct { baseTxnContextProvider // Used for ForUpdateRead statement - forUpdateTS uint64 + forUpdateTS uint64 + latestForUpdateTS uint64 // It may decide whether to update forUpdateTs when calling provider's getForUpdateTs // See more details in the comments of optimizeWithPlan - followingOperatorIsPointGetForUpdate bool + optimizeForNotFetchingLatestTS bool } // NewPessimisticRRTxnContextProvider returns a new PessimisticRRTxnContextProvider @@ -73,7 +74,7 @@ func (p *PessimisticRRTxnContextProvider) getForUpdateTs() (ts uint64, err error return 0, err } - if p.followingOperatorIsPointGetForUpdate { + if p.optimizeForNotFetchingLatestTS { p.forUpdateTS = p.sctx.GetSessionVars().TxnCtx.GetForUpdateTS() return p.forUpdateTS, nil } @@ -114,7 +115,8 @@ func (p *PessimisticRRTxnContextProvider) updateForUpdateTS() (err error) { } sctx.GetSessionVars().TxnCtx.SetForUpdateTS(version.Ver) - txn.SetOption(kv.SnapshotTS, sctx.GetSessionVars().TxnCtx.GetForUpdateTS()) + p.latestForUpdateTS = version.Ver + txn.SetOption(kv.SnapshotTS, version.Ver) return nil } @@ -126,7 +128,7 @@ func (p *PessimisticRRTxnContextProvider) OnStmtStart(ctx context.Context) error } p.forUpdateTS = 0 - p.followingOperatorIsPointGetForUpdate = false + p.optimizeForNotFetchingLatestTS = false return nil } @@ -137,15 +139,14 @@ func (p *PessimisticRRTxnContextProvider) OnStmtRetry(ctx context.Context) (err return err } - txnCtxForUpdateTS := p.sctx.GetSessionVars().TxnCtx.GetForUpdateTS() // If TxnCtx.forUpdateTS is updated in OnStmtErrorForNextAction, we assign the value to the provider - if txnCtxForUpdateTS > p.forUpdateTS { - p.forUpdateTS = txnCtxForUpdateTS + if p.latestForUpdateTS > p.forUpdateTS { + p.forUpdateTS = p.latestForUpdateTS } else { p.forUpdateTS = 0 } - p.followingOperatorIsPointGetForUpdate = false + p.optimizeForNotFetchingLatestTS = false return nil } @@ -165,6 +166,8 @@ func (p *PessimisticRRTxnContextProvider) OnStmtErrorForNextAction(point session // We expect that the data that the point get acquires has not been changed. // Benefit: Save the cost of acquiring ts from PD. // Drawbacks: If the data has been changed since the ts we used, we need to retry. +// One exception is insert operation, when it has no select plan, we do not fetch the latest ts immediately. We only update ts +// if write conflict is incurred. func (p *PessimisticRRTxnContextProvider) AdviseOptimizeWithPlan(val interface{}) (err error) { if p.isTidbSnapshotEnabled() || p.isBeginStmtWithStaleRead() { return nil @@ -179,24 +182,44 @@ func (p *PessimisticRRTxnContextProvider) AdviseOptimizeWithPlan(val interface{} plan = execute.Plan } - mayOptimizeForPointGet := false - if v, ok := plan.(*plannercore.PhysicalLock); ok { - if _, ok := v.Children()[0].(*plannercore.PointGetPlan); ok { - mayOptimizeForPointGet = true - } - } else if v, ok := plan.(*plannercore.Update); ok { - if _, ok := v.SelectPlan.(*plannercore.PointGetPlan); ok { - mayOptimizeForPointGet = true + p.optimizeForNotFetchingLatestTS = notNeedGetLatestTSFromPD(plan, false) + + return nil +} + +// notNeedGetLatestTSFromPD searches for optimization condition recursively +// Note: For point get and batch point get (name it plan), if one of the ancestor node is update/delete/physicalLock, +// we should check whether the plan.Lock is true or false. See comments in needNotToBeOptimized. +// inLockOrWriteStmt = true means one of the ancestor node is update/delete/physicalLock. +func notNeedGetLatestTSFromPD(plan plannercore.Plan, inLockOrWriteStmt bool) bool { + switch v := plan.(type) { + case *plannercore.PointGetPlan: + // We do not optimize the point get/ batch point get if plan.lock = false and inLockOrWriteStmt = true. + // Theoretically, the plan.lock should be true if the flag is true. But due to the bug describing in Issue35524, + // the plan.lock can be false in the case of inLockOrWriteStmt being true. In this case, optimization here can lead to different results + // which cannot be accepted as AdviseOptimizeWithPlan cannot change results. + return !inLockOrWriteStmt || v.Lock + case *plannercore.BatchPointGetPlan: + return !inLockOrWriteStmt || v.Lock + case plannercore.PhysicalPlan: + if len(v.Children()) == 0 { + return false } - } else if v, ok := plan.(*plannercore.Delete); ok { - if _, ok := v.SelectPlan.(*plannercore.PointGetPlan); ok { - mayOptimizeForPointGet = true + _, isPhysicalLock := v.(*plannercore.PhysicalLock) + for _, p := range v.Children() { + if !notNeedGetLatestTSFromPD(p, isPhysicalLock || inLockOrWriteStmt) { + return false + } } + return true + case *plannercore.Update: + return notNeedGetLatestTSFromPD(v.SelectPlan, true) + case *plannercore.Delete: + return notNeedGetLatestTSFromPD(v.SelectPlan, true) + case *plannercore.Insert: + return v.SelectPlan == nil } - - p.followingOperatorIsPointGetForUpdate = mayOptimizeForPointGet - - return nil + return false } func (p *PessimisticRRTxnContextProvider) handleAfterPessimisticLockError(lockErr error) (sessiontxn.StmtErrorAction, error) { diff --git a/sessiontxn/isolation/repeatable_read_test.go b/sessiontxn/isolation/repeatable_read_test.go index dfeed73e9af33..c60c1c3da560d 100644 --- a/sessiontxn/isolation/repeatable_read_test.go +++ b/sessiontxn/isolation/repeatable_read_test.go @@ -21,6 +21,7 @@ import ( "time" "github.com/pingcap/errors" + "github.com/pingcap/failpoint" "github.com/pingcap/kvproto/pkg/kvrpcpb" "github.com/pingcap/tidb/config" "github.com/pingcap/tidb/executor" @@ -344,60 +345,89 @@ func TestOptimizeWithPlanInPessimisticRR(t *testing.T) { tk.MustExec("insert into t values (1,1), (2,2)") se := tk.Session() provider := initializeRepeatableReadProvider(t, tk, true) - forUpdateTS := se.GetSessionVars().TxnCtx.GetForUpdateTS() + lastFetchedForUpdateTS := se.GetSessionVars().TxnCtx.GetForUpdateTS() txnManager := sessiontxn.GetTxnManager(se) - require.NoError(t, txnManager.OnStmtStart(context.TODO())) - stmt, err := parser.New().ParseOneStmt("delete from t where id = 1", "", "") - require.NoError(t, err) - compareTs := getOracleTS(t, se) - compiler := executor.Compiler{Ctx: se} - execStmt, err := compiler.Compile(context.TODO(), stmt) - require.NoError(t, err) - err = txnManager.AdviseOptimizeWithPlan(execStmt.Plan) - require.NoError(t, err) - ts, err := provider.GetStmtForUpdateTS() - require.NoError(t, err) - require.Greater(t, compareTs, ts) - require.Equal(t, ts, forUpdateTS) + type testStruct struct { + sql string + shouldOptimize bool + } - require.NoError(t, txnManager.OnStmtStart(context.TODO())) - stmt, err = parser.New().ParseOneStmt("update t set v = v + 10 where id = 1", "", "") - require.NoError(t, err) - compiler = executor.Compiler{Ctx: se} - execStmt, err = compiler.Compile(context.TODO(), stmt) - require.NoError(t, err) - err = txnManager.AdviseOptimizeWithPlan(execStmt.Plan) - require.NoError(t, err) - ts, err = provider.GetStmtForUpdateTS() - require.NoError(t, err) - require.Equal(t, ts, forUpdateTS) + cases := []testStruct{ + { + "delete from t where id = 1", + true, + }, + { + "update t set v = v + 10 where id = 1", + true, + }, + { + "select * from (select * from t where id = 1 for update) as t1 for update", + true, + }, + { + "select * from t where id = 1 for update", + true, + }, + { + "select * from t where id = 1 or id = 2 for update", + true, + }, + { + "select * from t for update", + false, + }, + } - require.NoError(t, txnManager.OnStmtStart(context.TODO())) - stmt, err = parser.New().ParseOneStmt("select * from (select * from t where id = 1 for update) as t1 for update", "", "") - require.NoError(t, err) - compiler = executor.Compiler{Ctx: se} - execStmt, err = compiler.Compile(context.TODO(), stmt) - require.NoError(t, err) - err = txnManager.AdviseOptimizeWithPlan(execStmt.Plan) - require.NoError(t, err) - ts, err = provider.GetStmtForUpdateTS() - require.NoError(t, err) - require.Equal(t, ts, forUpdateTS) + var stmt ast.StmtNode + var err error + var execStmt *executor.ExecStmt + var compiler executor.Compiler + var ts, compareTS uint64 + var action sessiontxn.StmtErrorAction - // Now, test for one that does not use the optimization - require.NoError(t, txnManager.OnStmtStart(context.TODO())) - stmt, err = parser.New().ParseOneStmt("select * from t for update", "", "") - compareTs = getOracleTS(t, se) - require.NoError(t, err) - compiler = executor.Compiler{Ctx: se} - execStmt, err = compiler.Compile(context.TODO(), stmt) - require.NoError(t, err) - err = txnManager.AdviseOptimizeWithPlan(execStmt.Plan) - require.NoError(t, err) - ts, err = provider.GetStmtForUpdateTS() - require.NoError(t, err) - require.Greater(t, ts, compareTs) + for _, c := range cases { + compareTS = getOracleTS(t, se) + + require.NoError(t, txnManager.OnStmtStart(context.TODO())) + stmt, err = parser.New().ParseOneStmt(c.sql, "", "") + require.NoError(t, err) + + err = provider.OnStmtStart(context.TODO()) + require.NoError(t, err) + + compiler = executor.Compiler{Ctx: se} + execStmt, err = compiler.Compile(context.TODO(), stmt) + require.NoError(t, err) + + err = txnManager.AdviseOptimizeWithPlan(execStmt.Plan) + require.NoError(t, err) + + ts, err = provider.GetStmtForUpdateTS() + require.NoError(t, err) + + if c.shouldOptimize { + require.Greater(t, compareTS, ts) + require.Equal(t, ts, lastFetchedForUpdateTS) + } else { + require.Greater(t, ts, compareTS) + } + + // retry + if c.shouldOptimize { + action, err = provider.OnStmtErrorForNextAction(sessiontxn.StmtErrAfterPessimisticLock, kv.ErrWriteConflict) + require.NoError(t, err) + require.Equal(t, sessiontxn.StmtActionRetryReady, action) + err = provider.OnStmtRetry(context.TODO()) + require.NoError(t, err) + ts, err = provider.GetStmtForUpdateTS() + require.NoError(t, err) + require.Greater(t, ts, compareTS) + + lastFetchedForUpdateTS = ts + } + } // Test use startTS after optimize when autocommit=0 activeAssert := activePessimisticRRAssert(t, tk.Session(), true) @@ -415,7 +445,7 @@ func TestOptimizeWithPlanInPessimisticRR(t *testing.T) { require.Equal(t, tk.Session().GetSessionVars().TxnCtx.StartTS, ts) // Test still fetch for update ts after optimize when autocommit=0 - compareTs = getOracleTS(t, se) + compareTS = getOracleTS(t, se) activeAssert = activePessimisticRRAssert(t, tk.Session(), true) provider = initializeRepeatableReadProvider(t, tk, false) require.NoError(t, txnManager.OnStmtStart(context.TODO())) @@ -427,7 +457,179 @@ func TestOptimizeWithPlanInPessimisticRR(t *testing.T) { require.NoError(t, err) ts, err = provider.GetStmtForUpdateTS() require.NoError(t, err) - require.Greater(t, ts, compareTs) + require.Greater(t, ts, compareTS) +} + +var errorsInInsert = []string{ + "errWriteConflict", + "errDuplicateKey", +} + +func TestConflictErrorInInsertInRR(t *testing.T) { + require.NoError(t, failpoint.Enable("github.com/pingcap/tidb/executor/assertPessimisticLockErr", "return")) + store, _, clean := testkit.CreateMockStoreAndDomain(t) + defer clean() + + tk := testkit.NewTestKit(t, store) + se := tk.Session() + tk2 := testkit.NewTestKit(t, store) + + tk.MustExec("use test") + tk2.MustExec("use test") + tk.MustExec("create table t (id int primary key, v int)") + + tk.MustExec("begin pessimistic") + tk2.MustExec("insert into t values (1, 2)") + se.SetValue(sessiontxn.AssertLockErr, nil) + _, err := tk.Exec("insert into t values (1, 1), (2, 2)") + require.Error(t, err) + records, ok := se.Value(sessiontxn.AssertLockErr).(map[string]int) + require.True(t, ok) + for _, name := range errorsInInsert { + require.Equal(t, records[name], 1) + } + + se.SetValue(sessiontxn.AssertLockErr, nil) + tk.MustExec("rollback") + require.NoError(t, failpoint.Disable("github.com/pingcap/tidb/executor/assertPessimisticLockErr")) +} + +func TestConflictErrorInPointGetForUpdateInRR(t *testing.T) { + require.NoError(t, failpoint.Enable("github.com/pingcap/tidb/executor/assertPessimisticLockErr", "return")) + store, _, clean := testkit.CreateMockStoreAndDomain(t) + defer clean() + + tk := testkit.NewTestKit(t, store) + se := tk.Session() + tk2 := testkit.NewTestKit(t, store) + + tk.MustExec("use test") + tk2.MustExec("use test") + tk.MustExec("create table t (id int primary key, v int)") + tk.MustExec("insert into t values (1, 1), (2, 2)") + + tk.MustExec("begin pessimistic") + tk2.MustExec("update t set v = v + 10 where id = 1") + se.SetValue(sessiontxn.AssertLockErr, nil) + tk.MustQuery("select * from t where id = 1 for update").Check(testkit.Rows("1 11")) + records, ok := se.Value(sessiontxn.AssertLockErr).(map[string]int) + require.True(t, ok) + require.Equal(t, records["errWriteConflict"], 1) + tk.MustExec("commit") + + // batch point get + tk.MustExec("begin pessimistic") + tk2.MustExec("update t set v = v + 10 where id = 1") + se.SetValue(sessiontxn.AssertLockErr, nil) + tk.MustQuery("select * from t where id = 1 or id = 2 for update").Check(testkit.Rows("1 21", "2 2")) + records, ok = se.Value(sessiontxn.AssertLockErr).(map[string]int) + require.True(t, ok) + require.Equal(t, records["errWriteConflict"], 1) + tk.MustExec("commit") + + tk.MustExec("rollback") + require.NoError(t, failpoint.Disable("github.com/pingcap/tidb/executor/assertPessimisticLockErr")) +} + +// Delete should get the latest ts and thus does not incur write conflict +func TestConflictErrorInDeleteInRR(t *testing.T) { + require.NoError(t, failpoint.Enable("github.com/pingcap/tidb/executor/assertPessimisticLockErr", "return")) + store, _, clean := testkit.CreateMockStoreAndDomain(t) + defer clean() + + tk := testkit.NewTestKit(t, store) + se := tk.Session() + tk2 := testkit.NewTestKit(t, store) + + tk.MustExec("use test") + tk2.MustExec("use test") + tk.MustExec("create table t (id int primary key, v int)") + tk.MustExec("insert into t values (1, 1), (2, 2)") + + tk.MustExec("begin pessimistic") + tk2.MustExec("insert into t values (3, 1)") + se.SetValue(sessiontxn.AssertLockErr, nil) + tk.MustExec("delete from t where v = 1") + _, ok := se.Value(sessiontxn.AssertLockErr).(map[string]int) + require.False(t, ok) + tk.MustQuery("select * from t").Check(testkit.Rows("2 2")) + tk.MustExec("commit") + + tk.MustExec("begin pessimistic") + // However, if sub select in delete is point get, we will incur one write conflict + tk2.MustExec("update t set id = 1 where id = 2") + se.SetValue(sessiontxn.AssertLockErr, nil) + tk.MustExec("delete from t where id = 1") + + records, ok := se.Value(sessiontxn.AssertLockErr).(map[string]int) + require.True(t, ok) + require.Equal(t, records["errWriteConflict"], 1) + tk.MustQuery("select * from t for update").Check(testkit.Rows()) + + tk.MustExec("rollback") + require.NoError(t, failpoint.Disable("github.com/pingcap/tidb/executor/assertPessimisticLockErr")) +} + +func TestConflictErrorInUpdateInRR(t *testing.T) { + require.NoError(t, failpoint.Enable("github.com/pingcap/tidb/executor/assertPessimisticLockErr", "return")) + store, _, clean := testkit.CreateMockStoreAndDomain(t) + defer clean() + + tk := testkit.NewTestKit(t, store) + se := tk.Session() + tk2 := testkit.NewTestKit(t, store) + + tk.MustExec("use test") + tk2.MustExec("use test") + tk.MustExec("create table t (id int primary key, v int)") + tk.MustExec("insert into t values (1, 1), (2, 2)") + + tk.MustExec("begin pessimistic") + tk2.MustExec("update t set v = v + 10") + se.SetValue(sessiontxn.AssertLockErr, nil) + tk.MustExec("update t set v = v + 10") + _, ok := se.Value(sessiontxn.AssertLockErr).(map[string]int) + require.False(t, ok) + tk.MustQuery("select * from t").Check(testkit.Rows("1 21", "2 22")) + tk.MustExec("commit") + + tk.MustExec("begin pessimistic") + // However, if the sub select plan is point get, we should incur one write conflict + tk2.MustExec("update t set v = v + 10 where id = 1") + tk.MustExec("update t set v = v + 10 where id = 1") + records, ok := se.Value(sessiontxn.AssertLockErr).(map[string]int) + require.True(t, ok) + require.Equal(t, records["errWriteConflict"], 1) + tk.MustQuery("select * from t for update").Check(testkit.Rows("1 41", "2 22")) + + tk.MustExec("rollback") + require.NoError(t, failpoint.Disable("github.com/pingcap/tidb/executor/assertPessimisticLockErr")) +} + +func TestConflictErrorInOtherQueryContainingPointGet(t *testing.T) { + require.NoError(t, failpoint.Enable("github.com/pingcap/tidb/executor/assertPessimisticLockErr", "return")) + store, _, clean := testkit.CreateMockStoreAndDomain(t) + defer clean() + + tk := testkit.NewTestKit(t, store) + se := tk.Session() + tk2 := testkit.NewTestKit(t, store) + + tk.MustExec("use test") + tk2.MustExec("use test") + tk.MustExec("create table t (id int primary key, v int)") + tk.MustExec("insert into t values (1, 1)") + + tk.MustExec("begin pessimistic") + tk2.MustExec("update t set v = v + 10 where id = 1") + se.SetValue(sessiontxn.AssertLockErr, nil) + tk.MustQuery("select * from t where id=1 and v > 1 for update").Check(testkit.Rows("1 11")) + records, ok := se.Value(sessiontxn.AssertLockErr).(map[string]int) + require.True(t, ok) + require.Equal(t, records["errWriteConflict"], 1) + + tk.MustExec("rollback") + require.NoError(t, failpoint.Disable("github.com/pingcap/tidb/executor/assertPessimisticLockErr")) } func activePessimisticRRAssert(t *testing.T, sctx sessionctx.Context, diff --git a/sessiontxn/txn_context_test.go b/sessiontxn/txn_context_test.go index d3e1f32124575..e247a14c86f2b 100644 --- a/sessiontxn/txn_context_test.go +++ b/sessiontxn/txn_context_test.go @@ -445,7 +445,7 @@ func TestTxnContextForHistoricalRead(t *testing.T) { }) doWithCheckPath(t, se, normalPathRecords, func() { - tk.MustQuery("select * from t1 where id=1 for update").Check(testkit.Rows("1 11")) + tk.MustQuery("select * from t1 where id=1 for update").Check(testkit.Rows("1 10")) }) tk.MustExec("rollback") From 31c92c67bc6c2b5d8e38a90f48d2d72aac8525f2 Mon Sep 17 00:00:00 2001 From: djshow832 Date: Mon, 27 Jun 2022 12:20:39 +0800 Subject: [PATCH 3/4] sessionctx: support encoding and decoding statement context (#35688) close pingcap/tidb#35664 --- executor/executor.go | 9 +- session/session.go | 9 +- sessionctx/sessionstates/session_states.go | 4 + .../sessionstates/session_states_test.go | 119 ++++++++++++++++++ sessionctx/stmtctx/stmtctx.go | 49 ++++++++ sessionctx/stmtctx/stmtctx_test.go | 43 +++++++ sessionctx/variable/session.go | 10 ++ 7 files changed, 237 insertions(+), 6 deletions(-) diff --git a/executor/executor.go b/executor/executor.go index d2b726f24adbb..491f3b2b4e27a 100644 --- a/executor/executor.go +++ b/executor/executor.go @@ -1928,7 +1928,7 @@ func ResetContextOfStmt(ctx sessionctx.Context, s ast.StmtNode) (err error) { sc.IgnoreTruncate = true sc.IgnoreZeroInDate = true sc.AllowInvalidDate = vars.SQLMode.HasAllowInvalidDatesMode() - if stmt.Tp == ast.ShowWarnings || stmt.Tp == ast.ShowErrors { + if stmt.Tp == ast.ShowWarnings || stmt.Tp == ast.ShowErrors || stmt.Tp == ast.ShowSessionStates { sc.InShowWarning = true sc.SetWarnings(vars.StmtCtx.GetWarnings()) } @@ -1936,6 +1936,11 @@ func ResetContextOfStmt(ctx sessionctx.Context, s ast.StmtNode) (err error) { sc.IgnoreTruncate = false sc.IgnoreZeroInDate = true sc.AllowInvalidDate = vars.SQLMode.HasAllowInvalidDatesMode() + case *ast.SetSessionStatesStmt: + sc.InSetSessionStatesStmt = true + sc.IgnoreTruncate = true + sc.IgnoreZeroInDate = true + sc.AllowInvalidDate = vars.SQLMode.HasAllowInvalidDatesMode() default: sc.IgnoreTruncate = true sc.IgnoreZeroInDate = true @@ -1954,7 +1959,7 @@ func ResetContextOfStmt(ctx sessionctx.Context, s ast.StmtNode) (err error) { sc.PrevLastInsertID = vars.StmtCtx.PrevLastInsertID } sc.PrevAffectedRows = 0 - if vars.StmtCtx.InUpdateStmt || vars.StmtCtx.InDeleteStmt || vars.StmtCtx.InInsertStmt { + if vars.StmtCtx.InUpdateStmt || vars.StmtCtx.InDeleteStmt || vars.StmtCtx.InInsertStmt || vars.StmtCtx.InSetSessionStatesStmt { sc.PrevAffectedRows = int64(vars.StmtCtx.AffectedRows()) } else if vars.StmtCtx.InSelectStmt { sc.PrevAffectedRows = -1 diff --git a/session/session.go b/session/session.go index c5c1ead4c65b4..d01c3e7e549d3 100644 --- a/session/session.go +++ b/session/session.go @@ -3543,15 +3543,16 @@ func (s *session) EncodeSessionStates(ctx context.Context, sctx sessionctx.Conte // DecodeSessionStates implements SessionStatesHandler.DecodeSessionStates interface. func (s *session) DecodeSessionStates(ctx context.Context, sctx sessionctx.Context, sessionStates *sessionstates.SessionStates) (err error) { - if err = s.sessionVars.DecodeSessionStates(ctx, sessionStates); err != nil { - return err - } - // Decode session variables. for name, val := range sessionStates.SystemVars { if err = variable.SetSessionSystemVar(s.sessionVars, name, val); err != nil { return err } } + + // Decode stmt ctx after session vars because setting session vars may override stmt ctx, such as warnings. + if err = s.sessionVars.DecodeSessionStates(ctx, sessionStates); err != nil { + return err + } return err } diff --git a/sessionctx/sessionstates/session_states.go b/sessionctx/sessionstates/session_states.go index baf876ff87b4f..10a2756dd04f4 100644 --- a/sessionctx/sessionstates/session_states.go +++ b/sessionctx/sessionstates/session_states.go @@ -18,6 +18,7 @@ import ( "time" ptypes "github.com/pingcap/tidb/parser/types" + "github.com/pingcap/tidb/sessionctx/stmtctx" "github.com/pingcap/tidb/types" ) @@ -52,4 +53,7 @@ type SessionStates struct { FoundInBinding bool `json:"in-binding,omitempty"` SequenceLatestValues map[int64]int64 `json:"seq-values,omitempty"` MPPStoreLastFailTime map[string]time.Time `json:"store-fail-time,omitempty"` + LastAffectedRows int64 `json:"affected-rows,omitempty"` + LastInsertID uint64 `json:"last-insert-id,omitempty"` + Warnings []stmtctx.SQLWarn `json:"warnings,omitempty"` } diff --git a/sessionctx/sessionstates/session_states_test.go b/sessionctx/sessionstates/session_states_test.go index 847f50f4e9a2b..29101af06f392 100644 --- a/sessionctx/sessionstates/session_states_test.go +++ b/sessionctx/sessionstates/session_states_test.go @@ -435,6 +435,125 @@ func TestSessionCtx(t *testing.T) { } } +func TestStatementCtx(t *testing.T) { + store, clean := testkit.CreateMockStore(t) + defer clean() + tk := testkit.NewTestKit(t, store) + tk.MustExec("create table test.t1(id int auto_increment primary key, str char(1))") + + tests := []struct { + setFunc func(tk *testkit.TestKit) any + checkFunc func(tk *testkit.TestKit, param any) + }{ + { + // check LastAffectedRows + setFunc: func(tk *testkit.TestKit) any { + tk.MustQuery("show warnings") + return nil + }, + checkFunc: func(tk *testkit.TestKit, param any) { + tk.MustQuery("select row_count()").Check(testkit.Rows("0")) + tk.MustQuery("select row_count()").Check(testkit.Rows("-1")) + }, + }, + { + // check LastAffectedRows + setFunc: func(tk *testkit.TestKit) any { + tk.MustQuery("select 1") + return nil + }, + checkFunc: func(tk *testkit.TestKit, param any) { + tk.MustQuery("select row_count()").Check(testkit.Rows("-1")) + }, + }, + { + // check LastAffectedRows + setFunc: func(tk *testkit.TestKit) any { + tk.MustExec("insert into test.t1(str) value('a'), ('b'), ('c')") + return nil + }, + checkFunc: func(tk *testkit.TestKit, param any) { + tk.MustQuery("select row_count()").Check(testkit.Rows("3")) + tk.MustQuery("select row_count()").Check(testkit.Rows("-1")) + }, + }, + { + // check LastInsertID + checkFunc: func(tk *testkit.TestKit, param any) { + tk.MustQuery("select @@last_insert_id").Check(testkit.Rows("0")) + }, + }, + { + // check LastInsertID + setFunc: func(tk *testkit.TestKit) any { + tk.MustExec("insert into test.t1(str) value('d')") + rows := tk.MustQuery("select @@last_insert_id").Rows() + require.NotEqual(t, "0", rows[0][0].(string)) + return rows + }, + checkFunc: func(tk *testkit.TestKit, param any) { + tk.MustQuery("select @@last_insert_id").Check(param.([][]any)) + }, + }, + { + // check Warning + setFunc: func(tk *testkit.TestKit) any { + tk.MustQuery("select 1") + tk.MustQuery("show warnings").Check(testkit.Rows()) + tk.MustQuery("show errors").Check(testkit.Rows()) + return nil + }, + checkFunc: func(tk *testkit.TestKit, param any) { + tk.MustQuery("show warnings").Check(testkit.Rows()) + tk.MustQuery("show errors").Check(testkit.Rows()) + tk.MustQuery("select @@warning_count, @@error_count").Check(testkit.Rows("0 0")) + }, + }, + { + // check Warning + setFunc: func(tk *testkit.TestKit) any { + tk.MustGetErrCode("insert into test.t1(str) value('ef')", errno.ErrDataTooLong) + rows := tk.MustQuery("show warnings").Rows() + require.Equal(t, 1, len(rows)) + tk.MustQuery("show errors").Check(rows) + return rows + }, + checkFunc: func(tk *testkit.TestKit, param any) { + tk.MustQuery("show warnings").Check(param.([][]any)) + tk.MustQuery("show errors").Check(param.([][]any)) + tk.MustQuery("select @@warning_count, @@error_count").Check(testkit.Rows("1 1")) + }, + }, + { + // check Warning + setFunc: func(tk *testkit.TestKit) any { + tk.MustExec("set sql_mode=''") + tk.MustExec("insert into test.t1(str) value('ef'), ('ef')") + rows := tk.MustQuery("show warnings").Rows() + require.Equal(t, 2, len(rows)) + tk.MustQuery("show errors").Check(testkit.Rows()) + return rows + }, + checkFunc: func(tk *testkit.TestKit, param any) { + tk.MustQuery("show warnings").Check(param.([][]any)) + tk.MustQuery("show errors").Check(testkit.Rows()) + tk.MustQuery("select @@warning_count, @@error_count").Check(testkit.Rows("2 0")) + }, + }, + } + + for _, tt := range tests { + tk1 := testkit.NewTestKit(t, store) + var param any + if tt.setFunc != nil { + param = tt.setFunc(tk1) + } + tk2 := testkit.NewTestKit(t, store) + showSessionStatesAndSet(t, tk1, tk2) + tt.checkFunc(tk2, param) + } +} + func showSessionStatesAndSet(t *testing.T, tk1, tk2 *testkit.TestKit) { rows := tk1.MustQuery("show session_states").Rows() require.Len(t, rows, 1) diff --git a/sessionctx/stmtctx/stmtctx.go b/sessionctx/stmtctx/stmtctx.go index 70cd4bec5f898..4d623015492cc 100644 --- a/sessionctx/stmtctx/stmtctx.go +++ b/sessionctx/stmtctx/stmtctx.go @@ -15,6 +15,7 @@ package stmtctx import ( + "encoding/json" "math" "sort" "strconv" @@ -22,10 +23,12 @@ import ( "sync/atomic" "time" + "github.com/pingcap/errors" "github.com/pingcap/tidb/parser" "github.com/pingcap/tidb/parser/ast" "github.com/pingcap/tidb/parser/model" "github.com/pingcap/tidb/parser/mysql" + "github.com/pingcap/tidb/parser/terror" "github.com/pingcap/tidb/util/disk" "github.com/pingcap/tidb/util/execdetails" "github.com/pingcap/tidb/util/memory" @@ -60,6 +63,43 @@ type SQLWarn struct { Err error } +type jsonSQLWarn struct { + Level string `json:"level"` + SQLErr *terror.Error `json:"err,omitempty"` + Msg string `json:"msg,omitempty"` +} + +// MarshalJSON implements the Marshaler.MarshalJSON interface. +func (warn *SQLWarn) MarshalJSON() ([]byte, error) { + w := &jsonSQLWarn{ + Level: warn.Level, + } + e := errors.Cause(warn.Err) + switch x := e.(type) { + case *terror.Error: + // Omit outter errors because only the most inner error matters. + w.SQLErr = x + default: + w.Msg = e.Error() + } + return json.Marshal(w) +} + +// UnmarshalJSON implements the Unmarshaler.UnmarshalJSON interface. +func (warn *SQLWarn) UnmarshalJSON(data []byte) error { + var w jsonSQLWarn + if err := json.Unmarshal(data, &w); err != nil { + return err + } + warn.Level = w.Level + if w.SQLErr != nil { + warn.Err = w.SQLErr + } else { + warn.Err = errors.New(w.Msg) + } + return nil +} + // StatementContext contains variables for a statement. // It should be reset before executing a statement. type StatementContext struct { @@ -76,6 +116,7 @@ type StatementContext struct { InLoadDataStmt bool InExplainStmt bool InCreateOrAlterStmt bool + InSetSessionStatesStmt bool InPreparedPlanBuilding bool IgnoreTruncate bool IgnoreZeroInDate bool @@ -406,6 +447,13 @@ func (sc *StatementContext) AddAffectedRows(rows uint64) { sc.mu.affectedRows += rows } +// SetAffectedRows sets affected rows. +func (sc *StatementContext) SetAffectedRows(rows uint64) { + sc.mu.Lock() + sc.mu.affectedRows = rows + sc.mu.Unlock() +} + // AffectedRows gets affected rows. func (sc *StatementContext) AffectedRows() uint64 { sc.mu.Lock() @@ -558,6 +606,7 @@ func (sc *StatementContext) SetWarnings(warns []SQLWarn) { sc.mu.Lock() defer sc.mu.Unlock() sc.mu.warnings = warns + sc.mu.errorCount = 0 for _, w := range warns { if w.Level == WarnLevelError { sc.mu.errorCount++ diff --git a/sessionctx/stmtctx/stmtctx_test.go b/sessionctx/stmtctx/stmtctx_test.go index 7a4ec77a90660..b8f36dcb25055 100644 --- a/sessionctx/stmtctx/stmtctx_test.go +++ b/sessionctx/stmtctx/stmtctx_test.go @@ -16,12 +16,15 @@ package stmtctx_test import ( "context" + "encoding/json" "fmt" "testing" "time" + "github.com/pingcap/errors" "github.com/pingcap/tidb/kv" "github.com/pingcap/tidb/sessionctx/stmtctx" + "github.com/pingcap/tidb/sessionctx/variable" "github.com/pingcap/tidb/testkit" "github.com/pingcap/tidb/util/execdetails" "github.com/stretchr/testify/require" @@ -143,3 +146,43 @@ func TestWeakConsistencyRead(t *testing.T) { execAndCheck("execute s", testkit.Rows("1 1 2"), kv.SI) tk.MustExec("rollback") } + +func TestMarshalSQLWarn(t *testing.T) { + warns := []stmtctx.SQLWarn{ + { + Level: stmtctx.WarnLevelError, + Err: errors.New("any error"), + }, + { + Level: stmtctx.WarnLevelError, + Err: errors.Trace(errors.New("any error")), + }, + { + Level: stmtctx.WarnLevelWarning, + Err: variable.ErrUnknownSystemVar.GenWithStackByArgs("unknown"), + }, + { + Level: stmtctx.WarnLevelWarning, + Err: errors.Trace(variable.ErrUnknownSystemVar.GenWithStackByArgs("unknown")), + }, + } + + store, clean := testkit.CreateMockStore(t) + defer clean() + tk := testkit.NewTestKit(t, store) + // First query can trigger loading global variables, which produces warnings. + tk.MustQuery("select 1") + tk.Session().GetSessionVars().StmtCtx.SetWarnings(warns) + rows := tk.MustQuery("show warnings").Rows() + require.Equal(t, len(warns), len(rows)) + + // The unmarshalled result doesn't need to be exactly the same with the original one. + // We only need that the results of `show warnings` are the same. + bytes, err := json.Marshal(warns) + require.NoError(t, err) + var newWarns []stmtctx.SQLWarn + err = json.Unmarshal(bytes, &newWarns) + require.NoError(t, err) + tk.Session().GetSessionVars().StmtCtx.SetWarnings(newWarns) + tk.MustQuery("show warnings").Check(rows) +} diff --git a/sessionctx/variable/session.go b/sessionctx/variable/session.go index fe4f469e76134..12546cde3c0ad 100644 --- a/sessionctx/variable/session.go +++ b/sessionctx/variable/session.go @@ -1867,6 +1867,11 @@ func (s *SessionVars) EncodeSessionStates(ctx context.Context, sessionStates *se sessionStates.MPPStoreLastFailTime = s.MPPStoreLastFailTime sessionStates.FoundInPlanCache = s.PrevFoundInPlanCache sessionStates.FoundInBinding = s.PrevFoundInBinding + + // Encode StatementContext. We encode it here to avoid circle dependency. + sessionStates.LastAffectedRows = s.StmtCtx.PrevAffectedRows + sessionStates.LastInsertID = s.StmtCtx.PrevLastInsertID + sessionStates.Warnings = s.StmtCtx.GetWarnings() return } @@ -1902,6 +1907,11 @@ func (s *SessionVars) DecodeSessionStates(ctx context.Context, sessionStates *se } s.FoundInPlanCache = sessionStates.FoundInPlanCache s.FoundInBinding = sessionStates.FoundInBinding + + // Decode StatementContext. + s.StmtCtx.SetAffectedRows(uint64(sessionStates.LastAffectedRows)) + s.StmtCtx.PrevLastInsertID = sessionStates.LastInsertID + s.StmtCtx.SetWarnings(sessionStates.Warnings) return } From ab27d4918a5a07d565ce3ef98761b6de6e90c9cc Mon Sep 17 00:00:00 2001 From: Yuanjia Zhang Date: Mon, 27 Jun 2022 14:24:38 +0800 Subject: [PATCH 4/4] planner: fix the wrong cost formula of MPPExchanger on cost model ver2 (#35718) ref pingcap/tidb#35240 --- planner/core/plan_cost.go | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/planner/core/plan_cost.go b/planner/core/plan_cost.go index 5612933b9cf9b..73758b562536a 100644 --- a/planner/core/plan_cost.go +++ b/planner/core/plan_cost.go @@ -1209,8 +1209,12 @@ func (p *PhysicalExchangeReceiver) GetPlanCost(taskType property.TaskType, costF } p.planCost = childCost // accumulate net cost - // TODO: this formula is wrong since it doesn't consider tableRowSize, fix it later - p.planCost += getCardinality(p.children[0], costFlag) * p.ctx.GetSessionVars().GetNetworkFactor(nil) + if p.ctx.GetSessionVars().CostModelVersion == modelVer1 { + p.planCost += getCardinality(p.children[0], costFlag) * p.ctx.GetSessionVars().GetNetworkFactor(nil) + } else { // to avoid regression, only consider row-size on model ver2 + rowSize := getTblStats(p.children[0]).GetAvgRowSize(p.ctx, p.children[0].Schema().Columns, false, false) + p.planCost += getCardinality(p.children[0], costFlag) * rowSize * p.ctx.GetSessionVars().GetNetworkFactor(nil) + } p.planCostInit = true return p.planCost, nil }