diff --git a/executor/analyze_col_v2.go b/executor/analyze_col_v2.go index 879d47e85a88e..1d9913d5f23e3 100644 --- a/executor/analyze_col_v2.go +++ b/executor/analyze_col_v2.go @@ -595,6 +595,13 @@ func (e *AnalyzeColumnsExecV2) subMergeWorker(resultCh chan<- *samplingMergeResu failpoint.Inject("mockAnalyzeSamplingMergeWorkerPanic", func() { panic("failpoint triggered") }) + failpoint.Inject("mockAnalyzeMergeWorkerSlowConsume", func(val failpoint.Value) { + times := val.(int) + for i := 0; i < times; i++ { + e.memTracker.Consume(5 << 20) + time.Sleep(100 * time.Millisecond) + } + }) retCollector := statistics.NewRowSampleCollector(int(e.analyzePB.ColReq.SampleSize), e.analyzePB.ColReq.GetSampleRate(), l) for i := 0; i < l; i++ { retCollector.Base().FMSketches = append(retCollector.Base().FMSketches, statistics.NewFMSketch(maxSketchSize)) diff --git a/executor/analyze_utils.go b/executor/analyze_utils.go index c6d4886d79b7c..cdf47373d29f0 100644 --- a/executor/analyze_utils.go +++ b/executor/analyze_utils.go @@ -17,6 +17,7 @@ package executor import ( "context" "strconv" + "strings" "sync" "github.com/pingcap/errors" @@ -24,6 +25,7 @@ import ( "github.com/pingcap/tidb/sessionctx" "github.com/pingcap/tidb/sessionctx/variable" "github.com/pingcap/tidb/statistics" + "github.com/pingcap/tidb/util/memory" "go.uber.org/atomic" ) @@ -45,8 +47,13 @@ func isAnalyzeWorkerPanic(err error) bool { } func getAnalyzePanicErr(r interface{}) error { - if msg, ok := r.(string); ok && msg == globalPanicAnalyzeMemoryExceed { - return errAnalyzeOOM + if msg, ok := r.(string); ok { + if msg == globalPanicAnalyzeMemoryExceed { + return errAnalyzeOOM + } + if strings.Contains(msg, memory.PanicMemoryExceed) { + return errors.Errorf(msg, errAnalyzeOOM) + } } if err, ok := r.(error); ok { if err.Error() == globalPanicAnalyzeMemoryExceed { diff --git a/executor/executor.go b/executor/executor.go index 68e347e1a4d79..cb96942e8f776 100644 --- a/executor/executor.go +++ b/executor/executor.go @@ -1978,19 +1978,19 @@ func ResetContextOfStmt(ctx sessionctx.Context, s ast.StmtNode) (err error) { vars.MemTracker.AttachTo(GlobalAnalyzeMemoryTracker) } else { sc.InitMemTracker(memory.LabelForSQLText, -1) - logOnQueryExceedMemQuota := domain.GetDomain(ctx).ExpensiveQueryHandle().LogOnQueryExceedMemQuota - switch variable.OOMAction.Load() { - case variable.OOMActionCancel: - action := &memory.PanicOnExceed{ConnID: vars.ConnectionID} - action.SetLogHook(logOnQueryExceedMemQuota) - vars.MemTracker.SetActionOnExceed(action) - case variable.OOMActionLog: - fallthrough - default: - action := &memory.LogOnExceed{ConnID: vars.ConnectionID} - action.SetLogHook(logOnQueryExceedMemQuota) - vars.MemTracker.SetActionOnExceed(action) - } + } + logOnQueryExceedMemQuota := domain.GetDomain(ctx).ExpensiveQueryHandle().LogOnQueryExceedMemQuota + switch variable.OOMAction.Load() { + case variable.OOMActionCancel: + action := &memory.PanicOnExceed{ConnID: vars.ConnectionID} + action.SetLogHook(logOnQueryExceedMemQuota) + vars.MemTracker.SetActionOnExceed(action) + case variable.OOMActionLog: + fallthrough + default: + action := &memory.LogOnExceed{ConnID: vars.ConnectionID} + action.SetLogHook(logOnQueryExceedMemQuota) + vars.MemTracker.SetActionOnExceed(action) } sc.MemTracker.SessionID = vars.ConnectionID sc.MemTracker.AttachTo(vars.MemTracker) diff --git a/executor/executor_test.go b/executor/executor_test.go index c73b3a3df7abd..59e70022727d5 100644 --- a/executor/executor_test.go +++ b/executor/executor_test.go @@ -6187,6 +6187,38 @@ func TestGlobalMemoryControl2(t *testing.T) { runtime.GC() } +func TestGlobalMemoryControlForAnalyze(t *testing.T) { + store, dom := testkit.CreateMockStoreAndDomain(t) + + tk0 := testkit.NewTestKit(t, store) + tk0.MustExec("set global tidb_mem_oom_action = 'cancel'") + tk0.MustExec("set global tidb_server_memory_limit = 512MB") + tk0.MustExec("set global tidb_server_memory_limit_sess_min_size = 128") + + sm := &testkit.MockSessionManager{ + PS: []*util.ProcessInfo{tk0.Session().ShowProcess()}, + } + dom.ServerMemoryLimitHandle().SetSessionManager(sm) + go dom.ServerMemoryLimitHandle().Run() + + tk0.MustExec("use test") + tk0.MustExec("create table t(a int)") + tk0.MustExec("insert into t select 1") + for i := 1; i <= 8; i++ { + tk0.MustExec("insert into t select * from t") // 256 Lines + } + sql := "analyze table t with 1.0 samplerate;" // Need about 100MB + require.NoError(t, failpoint.Enable("github.com/pingcap/tidb/util/memory/ReadMemStats", `return(536870912)`)) + require.NoError(t, failpoint.Enable("github.com/pingcap/tidb/executor/mockAnalyzeMergeWorkerSlowConsume", `return(100)`)) + defer func() { + require.NoError(t, failpoint.Disable("github.com/pingcap/tidb/util/memory/ReadMemStats")) + require.NoError(t, failpoint.Disable("github.com/pingcap/tidb/executor/mockAnalyzeMergeWorkerSlowConsume")) + }() + _, err := tk0.Exec(sql) + require.True(t, strings.Contains(err.Error(), "Out Of Memory Quota!")) + runtime.GC() +} + func TestCompileOutOfMemoryQuota(t *testing.T) { store := testkit.CreateMockStore(t) tk := testkit.NewTestKit(t, store) diff --git a/util/memory/memstats.go b/util/memory/memstats.go index cd0074ece52c7..9cc4a3b14fb5a 100644 --- a/util/memory/memstats.go +++ b/util/memory/memstats.go @@ -18,6 +18,8 @@ import ( "runtime" "sync/atomic" "time" + + "github.com/pingcap/failpoint" ) var stats atomic.Pointer[globalMstats] @@ -26,12 +28,18 @@ var stats atomic.Pointer[globalMstats] const ReadMemInterval = 300 * time.Millisecond // ReadMemStats read the mem stats from runtime.ReadMemStats -func ReadMemStats() *runtime.MemStats { +func ReadMemStats() (memStats *runtime.MemStats) { s := stats.Load() if s != nil { - return &s.m + memStats = &s.m + } else { + memStats = ForceReadMemStats() } - return ForceReadMemStats() + failpoint.Inject("ReadMemStats", func(val failpoint.Value) { + injectedSize := val.(int) + memStats.HeapInuse += uint64(injectedSize) + }) + return } // ForceReadMemStats is to force read memory stats.