diff --git a/lightning/restore/restore.go b/lightning/restore/restore.go index ca4f68ce6..56cdda9cc 100644 --- a/lightning/restore/restore.go +++ b/lightning/restore/restore.go @@ -137,6 +137,7 @@ type RestoreController struct { indexWorkers *worker.Pool regionWorkers *worker.Pool ioWorkers *worker.Pool + pauser *common.Pauser backend kv.Backend tidbMgr *TiDBManager postProcessLock sync.Mutex // a simple way to ensure post-processing is not concurrent without using complicated goroutines @@ -153,6 +154,10 @@ type RestoreController struct { } func NewRestoreController(ctx context.Context, dbMetas []*mydump.MDDatabaseMeta, cfg *config.Config) (*RestoreController, error) { + return NewRestoreControllerWithPauser(ctx, dbMetas, cfg, DeliverPauser) +} + +func NewRestoreControllerWithPauser(ctx context.Context, dbMetas []*mydump.MDDatabaseMeta, cfg *config.Config, pauser *common.Pauser) (*RestoreController, error) { cpdb, err := OpenCheckpointsDB(ctx, cfg) if err != nil { return nil, errors.Trace(err) @@ -184,6 +189,7 @@ func NewRestoreController(ctx context.Context, dbMetas []*mydump.MDDatabaseMeta, indexWorkers: worker.NewPool(ctx, cfg.App.IndexConcurrency, "index"), regionWorkers: worker.NewPool(ctx, cfg.App.RegionConcurrency, "region"), ioWorkers: worker.NewPool(ctx, cfg.App.IOConcurrency, "io"), + pauser: pauser, backend: backend, tidbMgr: tidbMgr, @@ -1663,6 +1669,7 @@ func (cr *chunkRestore) encodeLoop( logger log.Logger, kvEncoder kv.Encoder, deliverCompleteCh <-chan deliverResult, + pauser *common.Pauser, ) (readTotalDur time.Duration, encodeTotalDur time.Duration, err error) { send := func(kvs deliveredKVs) error { select { @@ -1685,7 +1692,7 @@ func (cr *chunkRestore) encodeLoop( initializedColumns := false outside: for { - if err = DeliverPauser.Wait(ctx); err != nil { + if err = pauser.Wait(ctx); err != nil { return } @@ -1775,7 +1782,7 @@ func (cr *chunkRestore) restore( zap.Stringer("path", &cr.chunk.Key), ).Begin(zap.InfoLevel, "restore file") - readTotalDur, encodeTotalDur, err := cr.encodeLoop(ctx, kvsCh, t, logTask.Logger, kvEncoder, deliverCompleteCh) + readTotalDur, encodeTotalDur, err := cr.encodeLoop(ctx, kvsCh, t, logTask.Logger, kvEncoder, deliverCompleteCh, rc.pauser) if err != nil { return err } diff --git a/lightning/restore/restore_test.go b/lightning/restore/restore_test.go index 478c798e8..5a81dfd20 100644 --- a/lightning/restore/restore_test.go +++ b/lightning/restore/restore_test.go @@ -747,7 +747,7 @@ func (s *chunkRestoreSuite) TestEncodeLoop(c *C) { deliverCompleteCh := make(chan deliverResult) kvEncoder := kv.NewTableKVEncoder(s.tr.encTable, s.cfg.TiDB.SQLMode, 1234567895) - _, _, err := s.cr.encodeLoop(ctx, kvsCh, s.tr, s.tr.logger, kvEncoder, deliverCompleteCh) + _, _, err := s.cr.encodeLoop(ctx, kvsCh, s.tr, s.tr.logger, kvEncoder, deliverCompleteCh, DeliverPauser) c.Assert(err, IsNil) c.Assert(kvsCh, HasLen, 2) @@ -767,7 +767,7 @@ func (s *chunkRestoreSuite) TestEncodeLoopCanceled(c *C) { kvEncoder := kv.NewTableKVEncoder(s.tr.encTable, s.cfg.TiDB.SQLMode, 1234567896) go cancel() - _, _, err := s.cr.encodeLoop(ctx, kvsCh, s.tr, s.tr.logger, kvEncoder, deliverCompleteCh) + _, _, err := s.cr.encodeLoop(ctx, kvsCh, s.tr, s.tr.logger, kvEncoder, deliverCompleteCh, DeliverPauser) c.Assert(errors.Cause(err), Equals, context.Canceled) c.Assert(kvsCh, HasLen, 0) } @@ -781,7 +781,7 @@ func (s *chunkRestoreSuite) TestEncodeLoopForcedError(c *C) { // close the chunk so reading it will result in the "file already closed" error. s.cr.parser.Close() - _, _, err := s.cr.encodeLoop(ctx, kvsCh, s.tr, s.tr.logger, kvEncoder, deliverCompleteCh) + _, _, err := s.cr.encodeLoop(ctx, kvsCh, s.tr, s.tr.logger, kvEncoder, deliverCompleteCh, DeliverPauser) c.Assert(err, ErrorMatches, `in file .*/db.table.2.sql:0 at offset 0:.*file already closed`) c.Assert(kvsCh, HasLen, 0) } @@ -797,7 +797,7 @@ func (s *chunkRestoreSuite) TestEncodeLoopDeliverErrored(c *C) { err: errors.New("fake deliver error"), } }() - _, _, err := s.cr.encodeLoop(ctx, kvsCh, s.tr, s.tr.logger, kvEncoder, deliverCompleteCh) + _, _, err := s.cr.encodeLoop(ctx, kvsCh, s.tr, s.tr.logger, kvEncoder, deliverCompleteCh, DeliverPauser) c.Assert(err, ErrorMatches, "fake deliver error") c.Assert(kvsCh, HasLen, 0) } @@ -848,6 +848,7 @@ func (s *chunkRestoreSuite) TestRestore(c *C) { cfg: s.cfg, saveCpCh: saveCpCh, backend: importer, + pauser: DeliverPauser, }) c.Assert(err, IsNil) c.Assert(saveCpCh, HasLen, 2)