From eb637b188fbd1630e3c7d4c3737039fa4bcda68f Mon Sep 17 00:00:00 2001 From: Chunzhu Li Date: Mon, 25 Apr 2022 15:38:49 +0800 Subject: [PATCH] worker: add lock protection for w.cfg (#5250) close pingcap/tiflow#5249 --- dm/dm/worker/server.go | 7 +++++-- dm/dm/worker/source_worker.go | 15 +++++++++++---- dm/dm/worker/status.go | 5 ++++- 3 files changed, 20 insertions(+), 7 deletions(-) diff --git a/dm/dm/worker/server.go b/dm/dm/worker/server.go index 7ff41340595..0d416c2b83f 100644 --- a/dm/dm/worker/server.go +++ b/dm/dm/worker/server.go @@ -790,11 +790,14 @@ func (s *Server) OperateSchema(ctx context.Context, req *pb.OperateWorkerSchemaR log.L().Info("", zap.String("request", "OperateSchema"), zap.Stringer("payload", req)) w := s.getSourceWorker(true) + w.RLock() + sourceID := w.cfg.SourceID + w.RUnlock() if w == nil { log.L().Warn("fail to call OperateSchema, because no mysql source is being handled in the worker") return makeCommonWorkerResponse(terror.ErrWorkerNoStart.Generate()), nil - } else if req.Source != w.cfg.SourceID { - log.L().Error("fail to call OperateSchema, because source mismatch", zap.String("request", req.Source), zap.String("current", w.cfg.SourceID)) + } else if req.Source != sourceID { + log.L().Error("fail to call OperateSchema, because source mismatch", zap.String("request", req.Source), zap.String("current", sourceID)) return makeCommonWorkerResponse(terror.ErrWorkerSourceNotMatch.Generate()), nil } diff --git a/dm/dm/worker/source_worker.go b/dm/dm/worker/source_worker.go index 9d43eff9406..2dd7d245c0f 100644 --- a/dm/dm/worker/source_worker.go +++ b/dm/dm/worker/source_worker.go @@ -43,9 +43,10 @@ import ( // SourceWorker manages a source(upstream) which is mainly related to subtasks and relay. type SourceWorker struct { - // ensure no other operation can be done when closing (we can use `WatGroup`/`Context` to archive this) + // ensure no other operation can be done when closing (we can use `WaitGroup`/`Context` to archive this) // TODO: check what does it guards. Now it's used to guard relayHolder and relayPurger (maybe subTaskHolder?) since // query-status maybe access them when closing/disable functionalities + // This lock is used to guards source worker's source config and subtask holder(subtask configs) sync.RWMutex wg sync.WaitGroup @@ -249,9 +250,12 @@ func (w *SourceWorker) Stop(graceful bool) { // updateSourceStatus updates w.sourceStatus. func (w *SourceWorker) updateSourceStatus(ctx context.Context) error { w.sourceDBMu.Lock() + w.RLock() + cfg := w.cfg + w.RUnlock() if w.sourceDB == nil { var err error - w.sourceDB, err = conn.DefaultDBProvider.Apply(&w.cfg.DecryptPassword().From) + w.sourceDB, err = conn.DefaultDBProvider.Apply(&cfg.DecryptPassword().From) if err != nil { w.sourceDBMu.Unlock() return err @@ -262,7 +266,7 @@ func (w *SourceWorker) updateSourceStatus(ctx context.Context) error { var status binlog.SourceStatus ctx, cancel := context.WithTimeout(ctx, utils.DefaultDBTimeout) defer cancel() - pos, gtidSet, err := utils.GetPosAndGs(ctx, w.sourceDB.DB, w.cfg.Flavor) + pos, gtidSet, err := utils.GetPosAndGs(ctx, w.sourceDB.DB, cfg.Flavor) if err != nil { return err } @@ -1160,7 +1164,10 @@ func (w *SourceWorker) observeValidatorStage(ctx context.Context, lastUsedRev in case <-ctx.Done(): return nil case <-time.After(500 * time.Millisecond): - startRevision, err = w.getCurrentValidatorRevision(w.cfg.SourceID) + w.RLock() + sourceID := w.cfg.SourceID + w.RUnlock() + startRevision, err = w.getCurrentValidatorRevision(sourceID) if err != nil { log.L().Error("reset validator stage failed, will retry later", zap.Error(err), zap.Int("retryNum", retryNum)) } diff --git a/dm/dm/worker/status.go b/dm/dm/worker/status.go index 1fa1237af37..5e749c98f54 100644 --- a/dm/dm/worker/status.go +++ b/dm/dm/worker/status.go @@ -131,7 +131,10 @@ func (w *SourceWorker) GetValidateStatus(stName string, filterStatus pb.Stage) [ if st == nil { return res } - sourceIP := w.cfg.From.Host + ":" + strconv.Itoa(w.cfg.From.Port) + w.RLock() + cfg := w.cfg + w.RUnlock() + sourceIP := cfg.From.Host + ":" + strconv.Itoa(cfg.From.Port) tblStats := st.GetValidatorStatus() for _, stat := range tblStats { if filterStatus == pb.Stage_InvalidStage || stat.ValidationStatus == filterStatus.String() {