Skip to content

Commit

Permalink
*: refine context usage to avoid potential blocks (pingcap#1285) (pin…
Browse files Browse the repository at this point in the history
  • Loading branch information
ti-srebot authored Nov 23, 2020
1 parent cec6f9d commit 13191a8
Show file tree
Hide file tree
Showing 71 changed files with 765 additions and 615 deletions.
2 changes: 1 addition & 1 deletion checker/checker.go
Original file line number Diff line number Diff line change
Expand Up @@ -180,7 +180,7 @@ func (c *Checker) Init(ctx context.Context) (err error) {
continue
}

mapping, err := utils.FetchTargetDoTables(instance.sourceDB.DB, bw, r)
mapping, err := utils.FetchTargetDoTables(ctx, instance.sourceDB.DB, bw, r)
if err != nil {
return err
}
Expand Down
28 changes: 11 additions & 17 deletions dm/config/source_config.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@ import (
"math"
"math/rand"
"strings"
"time"

"github.com/BurntSushi/toml"
"github.com/siddontang/go-mysql/mysql"
Expand All @@ -23,11 +22,6 @@ import (
)

const (
// dbReadTimeout is readTimeout for DB connection in adjust
dbReadTimeout = "30s"
// dbGetTimeout is timeout for getting some information from DB
dbGetTimeout = 30 * time.Second

// the default base(min) server id generated by random
defaultBaseServerID = math.MaxUint32 / 10
)
Expand Down Expand Up @@ -214,32 +208,32 @@ func (c *SourceConfig) GenerateDBConfig() *DBConfig {
// decrypt password
clone := c.DecryptPassword()
from := &clone.From
from.RawDBCfg = DefaultRawDBConfig().SetReadTimeout(dbReadTimeout)
from.RawDBCfg = DefaultRawDBConfig().SetReadTimeout(utils.DefaultDBTimeout.String())
return from
}

// Adjust flavor and serverid of SourceConfig
func (c *SourceConfig) Adjust(db *sql.DB) (err error) {
// Adjust flavor and server-id of SourceConfig
func (c *SourceConfig) Adjust(ctx context.Context, db *sql.DB) (err error) {
c.From.Adjust()
c.Checker.Adjust()

// use one timeout for all following DB operations.
ctx2, cancel := context.WithTimeout(ctx, utils.DefaultDBTimeout)
defer cancel()
if c.Flavor == "" || c.ServerID == 0 {
ctx, cancel := context.WithTimeout(context.Background(), dbGetTimeout)
defer cancel()

err = c.AdjustFlavor(ctx, db)
err = c.AdjustFlavor(ctx2, db)
if err != nil {
return err
}

err = c.AdjustServerID(ctx, db)
err = c.AdjustServerID(ctx2, db)
if err != nil {
return err
}
}

if c.EnableGTID {
val, err := utils.GetGTID(db)
val, err := utils.GetGTIDMode(ctx2, db)
if err != nil {
return err
}
Expand All @@ -264,7 +258,7 @@ func (c *SourceConfig) AdjustFlavor(ctx context.Context, db *sql.DB) (err error)

c.Flavor, err = utils.GetFlavor(ctx, db)
if ctx.Err() != nil {
err = terror.Annotatef(err, "time cost to get flavor info exceeds %s", dbGetTimeout)
err = terror.Annotatef(err, "fail to get flavor info %v", ctx.Err())
}
return terror.WithScope(err, terror.ScopeUpstream)
}
Expand All @@ -277,7 +271,7 @@ func (c *SourceConfig) AdjustServerID(ctx context.Context, db *sql.DB) error {

serverIDs, err := getAllServerIDFunc(ctx, db)
if ctx.Err() != nil {
err = terror.Annotatef(err, "time cost to get server-id info exceeds %s", dbGetTimeout)
err = terror.Annotatef(err, "fail to get server-id info %v", ctx.Err())
}
if err != nil {
return terror.WithScope(err, terror.ScopeUpstream)
Expand Down
16 changes: 10 additions & 6 deletions dm/ctl/master/query_status.go
Original file line number Diff line number Diff line change
Expand Up @@ -84,11 +84,13 @@ func queryStatusFunc(cmd *cobra.Command, _ []string) (err error) {
}

if resp.Result && taskName == "" && len(sources) == 0 && !more {
result := wrapTaskResult(resp)
common.PrettyPrintInterface(result)
} else {
common.PrettyPrintResponse(resp)
result, hasFalseResult := wrapTaskResult(resp)
if !hasFalseResult { // if any result is false, we still print the full status.
common.PrettyPrintInterface(result)
return
}
}
common.PrettyPrintResponse(resp)
return
}

Expand All @@ -106,10 +108,12 @@ func getRelayStage(relayStatus *pb.RelayStatus) string {
}

// wrapTaskResult picks task info and generate tasks' status and relative workers
func wrapTaskResult(resp *pb.QueryStatusListResponse) *taskResult {
func wrapTaskResult(resp *pb.QueryStatusListResponse) (result *taskResult, hasFalseResult bool) {
taskStatusMap := make(map[string]string)
taskCorrespondingSources := make(map[string][]string)
hasFalseResult = !resp.Result
for _, source := range resp.Sources {
hasFalseResult = hasFalseResult || !source.Result
relayStatus := source.SourceStatus.RelayStatus
for _, subTask := range source.SubTaskStatus {
subTaskName := subTask.Name
Expand Down Expand Up @@ -157,5 +161,5 @@ func wrapTaskResult(resp *pb.QueryStatusListResponse) *taskResult {
Result: resp.Result,
Msg: resp.Msg,
Tasks: taskList,
}
}, hasFalseResult
}
19 changes: 17 additions & 2 deletions dm/ctl/master/query_status_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,8 @@ type testCtlMaster struct {
var _ = check.Suite(&testCtlMaster{})

func generateAndCheckTaskResult(c *check.C, resp *pb.QueryStatusListResponse, expectedResult []*taskInfo) {
result := wrapTaskResult(resp)
result, hasFalseResult := wrapTaskResult(resp)
c.Assert(hasFalseResult, check.IsFalse)
c.Assert(result.Result, check.IsTrue)
c.Assert(result.Tasks, check.HasLen, 1)
sort.Strings(result.Tasks[0].Sources)
Expand Down Expand Up @@ -136,7 +137,8 @@ func (t *testCtlMaster) TestWrapTaskResult(c *check.C) {
},
}},
})
result := wrapTaskResult(resp)
result, hasFalseResult := wrapTaskResult(resp)
c.Assert(hasFalseResult, check.IsFalse)
c.Assert(result.Tasks, check.HasLen, 2)
if result.Tasks[0].TaskName == "test2" {
result.Tasks[0], result.Tasks[1] = result.Tasks[1], result.Tasks[0]
Expand All @@ -153,4 +155,17 @@ func (t *testCtlMaster) TestWrapTaskResult(c *check.C) {
},
}
c.Assert(result.Tasks, check.DeepEquals, expectedResult)

resp.Result = false
_, hasFalseResult = wrapTaskResult(resp)
c.Assert(hasFalseResult, check.IsTrue)

resp.Result = true
resp.Sources[0].Result = false
_, hasFalseResult = wrapTaskResult(resp)
c.Assert(hasFalseResult, check.IsTrue)

resp.Sources[0].Result = true
_, hasFalseResult = wrapTaskResult(resp)
c.Assert(hasFalseResult, check.IsFalse)
}
2 changes: 1 addition & 1 deletion dm/master/bootstrap.go
Original file line number Diff line number Diff line change
Expand Up @@ -161,7 +161,7 @@ func (s *Server) collectSourceConfigFilesV1Import(tctx *tcontext.Context) (map[s
return nil, err
}

cfgs2, err := parseAndAdjustSourceConfig([]string{string(content)})
cfgs2, err := parseAndAdjustSourceConfig(tctx.Ctx, []string{string(content)})
if err != nil {
// abort importing if any invalid source config files exist.
return nil, err
Expand Down
9 changes: 5 additions & 4 deletions dm/master/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -986,7 +986,7 @@ func (s *Server) CheckTask(ctx context.Context, req *pb.CheckTaskRequest) (*pb.C
}, nil
}

func parseAndAdjustSourceConfig(contents []string) ([]*config.SourceConfig, error) {
func parseAndAdjustSourceConfig(ctx context.Context, contents []string) ([]*config.SourceConfig, error) {
cfgs := make([]*config.SourceConfig, len(contents))
for i, content := range contents {
cfg := config.NewSourceConfig()
Expand All @@ -1000,7 +1000,7 @@ func parseAndAdjustSourceConfig(contents []string) ([]*config.SourceConfig, erro
if err != nil {
return cfgs, err
}
if err = cfg.Adjust(fromDB.DB); err != nil {
if err = cfg.Adjust(ctx, fromDB.DB); err != nil {
fromDB.Close()
return cfgs, err
}
Expand Down Expand Up @@ -1053,7 +1053,7 @@ func (s *Server) OperateSource(ctx context.Context, req *pb.OperateSourceRequest
return resp2, err2
}

cfgs, err := parseAndAdjustSourceConfig(req.Config)
cfgs, err := parseAndAdjustSourceConfig(ctx, req.Config)
resp := &pb.OperateSourceResponse{
Result: false,
}
Expand Down Expand Up @@ -1286,7 +1286,8 @@ func (s *Server) removeMetaData(ctx context.Context, cfg *config.TaskConfig) err
log.L().Warn("fail to close connection", zap.Error(err2))
}
}()
ctctx := tcontext.Background().WithContext(ctx).WithLogger(log.With(zap.String("job", "remove metadata")))

ctctx := tcontext.NewContext(ctx, log.With(zap.String("job", "remove metadata")))

sqls := make([]string, 0, 4)
// clear loader and syncer checkpoints
Expand Down
2 changes: 1 addition & 1 deletion dm/unit/unit.go
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ type Unit interface {
Update(cfg *config.SubTaskConfig) error

// Status returns the unit's current status
Status() interface{}
Status(ctx context.Context) interface{}
// Type returns the unit's type
Type() pb.UnitType
// IsFreshTask return whether is a fresh task (not processed before)
Expand Down
1 change: 1 addition & 0 deletions dm/worker/join.go
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ func (s *Server) JoinMaster(endpoints []string) error {
return terror.ErrWorkerTLSConfigNotValid.Delegate(err)
}

// join doesn't support to be canceled now, because it can return at most in (3s+3s) * len(endpoints).
ctx, cancel := context.WithCancel(context.Background())
defer cancel()

Expand Down
9 changes: 5 additions & 4 deletions dm/worker/relay.go
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ type RelayHolder interface {
// Close closes the holder
Close()
// Status returns relay unit's status
Status() *pb.RelayStatus
Status(ctx context.Context) *pb.RelayStatus
// Stage returns the stage of the relay
Stage() pb.Stage
// Error returns relay unit's status
Expand Down Expand Up @@ -123,6 +123,7 @@ func (h *realRelayHolder) Init(interceptors []purger.PurgeInterceptor) (purger.P
streamer.GetReaderHub(),
}

// TODO: refine the context usage of relay, and it may need to be initialized before handle any subtasks.
ctx, cancel := context.WithTimeout(context.Background(), unit.DefaultInitTimeout)
defer cancel()
if err := h.relay.Init(ctx); err != nil {
Expand Down Expand Up @@ -175,14 +176,14 @@ func (h *realRelayHolder) run() {
}

// Status returns relay unit's status
func (h *realRelayHolder) Status() *pb.RelayStatus {
func (h *realRelayHolder) Status(ctx context.Context) *pb.RelayStatus {
if h.closed.Get() == closedTrue || h.relay.IsClosed() {
return &pb.RelayStatus{
Stage: pb.Stage_Stopped,
}
}

s := h.relay.Status().(*pb.RelayStatus)
s := h.relay.Status(ctx).(*pb.RelayStatus)
s.Stage = h.Stage()
s.Result = h.Result()

Expand Down Expand Up @@ -418,7 +419,7 @@ func (d *dummyRelayHolder) Close() {
}

// Status implements interface of RelayHolder
func (d *dummyRelayHolder) Status() *pb.RelayStatus {
func (d *dummyRelayHolder) Status(ctx context.Context) *pb.RelayStatus {
d.Lock()
defer d.Unlock()
return &pb.RelayStatus{
Expand Down
10 changes: 5 additions & 5 deletions dm/worker/relay_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ func (d *DummyRelay) Error() interface{} {
}

// Status implements Process interface
func (d *DummyRelay) Status() interface{} {
func (d *DummyRelay) Status(ctx context.Context) interface{} {
return &pb.RelayStatus{
Stage: pb.Stage_New,
}
Expand Down Expand Up @@ -169,7 +169,7 @@ func (t *testRelay) testStart(c *C, holder *realRelayHolder) {
c.Assert(holder.closed.Get(), Equals, closedFalse)

// test status
status := holder.Status()
status := holder.Status(context.Background())
c.Assert(status.Stage, Equals, pb.Stage_Running)
c.Assert(status.Result, IsNil)

Expand Down Expand Up @@ -210,7 +210,7 @@ func (t *testRelay) testClose(c *C, holder *realRelayHolder) {
c.Assert(holder.closed.Get(), Equals, closedTrue)

// todo: very strange, and can't resume
status := holder.Status()
status := holder.Status(context.Background())
c.Assert(status.Stage, Equals, pb.Stage_Stopped)
c.Assert(status.Result, IsNil)

Expand All @@ -228,7 +228,7 @@ func (t *testRelay) testPauseAndResume(c *C, holder *realRelayHolder) {
c.Assert(err, ErrorMatches, ".*current stage is Paused.*")

// test status
status := holder.Status()
status := holder.Status(context.Background())
c.Assert(status.Stage, Equals, pb.Stage_Paused)

// test update
Expand All @@ -244,7 +244,7 @@ func (t *testRelay) testPauseAndResume(c *C, holder *realRelayHolder) {
c.Assert(err, ErrorMatches, ".*current stage is Running.*")

// test status
status = holder.Status()
status = holder.Status(context.Background())
c.Assert(status.Stage, Equals, pb.Stage_Running)
c.Assert(status.Result, IsNil)

Expand Down
4 changes: 2 additions & 2 deletions dm/worker/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -461,9 +461,9 @@ func (s *Server) QueryStatus(ctx context.Context, req *pb.QueryStatusRequest) (*
return resp, nil
}

resp.SubTaskStatus = w.QueryStatus(req.Name)
resp.SubTaskStatus = w.QueryStatus(ctx, req.Name)
if w.relayHolder != nil {
sourceStatus.RelayStatus = w.relayHolder.Status()
sourceStatus.RelayStatus = w.relayHolder.Status(ctx)
}

unifyMasterBinlogPos(resp, w.cfg.EnableGTID)
Expand Down
20 changes: 11 additions & 9 deletions dm/worker/status.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
package worker

import (
"context"
"encoding/json"
"fmt"
"sort"
Expand All @@ -25,26 +26,27 @@ import (
)

// Status returns the status of the current sub task
func (st *SubTask) Status() interface{} {
func (st *SubTask) Status(ctx context.Context) interface{} {
if cu := st.CurrUnit(); cu != nil {
return cu.Status()
return cu.Status(ctx)
}
return nil
}

// StatusJSON returns the status of the current sub task as json string
func (st *SubTask) StatusJSON() string {
sj, err := json.Marshal(st.Status())
func (st *SubTask) StatusJSON(ctx context.Context) string {
status := st.Status(ctx)
sj, err := json.Marshal(status)
if err != nil {
st.l.Error("fail to marshal status", zap.Reflect("status", st.Status()), zap.Error(err))
st.l.Error("fail to marshal status", zap.Reflect("status", status), zap.Error(err))
return ""
}
return string(sj)
}

// Status returns the status of the worker (and sub tasks)
// if stName is empty, all sub task's status will be returned
func (w *Worker) Status(stName string) []*pb.SubTaskStatus {
func (w *Worker) Status(ctx context.Context, stName string) []*pb.SubTaskStatus {

sts := w.subTaskHolder.getAllSubTasks()

Expand Down Expand Up @@ -91,7 +93,7 @@ func (w *Worker) Status(stName string) []*pb.SubTaskStatus {
if cu != nil {
stStatus.Unit = cu.Type()
// oneof status
us := cu.Status()
us := cu.Status(ctx)
switch stStatus.Unit {
case pb.UnitType_Check:
stStatus.Status = &pb.SubTaskStatus_Check{Check: us.(*pb.CheckStatus)}
Expand All @@ -111,8 +113,8 @@ func (w *Worker) Status(stName string) []*pb.SubTaskStatus {
}

// StatusJSON returns the status of the worker as json string
func (w *Worker) StatusJSON(stName string) string {
sl := &pb.SubTaskStatusList{Status: w.Status(stName)}
func (w *Worker) StatusJSON(ctx context.Context, stName string) string {
sl := &pb.SubTaskStatusList{Status: w.Status(ctx, stName)}
mar := jsonpb.Marshaler{EmitDefaults: true, Indent: " "}
s, err := mar.MarshalToString(sl)
if err != nil {
Expand Down
Loading

0 comments on commit 13191a8

Please sign in to comment.