diff --git a/dm/checker/checker.go b/dm/checker/checker.go index 9c44d1beae7..c49fdae6c3a 100644 --- a/dm/checker/checker.go +++ b/dm/checker/checker.go @@ -169,35 +169,6 @@ func (c *Checker) getTablePairInfo(ctx context.Context) (info *tablePairInfo, er return nil, egErr } - if _, ok := c.checkingItems[config.LightningFreeSpaceChecking]; ok && - c.stCfgs[0].LoaderConfig.ImportMode == config.LoadModePhysical && - c.stCfgs[0].Mode != config.ModeIncrement { - // TODO: concurrently read it intra-source later - for idx := range c.instances { - i := idx - eg.Go(func() error { - for _, sourceTables := range tableMapPerUpstream[i] { - for _, sourceTable := range sourceTables { - size, err2 := conn.FetchTableEstimatedBytes( - ctx, - c.instances[i].sourceDB, - sourceTable.Schema, - sourceTable.Name, - ) - if err2 != nil { - return err2 - } - info.totalDataSize.Add(size) - } - } - return nil - }) - } - } - if egErr := eg.Wait(); egErr != nil { - return nil, egErr - } - info.targetTable2ExtendedColumns = extendedColumnPerTable info.targetTable2SourceTablesMap = make(map[filter.Table]map[string][]filter.Table) info.targetTableShardNum = make(map[filter.Table]int) @@ -226,6 +197,8 @@ func (c *Checker) getTablePairInfo(ctx context.Context) (info *tablePairInfo, er info.sourceID2SourceTables = make(map[string][]filter.Table, len(c.instances)) info.sourceID2InterestedDB = make([]map[string]struct{}, len(c.instances)) info.sourceID2TableMap = make(map[string]map[filter.Table][]filter.Table, len(c.instances)) + sourceIDs := make([]string, 0, len(c.instances)) + dbs := make(map[string]*conn.BaseDB, len(c.instances)) for i, inst := range c.instances { sourceID := inst.cfg.SourceID info.sourceID2InterestedDB[i] = make(map[string]struct{}) @@ -237,7 +210,55 @@ func (c *Checker) getTablePairInfo(ctx context.Context) (info *tablePairInfo, er info.sourceID2InterestedDB[i][table.Schema] = struct{}{} } } + sourceIDs = append(sourceIDs, sourceID) + dbs[sourceID] = inst.sourceDB } + + if _, ok := c.checkingItems[config.LightningFreeSpaceChecking]; ok && + c.stCfgs[0].LoaderConfig.ImportMode == config.LoadModePhysical && + c.stCfgs[0].Mode != config.ModeIncrement { + concurrency, err := checker.GetConcurrency(ctx, sourceIDs, dbs, c.stCfgs[0].MydumperConfig.Threads) + if err != nil { + return nil, err + } + + type job struct { + db *conn.BaseDB + schema string + table string + } + + pool := checker.NewWorkerPoolWithContext[job, int64](ctx, func(result int64) { + info.totalDataSize.Add(result) + }) + for i := 0; i < concurrency; i++ { + pool.Go(func(ctx context.Context, job job) (int64, error) { + return conn.FetchTableEstimatedBytes( + ctx, + job.db, + job.schema, + job.table, + ) + }) + } + + for idx := range c.instances { + for _, sourceTables := range tableMapPerUpstream[idx] { + for _, sourceTable := range sourceTables { + pool.PutJob(job{ + db: c.instances[idx].sourceDB, + schema: sourceTable.Schema, + table: sourceTable.Name, + }) + } + } + } + err2 := pool.Wait() + if err2 != nil { + return nil, err2 + } + } + return info, nil } diff --git a/dm/pkg/checker/table_structure.go b/dm/pkg/checker/table_structure.go index 8d0dd0d94e8..44765895c3f 100644 --- a/dm/pkg/checker/table_structure.go +++ b/dm/pkg/checker/table_structure.go @@ -82,9 +82,6 @@ type TablesChecker struct { tableMap map[string]map[filter.Table][]filter.Table // downstream table -> extended column names extendedColumnPerTable map[filter.Table][]string - reMu sync.Mutex - inCh chan *checkItem - optCh chan *incompatibilityOption dumpThreads int // a simple cache for downstream table structure // filter.Table -> *ast.CreateTableStmt @@ -111,11 +108,97 @@ func NewTablesChecker( dumpThreads: dumpThreads, } log.L().Logger.Debug("check table structure", zap.Int("channel pool size", dumpThreads)) - c.inCh = make(chan *checkItem, dumpThreads) - c.optCh = make(chan *incompatibilityOption, dumpThreads) return c } +type tablesCheckerWorker struct { + c *TablesChecker + downstreamParser *parser.Parser + + lastSourceID string + upstreamParser *parser.Parser +} + +func (w *tablesCheckerWorker) handle(ctx context.Context, checkItem *checkItem) ([]*incompatibilityOption, error) { + var ( + err error + ret = make([]*incompatibilityOption, 0, 1) + table = checkItem.upstreamTable + ) + log.L().Logger.Debug("checking table", zap.String("db", table.Schema), zap.String("table", table.Name)) + if w.lastSourceID == "" || w.lastSourceID != checkItem.sourceID { + w.lastSourceID = checkItem.sourceID + w.upstreamParser, err = dbutil.GetParserForDB(ctx, w.c.upstreamDBs[w.lastSourceID].DB) + if err != nil { + return nil, err + } + } + db := w.c.upstreamDBs[checkItem.sourceID].DB + upstreamSQL, err := dbutil.GetCreateTableSQL(ctx, db, table.Schema, table.Name) + if err != nil { + // continue if table was deleted when checking + if isMySQLError(err, mysql.ErrNoSuchTable) { + return nil, nil + } + return nil, err + } + + upstreamStmt, err := getCreateTableStmt(w.upstreamParser, upstreamSQL) + if err != nil { + opt := &incompatibilityOption{ + state: StateWarning, + tableID: dbutil.TableName(table.Schema, table.Name), + errMessage: err.Error(), + } + ret = append(ret, opt) + // nolint:nilerr + return ret, nil + } + + downstreamStmt, ok := w.c.downstreamTables.Load(checkItem.downstreamTable) + if !ok { + sql, err2 := dbutil.GetCreateTableSQL( + ctx, + w.c.downstreamDB.DB, + checkItem.downstreamTable.Schema, + checkItem.downstreamTable.Name, + ) + if err2 != nil && !isMySQLError(err2, mysql.ErrNoSuchTable) { + return nil, err2 + } + if sql == "" { + downstreamStmt = (*ast.CreateTableStmt)(nil) + } else { + downstreamStmt, err2 = getCreateTableStmt(w.downstreamParser, sql) + if err2 != nil { + opt := &incompatibilityOption{ + state: StateWarning, + tableID: dbutil.TableName(table.Schema, table.Name), + errMessage: err2.Error(), + } + ret = append(ret, opt) + } + } + w.c.downstreamTables.Store(checkItem.downstreamTable, downstreamStmt) + } + + downstreamTable := filter.Table{ + Schema: checkItem.downstreamTable.Schema, + Name: checkItem.downstreamTable.Name, + } + opts := w.c.checkAST( + upstreamStmt, + downstreamStmt.(*ast.CreateTableStmt), + w.c.extendedColumnPerTable[downstreamTable], + ) + for _, opt := range opts { + opt.tableID = table.String() + ret = append(ret, opt) + } + log.L().Logger.Debug("finish checking table", zap.String("db", table.Schema), zap.String("table", table.Name)) + return ret, nil +} + // Check implements RealChecker interface. func (c *TablesChecker) Check(ctx context.Context) *Result { r := &Result{ @@ -131,30 +214,30 @@ func (c *TablesChecker) Check(ctx context.Context) *Result { markCheckError(r, err) return r } - eg, checkCtx := errgroup.WithContext(ctx) + + everyOptHandler, finalHandler := c.handleOpts(r) + + pool := NewWorkerPoolWithContext[*checkItem, []*incompatibilityOption]( + ctx, everyOptHandler, + ) + for i := 0; i < concurrency; i++ { - eg.Go(func() error { - return c.startWorker(checkCtx) - }) + worker := &tablesCheckerWorker{c: c} + worker.downstreamParser, err = dbutil.GetParserForDB(ctx, c.downstreamDB.DB) + if err != nil { + markCheckError(r, err) + return r + } + pool.Go(worker.handle) } - // start consuming results before dispatching - // or the dispatching thread could be blocked when - // the output channel is full. - var optWg sync.WaitGroup - optWg.Add(1) - go func() { - defer optWg.Done() - c.handleOpts(ctx, r) - }() - - dispatchTableItemWithDownstreamTable(checkCtx, c.tableMap, c.inCh) - if err := eg.Wait(); err != nil { - c.reMu.Lock() + + dispatchTableItemWithDownstreamTable(c.tableMap, pool) + + if err := pool.Wait(); err != nil { markCheckError(r, err) - c.reMu.Unlock() + return r } - close(c.optCh) - optWg.Wait() + finalHandler() log.L().Logger.Info("check table structure over", zap.Duration("spend time", time.Since(startTime))) return r @@ -165,144 +248,41 @@ func (c *TablesChecker) Name() string { return "table structure compatibility check" } -func (c *TablesChecker) handleOpts(ctx context.Context, r *Result) { +// handleOpts returns a handler that should be called on every +// incompatibilityOption, and a second handler that should be called once after +// all incompatibilityOption. +func (c *TablesChecker) handleOpts(r *Result) (func(options []*incompatibilityOption), func()) { // extract same instruction from Errors to Result.Instruction - resultInstructions := map[string]interface{}{} - defer func() { - c.reMu.Lock() - for k := range resultInstructions { - r.Instruction += k + "; " - } - c.reMu.Unlock() - }() - for { - select { - case <-ctx.Done(): - return - case opt, ok := <-c.optCh: - if !ok { - return - } - tableMsg := "table " + opt.tableID + " " - c.reMu.Lock() - switch opt.state { - case StateWarning: - if r.State != StateFailure { - r.State = StateWarning - } - e := NewError(tableMsg + opt.errMessage) - e.Severity = StateWarning - if _, ok := resultInstructions[opt.instruction]; !ok && opt.instruction != "" { - resultInstructions[opt.instruction] = "" - } - r.Errors = append(r.Errors, e) - case StateFailure: - r.State = StateFailure - e := NewError(tableMsg + opt.errMessage) - if _, ok := resultInstructions[opt.instruction]; !ok && opt.instruction != "" { - resultInstructions[opt.instruction] = "" - } - r.Errors = append(r.Errors, e) - } - c.reMu.Unlock() - } - } -} - -func (c *TablesChecker) startWorker(ctx context.Context) error { - var ( - sourceID string - upstreamParser *parser.Parser - downstreamParser *parser.Parser - err error - ) - - downstreamParser, err = dbutil.GetParserForDB(ctx, c.downstreamDB.DB) - if err != nil { - return err - } - - for { - select { - case <-ctx.Done(): - return context.Canceled - case checkItem, ok := <-c.inCh: - if !ok { - return nil - } - table := checkItem.upstreamTable - log.L().Logger.Debug("checking table", zap.String("db", table.Schema), zap.String("table", table.Name)) - if len(sourceID) == 0 || sourceID != checkItem.sourceID { - sourceID = checkItem.sourceID - upstreamParser, err = dbutil.GetParserForDB(ctx, c.upstreamDBs[sourceID].DB) - if err != nil { - return err - } - } - db := c.upstreamDBs[checkItem.sourceID] - upstreamSQL, err := dbutil.GetCreateTableSQL(ctx, db.DB, table.Schema, table.Name) - if err != nil { - // continue if table was deleted when checking - if isMySQLError(err, mysql.ErrNoSuchTable) { - continue - } - return err - } - - upstreamStmt, err := getCreateTableStmt(upstreamParser, upstreamSQL) - if err != nil { - opt := &incompatibilityOption{ - state: StateWarning, - tableID: dbutil.TableName(table.Schema, table.Name), - errMessage: err.Error(), - } - c.optCh <- opt - continue - } - - downstreamStmt, ok := c.downstreamTables.Load(checkItem.downstreamTable) - if !ok { - sql, err2 := dbutil.GetCreateTableSQL( - ctx, - c.downstreamDB.DB, - checkItem.downstreamTable.Schema, - checkItem.downstreamTable.Name, - ) - if err2 != nil && !isMySQLError(err2, mysql.ErrNoSuchTable) { - return err2 - } - if sql == "" { - downstreamStmt = (*ast.CreateTableStmt)(nil) - } else { - downstreamStmt, err2 = getCreateTableStmt(downstreamParser, sql) - if err2 != nil { - opt := &incompatibilityOption{ - state: StateWarning, - tableID: dbutil.TableName(table.Schema, table.Name), - errMessage: err2.Error(), - } - c.optCh <- opt + resultInstructions := map[string]struct{}{} + + return func(options []*incompatibilityOption) { + for _, opt := range options { + tableMsg := "table " + opt.tableID + " " + switch opt.state { + case StateWarning: + if r.State != StateFailure { + r.State = StateWarning } + e := NewError(tableMsg + opt.errMessage) + e.Severity = StateWarning + if _, ok := resultInstructions[opt.instruction]; !ok && opt.instruction != "" { + resultInstructions[opt.instruction] = struct{}{} + } + r.Errors = append(r.Errors, e) + case StateFailure: + r.State = StateFailure + e := NewError(tableMsg + opt.errMessage) + if _, ok := resultInstructions[opt.instruction]; !ok && opt.instruction != "" { + resultInstructions[opt.instruction] = struct{}{} + } + r.Errors = append(r.Errors, e) } - c.downstreamTables.Store(checkItem.downstreamTable, downstreamStmt) - } - - downstreamTable := filter.Table{ - Schema: checkItem.downstreamTable.Schema, - Name: checkItem.downstreamTable.Name, } - opts := c.checkAST( - upstreamStmt, - downstreamStmt.(*ast.CreateTableStmt), - c.extendedColumnPerTable[downstreamTable], - ) - for _, opt := range opts { - opt.tableID = table.String() - c.optCh <- opt + }, func() { + for k := range resultInstructions { + r.Instruction += k + "; " } - log.L().Logger.Debug("finish checking table", zap.String("db", table.Schema), zap.String("table", table.Name)) } - } } func (c *TablesChecker) checkAST( @@ -949,23 +929,23 @@ func dispatchTableItem(ctx context.Context, tableMap map[string][]filter.Table, } func dispatchTableItemWithDownstreamTable( - ctx context.Context, tableMaps map[string]map[filter.Table][]filter.Table, - inCh chan *checkItem, + pool *WorkerPool[*checkItem, []*incompatibilityOption], ) { for sourceID, tableMap := range tableMaps { for downTable, upTables := range tableMap { for _, upTable := range upTables { - select { - case <-ctx.Done(): - log.L().Logger.Warn("ctx canceled before input tables completely") + ok := pool.PutJob(&checkItem{ + upstreamTable: upTable, + downstreamTable: downTable, + sourceID: sourceID, + }) + if !ok { return - case inCh <- &checkItem{upstreamTable: upTable, downstreamTable: downTable, sourceID: sourceID}: } } } } - close(inCh) } // GetConcurrency gets the concurrency of workers that we can randomly dispatch diff --git a/dm/pkg/checker/worker_pool.go b/dm/pkg/checker/worker_pool.go new file mode 100644 index 00000000000..ae035263a41 --- /dev/null +++ b/dm/pkg/checker/worker_pool.go @@ -0,0 +1,115 @@ +// Copyright 2022 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// See the License for the specific language governing permissions and +// limitations under the License. + +package checker + +import ( + "context" + + "golang.org/x/sync/errgroup" +) + +// WorkerPool is a easy-to-use worker pool that can start workers by Go then use +// PutJob to send jobs to worker. After worker finished a job, the result is +// sequentially called by resultHandler function which is the parameter of +// NewWorkerPool or NewWorkerPoolWithContext. After caller send all jobs, it can +// call Wait to make sure all jobs are finished. +// The type parameter J means job, R means result. Type J MUST only share +// concurrent-safe member like *sql.DB. +type WorkerPool[J, R any] struct { + ctx context.Context + + // the closing order is inCh -> outCh -> done + inCh chan J + outCh chan R + done chan struct{} + errGroup *errgroup.Group +} + +// NewWorkerPool creates a new worker pool. +// The type parameter J means job, R means result. Type J MUST only share +// concurrent-safe member like *sql.DB. +func NewWorkerPool[J, R any](resultHandler func(R)) *WorkerPool[J, R] { + return NewWorkerPoolWithContext[J, R](context.Background(), resultHandler) +} + +// NewWorkerPoolWithContext creates a new worker pool with a context which may +// be canceled from caller. +// The type parameter J means job, R means result. Type J MUST only share +// concurrent-safe member like *sql.DB. +func NewWorkerPoolWithContext[J, R any]( + ctx context.Context, + resultHandler func(R), +) *WorkerPool[J, R] { + group, groupCtx := errgroup.WithContext(ctx) + ret := &WorkerPool[J, R]{ + ctx: groupCtx, + errGroup: group, + inCh: make(chan J), + outCh: make(chan R), + done: make(chan struct{}), + } + go func() { + for r := range ret.outCh { + resultHandler(r) + } + close(ret.done) + }() + + return ret +} + +// Go is like a builtin go keyword. handler represents the logic of worker, if +// the worker has initializing logic, caller can use method of structure or +// closure to refer to the initialized part. +func (p *WorkerPool[J, R]) Go(handler func(ctx context.Context, job J) (R, error)) { + p.errGroup.Go(func() error { + for { + select { + case <-p.ctx.Done(): + return p.ctx.Err() + case job, ok := <-p.inCh: + if !ok { + return nil + } + result, err := handler(p.ctx, job) + if err != nil { + return err + } + p.outCh <- result + } + } + }) +} + +// PutJob sends a job to worker pool. The return value means whether the workers +// are stopped so caller can stop early. +func (p *WorkerPool[J, R]) PutJob(job J) bool { + select { + case <-p.ctx.Done(): + return false + case p.inCh <- job: + return true + } +} + +// Wait waits all workers to finish. It will return the first error occurred in +// workers, or nil if no error. +// Other methods should not be called concurrent with Wait or after Wait. +func (p *WorkerPool[J, R]) Wait() error { + close(p.inCh) + err := p.errGroup.Wait() + close(p.outCh) + <-p.done + return err +} diff --git a/dm/pkg/checker/worker_pool_test.go b/dm/pkg/checker/worker_pool_test.go new file mode 100644 index 00000000000..4588ef556f1 --- /dev/null +++ b/dm/pkg/checker/worker_pool_test.go @@ -0,0 +1,112 @@ +// Copyright 2022 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// See the License for the specific language governing permissions and +// limitations under the License. + +package checker + +import ( + "context" + "errors" + "sync/atomic" + "testing" + "time" + + "github.com/stretchr/testify/require" +) + +type slowIncrementer struct { + atomic.Int64 +} + +func (i *slowIncrementer) Inc() int64 { + time.Sleep(100 * time.Millisecond) + return i.Add(1) +} + +type job struct { + inc *slowIncrementer +} + +type baseAdder struct { + base int64 + closed atomic.Bool +} + +func (a *baseAdder) add(_ context.Context, j job) (int64, error) { + i := j.inc.Inc() + return a.base + i, nil +} + +func TestExampleWorkerPool(t *testing.T) { + sum := int64(0) + concurrency := 100 + jobNum := 1000 + incrementer := slowIncrementer{} + + pool := NewWorkerPool[job, int64](func(result int64) { + sum += result + }) + for i := 0; i < concurrency; i++ { + worker := &baseAdder{base: 666} + pool.Go(worker.add) + } + for i := 0; i < jobNum; i++ { + pool.PutJob(job{inc: &incrementer}) + } + + err := pool.Wait() + require.NoError(t, err) + // sum 1 to 1000 = 500500 + require.Equal(t, int64(666*jobNum+500500), sum) +} + +var ( + errMock = errors.New("mock error") + errorAt = int64(500) +) + +func (a *baseAdder) addAndError(_ context.Context, j job) (int64, error) { + i := j.inc.Inc() + if i == errorAt { + if a.closed.Load() { + panic("worker is used after closed") + } + a.closed.Store(true) + return 0, errMock + } + return a.base + i, nil +} + +func TestExampleWorkerPoolError(t *testing.T) { + sum := int64(0) + concurrency := 100 + jobNum := 1000 + incrementer := slowIncrementer{} + + pool := NewWorkerPool[job, int64](func(result int64) { + sum += result + }) + for i := 0; i < concurrency; i++ { + worker := &baseAdder{base: 666} + pool.Go(worker.addAndError) + } + for i := 0; i < jobNum; i++ { + ok := pool.PutJob(job{inc: &incrementer}) + if !ok { + require.GreaterOrEqual(t, int64(i), errorAt) + break + } + } + + err := pool.Wait() + require.ErrorIs(t, err, errMock) +}