diff --git a/pkg/server/internal/testserverclient/BUILD.bazel b/pkg/server/internal/testserverclient/BUILD.bazel index 1f751d9f38580..b29f420f7e1f5 100644 --- a/pkg/server/internal/testserverclient/BUILD.bazel +++ b/pkg/server/internal/testserverclient/BUILD.bazel @@ -6,6 +6,7 @@ go_library( importpath = "github.com/pingcap/tidb/pkg/server/internal/testserverclient", visibility = ["//pkg/server:__subpackages__"], deps = [ + "//pkg/config", "//pkg/errno", "//pkg/kv", "//pkg/parser/mysql", diff --git a/pkg/server/internal/testserverclient/server_client.go b/pkg/server/internal/testserverclient/server_client.go index f02a288baa514..f7d614328ef1a 100644 --- a/pkg/server/internal/testserverclient/server_client.go +++ b/pkg/server/internal/testserverclient/server_client.go @@ -35,6 +35,7 @@ import ( "github.com/pingcap/errors" "github.com/pingcap/failpoint" "github.com/pingcap/log" + "github.com/pingcap/tidb/pkg/config" "github.com/pingcap/tidb/pkg/errno" "github.com/pingcap/tidb/pkg/kv" tmysql "github.com/pingcap/tidb/pkg/parser/mysql" @@ -2446,4 +2447,77 @@ func (cli *TestServerClient) RunTestInfoschemaClientErrors(t *testing.T) { }) } +func (cli *TestServerClient) RunTestStmtCountLimit(t *testing.T) { + originalStmtCountLimit := config.GetGlobalConfig().Performance.StmtCountLimit + config.UpdateGlobal(func(conf *config.Config) { + conf.Performance.StmtCountLimit = 3 + }) + defer func() { + config.UpdateGlobal(func(conf *config.Config) { + conf.Performance.StmtCountLimit = originalStmtCountLimit + }) + }() + + cli.RunTests(t, nil, func(dbt *testkit.DBTestKit) { + dbt.MustExec("create table t (id int key);") + dbt.MustExec("set @@tidb_disable_txn_auto_retry=0;") + dbt.MustExec("set autocommit=0;") + dbt.MustExec("begin optimistic;") + dbt.MustExec("insert into t values (1);") + dbt.MustExec("insert into t values (2);") + _, err := dbt.GetDB().Query("select * from t for update;") + require.Error(t, err) + require.Equal(t, "Error 1105 (HY000): statement count 4 exceeds the transaction limitation, transaction has been rollback, autocommit = false", err.Error()) + dbt.MustExec("insert into t values (3);") + dbt.MustExec("commit;") + rows := dbt.MustQuery("select * from t;") + var id int + count := 0 + for rows.Next() { + rows.Scan(&id) + count++ + } + require.NoError(t, rows.Close()) + require.Equal(t, 3, id) + require.Equal(t, 1, count) + + dbt.MustExec("delete from t;") + dbt.MustExec("commit;") + dbt.MustExec("set @@tidb_disable_txn_auto_retry=0;") + dbt.MustExec("set autocommit=0;") + dbt.MustExec("begin optimistic;") + dbt.MustExec("insert into t values (1);") + dbt.MustExec("insert into t values (2);") + _, err = dbt.GetDB().Exec("insert into t values (3);") + require.Error(t, err) + require.Equal(t, "Error 1105 (HY000): statement count 4 exceeds the transaction limitation, transaction has been rollback, autocommit = false", err.Error()) + dbt.MustExec("commit;") + rows = dbt.MustQuery("select count(*) from t;") + for rows.Next() { + rows.Scan(&count) + } + require.NoError(t, rows.Close()) + require.Equal(t, 0, count) + + dbt.MustExec("delete from t;") + dbt.MustExec("commit;") + dbt.MustExec("set @@tidb_batch_commit=1;") + dbt.MustExec("set @@tidb_disable_txn_auto_retry=0;") + dbt.MustExec("set autocommit=0;") + dbt.MustExec("begin optimistic;") + dbt.MustExec("insert into t values (1);") + dbt.MustExec("insert into t values (2);") + dbt.MustExec("insert into t values (3);") + dbt.MustExec("insert into t values (4);") + dbt.MustExec("insert into t values (5);") + dbt.MustExec("commit;") + rows = dbt.MustQuery("select count(*) from t;") + for rows.Next() { + rows.Scan(&count) + } + require.NoError(t, rows.Close()) + require.Equal(t, 5, count) + }) +} + //revive:enable:exported diff --git a/pkg/server/tests/tidb_test.go b/pkg/server/tests/tidb_test.go index f4f15f5d7ab08..365fad6d60ecc 100644 --- a/pkg/server/tests/tidb_test.go +++ b/pkg/server/tests/tidb_test.go @@ -1126,6 +1126,11 @@ func TestSumAvg(t *testing.T) { ts.RunTestSumAvg(t) } +func TestStmtCountLimit(t *testing.T) { + ts := createTidbTestSuite(t) + ts.RunTestStmtCountLimit(t) +} + func TestNullFlag(t *testing.T) { ts := createTidbTestSuite(t) diff --git a/pkg/session/session.go b/pkg/session/session.go index 4c1a62e08e6ad..74e4877d8bae1 100644 --- a/pkg/session/session.go +++ b/pkg/session/session.go @@ -2417,6 +2417,14 @@ func runStmt(ctx context.Context, se *session, s sqlexec.Statement) (rs sqlexec. if err != nil { return nil, err } + if sessVars.TxnCtx.CouldRetry && !s.IsReadOnly(sessVars) { + // Only when the txn is could retry and the statement is not read only, need to do stmt-count-limit check, + // otherwise, the stmt won't be add into stmt history, and also don't need check. + // About `stmt-count-limit`, see more in https://docs.pingcap.com/tidb/stable/tidb-configuration-file#stmt-count-limit + if err := checkStmtLimit(ctx, se, false); err != nil { + return nil, err + } + } rs, err = s.Exec(ctx) se.updateTelemetryMetric(s.(*executor.ExecStmt)) diff --git a/pkg/session/test/txn/txn_test.go b/pkg/session/test/txn/txn_test.go index a0af220f1077a..3f5893157ea34 100644 --- a/pkg/session/test/txn/txn_test.go +++ b/pkg/session/test/txn/txn_test.go @@ -379,6 +379,16 @@ func TestBatchCommit(t *testing.T) { tk.MustExec("insert into t values (7)") tk1.MustQuery("select * from t").Check(testkit.Rows("5", "6", "7")) + tk.MustExec("delete from t") + tk.MustExec("commit") + tk.MustExec("begin") + tk.MustExec("explain analyze insert into t values (5)") + tk1.MustQuery("select * from t").Check(testkit.Rows()) + tk.MustExec("explain analyze insert into t values (6)") + tk1.MustQuery("select * from t").Check(testkit.Rows()) + tk.MustExec("explain analyze insert into t values (7)") + tk1.MustQuery("select * from t").Check(testkit.Rows("5", "6", "7")) + // The session is still in transaction. tk.MustExec("insert into t values (8)") tk1.MustQuery("select * from t").Check(testkit.Rows("5", "6", "7")) diff --git a/pkg/session/tidb.go b/pkg/session/tidb.go index d6c9c59a6a4d3..16e3a2f423678 100644 --- a/pkg/session/tidb.go +++ b/pkg/session/tidb.go @@ -271,7 +271,7 @@ func finishStmt(ctx context.Context, se *session, meetsErr error, sql sqlexec.St if err != nil { return err } - return checkStmtLimit(ctx, se) + return checkStmtLimit(ctx, se, true) } func autoCommitAfterStmt(ctx context.Context, se *session, meetsErr error, sql sqlexec.Statement) error { @@ -305,18 +305,29 @@ func autoCommitAfterStmt(ctx context.Context, se *session, meetsErr error, sql s return nil } -func checkStmtLimit(ctx context.Context, se *session) error { +func checkStmtLimit(ctx context.Context, se *session, isFinish bool) error { // If the user insert, insert, insert ... but never commit, TiDB would OOM. // So we limit the statement count in a transaction here. var err error sessVars := se.GetSessionVars() history := GetHistory(se) - if history.Count() > int(config.GetGlobalConfig().Performance.StmtCountLimit) { + stmtCount := history.Count() + if !isFinish { + // history stmt count + current stmt, since current stmt is not finish, it has not add to history. + stmtCount++ + } + if stmtCount > int(config.GetGlobalConfig().Performance.StmtCountLimit) { if !sessVars.BatchCommit { se.RollbackTxn(ctx) - return errors.Errorf("statement count %d exceeds the transaction limitation, autocommit = %t", - history.Count(), sessVars.IsAutocommit()) + return errors.Errorf("statement count %d exceeds the transaction limitation, transaction has been rollback, autocommit = %t", + stmtCount, sessVars.IsAutocommit()) + } + if !isFinish { + // if the stmt is not finish execute, then just return, since some work need to be done such as StmtCommit. + return nil } + // If the stmt is finish execute, and exceed the StmtCountLimit, and BatchCommit is true, + // then commit the current transaction and create a new transaction. err = sessiontxn.NewTxn(ctx, se) // The transaction does not committed yet, we need to keep it in transaction. // The last history could not be "commit"/"rollback" statement. @@ -328,6 +339,7 @@ func checkStmtLimit(ctx context.Context, se *session) error { } // GetHistory get all stmtHistory in current txn. Exported only for test. +// If stmtHistory is nil, will create a new one for current txn. func GetHistory(ctx sessionctx.Context) *StmtHistory { hist, ok := ctx.GetSessionVars().TxnCtx.History.(*StmtHistory) if ok {