diff --git a/executor/adapter.go b/executor/adapter.go index 01e832288d8e4..5f09ad138afd3 100644 --- a/executor/adapter.go +++ b/executor/adapter.go @@ -740,12 +740,7 @@ func (a *ExecStmt) handleNoDelay(ctx context.Context, e Executor, isPessimistic // done in the `defer` function. If the rs is not nil, the detachment will be done in // `rs.Close` in `handleStmt` if handled && sc != nil && rs == nil { - if sc.MemTracker != nil { - sc.MemTracker.Detach() - } - if sc.DiskTracker != nil { - sc.DiskTracker.Detach() - } + sc.DetachMemDiskTracker() } }() @@ -1419,15 +1414,7 @@ func (a *ExecStmt) checkPlanReplayerCapture(txnTS uint64) { func (a *ExecStmt) CloseRecordSet(txnStartTS uint64, lastErr error) { a.FinishExecuteStmt(txnStartTS, lastErr, false) a.logAudit() - // Detach the Memory and disk tracker for the previous stmtCtx from GlobalMemoryUsageTracker and GlobalDiskUsageTracker - if stmtCtx := a.Ctx.GetSessionVars().StmtCtx; stmtCtx != nil { - if stmtCtx.DiskTracker != nil { - stmtCtx.DiskTracker.Detach() - } - if stmtCtx.MemTracker != nil { - stmtCtx.MemTracker.Detach() - } - } + a.Ctx.GetSessionVars().StmtCtx.DetachMemDiskTracker() } // LogSlowQuery is used to print the slow query in the log files. diff --git a/executor/issuetest/BUILD.bazel b/executor/issuetest/BUILD.bazel index 12282131b98e0..2680bd5126f9c 100644 --- a/executor/issuetest/BUILD.bazel +++ b/executor/issuetest/BUILD.bazel @@ -17,10 +17,12 @@ go_test( "//parser/auth", "//parser/charset", "//parser/mysql", + "//session", "//sessionctx/variable", "//statistics", "//testkit", "//util", + "//util/memory", "@com_github_pingcap_failpoint//:failpoint", "@com_github_stretchr_testify//require", "@com_github_tikv_client_go_v2//tikv", diff --git a/executor/issuetest/executor_issue_test.go b/executor/issuetest/executor_issue_test.go index b29fc54031705..6c132ef088f87 100644 --- a/executor/issuetest/executor_issue_test.go +++ b/executor/issuetest/executor_issue_test.go @@ -20,6 +20,7 @@ import ( "math/rand" "strings" "testing" + "time" "github.com/pingcap/failpoint" _ "github.com/pingcap/tidb/autoid_service" @@ -28,10 +29,12 @@ import ( "github.com/pingcap/tidb/parser/auth" "github.com/pingcap/tidb/parser/charset" "github.com/pingcap/tidb/parser/mysql" + "github.com/pingcap/tidb/session" "github.com/pingcap/tidb/sessionctx/variable" "github.com/pingcap/tidb/statistics" "github.com/pingcap/tidb/testkit" "github.com/pingcap/tidb/util" + "github.com/pingcap/tidb/util/memory" "github.com/stretchr/testify/require" ) @@ -1415,3 +1418,46 @@ func TestIssue42298(t *testing.T) { res = tk.MustQuery("admin show ddl job queries limit 999 offset 268430000") require.Zero(t, len(res.Rows()), len(res.Rows())) } + +func TestIssue42662(t *testing.T) { + store, dom := testkit.CreateMockStoreAndDomain(t) + tk := testkit.NewTestKit(t, store) + tk.Session().GetSessionVars().ConnectionID = 12345 + tk.Session().GetSessionVars().MemTracker = memory.NewTracker(memory.LabelForSession, -1) + tk.Session().GetSessionVars().MemTracker.SessionID = 12345 + tk.Session().GetSessionVars().MemTracker.IsRootTrackerOfSess = true + + sm := &testkit.MockSessionManager{ + PS: []*util.ProcessInfo{tk.Session().ShowProcess()}, + } + sm.Conn = make(map[uint64]session.Session) + sm.Conn[tk.Session().GetSessionVars().ConnectionID] = tk.Session() + dom.ServerMemoryLimitHandle().SetSessionManager(sm) + go dom.ServerMemoryLimitHandle().Run() + + tk.MustExec("use test") + tk.MustQuery("select connection_id()").Check(testkit.Rows("12345")) + tk.MustExec("drop table if exists t1, t2") + tk.MustExec("create table t1 (a int, b int, c int)") + tk.MustExec("create table t2 (a int, b int, c int)") + tk.MustExec("insert into t1 values (1, 1, 1), (1, 2, 2), (2, 1, 3), (2, 2, 4)") + tk.MustExec("insert into t2 values (1, 1, 1), (1, 2, 2), (2, 1, 3), (2, 2, 4)") + // set tidb_server_memory_limit to 1.6GB, tidb_server_memory_limit_sess_min_size to 128MB + tk.MustExec("set global tidb_server_memory_limit='1600MB'") + tk.MustExec("set global tidb_server_memory_limit_sess_min_size=128*1024*1024") + tk.MustExec("set global tidb_mem_oom_action = 'cancel'") + require.NoError(t, failpoint.Enable("github.com/pingcap/tidb/executor/issue42662_1", `return(true)`)) + // tk.Session() should be marked as MemoryTop1Tracker but not killed. + tk.MustQuery("select /*+ hash_join(t1)*/ * from t1 join t2 on t1.a = t2.a and t1.b = t2.b") + + // try to trigger the kill top1 logic + require.NoError(t, failpoint.Enable("github.com/pingcap/tidb/util/servermemorylimit/issue42662_2", `return(true)`)) + time.Sleep(1 * time.Second) + + // no error should be returned + tk.MustQuery("select count(*) from t1") + tk.MustQuery("select count(*) from t1") + + require.NoError(t, failpoint.Disable("github.com/pingcap/tidb/executor/issue42662_1")) + require.NoError(t, failpoint.Disable("github.com/pingcap/tidb/util/servermemorylimit/issue42662_2")) +} diff --git a/executor/join.go b/executor/join.go index 6ac59cc15d3f3..759e92fa0d392 100644 --- a/executor/join.go +++ b/executor/join.go @@ -315,6 +315,15 @@ func (w *buildWorker) fetchBuildSideRows(ctx context.Context, chkCh chan<- *chun return } }) + failpoint.Inject("issue42662_1", func(val failpoint.Value) { + if val.(bool) { + if w.hashJoinCtx.sessCtx.GetSessionVars().ConnectionID != 0 { + // consume 170MB memory, this sql should be tracked into MemoryTop1Tracker + w.hashJoinCtx.memTracker.Consume(170 * 1024 * 1024) + } + return + } + }) sessVars := w.hashJoinCtx.sessCtx.GetSessionVars() for { if w.hashJoinCtx.finished.Load() { diff --git a/server/conn.go b/server/conn.go index 6eacc3cfaf186..a058b39b9c614 100644 --- a/server/conn.go +++ b/server/conn.go @@ -2058,12 +2058,20 @@ func (cc *clientConn) handleStmt(ctx context.Context, stmt ast.StmtNode, warns [ cc.audit(plugin.Starting) rs, err := cc.ctx.ExecuteStmt(ctx, stmt) reg.End() - // The session tracker detachment from global tracker is solved in the `rs.Close` in most cases. - // If the rs is nil, the detachment will be done in the `handleNoDelay`. + // - If rs is not nil, the statement tracker detachment from session tracker + // is done in the `rs.Close` in most cases. + // - If the rs is nil and err is not nil, the detachment will be done in + // the `handleNoDelay`. if rs != nil { defer terror.Call(rs.Close) } if err != nil { + // If error is returned during the planner phase or the executor.Open + // phase, the rs will be nil, and StmtCtx.MemTracker StmtCtx.DiskTracker + // will not be detached. We need to detach them manually. + if sv := cc.ctx.GetSessionVars(); sv != nil && sv.StmtCtx != nil { + sv.StmtCtx.DetachMemDiskTracker() + } return true, err } diff --git a/sessionctx/stmtctx/stmtctx.go b/sessionctx/stmtctx/stmtctx.go index d9e4cc5a2376d..8fd12745f0635 100644 --- a/sessionctx/stmtctx/stmtctx.go +++ b/sessionctx/stmtctx/stmtctx.go @@ -1202,6 +1202,19 @@ func (sc *StatementContext) UseDynamicPartitionPrune() bool { return sc.UseDynamicPruneMode } +// DetachMemDiskTracker detaches the memory and disk tracker from the sessionTracker. +func (sc *StatementContext) DetachMemDiskTracker() { + if sc == nil { + return + } + if sc.MemTracker != nil { + sc.MemTracker.Detach() + } + if sc.DiskTracker != nil { + sc.DiskTracker.Detach() + } +} + // CopTasksDetails collects some useful information of cop-tasks during execution. type CopTasksDetails struct { NumCopTasks int diff --git a/testkit/mocksessionmanager.go b/testkit/mocksessionmanager.go index 67280bc2e4cbe..5cf67cec251fb 100644 --- a/testkit/mocksessionmanager.go +++ b/testkit/mocksessionmanager.go @@ -33,7 +33,7 @@ type MockSessionManager struct { SerID uint64 TxnInfo []*txninfo.TxnInfo Dom *domain.Domain - conn map[uint64]session.Session + Conn map[uint64]session.Session mu sync.Mutex } @@ -44,8 +44,8 @@ func (msm *MockSessionManager) ShowTxnList() []*txninfo.TxnInfo { if len(msm.TxnInfo) > 0 { return msm.TxnInfo } - rs := make([]*txninfo.TxnInfo, 0, len(msm.conn)) - for _, se := range msm.conn { + rs := make([]*txninfo.TxnInfo, 0, len(msm.Conn)) + for _, se := range msm.Conn { info := se.TxnInfo() if info != nil { rs = append(rs, info) @@ -66,7 +66,7 @@ func (msm *MockSessionManager) ShowProcessList() map[uint64]*util.ProcessInfo { return ret } msm.mu.Lock() - for connID, pi := range msm.conn { + for connID, pi := range msm.Conn { ret[connID] = pi.ShowProcess() } msm.mu.Unlock() @@ -89,7 +89,7 @@ func (msm *MockSessionManager) GetProcessInfo(id uint64) (*util.ProcessInfo, boo } msm.mu.Lock() defer msm.mu.Unlock() - if sess := msm.conn[id]; sess != nil { + if sess := msm.Conn[id]; sess != nil { return sess.ShowProcess(), true } if msm.Dom != nil { @@ -130,7 +130,7 @@ func (*MockSessionManager) GetInternalSessionStartTSList() []uint64 { // KillNonFlashbackClusterConn implement SessionManager interface. func (msm *MockSessionManager) KillNonFlashbackClusterConn() { - for _, se := range msm.conn { + for _, se := range msm.Conn { processInfo := se.ShowProcess() ddl, ok := processInfo.StmtCtx.GetPlan().(*core.DDL) if !ok { @@ -148,7 +148,7 @@ func (msm *MockSessionManager) KillNonFlashbackClusterConn() { // CheckOldRunningTxn is to get all startTS of every transactions running in the current internal sessions func (msm *MockSessionManager) CheckOldRunningTxn(job2ver map[int64]int64, job2ids map[int64]string) { msm.mu.Lock() - for _, se := range msm.conn { + for _, se := range msm.Conn { session.RemoveLockDDLJobs(se, job2ver, job2ids) } msm.mu.Unlock() diff --git a/testkit/testkit.go b/testkit/testkit.go index 5352e095d7d51..ef0ffd11acf50 100644 --- a/testkit/testkit.go +++ b/testkit/testkit.go @@ -74,10 +74,10 @@ func NewTestKit(t testing.TB, store kv.Storage) *TestKit { mockSm, ok := sm.(*MockSessionManager) if ok { mockSm.mu.Lock() - if mockSm.conn == nil { - mockSm.conn = make(map[uint64]session.Session) + if mockSm.Conn == nil { + mockSm.Conn = make(map[uint64]session.Session) } - mockSm.conn[tk.session.GetSessionVars().ConnectionID] = tk.session + mockSm.Conn[tk.session.GetSessionVars().ConnectionID] = tk.session mockSm.mu.Unlock() } tk.session.SetSessionManager(sm) diff --git a/util/memory/tracker.go b/util/memory/tracker.go index 4ca49ad0bba3f..ba424b5cf9145 100644 --- a/util/memory/tracker.go +++ b/util/memory/tracker.go @@ -302,6 +302,9 @@ func (t *Tracker) AttachTo(parent *Tracker) { // Detach de-attach the tracker child from its parent, then set its parent property as nil func (t *Tracker) Detach() { + if t == nil { + return + } parent := t.getParent() if parent == nil { return @@ -446,6 +449,7 @@ func (t *Tracker) Consume(bs int64) { currentAction = nextAction nextAction = currentAction.GetFallback() } + logutil.BgLogger().Warn("global memory controller, lastAction", zap.Any("action", currentAction)) currentAction.Action(tracker) } } @@ -471,6 +475,7 @@ func (t *Tracker) Consume(bs int64) { } oldTracker = MemUsageTop1Tracker.Load() } + logutil.BgLogger().Error("global memory controller, update the Top1 session", zap.Int64("memUsage", memUsage), zap.Uint64("conn", sessionRootTracker.SessionID), zap.Uint64("limitSessMinSize", limitSessMinSize)) } } diff --git a/util/servermemorylimit/BUILD.bazel b/util/servermemorylimit/BUILD.bazel index c8fcc3e3c4c79..4abe43930abc8 100644 --- a/util/servermemorylimit/BUILD.bazel +++ b/util/servermemorylimit/BUILD.bazel @@ -11,6 +11,7 @@ go_library( "//util", "//util/logutil", "//util/memory", + "@com_github_pingcap_failpoint//:failpoint", "@org_uber_go_atomic//:atomic", "@org_uber_go_zap//:zap", ], diff --git a/util/servermemorylimit/servermemorylimit.go b/util/servermemorylimit/servermemorylimit.go index 3cdf8b73ff758..e76b31cbb4c49 100644 --- a/util/servermemorylimit/servermemorylimit.go +++ b/util/servermemorylimit/servermemorylimit.go @@ -21,6 +21,7 @@ import ( "sync/atomic" "time" + "github.com/pingcap/failpoint" "github.com/pingcap/tidb/parser/mysql" "github.com/pingcap/tidb/types" "github.com/pingcap/tidb/util" @@ -88,6 +89,15 @@ type sessionToBeKilled struct { lastLogTime time.Time } +func (s *sessionToBeKilled) reset() { + s.isKilling = false + s.sqlStartTime = time.Time{} + s.sessionID = 0 + s.sessionTracker = nil + s.killStartTime = time.Time{} + s.lastLogTime = time.Time{} +} + func killSessIfNeeded(s *sessionToBeKilled, bt uint64, sm util.SessionManager) { if s.isKilling { if info, ok := sm.GetProcessInfo(s.sessionID); ok { @@ -104,7 +114,7 @@ func killSessIfNeeded(s *sessionToBeKilled, bt uint64, sm util.SessionManager) { return } } - s.isKilling = false + s.reset() IsKilling.Store(false) memory.MemUsageTop1Tracker.CompareAndSwap(s.sessionTracker, nil) //nolint: all_revive,revive @@ -115,14 +125,25 @@ func killSessIfNeeded(s *sessionToBeKilled, bt uint64, sm util.SessionManager) { if bt == 0 { return } + failpoint.Inject("issue42662_2", func(val failpoint.Value) { + if val.(bool) { + bt = 1 + } + }) instanceStats := memory.ReadMemStats() if instanceStats.HeapInuse > MemoryMaxUsed.Load() { MemoryMaxUsed.Store(instanceStats.HeapInuse) } + limitSessMinSize := memory.ServerMemoryLimitSessMinSize.Load() if instanceStats.HeapInuse > bt { t := memory.MemUsageTop1Tracker.Load() if t != nil { - if info, ok := sm.GetProcessInfo(t.SessionID); ok { + memUsage := t.BytesConsumed() + // If the memory usage of the top1 session is less than tidb_server_memory_limit_sess_min_size, we do not need to kill it. + if uint64(memUsage) < limitSessMinSize { + memory.MemUsageTop1Tracker.CompareAndSwap(t, nil) + t = nil + } else if info, ok := sm.GetProcessInfo(t.SessionID); ok { logutil.BgLogger().Warn("global memory controller tries to kill the top1 memory consumer", zap.Uint64("conn", info.ID), zap.String("sql digest", info.Digest), @@ -146,6 +167,17 @@ func killSessIfNeeded(s *sessionToBeKilled, bt uint64, sm util.SessionManager) { s.killStartTime = time.Now() } } + // If no one larger than tidb_server_memory_limit_sess_min_size is found, we will not kill any one. + if t == nil { + if s.lastLogTime.IsZero() { + s.lastLogTime = time.Now() + } + if time.Since(s.lastLogTime) < 5*time.Second { + return + } + logutil.BgLogger().Warn("global memory controller tries to kill the top1 memory consumer, but no one larger than tidb_server_memory_limit_sess_min_size is found", zap.Uint64("tidb_server_memory_limit_sess_min_size", limitSessMinSize)) + s.lastLogTime = time.Now() + } } }