From 2e29174309d75a6c5f6bfe81d250d3bc92ffd9d1 Mon Sep 17 00:00:00 2001 From: Ti Chi Robot Date: Fri, 24 Nov 2023 15:14:44 +0800 Subject: [PATCH] session: fix select for update statement can't get stmt-count-limit error (#48412) (#48466) close pingcap/tidb#48411 --- server/server_test.go | 74 +++++++++++++++++++ server/tidb_test.go | 6 ++ session/session.go | 8 ++ session/tidb.go | 22 ++++-- .../realtikvtest/sessiontest/session_test.go | 10 +++ 5 files changed, 115 insertions(+), 5 deletions(-) diff --git a/server/server_test.go b/server/server_test.go index b463303666baa..b73e754d91a5f 100644 --- a/server/server_test.go +++ b/server/server_test.go @@ -35,6 +35,7 @@ import ( "github.com/pingcap/errors" "github.com/pingcap/failpoint" "github.com/pingcap/log" + "github.com/pingcap/tidb/config" "github.com/pingcap/tidb/errno" "github.com/pingcap/tidb/kv" tmysql "github.com/pingcap/tidb/parser/mysql" @@ -2292,3 +2293,76 @@ func (cli *testServerClient) runTestInfoschemaClientErrors(t *testing.T) { }) } + +func (cli *testServerClient) RunTestStmtCountLimit(t *testing.T) { + originalStmtCountLimit := config.GetGlobalConfig().Performance.StmtCountLimit + config.UpdateGlobal(func(conf *config.Config) { + conf.Performance.StmtCountLimit = 3 + }) + defer func() { + config.UpdateGlobal(func(conf *config.Config) { + conf.Performance.StmtCountLimit = originalStmtCountLimit + }) + }() + + cli.runTests(t, nil, func(dbt *testkit.DBTestKit) { + dbt.MustExec("create table t (id int key);") + dbt.MustExec("set @@tidb_disable_txn_auto_retry=0;") + dbt.MustExec("set autocommit=0;") + dbt.MustExec("begin optimistic;") + dbt.MustExec("insert into t values (1);") + dbt.MustExec("insert into t values (2);") + _, err := dbt.GetDB().Query("select * from t for update;") + require.Error(t, err) + require.Equal(t, "Error 1105: statement count 4 exceeds the transaction limitation, transaction has been rollback, autocommit = false", err.Error()) + dbt.MustExec("insert into t values (3);") + dbt.MustExec("commit;") + rows := dbt.MustQuery("select * from t;") + var id int + count := 0 + for rows.Next() { + rows.Scan(&id) + count++ + } + require.NoError(t, rows.Close()) + require.Equal(t, 3, id) + require.Equal(t, 1, count) + + dbt.MustExec("delete from t;") + dbt.MustExec("commit;") + dbt.MustExec("set @@tidb_disable_txn_auto_retry=0;") + dbt.MustExec("set autocommit=0;") + dbt.MustExec("begin optimistic;") + dbt.MustExec("insert into t values (1);") + dbt.MustExec("insert into t values (2);") + _, err = dbt.GetDB().Exec("insert into t values (3);") + require.Error(t, err) + require.Equal(t, "Error 1105: statement count 4 exceeds the transaction limitation, transaction has been rollback, autocommit = false", err.Error()) + dbt.MustExec("commit;") + rows = dbt.MustQuery("select count(*) from t;") + for rows.Next() { + rows.Scan(&count) + } + require.NoError(t, rows.Close()) + require.Equal(t, 0, count) + + dbt.MustExec("delete from t;") + dbt.MustExec("commit;") + dbt.MustExec("set @@tidb_batch_commit=1;") + dbt.MustExec("set @@tidb_disable_txn_auto_retry=0;") + dbt.MustExec("set autocommit=0;") + dbt.MustExec("begin optimistic;") + dbt.MustExec("insert into t values (1);") + dbt.MustExec("insert into t values (2);") + dbt.MustExec("insert into t values (3);") + dbt.MustExec("insert into t values (4);") + dbt.MustExec("insert into t values (5);") + dbt.MustExec("commit;") + rows = dbt.MustQuery("select count(*) from t;") + for rows.Next() { + rows.Scan(&count) + } + require.NoError(t, rows.Close()) + require.Equal(t, 5, count) + }) +} diff --git a/server/tidb_test.go b/server/tidb_test.go index c2b4c591b8eb3..2aea04536ebcc 100644 --- a/server/tidb_test.go +++ b/server/tidb_test.go @@ -1136,6 +1136,12 @@ func TestSumAvg(t *testing.T) { ts.runTestSumAvg(t) } +func TestStmtCountLimit(t *testing.T) { + ts, cleanup := createTidbTestSuite(t) + defer cleanup() + ts.RunTestStmtCountLimit(t) +} + func TestNullFlag(t *testing.T) { ts, cleanup := createTidbTestSuite(t) defer cleanup() diff --git a/session/session.go b/session/session.go index 25a3f05441352..6a04cfc0eb780 100644 --- a/session/session.go +++ b/session/session.go @@ -2114,6 +2114,14 @@ func runStmt(ctx context.Context, se *session, s sqlexec.Statement) (rs sqlexec. if err != nil { return nil, err } + if sessVars.TxnCtx.CouldRetry && !s.IsReadOnly(sessVars) { + // Only when the txn is could retry and the statement is not read only, need to do stmt-count-limit check, + // otherwise, the stmt won't be add into stmt history, and also don't need check. + // About `stmt-count-limit`, see more in https://docs.pingcap.com/tidb/stable/tidb-configuration-file#stmt-count-limit + if err := checkStmtLimit(ctx, se, false); err != nil { + return nil, err + } + } rs, err = s.Exec(ctx) se.updateTelemetryMetric(s.(*executor.ExecStmt)) diff --git a/session/tidb.go b/session/tidb.go index 9ba9c63d9fcc3..ac3431802178d 100644 --- a/session/tidb.go +++ b/session/tidb.go @@ -253,7 +253,7 @@ func finishStmt(ctx context.Context, se *session, meetsErr error, sql sqlexec.St if err != nil { return err } - return checkStmtLimit(ctx, se) + return checkStmtLimit(ctx, se, true) } func autoCommitAfterStmt(ctx context.Context, se *session, meetsErr error, sql sqlexec.Statement) error { @@ -283,18 +283,29 @@ func autoCommitAfterStmt(ctx context.Context, se *session, meetsErr error, sql s return nil } -func checkStmtLimit(ctx context.Context, se *session) error { +func checkStmtLimit(ctx context.Context, se *session, isFinish bool) error { // If the user insert, insert, insert ... but never commit, TiDB would OOM. // So we limit the statement count in a transaction here. var err error sessVars := se.GetSessionVars() history := GetHistory(se) - if history.Count() > int(config.GetGlobalConfig().Performance.StmtCountLimit) { + stmtCount := history.Count() + if !isFinish { + // history stmt count + current stmt, since current stmt is not finish, it has not add to history. + stmtCount++ + } + if stmtCount > int(config.GetGlobalConfig().Performance.StmtCountLimit) { if !sessVars.BatchCommit { se.RollbackTxn(ctx) - return errors.Errorf("statement count %d exceeds the transaction limitation, autocommit = %t", - history.Count(), sessVars.IsAutocommit()) + return errors.Errorf("statement count %d exceeds the transaction limitation, transaction has been rollback, autocommit = %t", + stmtCount, sessVars.IsAutocommit()) + } + if !isFinish { + // if the stmt is not finish execute, then just return, since some work need to be done such as StmtCommit. + return nil } + // If the stmt is finish execute, and exceed the StmtCountLimit, and BatchCommit is true, + // then commit the current transaction and create a new transaction. err = sessiontxn.NewTxn(ctx, se) // The transaction does not committed yet, we need to keep it in transaction. // The last history could not be "commit"/"rollback" statement. @@ -306,6 +317,7 @@ func checkStmtLimit(ctx context.Context, se *session) error { } // GetHistory get all stmtHistory in current txn. Exported only for test. +// If stmtHistory is nil, will create a new one for current txn. func GetHistory(ctx sessionctx.Context) *StmtHistory { hist, ok := ctx.GetSessionVars().TxnCtx.History.(*StmtHistory) if ok { diff --git a/tests/realtikvtest/sessiontest/session_test.go b/tests/realtikvtest/sessiontest/session_test.go index 075e0d0ef161b..76eb441973c59 100644 --- a/tests/realtikvtest/sessiontest/session_test.go +++ b/tests/realtikvtest/sessiontest/session_test.go @@ -1169,6 +1169,16 @@ func TestBatchCommit(t *testing.T) { tk.MustExec("insert into t values (7)") tk1.MustQuery("select * from t").Check(testkit.Rows("5", "6", "7")) + tk.MustExec("delete from t") + tk.MustExec("commit") + tk.MustExec("begin") + tk.MustExec("explain analyze insert into t values (5)") + tk1.MustQuery("select * from t").Check(testkit.Rows()) + tk.MustExec("explain analyze insert into t values (6)") + tk1.MustQuery("select * from t").Check(testkit.Rows()) + tk.MustExec("explain analyze insert into t values (7)") + tk1.MustQuery("select * from t").Check(testkit.Rows("5", "6", "7")) + // The session is still in transaction. tk.MustExec("insert into t values (8)") tk1.MustQuery("select * from t").Check(testkit.Rows("5", "6", "7"))