diff --git a/distsql/distsql.go b/distsql/distsql.go index 9c8d8e4e63384..83236ce4dc5f4 100644 --- a/distsql/distsql.go +++ b/distsql/distsql.go @@ -30,7 +30,7 @@ import ( // DispatchMPPTasks dispathes all tasks and returns an iterator. func DispatchMPPTasks(ctx context.Context, sctx sessionctx.Context, tasks []*kv.MPPDispatchRequest, fieldTypes []*types.FieldType, planIDs []int, rootID int) (SelectResult, error) { - resp := sctx.GetMPPClient().DispatchMPPTasks(ctx, tasks) + resp := sctx.GetMPPClient().DispatchMPPTasks(ctx, sctx.GetSessionVars().KVVars, tasks) if resp == nil { err := errors.New("client returns nil response") return nil, err diff --git a/executor/mpp_gather.go b/executor/mpp_gather.go index f29346a62278f..a8628330e16ef 100644 --- a/executor/mpp_gather.go +++ b/executor/mpp_gather.go @@ -82,6 +82,7 @@ func (e *MPPGather) appendMPPDispatchReq(pf *plannercore.Fragment, tasks []*kv.M Timeout: 10, SchemaVar: e.is.SchemaMetaVersion(), StartTs: e.startTS, + State: kv.MppTaskReady, } e.mppReqs = append(e.mppReqs, req) } diff --git a/executor/tiflash_test.go b/executor/tiflash_test.go index 1d628375a9c82..20679644dabb6 100644 --- a/executor/tiflash_test.go +++ b/executor/tiflash_test.go @@ -15,18 +15,25 @@ package executor_test import ( "fmt" + "sync" + "sync/atomic" + "time" . "github.com/pingcap/check" + "github.com/pingcap/errors" "github.com/pingcap/failpoint" "github.com/pingcap/kvproto/pkg/metapb" "github.com/pingcap/parser" + "github.com/pingcap/parser/terror" "github.com/pingcap/tidb/domain" + "github.com/pingcap/tidb/executor" "github.com/pingcap/tidb/kv" "github.com/pingcap/tidb/session" "github.com/pingcap/tidb/store/mockstore" "github.com/pingcap/tidb/store/mockstore/unistore" "github.com/pingcap/tidb/store/tikv/mockstore/cluster" "github.com/pingcap/tidb/util/testkit" + "github.com/pingcap/tidb/util/testleak" ) type tiflashTestSuite struct { @@ -36,6 +43,7 @@ type tiflashTestSuite struct { } func (s *tiflashTestSuite) SetUpSuite(c *C) { + testleak.BeforeTest() var err error s.store, err = mockstore.NewMockStore( mockstore.WithClusterInspector(func(c cluster.Cluster) { @@ -271,3 +279,74 @@ func (s *tiflashTestSuite) TestPartitionTable(c *C) { failpoint.Disable("github.com/pingcap/tidb/executor/checkTotalMPPTasks") failpoint.Disable("github.com/pingcap/tidb/executor/checkUseMPP") } + +func (s *tiflashTestSuite) TestCancelMppTasks(c *C) { + var hang = "github.com/pingcap/tidb/store/mockstore/unistore/mppRecvHang" + tk := testkit.NewTestKit(c, s.store) + tk.MustExec("use test") + tk.MustExec("drop table if exists t") + tk.MustExec("create table t(a int not null primary key, b int not null)") + tk.MustExec("alter table t set tiflash replica 1") + tk.MustExec("insert into t values(1,0)") + tk.MustExec("insert into t values(2,0)") + tk.MustExec("insert into t values(3,0)") + tk.MustExec("insert into t values(4,0)") + tb := testGetTableByName(c, tk.Se, "test", "t") + err := domain.GetDomain(tk.Se).DDL().UpdateTableReplicaInfo(tk.Se, tb.Meta().ID, true) + c.Assert(err, IsNil) + tk.MustExec("set @@session.tidb_isolation_read_engines=\"tiflash\"") + tk.MustExec("set @@session.tidb_allow_mpp=ON") + atomic.StoreUint32(&tk.Se.GetSessionVars().Killed, 0) + c.Assert(failpoint.Enable(hang, `return(true)`), IsNil) + wg := &sync.WaitGroup{} + wg.Add(1) + go func() { + defer wg.Done() + err := tk.QueryToErr("select count(*) from t as t1 , t where t1.a = t.a") + c.Assert(err, NotNil) + c.Assert(int(terror.ToSQLError(errors.Cause(err).(*terror.Error)).Code), Equals, int(executor.ErrQueryInterrupted.Code())) + }() + time.Sleep(1 * time.Second) + atomic.StoreUint32(&tk.Se.GetSessionVars().Killed, 1) + wg.Wait() + c.Assert(failpoint.Disable(hang), IsNil) +} + +// all goroutines exit if one goroutine hangs but another return errors +func (s *tiflashTestSuite) TestMppGoroutinesExitFromErrors(c *C) { + defer testleak.AfterTest(c)() + // mock non-root tasks return error + var mppNonRootTaskError = "github.com/pingcap/tidb/store/copr/mppNonRootTaskError" + // mock root tasks hang + var hang = "github.com/pingcap/tidb/store/mockstore/unistore/mppRecvHang" + tk := testkit.NewTestKit(c, s.store) + tk.MustExec("use test") + tk.MustExec("drop table if exists t") + tk.MustExec("create table t(a int not null primary key, b int not null)") + tk.MustExec("alter table t set tiflash replica 1") + tb := testGetTableByName(c, tk.Se, "test", "t") + err := domain.GetDomain(tk.Se).DDL().UpdateTableReplicaInfo(tk.Se, tb.Meta().ID, true) + c.Assert(err, IsNil) + tk.MustExec("insert into t values(1,0)") + tk.MustExec("insert into t values(2,0)") + tk.MustExec("insert into t values(3,0)") + tk.MustExec("drop table if exists t1") + tk.MustExec("create table t1(a int not null primary key, b int not null)") + tk.MustExec("alter table t1 set tiflash replica 1") + tb = testGetTableByName(c, tk.Se, "test", "t1") + err = domain.GetDomain(tk.Se).DDL().UpdateTableReplicaInfo(tk.Se, tb.Meta().ID, true) + c.Assert(err, IsNil) + tk.MustExec("insert into t1 values(1,0)") + tk.MustExec("insert into t1 values(2,0)") + tk.MustExec("insert into t1 values(3,0)") + tk.MustExec("set @@session.tidb_isolation_read_engines=\"tiflash\"") + tk.MustExec("set @@session.tidb_allow_mpp=ON") + c.Assert(failpoint.Enable(mppNonRootTaskError, `return(true)`), IsNil) + c.Assert(failpoint.Enable(hang, `return(true)`), IsNil) + + // generate 2 root tasks, one will hang and another will return errors + err = tk.QueryToErr("select count(*) from t as t1 , t where t1.a = t.a") + c.Assert(err, NotNil) + c.Assert(failpoint.Disable(mppNonRootTaskError), IsNil) + c.Assert(failpoint.Disable(hang), IsNil) +} diff --git a/go.mod b/go.mod index 161178fe0ab9b..5a53c7b55cb36 100644 --- a/go.mod +++ b/go.mod @@ -31,7 +31,7 @@ require ( github.com/kr/text v0.2.0 // indirect github.com/ngaut/pools v0.0.0-20180318154953-b7bc8c42aac7 github.com/ngaut/sync2 v0.0.0-20141008032647-7a24ed77b2ef - github.com/ngaut/unistore v0.0.0-20210304095907-0ebafaf44efb + github.com/ngaut/unistore v0.0.0-20210310131351-7ad6a204de87 github.com/niemeyer/pretty v0.0.0-20200227124842-a10e7caefd8e // indirect github.com/opentracing/basictracer-go v1.0.0 github.com/opentracing/opentracing-go v1.1.0 diff --git a/go.sum b/go.sum index 7e76c97bf9a1f..0636203d83bd1 100644 --- a/go.sum +++ b/go.sum @@ -353,8 +353,8 @@ github.com/ngaut/pools v0.0.0-20180318154953-b7bc8c42aac7 h1:7KAv7KMGTTqSmYZtNdc github.com/ngaut/pools v0.0.0-20180318154953-b7bc8c42aac7/go.mod h1:iWMfgwqYW+e8n5lC/jjNEhwcjbRDpl5NT7n2h+4UNcI= github.com/ngaut/sync2 v0.0.0-20141008032647-7a24ed77b2ef h1:K0Fn+DoFqNqktdZtdV3bPQ/0cuYh2H4rkg0tytX/07k= github.com/ngaut/sync2 v0.0.0-20141008032647-7a24ed77b2ef/go.mod h1:7WjlapSfwQyo6LNmIvEWzsW1hbBQfpUO4JWnuQRmva8= -github.com/ngaut/unistore v0.0.0-20210304095907-0ebafaf44efb h1:2rGvEhflp/uK1l1rNUmoHA4CiHpbddHGxg52H71Fke8= -github.com/ngaut/unistore v0.0.0-20210304095907-0ebafaf44efb/go.mod h1:ZR3NH+HzqfiYetwdoAivApnIy8iefPZHTMLfrFNm8g4= +github.com/ngaut/unistore v0.0.0-20210310131351-7ad6a204de87 h1:lVRrhmqIT2zMbmoalrgxQLwWzFd3VtFaaWy0fnMwPro= +github.com/ngaut/unistore v0.0.0-20210310131351-7ad6a204de87/go.mod h1:ZR3NH+HzqfiYetwdoAivApnIy8iefPZHTMLfrFNm8g4= github.com/nicksnyder/go-i18n v1.10.0/go.mod h1:HrK7VCrbOvQoUAQ7Vpy7i87N7JZZZ7R2xBGjv0j365Q= github.com/niemeyer/pretty v0.0.0-20200227124842-a10e7caefd8e h1:fD57ERR4JtEqsWbfPhv4DMiApHyliiK5xCTNVSPiaAs= github.com/niemeyer/pretty v0.0.0-20200227124842-a10e7caefd8e/go.mod h1:zD1mROLANZcx1PVRCS0qkT7pwLkGfwJo4zjcN/Tysno= @@ -844,6 +844,7 @@ honnef.co/go/tools v0.0.0-20190102054323-c2f93a96b099/go.mod h1:rf3lG4BRIbNafJWh honnef.co/go/tools v0.0.0-20190106161140-3f1c8253044a/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4= honnef.co/go/tools v0.0.0-20190418001031-e561f6794a2a/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4= honnef.co/go/tools v0.0.0-20190523083050-ea95bdfd59fc/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4= +honnef.co/go/tools v0.0.1-2019.2.3 h1:3JgtbtFHMiCmsznwGVTUWbgGov+pVqnlf1dEJTNAXeM= honnef.co/go/tools v0.0.1-2019.2.3/go.mod h1:a3bituU0lyd329TUQxRnasdCoJDkEUEAqEt0JzvZhAg= honnef.co/go/tools v0.1.3 h1:qTakTkI6ni6LFD5sBwwsdSO+AQqbSIxOauHTTQKZ/7o= honnef.co/go/tools v0.1.3/go.mod h1:NgwopIslSNH47DimFoV78dnkksY2EFtX0ajyb3K/las= diff --git a/kv/mpp.go b/kv/mpp.go index f275346e2d477..df38894d0c5d1 100644 --- a/kv/mpp.go +++ b/kv/mpp.go @@ -45,6 +45,20 @@ func (t *MPPTask) ToPB() *mpp.TaskMeta { return meta } +//MppTaskStates denotes the state of mpp tasks +type MppTaskStates uint8 + +const ( + // MppTaskReady means the task is ready + MppTaskReady MppTaskStates = iota + // MppTaskRunning means the task is running + MppTaskRunning + // MppTaskCancelled means the task is cancelled + MppTaskCancelled + // MppTaskDone means the task is done + MppTaskDone +) + // MPPDispatchRequest stands for a dispatching task. type MPPDispatchRequest struct { Data []byte // data encodes the dag coprocessor request. @@ -55,6 +69,7 @@ type MPPDispatchRequest struct { SchemaVar int64 StartTs uint64 ID int64 // identify a single task + State MppTaskStates } // MPPClient accepts and processes mpp requests. @@ -64,7 +79,7 @@ type MPPClient interface { ConstructMPPTasks(context.Context, *MPPBuildTasksRequest) ([]MPPTaskMeta, error) // DispatchMPPTasks dispatches ALL mpp requests at once, and returns an iterator that transfers the data. - DispatchMPPTasks(context.Context, []*MPPDispatchRequest) Response + DispatchMPPTasks(context.Context, *Variables, []*MPPDispatchRequest) Response } // MPPBuildTasksRequest request the stores allocation for a mpp plan fragment. diff --git a/store/copr/mpp.go b/store/copr/mpp.go index da008937f65cf..8a6bd3c108b38 100644 --- a/store/copr/mpp.go +++ b/store/copr/mpp.go @@ -21,6 +21,7 @@ import ( "time" "github.com/pingcap/errors" + "github.com/pingcap/failpoint" "github.com/pingcap/kvproto/pkg/coprocessor" "github.com/pingcap/kvproto/pkg/kvrpcpb" "github.com/pingcap/kvproto/pkg/metapb" @@ -123,15 +124,31 @@ type mppIterator struct { respChan chan *mppResponse - rpcCancel *tikv.RPCCanceller + cancelFunc context.CancelFunc wg sync.WaitGroup closed uint32 + + vars *kv.Variables + + mu sync.Mutex } func (m *mppIterator) run(ctx context.Context) { for _, task := range m.tasks { + if atomic.LoadUint32(&m.closed) == 1 { + break + } + m.mu.Lock() + switch task.State { + case kv.MppTaskReady: + task.State = kv.MppTaskRunning + m.mu.Unlock() + default: + m.mu.Unlock() + break + } m.wg.Add(1) bo := tikv.NewBackoffer(ctx, copNextMaxBackoff) go m.handleDispatchReq(ctx, bo, task) @@ -142,6 +159,7 @@ func (m *mppIterator) run(ctx context.Context) { func (m *mppIterator) sendError(err error) { m.sendToRespCh(&mppResponse{err: err}) + m.cancelMppTasks() } func (m *mppIterator) sendToRespCh(resp *mppResponse) (exit bool) { @@ -223,7 +241,13 @@ func (m *mppIterator) handleDispatchReq(ctx context.Context, bo *tikv.Backoffer, m.sendError(errors.New(realResp.Error.Msg)) return } - + failpoint.Inject("mppNonRootTaskError", func(val failpoint.Value) { + if val.(bool) && !req.IsRoot { + time.Sleep(1 * time.Second) + m.sendError(tikv.ErrTiFlashServerTimeout) + return + } + }) if !req.IsRoot { return } @@ -231,6 +255,39 @@ func (m *mppIterator) handleDispatchReq(ctx context.Context, bo *tikv.Backoffer, m.establishMPPConns(bo, req, taskMeta) } +// NOTE: We do not retry here, because retry is helpless when errors result from TiFlash or Network. If errors occur, the execution on TiFlash will finally stop after some minutes. +// This function is exclusively called, and only the first call succeeds sending tasks and setting all tasks as cancelled, while others will not work. +func (m *mppIterator) cancelMppTasks() { + m.mu.Lock() + defer m.mu.Unlock() + killReq := &mpp.CancelTaskRequest{ + Meta: &mpp.TaskMeta{StartTs: m.startTs}, + } + + wrappedReq := tikvrpc.NewRequest(tikvrpc.CmdMPPCancel, killReq, kvrpcpb.Context{}) + wrappedReq.StoreTp = kv.TiFlash + + usedStoreAddrs := make(map[string]bool) + for _, task := range m.tasks { + // get the store address of running tasks + if task.State == kv.MppTaskRunning && !usedStoreAddrs[task.Meta.GetAddress()] { + usedStoreAddrs[task.Meta.GetAddress()] = true + } else if task.State == kv.MppTaskCancelled { + return + } + task.State = kv.MppTaskCancelled + } + + // send cancel cmd to all stores where tasks run + for addr := range usedStoreAddrs { + _, err := m.store.GetTiKVClient().SendRequest(context.Background(), addr, wrappedReq, tikv.ReadTimeoutUltraLong) + logutil.BgLogger().Debug("cancel task ", zap.Uint64("query id ", m.startTs), zap.String(" on addr ", addr)) + if err != nil { + logutil.BgLogger().Error("cancel task error: ", zap.Error(err), zap.Uint64(" for query id ", m.startTs), zap.String(" on addr ", addr)) + } + } +} + func (m *mppIterator) establishMPPConns(bo *tikv.Backoffer, req *kv.MPPDispatchRequest, taskMeta *mpp.TaskMeta) { connReq := &mpp.EstablishMPPConnectionRequest{ SenderMeta: taskMeta, @@ -260,13 +317,13 @@ func (m *mppIterator) establishMPPConns(bo *tikv.Backoffer, req *kv.MPPDispatchR return } - // TODO: cancel the whole process when some error happens for { err := m.handleMPPStreamResponse(bo, resp, req) if err != nil { m.sendError(err) return } + resp, err = stream.Recv() if err != nil { if errors.Cause(err) == io.EOF { @@ -280,9 +337,7 @@ func (m *mppIterator) establishMPPConns(bo *tikv.Backoffer, req *kv.MPPDispatchR logutil.BgLogger().Info("stream unknown error", zap.Error(err)) } } - m.sendToRespCh(&mppResponse{ - err: tikv.ErrTiFlashServerTimeout, - }) + m.sendError(tikv.ErrTiFlashServerTimeout) return } } @@ -293,7 +348,7 @@ func (m *mppIterator) Close() error { if atomic.CompareAndSwapUint32(&m.closed, 0, 1) { close(m.finishCh) } - m.rpcCancel.CancelAll() + m.cancelFunc() m.wg.Wait() return nil } @@ -336,7 +391,11 @@ func (m *mppIterator) nextImpl(ctx context.Context) (resp *mppResponse, ok bool, case resp, ok = <-m.respChan: return case <-ticker.C: - //TODO: kill query + if m.vars != nil && m.vars.Killed != nil && atomic.LoadUint32(m.vars.Killed) == 1 { + err = tikv.ErrQueryInterrupted + exit = true + return + } case <-m.finishCh: exit = true return @@ -370,19 +429,18 @@ func (m *mppIterator) Next(ctx context.Context) (kv.ResultSubset, error) { return resp, nil } -// DispatchMPPTasks dispatches all the mpp task and waits for the reponses. -func (c *MPPClient) DispatchMPPTasks(ctx context.Context, dispatchReqs []*kv.MPPDispatchRequest) kv.Response { +// DispatchMPPTasks dispatches all the mpp task and waits for the responses. +func (c *MPPClient) DispatchMPPTasks(ctx context.Context, vars *kv.Variables, dispatchReqs []*kv.MPPDispatchRequest) kv.Response { + ctxChild, cancelFunc := context.WithCancel(ctx) iter := &mppIterator{ - store: c.store, - tasks: dispatchReqs, - finishCh: make(chan struct{}), - rpcCancel: tikv.NewRPCanceller(), - respChan: make(chan *mppResponse, 4096), - startTs: dispatchReqs[0].StartTs, + store: c.store, + tasks: dispatchReqs, + finishCh: make(chan struct{}), + cancelFunc: cancelFunc, + respChan: make(chan *mppResponse, 4096), + startTs: dispatchReqs[0].StartTs, + vars: vars, } - ctx = context.WithValue(ctx, tikv.RPCCancellerCtxKey{}, iter.rpcCancel) - - // TODO: Process the case of query cancellation. - go iter.run(ctx) + go iter.run(ctxChild) return iter } diff --git a/store/mockstore/unistore/cophandler/cop_handler.go b/store/mockstore/unistore/cophandler/cop_handler.go index 685804f12314e..c119a49db16e2 100644 --- a/store/mockstore/unistore/cophandler/cop_handler.go +++ b/store/mockstore/unistore/cophandler/cop_handler.go @@ -15,6 +15,7 @@ package cophandler import ( "bytes" + "context" "fmt" "time" @@ -46,6 +47,7 @@ type MPPCtx struct { RPCClient client.Client StoreAddr string TaskHandler *MPPTaskHandler + Ctx context.Context } // HandleCopRequest handles coprocessor request. diff --git a/store/mockstore/unistore/cophandler/mpp_exec.go b/store/mockstore/unistore/cophandler/mpp_exec.go index f586230a4ea53..8da2ed1085f1d 100644 --- a/store/mockstore/unistore/cophandler/mpp_exec.go +++ b/store/mockstore/unistore/cophandler/mpp_exec.go @@ -14,7 +14,6 @@ package cophandler import ( - "context" "io" "math" "sync" @@ -258,7 +257,7 @@ func (e *exchRecvExec) next() (*chunk.Chunk, error) { func (e *exchRecvExec) EstablishConnAndReceiveData(h *MPPTaskHandler, meta *mpp.TaskMeta) ([]*mpp.MPPDataPacket, error) { req := &mpp.EstablishMPPConnectionRequest{ReceiverMeta: h.Meta, SenderMeta: meta} rpcReq := tikvrpc.NewRequest(tikvrpc.CmdMPPConn, req, kvrpcpb.Context{}) - rpcResp, err := h.RPCClient.SendRequest(context.Background(), meta.Address, rpcReq, 3600*time.Second) + rpcResp, err := h.RPCClient.SendRequest(e.mppCtx.Ctx, meta.Address, rpcReq, 3600*time.Second) if err != nil { return nil, errors.Trace(err) } diff --git a/store/mockstore/unistore/rpc.go b/store/mockstore/unistore/rpc.go index 52bdc5e34a513..72f36eb239abe 100644 --- a/store/mockstore/unistore/rpc.go +++ b/store/mockstore/unistore/rpc.go @@ -248,6 +248,7 @@ func (c *RPCClient) SendRequest(ctx context.Context, addr string, req *tikvrpc.R } }) resp.Resp, err = c.handleDispatchMPPTask(ctx, req.DispatchMPPTask(), storeID) + case tikvrpc.CmdMPPCancel: case tikvrpc.CmdMvccGetByKey: resp.Resp, err = c.usSvr.MvccGetByKey(ctx, req.MvccGetByKey()) case tikvrpc.CmdMvccGetByStartTs: @@ -297,7 +298,7 @@ func (c *RPCClient) handleEstablishMPPConnection(ctx context.Context, r *mpp.Est if err != nil { return nil, err } - var mockClient = mockMPPConnectionClient{mppResponses: mockServer.mppResponses, idx: 0, targetTask: r.ReceiverMeta} + var mockClient = mockMPPConnectionClient{mppResponses: mockServer.mppResponses, idx: 0, ctx: ctx, targetTask: r.ReceiverMeta} streamResp := &tikvrpc.MPPStreamResponse{Tikv_EstablishMPPConnectionClient: &mockClient} _, cancel := context.WithCancel(ctx) streamResp.Lease.Cancel = cancel @@ -472,8 +473,8 @@ type mockMPPConnectionClient struct { mockClientStream mppResponses []*mpp.MPPDataPacket idx int - - targetTask *mpp.TaskMeta + ctx context.Context + targetTask *mpp.TaskMeta } func (mock *mockMPPConnectionClient) Recv() (*mpp.MPPDataPacket, error) { @@ -487,6 +488,18 @@ func (mock *mockMPPConnectionClient) Recv() (*mpp.MPPDataPacket, error) { failpoint.Return(nil, context.Canceled) } }) + failpoint.Inject("mppRecvHang", func(val failpoint.Value) { + for val.(bool) { + select { + case <-mock.ctx.Done(): + { + failpoint.Return(nil, context.Canceled) + } + default: + time.Sleep(1 * time.Second) + } + } + }) return nil, io.EOF } diff --git a/store/tikv/tikvrpc/tikvrpc.go b/store/tikv/tikvrpc/tikvrpc.go index c695f8ecb3827..4975719c1238a 100644 --- a/store/tikv/tikvrpc/tikvrpc.go +++ b/store/tikv/tikvrpc/tikvrpc.go @@ -73,6 +73,7 @@ const ( CmdBatchCop CmdMPPTask CmdMPPConn + CmdMPPCancel CmdMvccGetByKey CmdType = 1024 + iota CmdMvccGetByStartTs @@ -147,6 +148,8 @@ func (t CmdType) String() string { return "DispatchMPPTask" case CmdMPPConn: return "EstablishMPPConnection" + case CmdMPPCancel: + return "CancelMPPTask" case CmdMvccGetByKey: return "MvccGetByKey" case CmdMvccGetByStartTs: @@ -339,6 +342,11 @@ func (req *Request) EstablishMPPConn() *mpp.EstablishMPPConnectionRequest { return req.Req.(*mpp.EstablishMPPConnectionRequest) } +// CancelMPPTask returns canceling task in request +func (req *Request) CancelMPPTask() *mpp.CancelTaskRequest { + return req.Req.(*mpp.CancelTaskRequest) +} + // MvccGetByKey returns MvccGetByKeyRequest in request. func (req *Request) MvccGetByKey() *kvrpcpb.MvccGetByKeyRequest { return req.Req.(*kvrpcpb.MvccGetByKeyRequest) @@ -871,6 +879,9 @@ func CallRPC(ctx context.Context, client tikvpb.TikvClient, req *Request) (*Resp resp.Resp = &MPPStreamResponse{ Tikv_EstablishMPPConnectionClient: streamClient, } + case CmdMPPCancel: + // it cannot use the ctx with cancel(), otherwise this cmd will fail. + resp.Resp, err = client.CancelMPPTask(ctx, req.CancelMPPTask()) case CmdCopStream: var streamClient tikvpb.Tikv_CoprocessorStreamClient streamClient, err = client.CoprocessorStream(ctx, req.Cop())