diff --git a/distsql/distsql.go b/distsql/distsql.go index 3219098b646e9..4a4560eea35fb 100644 --- a/distsql/distsql.go +++ b/distsql/distsql.go @@ -30,6 +30,7 @@ import ( "github.com/pingcap/tidb/statistics" "github.com/pingcap/tidb/types" "github.com/pingcap/tidb/util/logutil" + "github.com/pingcap/tidb/util/memory" "github.com/pingcap/tidb/util/trxevents" "github.com/pingcap/tipb/go-tipb" "github.com/tikv/client-go/v2/tikvrpc/interceptor" @@ -38,11 +39,19 @@ import ( ) // DispatchMPPTasks dispatches all tasks and returns an iterator. +<<<<<<< HEAD func DispatchMPPTasks(ctx context.Context, sctx sessionctx.Context, tasks []*kv.MPPDispatchRequest, fieldTypes []*types.FieldType, planIDs []int, rootID int, startTs uint64) (SelectResult, error) { ctx = WithSQLKvExecCounterInterceptor(ctx, sctx.GetSessionVars().StmtCtx) _, allowTiFlashFallback := sctx.GetSessionVars().AllowFallbackToTiKV[kv.TiFlash] ctx = SetTiFlashMaxThreadsInContext(ctx, sctx) resp := sctx.GetMPPClient().DispatchMPPTasks(ctx, sctx.GetSessionVars().KVVars, tasks, allowTiFlashFallback, startTs) +======= +func DispatchMPPTasks(ctx context.Context, sctx sessionctx.Context, tasks []*kv.MPPDispatchRequest, fieldTypes []*types.FieldType, planIDs []int, rootID int, startTs uint64, mppQueryID kv.MPPQueryID, memTracker *memory.Tracker) (SelectResult, error) { + ctx = WithSQLKvExecCounterInterceptor(ctx, sctx.GetSessionVars().StmtCtx) + _, allowTiFlashFallback := sctx.GetSessionVars().AllowFallbackToTiKV[kv.TiFlash] + ctx = SetTiFlashMaxThreadsInContext(ctx, sctx) + resp := sctx.GetMPPClient().DispatchMPPTasks(ctx, sctx.GetSessionVars().KVVars, tasks, allowTiFlashFallback, startTs, mppQueryID, sctx.GetSessionVars().ChooseMppVersion(), memTracker) +>>>>>>> 07af605381 (*: add memory tracker for mppIterator (#40901)) if resp == nil { return nil, errors.New("client returns nil response") } diff --git a/executor/builder.go b/executor/builder.go index 0c33214f847e4..0b81826001070 100644 --- a/executor/builder.go +++ b/executor/builder.go @@ -3281,7 +3281,13 @@ func (b *executorBuilder) buildMPPGather(v *plannercore.PhysicalTableReader) Exe is: b.is, originalPlan: v.GetTablePlan(), startTS: startTs, +<<<<<<< HEAD +======= + mppQueryID: kv.MPPQueryID{QueryTs: getMPPQueryTS(b.ctx), LocalQueryID: getMPPQueryID(b.ctx), ServerID: domain.GetDomain(b.ctx).ServerID()}, + memTracker: memory.NewTracker(v.ID(), -1), +>>>>>>> 07af605381 (*: add memory tracker for mppIterator (#40901)) } + gather.memTracker.AttachTo(b.ctx.GetSessionVars().StmtCtx.MemTracker) return gather } diff --git a/executor/mpp_gather.go b/executor/mpp_gather.go index 3a8ce366742b6..fe54951dead5b 100644 --- a/executor/mpp_gather.go +++ b/executor/mpp_gather.go @@ -26,6 +26,7 @@ import ( "github.com/pingcap/tidb/sessionctx" "github.com/pingcap/tidb/util/chunk" "github.com/pingcap/tidb/util/logutil" + "github.com/pingcap/tidb/util/memory" "github.com/pingcap/tipb/go-tipb" "go.uber.org/zap" ) @@ -49,6 +50,8 @@ type MPPGather struct { mppReqs []*kv.MPPDispatchRequest respIter distsql.SelectResult + + memTracker *memory.Tracker } func (e *MPPGather) appendMPPDispatchReq(pf *plannercore.Fragment) error { @@ -122,7 +125,11 @@ func (e *MPPGather) Open(ctx context.Context) (err error) { failpoint.Return(errors.Errorf("The number of tasks is not right, expect %d tasks but actually there are %d tasks", val.(int), len(e.mppReqs))) } }) +<<<<<<< HEAD e.respIter, err = distsql.DispatchMPPTasks(ctx, e.ctx, e.mppReqs, e.retFieldTypes, planIDs, e.id, e.startTS) +======= + e.respIter, err = distsql.DispatchMPPTasks(ctx, e.ctx, e.mppReqs, e.retFieldTypes, planIDs, e.id, e.startTS, e.mppQueryID, e.memTracker) +>>>>>>> 07af605381 (*: add memory tracker for mppIterator (#40901)) if err != nil { return errors.Trace(err) } diff --git a/executor/tiflash_test.go b/executor/tiflash_test.go index 999ea336acffa..260d72ddad89e 100644 --- a/executor/tiflash_test.go +++ b/executor/tiflash_test.go @@ -1299,3 +1299,145 @@ func TestTiflashEmptyDynamicPruneResult(t *testing.T) { tk.MustQuery("select /*+ read_from_storage(tiflash[t2]) */ * from IDT_RP24833 partition(p2) t2 where t2. col1 <= -8448770111093677011;").Check(testkit.Rows()) tk.MustQuery("select /*+ read_from_storage(tiflash[t1, t2]) */ * from IDT_RP24833 partition(p3, p4) t1 join IDT_RP24833 partition(p2) t2 on t1.col1 = t2.col1 where t1. col1 between -8448770111093677011 and -8448770111093677011 and t2. col1 <= -8448770111093677011;").Check(testkit.Rows()) } +<<<<<<< HEAD:executor/tiflash_test.go +======= + +func TestDisaggregatedTiFlash(t *testing.T) { + config.UpdateGlobal(func(conf *config.Config) { + conf.DisaggregatedTiFlash = true + }) + defer config.UpdateGlobal(func(conf *config.Config) { + conf.DisaggregatedTiFlash = false + }) + err := tiflashcompute.InitGlobalTopoFetcher(tiflashcompute.TestASStr, "", "", false) + require.NoError(t, err) + + store := testkit.CreateMockStore(t, withMockTiFlash(2)) + tk := testkit.NewTestKit(t, store) + tk.MustExec("use test") + tk.MustExec("drop table if exists t") + tk.MustExec("create table t(c1 int)") + tk.MustExec("alter table t set tiflash replica 1") + tb := external.GetTableByName(t, tk, "test", "t") + err = domain.GetDomain(tk.Session()).DDL().UpdateTableReplicaInfo(tk.Session(), tb.Meta().ID, true) + require.NoError(t, err) + tk.MustExec("set @@session.tidb_isolation_read_engines=\"tiflash\"") + + err = tk.ExecToErr("select * from t;") + // Expect error, because TestAutoScaler return empty topo. + require.Contains(t, err.Error(), "Cannot find proper topo from AutoScaler") + + err = tiflashcompute.InitGlobalTopoFetcher(tiflashcompute.AWSASStr, "", "", false) + require.NoError(t, err) + err = tk.ExecToErr("select * from t;") + // Expect error, because AWSAutoScaler is not setup, so http request will fail. + require.Contains(t, err.Error(), "[util:1815]Internal : get tiflash_compute topology failed") +} + +// todo: remove this after AutoScaler is stable. +func TestDisaggregatedTiFlashNonAutoScaler(t *testing.T) { + config.UpdateGlobal(func(conf *config.Config) { + conf.DisaggregatedTiFlash = true + conf.UseAutoScaler = false + }) + defer config.UpdateGlobal(func(conf *config.Config) { + conf.DisaggregatedTiFlash = false + conf.UseAutoScaler = true + }) + + // Setting globalTopoFetcher to nil to can make sure cannot fetch topo from AutoScaler. + err := tiflashcompute.InitGlobalTopoFetcher(tiflashcompute.InvalidASStr, "", "", false) + require.Contains(t, err.Error(), "unexpected topo fetch type. expect: mock or aws or gcp, got invalid") + + store := testkit.CreateMockStore(t, withMockTiFlash(2)) + tk := testkit.NewTestKit(t, store) + tk.MustExec("use test") + tk.MustExec("drop table if exists t") + tk.MustExec("create table t(c1 int)") + tk.MustExec("alter table t set tiflash replica 1") + tb := external.GetTableByName(t, tk, "test", "t") + err = domain.GetDomain(tk.Session()).DDL().UpdateTableReplicaInfo(tk.Session(), tb.Meta().ID, true) + require.NoError(t, err) + tk.MustExec("set @@session.tidb_isolation_read_engines=\"tiflash\"") + + err = tk.ExecToErr("select * from t;") + // This error message means we use PD instead of AutoScaler. + require.Contains(t, err.Error(), "tiflash_compute node is unavailable") +} + +func TestDisaggregatedTiFlashQuery(t *testing.T) { + config.UpdateGlobal(func(conf *config.Config) { + conf.DisaggregatedTiFlash = true + }) + defer config.UpdateGlobal(func(conf *config.Config) { + conf.DisaggregatedTiFlash = false + }) + + store := testkit.CreateMockStore(t, withMockTiFlash(2)) + tk := testkit.NewTestKit(t, store) + tk.MustExec("use test") + tk.MustExec("drop table if exists tbl_1") + tk.MustExec(`create table tbl_1 ( col_1 bigint not null default -1443635317331776148, + col_2 text ( 176 ) collate utf8mb4_bin not null, + col_3 decimal ( 8, 3 ), + col_4 varchar ( 128 ) collate utf8mb4_bin not null, + col_5 varchar ( 377 ) collate utf8mb4_bin, + col_6 double, + col_7 varchar ( 459 ) collate utf8mb4_bin, + col_8 tinyint default -88 ) charset utf8mb4 collate utf8mb4_bin ;`) + tk.MustExec("alter table tbl_1 set tiflash replica 1") + tb := external.GetTableByName(t, tk, "test", "tbl_1") + err := domain.GetDomain(tk.Session()).DDL().UpdateTableReplicaInfo(tk.Session(), tb.Meta().ID, true) + require.NoError(t, err) + tk.MustExec("set @@session.tidb_isolation_read_engines=\"tiflash\"") + + tk.MustExec("explain select max( tbl_1.col_1 ) as r0 , sum( tbl_1.col_1 ) as r1 , sum( tbl_1.col_8 ) as r2 from tbl_1 where tbl_1.col_8 != 68 or tbl_1.col_3 between null and 939 order by r0,r1,r2;") + + tk.MustExec("set @@tidb_partition_prune_mode = 'static';") + tk.MustExec("set @@session.tidb_isolation_read_engines=\"tiflash\"") + tk.MustExec("create table t1(c1 int, c2 int) partition by hash(c1) partitions 3") + tk.MustExec("insert into t1 values(1, 1), (2, 2), (3, 3)") + tk.MustExec("alter table t1 set tiflash replica 1") + tb = external.GetTableByName(t, tk, "test", "t1") + err = domain.GetDomain(tk.Session()).DDL().UpdateTableReplicaInfo(tk.Session(), tb.Meta().ID, true) + require.NoError(t, err) + tk.MustQuery("explain select * from t1 where c1 < 2").Check(testkit.Rows( + "PartitionUnion_10 9970.00 root ", + "├─TableReader_15 3323.33 root MppVersion: 1, data:ExchangeSender_14", + "│ └─ExchangeSender_14 3323.33 mpp[tiflash] ExchangeType: PassThrough", + "│ └─Selection_13 3323.33 mpp[tiflash] lt(test.t1.c1, 2)", + "│ └─TableFullScan_12 10000.00 mpp[tiflash] table:t1, partition:p0 keep order:false, stats:pseudo", + "├─TableReader_19 3323.33 root MppVersion: 1, data:ExchangeSender_18", + "│ └─ExchangeSender_18 3323.33 mpp[tiflash] ExchangeType: PassThrough", + "│ └─Selection_17 3323.33 mpp[tiflash] lt(test.t1.c1, 2)", + "│ └─TableFullScan_16 10000.00 mpp[tiflash] table:t1, partition:p1 keep order:false, stats:pseudo", + "└─TableReader_23 3323.33 root MppVersion: 1, data:ExchangeSender_22", + " └─ExchangeSender_22 3323.33 mpp[tiflash] ExchangeType: PassThrough", + " └─Selection_21 3323.33 mpp[tiflash] lt(test.t1.c1, 2)", + " └─TableFullScan_20 10000.00 mpp[tiflash] table:t1, partition:p2 keep order:false, stats:pseudo")) + // tk.MustQuery("select * from t1 where c1 < 2").Check(testkit.Rows("1 1")) +} + +func TestMPPMemoryTracker(t *testing.T) { + store := testkit.CreateMockStore(t, withMockTiFlash(2)) + tk := testkit.NewTestKit(t, store) + tk.MustExec("set tidb_mem_quota_query = 1 << 30") + tk.MustExec("set global tidb_mem_oom_action = 'CANCEL'") + tk.MustExec("use test") + tk.MustExec("create table t(a int);") + tk.MustExec("insert into t values (1);") + tk.MustExec("alter table t set tiflash replica 1") + tb := external.GetTableByName(t, tk, "test", "t") + err := domain.GetDomain(tk.Session()).DDL().UpdateTableReplicaInfo(tk.Session(), tb.Meta().ID, true) + require.NoError(t, err) + tk.MustExec("set tidb_enforce_mpp = on;") + tk.MustQuery("select * from t").Check(testkit.Rows("1")) + require.NoError(t, failpoint.Enable("github.com/pingcap/tidb/store/copr/testMPPOOMPanic", `return(true)`)) + defer func() { + require.NoError(t, failpoint.Disable("github.com/pingcap/tidb/store/copr/testMPPOOMPanic")) + }() + err = tk.QueryToErr("select * from t") + require.NotNil(t, err) + require.True(t, strings.Contains(err.Error(), "Out Of Memory Quota!")) +} +>>>>>>> 07af605381 (*: add memory tracker for mppIterator (#40901)):executor/tiflashtest/tiflash_test.go diff --git a/kv/mpp.go b/kv/mpp.go index 2e026a8a55a67..e2569da64f03a 100644 --- a/kv/mpp.go +++ b/kv/mpp.go @@ -19,6 +19,11 @@ import ( "time" "github.com/pingcap/kvproto/pkg/mpp" +<<<<<<< HEAD +======= + "github.com/pingcap/tidb/util/memory" + "github.com/pingcap/tipb/go-tipb" +>>>>>>> 07af605381 (*: add memory tracker for mppIterator (#40901)) ) // MPPTaskMeta means the meta info such as location of a mpp task. @@ -83,7 +88,11 @@ type MPPClient interface { ConstructMPPTasks(context.Context, *MPPBuildTasksRequest, map[string]time.Time, time.Duration) ([]MPPTaskMeta, error) // DispatchMPPTasks dispatches ALL mpp requests at once, and returns an iterator that transfers the data. +<<<<<<< HEAD DispatchMPPTasks(ctx context.Context, vars interface{}, reqs []*MPPDispatchRequest, needTriggerFallback bool, startTs uint64) Response +======= + DispatchMPPTasks(ctx context.Context, vars interface{}, reqs []*MPPDispatchRequest, needTriggerFallback bool, startTs uint64, mppQueryID MPPQueryID, mppVersion MppVersion, memTracker *memory.Tracker) Response +>>>>>>> 07af605381 (*: add memory tracker for mppIterator (#40901)) } // MPPBuildTasksRequest request the stores allocation for a mpp plan fragment. diff --git a/store/copr/mpp.go b/store/copr/mpp.go index 22061039616a5..5458bf8deb0a9 100644 --- a/store/copr/mpp.go +++ b/store/copr/mpp.go @@ -16,6 +16,7 @@ package copr import ( "context" + "fmt" "io" "strconv" "sync" @@ -34,6 +35,7 @@ import ( derr "github.com/pingcap/tidb/store/driver/error" "github.com/pingcap/tidb/util/logutil" "github.com/pingcap/tidb/util/mathutil" + "github.com/pingcap/tidb/util/memory" "github.com/tikv/client-go/v2/tikv" "github.com/tikv/client-go/v2/tikvrpc" "go.uber.org/zap" @@ -159,6 +161,8 @@ type mppIterator struct { mu sync.Mutex enableCollectExecutionInfo bool + + memTracker *memory.Tracker } func (m *mppIterator) run(ctx context.Context) { @@ -196,6 +200,22 @@ func (m *mppIterator) sendError(err error) { } func (m *mppIterator) sendToRespCh(resp *mppResponse) (exit bool) { + defer func() { + if r := recover(); r != nil { + logutil.BgLogger().Error("mppIterator panic", zap.Stack("stack"), zap.Any("recover", r)) + m.sendError(errors.New(fmt.Sprint(r))) + } + }() + if m.memTracker != nil { + respSize := resp.MemSize() + failpoint.Inject("testMPPOOMPanic", func(val failpoint.Value) { + if val.(bool) && respSize != 0 { + respSize = 1 << 30 + } + }) + m.memTracker.Consume(respSize) + defer m.memTracker.Consume(-respSize) + } select { case m.respChan <- resp: case <-m.finishCh: @@ -502,7 +522,11 @@ func (m *mppIterator) Next(ctx context.Context) (kv.ResultSubset, error) { } // DispatchMPPTasks dispatches all the mpp task and waits for the responses. +<<<<<<< HEAD func (c *MPPClient) DispatchMPPTasks(ctx context.Context, variables interface{}, dispatchReqs []*kv.MPPDispatchRequest, needTriggerFallback bool, startTs uint64) kv.Response { +======= +func (c *MPPClient) DispatchMPPTasks(ctx context.Context, variables interface{}, dispatchReqs []*kv.MPPDispatchRequest, needTriggerFallback bool, startTs uint64, mppQueryID kv.MPPQueryID, mppVersion kv.MppVersion, memTracker *memory.Tracker) kv.Response { +>>>>>>> 07af605381 (*: add memory tracker for mppIterator (#40901)) vars := variables.(*tikv.Variables) ctxChild, cancelFunc := context.WithCancel(ctx) iter := &mppIterator{ @@ -514,7 +538,12 @@ func (c *MPPClient) DispatchMPPTasks(ctx context.Context, variables interface{}, startTs: startTs, vars: vars, needTriggerFallback: needTriggerFallback, +<<<<<<< HEAD enableCollectExecutionInfo: config.GetGlobalConfig().Instance.EnableCollectExecutionInfo, +======= + enableCollectExecutionInfo: config.GetGlobalConfig().Instance.EnableCollectExecutionInfo.Load(), + memTracker: memTracker, +>>>>>>> 07af605381 (*: add memory tracker for mppIterator (#40901)) } go iter.run(ctxChild) return iter