diff --git a/dm/worker/relay.go b/dm/worker/relay.go index 36384ac518..5eda3d9674 100644 --- a/dm/worker/relay.go +++ b/dm/worker/relay.go @@ -381,7 +381,11 @@ func (h *realRelayHolder) Migrate(ctx context.Context, binlogName string, binlog /******************** dummy relay holder ********************/ type dummyRelayHolder struct { - initError error + sync.RWMutex + + initError error + stage pb.Stage + relayBinlog string cfg *Config } @@ -389,7 +393,8 @@ type dummyRelayHolder struct { // NewDummyRelayHolder creates a new RelayHolder func NewDummyRelayHolder(cfg *Config) RelayHolder { return &dummyRelayHolder{ - cfg: cfg, + cfg: cfg, + stage: pb.Stage_New, } } @@ -398,6 +403,16 @@ func NewDummyRelayHolderWithInitError(cfg *Config) RelayHolder { return &dummyRelayHolder{ initError: errors.New("init error"), cfg: cfg, + stage: pb.Stage_New, + } +} + +// NewDummyRelayHolderWithRelayBinlog creates a new RelayHolder with relayBinlog in relayStatus +func NewDummyRelayHolderWithRelayBinlog(cfg *Config, relayBinlog string) RelayHolder { + return &dummyRelayHolder{ + cfg: cfg, + relayBinlog: relayBinlog, + stage: pb.Stage_New, } } @@ -419,7 +434,12 @@ func (d *dummyRelayHolder) Close() {} // Status implements interface of RelayHolder func (d *dummyRelayHolder) Status() *pb.RelayStatus { - return nil + d.RLock() + defer d.RUnlock() + return &pb.RelayStatus{ + Stage: d.stage, + RelayBinlog: d.relayBinlog, + } } // Error implements interface of RelayHolder diff --git a/dm/worker/subtask.go b/dm/worker/subtask.go index 66d24ce84c..1b3403e1fc 100644 --- a/dm/worker/subtask.go +++ b/dm/worker/subtask.go @@ -68,9 +68,13 @@ type SubTask struct { l log.Logger sync.RWMutex - wg sync.WaitGroup + wg sync.WaitGroup + // ctx is used for the whole subtask. It will be created only when we new a subtask. ctx context.Context cancel context.CancelFunc + // currCtx is used for one loop. It will be created each time we use st.run/st.Resume + currCtx context.Context + currCancel context.CancelFunc units []unit.Unit // units do job one by one currUnit unit.Unit @@ -92,11 +96,14 @@ func NewSubTask(cfg *config.SubTaskConfig) *SubTask { // NewSubTaskWithStage creates a new SubTask with stage func NewSubTaskWithStage(cfg *config.SubTaskConfig, stage pb.Stage) *SubTask { + ctx, cancel := context.WithCancel(context.Background()) st := SubTask{ cfg: cfg, units: createUnits(cfg), stage: stage, l: log.With(zap.String("subtask", cfg.Name)), + ctx: ctx, + cancel: cancel, DDLInfo: make(chan *pb.DDLInfo, 1), } taskState.WithLabelValues(st.cfg.Name).Set(float64(st.stage)) @@ -180,26 +187,45 @@ func (st *SubTask) Run() { } func (st *SubTask) run() { - st.setStage(pb.Stage_Paused) - err := st.unitTransWaitCondition() + st.setStage(pb.Stage_Running) + ctx, cancel := context.WithCancel(st.ctx) + st.setCurrCtx(ctx, cancel) + err := st.unitTransWaitCondition(ctx) if err != nil { st.l.Error("wait condition", log.ShortError(err)) st.fail(err) return + } else if ctx.Err() != nil { + return } - st.setStage(pb.Stage_Running) st.setResult(nil) // clear previous result cu := st.CurrUnit() st.l.Info("start to run", zap.Stringer("unit", cu.Type())) - st.ctx, st.cancel = context.WithCancel(context.Background()) pr := make(chan pb.ProcessResult, 1) st.wg.Add(1) go st.fetchResult(pr) - go cu.Process(st.ctx, pr) + go cu.Process(ctx, pr) st.wg.Add(1) - go st.fetchUnitDDLInfo(st.ctx) + go st.fetchUnitDDLInfo(ctx) +} + +func (st *SubTask) setCurrCtx(ctx context.Context, cancel context.CancelFunc) { + st.Lock() + // call previous cancel func for safety + if st.currCancel != nil { + st.currCancel() + } + st.currCtx = ctx + st.currCancel = cancel + st.Unlock() +} + +func (st *SubTask) callCurrCancel() { + st.RLock() + st.currCancel() + st.RUnlock() } // fetchResult fetches units process result @@ -207,12 +233,16 @@ func (st *SubTask) run() { func (st *SubTask) fetchResult(pr chan pb.ProcessResult) { defer st.wg.Done() + st.RLock() + ctx := st.currCtx + st.RUnlock() + select { - case <-st.ctx.Done(): + case <-ctx.Done(): return case result := <-pr: st.setResult(&result) // save result - st.cancel() // dm-unit finished, canceled or error occurred, always cancel processing + st.callCurrCancel() // dm-unit finished, canceled or error occurred, always cancel processing if len(result.Errors) == 0 && st.Stage() == pb.Stage_Paused { return // paused by external request @@ -381,16 +411,16 @@ func (st *SubTask) Result() *pb.ProcessResult { // Close stops the sub task func (st *SubTask) Close() { st.l.Info("closing") - if st.cancel == nil { - st.l.Info("not run yet, no need to close") + if st.Stage() == pb.Stage_Stopped { + st.l.Info("subTask is already closed, no need to close") return } st.cancel() st.closeUnits() // close all un-closed units - st.setStageIfNot(pb.Stage_Finished, pb.Stage_Stopped) st.removeLabelValuesWithTaskInMetrics(st.cfg.Name) st.wg.Wait() + st.setStageIfNot(pb.Stage_Finished, pb.Stage_Stopped) } // Pause pauses the running sub task @@ -399,7 +429,7 @@ func (st *SubTask) Pause() error { return terror.ErrWorkerNotRunningStage.Generate() } - st.cancel() + st.callCurrCancel() st.wg.Wait() // wait fetchResult return cu := st.CurrUnit() @@ -417,30 +447,32 @@ func (st *SubTask) Resume() error { return nil } + if !st.stageCAS(pb.Stage_Paused, pb.Stage_Running) { + return terror.ErrWorkerNotPausedStage.Generate() + } + ctx, cancel := context.WithCancel(st.ctx) + st.setCurrCtx(ctx, cancel) // NOTE: this may block if user resume a task - err := st.unitTransWaitCondition() + err := st.unitTransWaitCondition(ctx) if err != nil { st.l.Error("wait condition", log.ShortError(err)) st.setStage(pb.Stage_Paused) return err - } - - if !st.stageCAS(pb.Stage_Paused, pb.Stage_Running) { - return terror.ErrWorkerNotPausedStage.Generate() + } else if ctx.Err() != nil { + return nil } st.setResult(nil) // clear previous result cu := st.CurrUnit() st.l.Info("resume with unit", zap.Stringer("unit", cu.Type())) - st.ctx, st.cancel = context.WithCancel(context.Background()) pr := make(chan pb.ProcessResult, 1) st.wg.Add(1) go st.fetchResult(pr) - go cu.Resume(st.ctx, pr) + go cu.Resume(ctx, pr) st.wg.Add(1) - go st.fetchUnitDDLInfo(st.ctx) + go st.fetchUnitDDLInfo(ctx) return nil } @@ -621,7 +653,7 @@ func (st *SubTask) ClearDDLInfo() { // unitTransWaitCondition waits when transferring from current unit to next unit. // Currently there is only one wait condition // from Load unit to Sync unit, wait for relay-log catched up with mydumper binlog position. -func (st *SubTask) unitTransWaitCondition() error { +func (st *SubTask) unitTransWaitCondition(subTaskCtx context.Context) error { pu := st.PrevUnit() cu := st.CurrUnit() if pu != nil && pu.Type() == pb.UnitType_Load && cu.Type() == pb.UnitType_Sync { @@ -650,6 +682,8 @@ func (st *SubTask) unitTransWaitCondition() error { select { case <-ctx.Done(): return terror.ErrWorkerWaitRelayCatchupTimeout.Generate(waitRelayCatchupTimeout, pos1, pos2) + case <-subTaskCtx.Done(): + return nil case <-time.After(time.Millisecond * 50): } } diff --git a/dm/worker/subtask_test.go b/dm/worker/subtask_test.go index fa06a9d5cc..aecc134a88 100644 --- a/dm/worker/subtask_test.go +++ b/dm/worker/subtask_test.go @@ -29,6 +29,12 @@ import ( "github.com/pingcap/errors" ) +const ( + // mocked loadMetaBinlog must be greater than relayHolderBinlog + loadMetaBinlog = "(mysql-bin.00001,154)" + relayHolderBinlog = "(mysql-bin.00001,150)" +) + type testSubTask struct{} var _ = Suite(&testSubTask{}) @@ -115,7 +121,20 @@ func (m *MockUnit) Update(cfg *config.SubTaskConfig) error { return m.errUpdate } -func (m *MockUnit) Status() interface{} { return nil } +func (m *MockUnit) Status() interface{} { + switch m.typ { + case pb.UnitType_Check: + return &pb.CheckStatus{} + case pb.UnitType_Dump: + return &pb.DumpStatus{} + case pb.UnitType_Load: + return &pb.LoadStatus{MetaBinlog: loadMetaBinlog} + case pb.UnitType_Sync: + return &pb.SyncStatus{} + default: + return struct{}{} + } +} func (m *MockUnit) Error() interface{} { return nil } @@ -496,3 +515,68 @@ func (t *testSubTask) TestDDLInfo(c *C) { c.Assert(st.SaveDDLInfo(ddlInfo), IsNil) c.Assert(st.GetDDLInfo(), DeepEquals, ddlInfo) } + +func (t *testSubTask) TestSubtaskFastQuit(c *C) { + // case: test subtask stuck into unitTransWaitCondition + cfg := &config.SubTaskConfig{ + Name: "testSubtaskFastQuit", + Mode: config.ModeAll, + } + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + w := &Worker{ + ctx: ctx, + // loadStatus relay MetaBinlog must be greater + relayHolder: NewDummyRelayHolderWithRelayBinlog(NewConfig(), relayHolderBinlog), + } + conditionHub = &ConditionHub{ + w: w, + } + + mockLoader := NewMockUnit(pb.UnitType_Load) + mockSyncer := NewMockUnit(pb.UnitType_Sync) + + st := NewSubTaskWithStage(cfg, pb.Stage_Paused) + st.prevUnit = mockLoader + st.currUnit = mockSyncer + + finished := make(chan struct{}) + go func() { + st.run() + close(finished) + }() + + // test Pause + time.Sleep(time.Second) // wait for task to run for some time + c.Assert(st.Stage(), Equals, pb.Stage_Running) + c.Assert(st.Pause(), IsNil) + select { + case <-time.After(500 * time.Millisecond): + c.Fatal("fail to pause subtask in 0.5s when stuck into unitTransWaitCondition") + case <-finished: + } + c.Assert(st.Stage(), Equals, pb.Stage_Paused) + + st = NewSubTaskWithStage(cfg, pb.Stage_Paused) + st.prevUnit = mockLoader + st.currUnit = mockSyncer + + finished = make(chan struct{}) + go func() { + st.run() + close(finished) + }() + + time.Sleep(time.Second) + c.Assert(st.Stage(), Equals, pb.Stage_Running) + // test Close + st.Close() + select { + case <-time.After(500 * time.Millisecond): + c.Fatal("fail to stop subtask in 0.5s when stuck into unitTransWaitCondition") + case <-finished: + } + c.Assert(st.Stage(), Equals, pb.Stage_Stopped) +}