Skip to content
This repository has been archived by the owner on Nov 24, 2023. It is now read-only.

*: fix context usage for SQL operation (#377) #400

Merged
merged 12 commits into from
Dec 9, 2019
2 changes: 1 addition & 1 deletion checker/checker.go
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ func NewChecker(cfgs []*config.SubTaskConfig, checkingItems map[string]string) *
}

// Init implements Unit interface
func (c *Checker) Init() (err error) {
func (c *Checker) Init(ctx context.Context) (err error) {
rollbackHolder := fr.NewRollbackHolder("checker")
defer func() {
if err != nil {
Expand Down
2 changes: 1 addition & 1 deletion checker/cmd.go
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ func CheckSyncConfig(ctx context.Context, cfgs []*config.SubTaskConfig) error {

c := NewChecker(cfgs, checkingItems)

err := c.Init()
err := c.Init(ctx)
if err != nil {
return terror.Annotate(err, "fail to initial checker")
}
Expand Down
10 changes: 8 additions & 2 deletions dm/unit/unit.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ package unit

import (
"context"
"time"

"github.com/pingcap/errors"

Expand All @@ -23,14 +24,19 @@ import (
"github.com/pingcap/dm/pkg/terror"
)

const (
// DefaultInitTimeout represents the default timeout value when initializing a process unit.
DefaultInitTimeout = time.Minute
)

// Unit defines interface for sub task process units, like syncer, loader, relay, etc.
type Unit interface {
// Init initializes the dm process unit
// every unit does base initialization in `Init`, and this must pass before start running the sub task
// other setups can be done in `Process`, but this should be treated carefully, let it's compatible with Pause / Resume
// if initialing successfully, the outer caller should call `Close` when the unit (or the task) finished, stopped or canceled (because other units Init fail).
// if initialing fail, Init itself should release resources it acquired before (rolling itself back).
Init() error
Init(ctx context.Context) error
// Process processes sub task
// When ctx.Done, stops the process and returns
// When not in processing, call Process to continue or resume the process
Expand All @@ -52,7 +58,7 @@ type Unit interface {
Type() pb.UnitType
// IsFreshTask return whether is a fresh task (not processed before)
// it will be used to decide where the task should become restoring
IsFreshTask() (bool, error)
IsFreshTask(ctx context.Context) (bool, error)
}

// NewProcessError creates a new ProcessError
Expand Down
5 changes: 4 additions & 1 deletion dm/worker/relay.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ import (
"go.uber.org/zap"

"github.com/pingcap/dm/dm/pb"
"github.com/pingcap/dm/dm/unit"
"github.com/pingcap/dm/pkg/log"
"github.com/pingcap/dm/pkg/streamer"
"github.com/pingcap/dm/pkg/terror"
Expand Down Expand Up @@ -125,7 +126,9 @@ func (h *realRelayHolder) Init(interceptors []purger.PurgeInterceptor) (purger.P
streamer.GetReaderHub(),
}

if err := h.relay.Init(); err != nil {
ctx, cancel := context.WithTimeout(context.Background(), unit.DefaultInitTimeout)
defer cancel()
if err := h.relay.Init(ctx); err != nil {
return nil, terror.Annotate(err, "initial relay unit")
}

Expand Down
2 changes: 1 addition & 1 deletion dm/worker/relay_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ func NewDummyRelay(cfg *relay.Config) relay.Process {
}

// Init implements Process interface
func (d *DummyRelay) Init() error {
func (d *DummyRelay) Init(ctx context.Context) error {
return d.initErr
}

Expand Down
8 changes: 6 additions & 2 deletions dm/worker/subtask.go
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,9 @@ func (st *SubTask) Init() error {
// other setups can be done in `Process`, like Loader's prepare which depends on Mydumper's output
// but setups in `Process` should be treated carefully, let it's compatible with Pause / Resume
for i, u := range st.units {
err := u.Init()
ctx, cancel := context.WithTimeout(context.Background(), unit.DefaultInitTimeout)
err := u.Init(ctx)
cancel()
if err != nil {
initializeUnitSuccess = false
// when init fail, other units initialized before should be closed
Expand All @@ -140,7 +142,9 @@ func (st *SubTask) Init() error {
var skipIdx = 0
for i := len(st.units) - 1; i > 0; i-- {
u := st.units[i]
isFresh, err := u.IsFreshTask()
ctx, cancel := context.WithTimeout(context.Background(), unit.DefaultInitTimeout)
isFresh, err := u.IsFreshTask(ctx)
cancel()
if err != nil {
initializeUnitSuccess = false
return terror.Annotatef(err, "fail to get fresh status of subtask %s %s", st.cfg.Name, u.Type())
Expand Down
4 changes: 2 additions & 2 deletions dm/worker/subtask_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ func NewMockUnit(typ pb.UnitType) *MockUnit {
}
}

func (m *MockUnit) Init() error {
func (m *MockUnit) Init(ctx context.Context) error {
return m.errInit
}

Expand Down Expand Up @@ -121,7 +121,7 @@ func (m *MockUnit) Error() interface{} { return nil }

func (m *MockUnit) Type() pb.UnitType { return m.typ }

func (m *MockUnit) IsFreshTask() (bool, error) { return m.isFresh, m.errFresh }
func (m *MockUnit) IsFreshTask(ctx context.Context) (bool, error) { return m.isFresh, m.errFresh }

func (m *MockUnit) InjectProcessError(ctx context.Context, err error) error {
newCtx, cancel := context.WithTimeout(ctx, time.Second)
Expand Down
73 changes: 38 additions & 35 deletions loader/checkpoint.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ type CheckPoint interface {
// Load loads all checkpoints recorded before.
// because of no checkpoints updated in memory when error occurred
// when resuming, Load will be called again to load checkpoints
Load() error
Load(tctx *tcontext.Context) error

// GetRestoringFileInfo get restoring data files for table
GetRestoringFileInfo(db, table string) map[string][]int64
Expand All @@ -48,19 +48,19 @@ type CheckPoint interface {
CalcProgress(allFiles map[string]Tables2DataFiles) error

// Init initialize checkpoint data in tidb
Init(filename string, endpos int64) error
Init(tctx *tcontext.Context, filename string, endpos int64) error

// ResetConn resets database connections owned by the Checkpoint
ResetConn() error
ResetConn(tctx *tcontext.Context) error

// Close closes the CheckPoint
Close()

// Clear clears all recorded checkpoints
Clear() error
Clear(tctx *tcontext.Context) error

// Count returns recorded checkpoints' count
Count() (int, error)
Count(tctx *tcontext.Context) (int, error)

// GenSQL generates sql to update checkpoint to DB
GenSQL(filename string, offset int64) string
Expand All @@ -76,7 +76,7 @@ type RemoteCheckPoint struct {
table string
restoringFiles map[string]map[string]FilePosSet
finishedTables map[string]struct{}
tctx *tcontext.Context
logCtx *tcontext.Context
}

func newRemoteCheckPoint(tctx *tcontext.Context, cfg *config.SubTaskConfig, id string) (CheckPoint, error) {
Expand All @@ -85,8 +85,6 @@ func newRemoteCheckPoint(tctx *tcontext.Context, cfg *config.SubTaskConfig, id s
return nil, err
}

newtctx := tctx.WithLogger(tctx.L().WithFields(zap.String("component", "remote checkpoint")))

cp := &RemoteCheckPoint{
db: db,
conn: dbConns[0],
Expand All @@ -95,36 +93,36 @@ func newRemoteCheckPoint(tctx *tcontext.Context, cfg *config.SubTaskConfig, id s
finishedTables: make(map[string]struct{}),
schema: cfg.MetaSchema,
table: fmt.Sprintf("%s_loader_checkpoint", cfg.Name),
tctx: newtctx,
logCtx: tcontext.Background().WithLogger(tctx.L().WithFields(zap.String("component", "remote checkpoint"))),
}

err = cp.prepare()
err = cp.prepare(tctx)
if err != nil {
return nil, err
}

return cp, nil
}

func (cp *RemoteCheckPoint) prepare() error {
func (cp *RemoteCheckPoint) prepare(tctx *tcontext.Context) error {
// create schema
if err := cp.createSchema(); err != nil {
if err := cp.createSchema(tctx); err != nil {
return err
}
// create table
if err := cp.createTable(); err != nil {
if err := cp.createTable(tctx); err != nil {
return err
}
return nil
}

func (cp *RemoteCheckPoint) createSchema() error {
func (cp *RemoteCheckPoint) createSchema(tctx *tcontext.Context) error {
sql2 := fmt.Sprintf("CREATE SCHEMA IF NOT EXISTS `%s`", cp.schema)
err := cp.conn.executeSQL(cp.tctx, []string{sql2})
err := cp.conn.executeSQL(tctx, []string{sql2})
return terror.WithScope(err, terror.ScopeDownstream)
}

func (cp *RemoteCheckPoint) createTable() error {
func (cp *RemoteCheckPoint) createTable(tctx *tcontext.Context) error {
tableName := fmt.Sprintf("`%s`.`%s`", cp.schema, cp.table)
createTable := `CREATE TABLE IF NOT EXISTS %s (
id char(32) NOT NULL,
Expand All @@ -139,19 +137,19 @@ func (cp *RemoteCheckPoint) createTable() error {
);
`
sql2 := fmt.Sprintf(createTable, tableName)
err := cp.conn.executeSQL(cp.tctx, []string{sql2})
err := cp.conn.executeSQL(tctx, []string{sql2})
return terror.WithScope(err, terror.ScopeDownstream)
}

// Load implements CheckPoint.Load
func (cp *RemoteCheckPoint) Load() error {
func (cp *RemoteCheckPoint) Load(tctx *tcontext.Context) error {
begin := time.Now()
defer func() {
cp.tctx.L().Info("load checkpoint", zap.Duration("cost time", time.Since(begin)))
cp.logCtx.L().Info("load checkpoint", zap.Duration("cost time", time.Since(begin)))
}()

query := fmt.Sprintf("SELECT `filename`,`cp_schema`,`cp_table`,`offset`,`end_pos` from `%s`.`%s` where `id`=?", cp.schema, cp.table)
rows, err := cp.conn.querySQL(cp.tctx, query, cp.id)
rows, err := cp.conn.querySQL(tctx, query, cp.id)
if err != nil {
return terror.WithScope(err, terror.ScopeDownstream)
}
Expand Down Expand Up @@ -248,14 +246,14 @@ func (cp *RemoteCheckPoint) CalcProgress(allFiles map[string]Tables2DataFiles) e
}
}

cp.tctx.L().Info("calculate checkpoint finished.", zap.Reflect("finished tables", cp.finishedTables))
cp.logCtx.L().Info("calculate checkpoint finished.", zap.Reflect("finished tables", cp.finishedTables))
return nil
}

func (cp *RemoteCheckPoint) allFilesFinished(files map[string][]int64) bool {
for file, pos := range files {
if len(pos) != 2 {
cp.tctx.L().Error("unexpected checkpoint record", zap.String("data file", file), zap.Int64s("position", pos))
cp.logCtx.L().Error("unexpected checkpoint record", zap.String("data file", file), zap.Int64s("position", pos))
return false
}
if pos[0] != pos[1] {
Expand All @@ -266,7 +264,7 @@ func (cp *RemoteCheckPoint) allFilesFinished(files map[string][]int64) bool {
}

// Init implements CheckPoint.Init
func (cp *RemoteCheckPoint) Init(filename string, endPos int64) error {
func (cp *RemoteCheckPoint) Init(tctx *tcontext.Context, filename string, endPos int64) error {
idx := strings.Index(filename, ".sql")
if idx < 0 {
return terror.ErrCheckpointInvalidTableFile.Generate(filename)
Expand All @@ -279,7 +277,7 @@ func (cp *RemoteCheckPoint) Init(filename string, endPos int64) error {

// fields[0] -> db name, fields[1] -> table name
sql2 := fmt.Sprintf("INSERT INTO `%s`.`%s` (`id`, `filename`, `cp_schema`, `cp_table`, `offset`, `end_pos`) VALUES(?,?,?,?,?,?)", cp.schema, cp.table)
cp.tctx.L().Debug("initial checkpoint record",
cp.logCtx.L().Debug("initial checkpoint record",
zap.String("sql", sql2),
zap.String("id", cp.id),
zap.String("filename", filename),
Expand All @@ -288,10 +286,10 @@ func (cp *RemoteCheckPoint) Init(filename string, endPos int64) error {
zap.Int64("offset", 0),
zap.Int64("end position", endPos))
args := []interface{}{cp.id, filename, fields[0], fields[1], 0, endPos}
err := cp.conn.executeSQL(cp.tctx, []string{sql2}, args)
err := cp.conn.executeSQL(tctx, []string{sql2}, args)
if err != nil {
if isErrDupEntry(err) {
cp.tctx.L().Info("checkpoint record already exists, skip it.", zap.String("id", cp.id), zap.String("filename", filename))
cp.logCtx.L().Info("checkpoint record already exists, skip it.", zap.String("id", cp.id), zap.String("filename", filename))
return nil
}
return terror.WithScope(terror.Annotate(err, "initialize checkpoint"), terror.ScopeDownstream)
Expand All @@ -300,15 +298,15 @@ func (cp *RemoteCheckPoint) Init(filename string, endPos int64) error {
}

// ResetConn implements CheckPoint.ResetConn
func (cp *RemoteCheckPoint) ResetConn() error {
return cp.conn.resetConn(cp.tctx)
func (cp *RemoteCheckPoint) ResetConn(tctx *tcontext.Context) error {
return cp.conn.resetConn(tctx)
}

// Close implements CheckPoint.Close
func (cp *RemoteCheckPoint) Close() {
err := cp.db.Close()
if err != nil {
cp.tctx.L().Error("close checkpoint db", log.ShortError(err))
cp.logCtx.L().Error("close checkpoint db", log.ShortError(err))
}
}

Expand All @@ -320,16 +318,16 @@ func (cp *RemoteCheckPoint) GenSQL(filename string, offset int64) string {
}

// Clear implements CheckPoint.Clear
func (cp *RemoteCheckPoint) Clear() error {
func (cp *RemoteCheckPoint) Clear(tctx *tcontext.Context) error {
sql2 := fmt.Sprintf("DELETE FROM `%s`.`%s` WHERE `id` = '%s'", cp.schema, cp.table, cp.id)
err := cp.conn.executeSQL(cp.tctx, []string{sql2})
err := cp.conn.executeSQL(tctx, []string{sql2})
return terror.WithScope(err, terror.ScopeDownstream)
}

// Count implements CheckPoint.Count
func (cp *RemoteCheckPoint) Count() (int, error) {
func (cp *RemoteCheckPoint) Count(tctx *tcontext.Context) (int, error) {
query := fmt.Sprintf("SELECT COUNT(id) FROM `%s`.`%s` WHERE `id` = ?", cp.schema, cp.table)
rows, err := cp.conn.querySQL(cp.tctx, query, cp.id)
rows, err := cp.conn.querySQL(tctx, query, cp.id)
if err != nil {
return 0, terror.WithScope(err, terror.ScopeDownstream)
}
Expand All @@ -344,12 +342,17 @@ func (cp *RemoteCheckPoint) Count() (int, error) {
if rows.Err() != nil {
return 0, terror.WithScope(terror.DBErrorAdapt(rows.Err(), terror.ErrDBDriverError), terror.ScopeDownstream)
}
cp.tctx.L().Debug("checkpoint record", zap.Int("count", count))
cp.logCtx.L().Debug("checkpoint record", zap.Int("count", count))
return count, nil
}

func (cp *RemoteCheckPoint) String() string {
if err := cp.Load(); err != nil {
// `String` is often used to log something, it's not a big problem even fail,
// so 1min should be enough.
tctx2, cancel := cp.logCtx.WithTimeout(time.Minute)
defer cancel()

if err := cp.Load(tctx2); err != nil {
return err.Error()
}

Expand Down
Loading