Skip to content

Commit

Permalink
This is an automated cherry-pick of #40901
Browse files Browse the repository at this point in the history
Signed-off-by: ti-chi-bot <[email protected]>
  • Loading branch information
wshwsh12 authored and ti-chi-bot committed Feb 2, 2023
1 parent ac6560f commit 1e1832f
Show file tree
Hide file tree
Showing 6 changed files with 202 additions and 0 deletions.
9 changes: 9 additions & 0 deletions distsql/distsql.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ import (
"github.com/pingcap/tidb/statistics"
"github.com/pingcap/tidb/types"
"github.com/pingcap/tidb/util/logutil"
"github.com/pingcap/tidb/util/memory"
"github.com/pingcap/tidb/util/trxevents"
"github.com/pingcap/tipb/go-tipb"
"github.com/tikv/client-go/v2/tikvrpc/interceptor"
Expand All @@ -38,11 +39,19 @@ import (
)

// DispatchMPPTasks dispatches all tasks and returns an iterator.
<<<<<<< HEAD
func DispatchMPPTasks(ctx context.Context, sctx sessionctx.Context, tasks []*kv.MPPDispatchRequest, fieldTypes []*types.FieldType, planIDs []int, rootID int, startTs uint64) (SelectResult, error) {
ctx = WithSQLKvExecCounterInterceptor(ctx, sctx.GetSessionVars().StmtCtx)
_, allowTiFlashFallback := sctx.GetSessionVars().AllowFallbackToTiKV[kv.TiFlash]
ctx = SetTiFlashMaxThreadsInContext(ctx, sctx)
resp := sctx.GetMPPClient().DispatchMPPTasks(ctx, sctx.GetSessionVars().KVVars, tasks, allowTiFlashFallback, startTs)
=======
func DispatchMPPTasks(ctx context.Context, sctx sessionctx.Context, tasks []*kv.MPPDispatchRequest, fieldTypes []*types.FieldType, planIDs []int, rootID int, startTs uint64, mppQueryID kv.MPPQueryID, memTracker *memory.Tracker) (SelectResult, error) {
ctx = WithSQLKvExecCounterInterceptor(ctx, sctx.GetSessionVars().StmtCtx)
_, allowTiFlashFallback := sctx.GetSessionVars().AllowFallbackToTiKV[kv.TiFlash]
ctx = SetTiFlashMaxThreadsInContext(ctx, sctx)
resp := sctx.GetMPPClient().DispatchMPPTasks(ctx, sctx.GetSessionVars().KVVars, tasks, allowTiFlashFallback, startTs, mppQueryID, sctx.GetSessionVars().ChooseMppVersion(), memTracker)
>>>>>>> 07af605381 (*: add memory tracker for mppIterator (#40901))
if resp == nil {
return nil, errors.New("client returns nil response")
}
Expand Down
6 changes: 6 additions & 0 deletions executor/builder.go
Original file line number Diff line number Diff line change
Expand Up @@ -3281,7 +3281,13 @@ func (b *executorBuilder) buildMPPGather(v *plannercore.PhysicalTableReader) Exe
is: b.is,
originalPlan: v.GetTablePlan(),
startTS: startTs,
<<<<<<< HEAD
=======
mppQueryID: kv.MPPQueryID{QueryTs: getMPPQueryTS(b.ctx), LocalQueryID: getMPPQueryID(b.ctx), ServerID: domain.GetDomain(b.ctx).ServerID()},
memTracker: memory.NewTracker(v.ID(), -1),
>>>>>>> 07af605381 (*: add memory tracker for mppIterator (#40901))
}
gather.memTracker.AttachTo(b.ctx.GetSessionVars().StmtCtx.MemTracker)
return gather
}

Expand Down
7 changes: 7 additions & 0 deletions executor/mpp_gather.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ import (
"github.com/pingcap/tidb/sessionctx"
"github.com/pingcap/tidb/util/chunk"
"github.com/pingcap/tidb/util/logutil"
"github.com/pingcap/tidb/util/memory"
"github.com/pingcap/tipb/go-tipb"
"go.uber.org/zap"
)
Expand All @@ -49,6 +50,8 @@ type MPPGather struct {
mppReqs []*kv.MPPDispatchRequest

respIter distsql.SelectResult

memTracker *memory.Tracker
}

func (e *MPPGather) appendMPPDispatchReq(pf *plannercore.Fragment) error {
Expand Down Expand Up @@ -122,7 +125,11 @@ func (e *MPPGather) Open(ctx context.Context) (err error) {
failpoint.Return(errors.Errorf("The number of tasks is not right, expect %d tasks but actually there are %d tasks", val.(int), len(e.mppReqs)))
}
})
<<<<<<< HEAD
e.respIter, err = distsql.DispatchMPPTasks(ctx, e.ctx, e.mppReqs, e.retFieldTypes, planIDs, e.id, e.startTS)
=======
e.respIter, err = distsql.DispatchMPPTasks(ctx, e.ctx, e.mppReqs, e.retFieldTypes, planIDs, e.id, e.startTS, e.mppQueryID, e.memTracker)
>>>>>>> 07af605381 (*: add memory tracker for mppIterator (#40901))
if err != nil {
return errors.Trace(err)
}
Expand Down
142 changes: 142 additions & 0 deletions executor/tiflash_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1299,3 +1299,145 @@ func TestTiflashEmptyDynamicPruneResult(t *testing.T) {
tk.MustQuery("select /*+ read_from_storage(tiflash[t2]) */ * from IDT_RP24833 partition(p2) t2 where t2. col1 <= -8448770111093677011;").Check(testkit.Rows())
tk.MustQuery("select /*+ read_from_storage(tiflash[t1, t2]) */ * from IDT_RP24833 partition(p3, p4) t1 join IDT_RP24833 partition(p2) t2 on t1.col1 = t2.col1 where t1. col1 between -8448770111093677011 and -8448770111093677011 and t2. col1 <= -8448770111093677011;").Check(testkit.Rows())
}
<<<<<<< HEAD:executor/tiflash_test.go
=======

func TestDisaggregatedTiFlash(t *testing.T) {
config.UpdateGlobal(func(conf *config.Config) {
conf.DisaggregatedTiFlash = true
})
defer config.UpdateGlobal(func(conf *config.Config) {
conf.DisaggregatedTiFlash = false
})
err := tiflashcompute.InitGlobalTopoFetcher(tiflashcompute.TestASStr, "", "", false)
require.NoError(t, err)

store := testkit.CreateMockStore(t, withMockTiFlash(2))
tk := testkit.NewTestKit(t, store)
tk.MustExec("use test")
tk.MustExec("drop table if exists t")
tk.MustExec("create table t(c1 int)")
tk.MustExec("alter table t set tiflash replica 1")
tb := external.GetTableByName(t, tk, "test", "t")
err = domain.GetDomain(tk.Session()).DDL().UpdateTableReplicaInfo(tk.Session(), tb.Meta().ID, true)
require.NoError(t, err)
tk.MustExec("set @@session.tidb_isolation_read_engines=\"tiflash\"")

err = tk.ExecToErr("select * from t;")
// Expect error, because TestAutoScaler return empty topo.
require.Contains(t, err.Error(), "Cannot find proper topo from AutoScaler")

err = tiflashcompute.InitGlobalTopoFetcher(tiflashcompute.AWSASStr, "", "", false)
require.NoError(t, err)
err = tk.ExecToErr("select * from t;")
// Expect error, because AWSAutoScaler is not setup, so http request will fail.
require.Contains(t, err.Error(), "[util:1815]Internal : get tiflash_compute topology failed")
}

// todo: remove this after AutoScaler is stable.
func TestDisaggregatedTiFlashNonAutoScaler(t *testing.T) {
config.UpdateGlobal(func(conf *config.Config) {
conf.DisaggregatedTiFlash = true
conf.UseAutoScaler = false
})
defer config.UpdateGlobal(func(conf *config.Config) {
conf.DisaggregatedTiFlash = false
conf.UseAutoScaler = true
})

// Setting globalTopoFetcher to nil to can make sure cannot fetch topo from AutoScaler.
err := tiflashcompute.InitGlobalTopoFetcher(tiflashcompute.InvalidASStr, "", "", false)
require.Contains(t, err.Error(), "unexpected topo fetch type. expect: mock or aws or gcp, got invalid")

store := testkit.CreateMockStore(t, withMockTiFlash(2))
tk := testkit.NewTestKit(t, store)
tk.MustExec("use test")
tk.MustExec("drop table if exists t")
tk.MustExec("create table t(c1 int)")
tk.MustExec("alter table t set tiflash replica 1")
tb := external.GetTableByName(t, tk, "test", "t")
err = domain.GetDomain(tk.Session()).DDL().UpdateTableReplicaInfo(tk.Session(), tb.Meta().ID, true)
require.NoError(t, err)
tk.MustExec("set @@session.tidb_isolation_read_engines=\"tiflash\"")

err = tk.ExecToErr("select * from t;")
// This error message means we use PD instead of AutoScaler.
require.Contains(t, err.Error(), "tiflash_compute node is unavailable")
}

func TestDisaggregatedTiFlashQuery(t *testing.T) {
config.UpdateGlobal(func(conf *config.Config) {
conf.DisaggregatedTiFlash = true
})
defer config.UpdateGlobal(func(conf *config.Config) {
conf.DisaggregatedTiFlash = false
})

store := testkit.CreateMockStore(t, withMockTiFlash(2))
tk := testkit.NewTestKit(t, store)
tk.MustExec("use test")
tk.MustExec("drop table if exists tbl_1")
tk.MustExec(`create table tbl_1 ( col_1 bigint not null default -1443635317331776148,
col_2 text ( 176 ) collate utf8mb4_bin not null,
col_3 decimal ( 8, 3 ),
col_4 varchar ( 128 ) collate utf8mb4_bin not null,
col_5 varchar ( 377 ) collate utf8mb4_bin,
col_6 double,
col_7 varchar ( 459 ) collate utf8mb4_bin,
col_8 tinyint default -88 ) charset utf8mb4 collate utf8mb4_bin ;`)
tk.MustExec("alter table tbl_1 set tiflash replica 1")
tb := external.GetTableByName(t, tk, "test", "tbl_1")
err := domain.GetDomain(tk.Session()).DDL().UpdateTableReplicaInfo(tk.Session(), tb.Meta().ID, true)
require.NoError(t, err)
tk.MustExec("set @@session.tidb_isolation_read_engines=\"tiflash\"")

tk.MustExec("explain select max( tbl_1.col_1 ) as r0 , sum( tbl_1.col_1 ) as r1 , sum( tbl_1.col_8 ) as r2 from tbl_1 where tbl_1.col_8 != 68 or tbl_1.col_3 between null and 939 order by r0,r1,r2;")

tk.MustExec("set @@tidb_partition_prune_mode = 'static';")
tk.MustExec("set @@session.tidb_isolation_read_engines=\"tiflash\"")
tk.MustExec("create table t1(c1 int, c2 int) partition by hash(c1) partitions 3")
tk.MustExec("insert into t1 values(1, 1), (2, 2), (3, 3)")
tk.MustExec("alter table t1 set tiflash replica 1")
tb = external.GetTableByName(t, tk, "test", "t1")
err = domain.GetDomain(tk.Session()).DDL().UpdateTableReplicaInfo(tk.Session(), tb.Meta().ID, true)
require.NoError(t, err)
tk.MustQuery("explain select * from t1 where c1 < 2").Check(testkit.Rows(
"PartitionUnion_10 9970.00 root ",
"├─TableReader_15 3323.33 root MppVersion: 1, data:ExchangeSender_14",
"│ └─ExchangeSender_14 3323.33 mpp[tiflash] ExchangeType: PassThrough",
"│ └─Selection_13 3323.33 mpp[tiflash] lt(test.t1.c1, 2)",
"│ └─TableFullScan_12 10000.00 mpp[tiflash] table:t1, partition:p0 keep order:false, stats:pseudo",
"├─TableReader_19 3323.33 root MppVersion: 1, data:ExchangeSender_18",
"│ └─ExchangeSender_18 3323.33 mpp[tiflash] ExchangeType: PassThrough",
"│ └─Selection_17 3323.33 mpp[tiflash] lt(test.t1.c1, 2)",
"│ └─TableFullScan_16 10000.00 mpp[tiflash] table:t1, partition:p1 keep order:false, stats:pseudo",
"└─TableReader_23 3323.33 root MppVersion: 1, data:ExchangeSender_22",
" └─ExchangeSender_22 3323.33 mpp[tiflash] ExchangeType: PassThrough",
" └─Selection_21 3323.33 mpp[tiflash] lt(test.t1.c1, 2)",
" └─TableFullScan_20 10000.00 mpp[tiflash] table:t1, partition:p2 keep order:false, stats:pseudo"))
// tk.MustQuery("select * from t1 where c1 < 2").Check(testkit.Rows("1 1"))
}

func TestMPPMemoryTracker(t *testing.T) {
store := testkit.CreateMockStore(t, withMockTiFlash(2))
tk := testkit.NewTestKit(t, store)
tk.MustExec("set tidb_mem_quota_query = 1 << 30")
tk.MustExec("set global tidb_mem_oom_action = 'CANCEL'")
tk.MustExec("use test")
tk.MustExec("create table t(a int);")
tk.MustExec("insert into t values (1);")
tk.MustExec("alter table t set tiflash replica 1")
tb := external.GetTableByName(t, tk, "test", "t")
err := domain.GetDomain(tk.Session()).DDL().UpdateTableReplicaInfo(tk.Session(), tb.Meta().ID, true)
require.NoError(t, err)
tk.MustExec("set tidb_enforce_mpp = on;")
tk.MustQuery("select * from t").Check(testkit.Rows("1"))
require.NoError(t, failpoint.Enable("github.com/pingcap/tidb/store/copr/testMPPOOMPanic", `return(true)`))
defer func() {
require.NoError(t, failpoint.Disable("github.com/pingcap/tidb/store/copr/testMPPOOMPanic"))
}()
err = tk.QueryToErr("select * from t")
require.NotNil(t, err)
require.True(t, strings.Contains(err.Error(), "Out Of Memory Quota!"))
}
>>>>>>> 07af605381 (*: add memory tracker for mppIterator (#40901)):executor/tiflashtest/tiflash_test.go
9 changes: 9 additions & 0 deletions kv/mpp.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,11 @@ import (
"time"

"github.com/pingcap/kvproto/pkg/mpp"
<<<<<<< HEAD
=======
"github.com/pingcap/tidb/util/memory"
"github.com/pingcap/tipb/go-tipb"
>>>>>>> 07af605381 (*: add memory tracker for mppIterator (#40901))
)

// MPPTaskMeta means the meta info such as location of a mpp task.
Expand Down Expand Up @@ -83,7 +88,11 @@ type MPPClient interface {
ConstructMPPTasks(context.Context, *MPPBuildTasksRequest, map[string]time.Time, time.Duration) ([]MPPTaskMeta, error)

// DispatchMPPTasks dispatches ALL mpp requests at once, and returns an iterator that transfers the data.
<<<<<<< HEAD
DispatchMPPTasks(ctx context.Context, vars interface{}, reqs []*MPPDispatchRequest, needTriggerFallback bool, startTs uint64) Response
=======
DispatchMPPTasks(ctx context.Context, vars interface{}, reqs []*MPPDispatchRequest, needTriggerFallback bool, startTs uint64, mppQueryID MPPQueryID, mppVersion MppVersion, memTracker *memory.Tracker) Response
>>>>>>> 07af605381 (*: add memory tracker for mppIterator (#40901))
}

// MPPBuildTasksRequest request the stores allocation for a mpp plan fragment.
Expand Down
29 changes: 29 additions & 0 deletions store/copr/mpp.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ package copr

import (
"context"
"fmt"
"io"
"strconv"
"sync"
Expand All @@ -34,6 +35,7 @@ import (
derr "github.com/pingcap/tidb/store/driver/error"
"github.com/pingcap/tidb/util/logutil"
"github.com/pingcap/tidb/util/mathutil"
"github.com/pingcap/tidb/util/memory"
"github.com/tikv/client-go/v2/tikv"
"github.com/tikv/client-go/v2/tikvrpc"
"go.uber.org/zap"
Expand Down Expand Up @@ -159,6 +161,8 @@ type mppIterator struct {
mu sync.Mutex

enableCollectExecutionInfo bool

memTracker *memory.Tracker
}

func (m *mppIterator) run(ctx context.Context) {
Expand Down Expand Up @@ -196,6 +200,22 @@ func (m *mppIterator) sendError(err error) {
}

func (m *mppIterator) sendToRespCh(resp *mppResponse) (exit bool) {
defer func() {
if r := recover(); r != nil {
logutil.BgLogger().Error("mppIterator panic", zap.Stack("stack"), zap.Any("recover", r))
m.sendError(errors.New(fmt.Sprint(r)))
}
}()
if m.memTracker != nil {
respSize := resp.MemSize()
failpoint.Inject("testMPPOOMPanic", func(val failpoint.Value) {
if val.(bool) && respSize != 0 {
respSize = 1 << 30
}
})
m.memTracker.Consume(respSize)
defer m.memTracker.Consume(-respSize)
}
select {
case m.respChan <- resp:
case <-m.finishCh:
Expand Down Expand Up @@ -502,7 +522,11 @@ func (m *mppIterator) Next(ctx context.Context) (kv.ResultSubset, error) {
}

// DispatchMPPTasks dispatches all the mpp task and waits for the responses.
<<<<<<< HEAD
func (c *MPPClient) DispatchMPPTasks(ctx context.Context, variables interface{}, dispatchReqs []*kv.MPPDispatchRequest, needTriggerFallback bool, startTs uint64) kv.Response {
=======
func (c *MPPClient) DispatchMPPTasks(ctx context.Context, variables interface{}, dispatchReqs []*kv.MPPDispatchRequest, needTriggerFallback bool, startTs uint64, mppQueryID kv.MPPQueryID, mppVersion kv.MppVersion, memTracker *memory.Tracker) kv.Response {
>>>>>>> 07af605381 (*: add memory tracker for mppIterator (#40901))
vars := variables.(*tikv.Variables)
ctxChild, cancelFunc := context.WithCancel(ctx)
iter := &mppIterator{
Expand All @@ -514,7 +538,12 @@ func (c *MPPClient) DispatchMPPTasks(ctx context.Context, variables interface{},
startTs: startTs,
vars: vars,
needTriggerFallback: needTriggerFallback,
<<<<<<< HEAD
enableCollectExecutionInfo: config.GetGlobalConfig().Instance.EnableCollectExecutionInfo,
=======
enableCollectExecutionInfo: config.GetGlobalConfig().Instance.EnableCollectExecutionInfo.Load(),
memTracker: memTracker,
>>>>>>> 07af605381 (*: add memory tracker for mppIterator (#40901))
}
go iter.run(ctxChild)
return iter
Expand Down

0 comments on commit 1e1832f

Please sign in to comment.