From a9d9aa781818445157d61efab854916bced77b8e Mon Sep 17 00:00:00 2001 From: D3Hunter Date: Mon, 18 Nov 2024 16:54:48 +0800 Subject: [PATCH 1/9] systbl --- .../snap_client/systable_restore_test.go | 2 +- pkg/session/bootstrap.go | 21 ++++++++++++++++--- 2 files changed, 19 insertions(+), 4 deletions(-) diff --git a/br/pkg/restore/snap_client/systable_restore_test.go b/br/pkg/restore/snap_client/systable_restore_test.go index d5952135dbc5b..9e1812aabaebb 100644 --- a/br/pkg/restore/snap_client/systable_restore_test.go +++ b/br/pkg/restore/snap_client/systable_restore_test.go @@ -116,5 +116,5 @@ func TestCheckSysTableCompatibility(t *testing.T) { // // The above variables are in the file br/pkg/restore/systable_restore.go func TestMonitorTheSystemTableIncremental(t *testing.T) { - require.Equal(t, int64(218), session.CurrentBootstrapVersion) + require.Equal(t, int64(239), session.CurrentBootstrapVersion) } diff --git a/pkg/session/bootstrap.go b/pkg/session/bootstrap.go index 9845605414ec8..0a24c5cbaa266 100644 --- a/pkg/session/bootstrap.go +++ b/pkg/session/bootstrap.go @@ -590,8 +590,9 @@ const ( step INT(11), target_scope VARCHAR(256) DEFAULT "", error BLOB, + modify_params json, key(state), - UNIQUE KEY task_key(task_key) + UNIQUE KEY task_key(task_key) );` // CreateGlobalTaskHistory is a table about history global task. @@ -611,8 +612,9 @@ const ( step INT(11), target_scope VARCHAR(256) DEFAULT "", error BLOB, + modify_params json, key(state), - UNIQUE KEY task_key(task_key) + UNIQUE KEY task_key(task_key) );` // CreateDistFrameworkMeta create a system table that distributed task framework use to store meta information @@ -1197,11 +1199,15 @@ const ( // ... // next version should start with 239 + + // version 239 + // add modify_params to tidb_global_task and tidb_global_task_history. + version239 = 239 ) // currentBootstrapVersion is defined as a variable, so we can modify its value for testing. // please make sure this is the largest version -var currentBootstrapVersion int64 = version218 +var currentBootstrapVersion int64 = version239 // DDL owner key's expired time is ManagerSessionTTL seconds, we should wait the time and give more time to have a chance to finish it. var internalSQLTimeout = owner.ManagerSessionTTL + 15 @@ -1375,6 +1381,7 @@ var ( upgradeToVer216, upgradeToVer217, upgradeToVer218, + upgradeToVer239, } ) @@ -3272,6 +3279,14 @@ func upgradeToVer218(_ sessiontypes.Session, ver int64) { // empty, just make lint happy. } +func upgradeToVer239(s sessiontypes.Session, ver int64) { + if ver >= version239 { + return + } + doReentrantDDL(s, "ALTER TABLE mysql.tidb_global_task ADD COLUMN modify_params json AFTER `error`;", infoschema.ErrColumnExists) + doReentrantDDL(s, "ALTER TABLE mysql.tidb_global_task_history ADD COLUMN modify_params json AFTER `error`;", infoschema.ErrColumnExists) +} + // initGlobalVariableIfNotExists initialize a global variable with specific val if it does not exist. func initGlobalVariableIfNotExists(s sessiontypes.Session, name string, val any) { ctx := kv.WithInternalSourceType(context.Background(), kv.InternalTxnBootstrap) From 78a48813738e17ae214c51a0314b0d23afcfb228 Mon Sep 17 00:00:00 2001 From: D3Hunter Date: Tue, 19 Nov 2024 20:48:50 +0800 Subject: [PATCH 2/9] scheduler part --- pkg/disttask/framework/proto/BUILD.bazel | 1 + pkg/disttask/framework/proto/modify.go | 40 ++++++++++++++++ pkg/disttask/framework/proto/task.go | 47 +++++++++++++++---- pkg/disttask/framework/scheduler/scheduler.go | 32 +++++++++++-- .../framework/scheduler/scheduler_manager.go | 2 +- pkg/disttask/framework/storage/converter.go | 6 +++ pkg/disttask/framework/storage/task_table.go | 3 +- 7 files changed, 114 insertions(+), 17 deletions(-) create mode 100644 pkg/disttask/framework/proto/modify.go diff --git a/pkg/disttask/framework/proto/BUILD.bazel b/pkg/disttask/framework/proto/BUILD.bazel index 3c17cc47b5f20..fe35478678d68 100644 --- a/pkg/disttask/framework/proto/BUILD.bazel +++ b/pkg/disttask/framework/proto/BUILD.bazel @@ -3,6 +3,7 @@ load("@io_bazel_rules_go//go:def.bzl", "go_library", "go_test") go_library( name = "proto", srcs = [ + "modify.go", "node.go", "step.go", "subtask.go", diff --git a/pkg/disttask/framework/proto/modify.go b/pkg/disttask/framework/proto/modify.go new file mode 100644 index 0000000000000..3b26b3c30e71a --- /dev/null +++ b/pkg/disttask/framework/proto/modify.go @@ -0,0 +1,40 @@ +// Copyright 2024 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package proto + +// ModificationType is the type of task modification. +type ModificationType string + +// String implements fmt.Stringer interface. +func (t ModificationType) String() string { + return string(t) +} + +const ( + // ModifyConcurrency is the type for modifying task concurrency. + ModifyConcurrency ModificationType = "modify_concurrency" +) + +// ModifyParam is the parameter for task modification. +type ModifyParam struct { + PrevState TaskState `json:"prev_state"` + Modifications []Modification `json:"modifications"` +} + +// Modification is one modification for task. +type Modification struct { + Type ModificationType `json:"type"` + To int64 `json:"to"` +} diff --git a/pkg/disttask/framework/proto/task.go b/pkg/disttask/framework/proto/task.go index c9378d42c7dd9..2c6c26ae41d1e 100644 --- a/pkg/disttask/framework/proto/task.go +++ b/pkg/disttask/framework/proto/task.go @@ -24,14 +24,13 @@ import ( // The `failed` state is used to mean the framework cannot run the task, such as // invalid task type, scheduler init error(fatal), etc. // -// ┌────────┐ -// ┌───────────│resuming│◄────────┐ -// │ └────────┘ │ -// ┌──────┐ │ ┌───────┐ ┌──┴───┐ -// │failed│ │ ┌────────►│pausing├──────►│paused│ -// └──────┘ │ │ └───────┘ └──────┘ -// ▲ ▼ │ -// ┌──┴────┐ ┌───┴───┐ ┌────────┐ +// normal execution state transition: +// +// ┌──────┐ +// │failed│ +// └──────┘ +// ▲ +// ┌──┴────┐ ┌───────┐ ┌────────┐ // │pending├────►│running├────►│succeed │ // └──┬────┘ └──┬┬───┘ └────────┘ // │ ││ ┌─────────┐ ┌────────┐ @@ -40,6 +39,32 @@ import ( // │ ┌──────────┐ ▲ // └─────────►│cancelling├────┘ // └──────────┘ +// +// pause/resume state transition: +// as we don't know the state of the task before `paused`, so the state after +// `resuming` is always `running`. +// +// ┌───────┐ +// │pending├──┐ +// └───────┘ │ ┌───────┐ ┌──────┐ +// ├────►│pausing├──────►│paused│ +// ┌───────┐ │ └───────┘ └───┬──┘ +// │running├──┘ │ +// └───▲───┘ ┌────────┐ │ +// └────────────┤resuming│◄─────────┘ +// └────────┘ +// +// modifying state transition: +// +// ┌───────┐ +// │pending├──┐ +// └───────┘ │ +// ┌───────┐ │ ┌─────────┐ +// │running├──┼────►│modifying├────► original state +// └───────┘ │ └─────────┘ +// ┌───────┐ │ +// │paused ├──┘ +// └───────┘ const ( TaskStatePending TaskState = "pending" TaskStateRunning TaskState = "running" @@ -51,6 +76,7 @@ const ( TaskStatePausing TaskState = "pausing" TaskStatePaused TaskState = "paused" TaskStateResuming TaskState = "resuming" + TaskStateModifying TaskState = "modifying" ) type ( @@ -154,8 +180,9 @@ type Task struct { // changed in below case, and framework will update the task meta in the storage. // - task switches to next step in Scheduler.OnNextSubtasksBatch // - on task cleanup, we might do some redaction on the meta. - Meta []byte - Error error + Meta []byte + Error error + ModifyParam ModifyParam } var ( diff --git a/pkg/disttask/framework/scheduler/scheduler.go b/pkg/disttask/framework/scheduler/scheduler.go index 9d2228b96292e..4b181b01c7f7c 100644 --- a/pkg/disttask/framework/scheduler/scheduler.go +++ b/pkg/disttask/framework/scheduler/scheduler.go @@ -138,7 +138,7 @@ func (s *BaseScheduler) refreshTaskIfNeeded() error { if err != nil { return err } - // state might be changed by user to pausing/resuming/cancelling, or + // state might be changed by user to pausing/resuming/cancelling/modifying, or // in case of network partition, state/step/meta might be changed by other scheduler, // in both cases we refresh the whole task object. if newTaskBase.State != task.State || newTaskBase.Step != task.Step { @@ -227,8 +227,7 @@ func (s *BaseScheduler) scheduleTask() { return } case proto.TaskStateResuming: - // Case with 2 nodes. - // Here is the timeline + // need to check allocatedSlots for the following case: // 1. task in pausing state. // 2. node1 and node2 start schedulers with task in pausing state without allocatedSlots. // 3. node1's scheduler transfer the node from pausing to paused state. @@ -242,10 +241,20 @@ func (s *BaseScheduler) scheduleTask() { case proto.TaskStateReverting: err = s.onReverting() case proto.TaskStatePending: + // need to check allocatedSlots for the following case: + // 1. task in modifying state, node A and B start schedulers with + // task in modifying state without allocatedSlots. + // 2. node A's scheduler finished modifying, and transfer the node + // from modifying to pending state. + // 3. node B's scheduler call refreshTask and get task with pending + // state, but this scheduler has not allocated slots. + if !s.allocatedSlots { + s.logger.Info("scheduler exit since not allocated slots", zap.Stringer("state", task.State)) + return + } err = s.onPending() case proto.TaskStateRunning: - // Case with 2 nodes. - // Here is the timeline + // need to check allocatedSlots for the following case: // 1. task in pausing state. // 2. node1 and node2 start schedulers with task in pausing state without allocatedSlots. // 3. node1's scheduler transfer the node from pausing to paused state. @@ -257,6 +266,12 @@ func (s *BaseScheduler) scheduleTask() { return } err = s.onRunning() + case proto.TaskStateModifying: + var recreateScheduler bool + recreateScheduler, err = s.onModifying() + if err == nil && recreateScheduler { + return + } case proto.TaskStateSucceed, proto.TaskStateReverted, proto.TaskStateFailed: s.onFinished() return @@ -406,6 +421,13 @@ func (s *BaseScheduler) onRunning() error { return nil } +// onModifying is called when task is in modifying state. +// the first return value indicates whether the scheduler should be recreated. +func (s *BaseScheduler) onModifying() (bool, error) { + // TODO: implement me + panic("implement me") +} + func (s *BaseScheduler) onFinished() { task := s.GetTask() s.logger.Debug("schedule task, task is finished", zap.Stringer("state", task.State)) diff --git a/pkg/disttask/framework/scheduler/scheduler_manager.go b/pkg/disttask/framework/scheduler/scheduler_manager.go index 0b946111a812c..dec9e16134eb4 100644 --- a/pkg/disttask/framework/scheduler/scheduler_manager.go +++ b/pkg/disttask/framework/scheduler/scheduler_manager.go @@ -278,7 +278,7 @@ func (sm *Manager) startSchedulers(schedulableTasks []*proto.TaskBase) error { // task of lower rank might be able to be scheduled. continue } - // reverting/cancelling/pausing + // reverting/cancelling/pausing/modifying, we don't allocate slots for them. default: allocateSlots = false sm.logger.Info("start scheduler without allocating slots", diff --git a/pkg/disttask/framework/storage/converter.go b/pkg/disttask/framework/storage/converter.go index b18a196c041b9..b469a69c880b6 100644 --- a/pkg/disttask/framework/storage/converter.go +++ b/pkg/disttask/framework/storage/converter.go @@ -15,6 +15,7 @@ package storage import ( + "encoding/json" "strconv" "time" @@ -66,6 +67,11 @@ func Row2Task(r chunk.Row) *proto.Task { task.Error = stdErr } } + if !r.IsNull(14) { + if err := json.Unmarshal(r.GetJSON(14).GetString(), &task.ModifyParam); err != nil { + logutil.BgLogger().Error("unmarshal task modify param", zap.Error(err)) + } + } return task } diff --git a/pkg/disttask/framework/storage/task_table.go b/pkg/disttask/framework/storage/task_table.go index 4cc852610ef40..29064e98a7f43 100644 --- a/pkg/disttask/framework/storage/task_table.go +++ b/pkg/disttask/framework/storage/task_table.go @@ -39,7 +39,7 @@ const ( basicTaskColumns = `t.id, t.task_key, t.type, t.state, t.step, t.priority, t.concurrency, t.create_time, t.target_scope` // TaskColumns is the columns for task. // TODO: dispatcher_id will update to scheduler_id later - TaskColumns = basicTaskColumns + `, t.start_time, t.state_update_time, t.meta, t.dispatcher_id, t.error` + TaskColumns = basicTaskColumns + `, t.start_time, t.state_update_time, t.meta, t.dispatcher_id, t.error, t.modify_params` // InsertTaskColumns is the columns used in insert task. InsertTaskColumns = `task_key, type, state, priority, concurrency, step, meta, create_time, target_scope` basicSubtaskColumns = `id, step, task_key, type, exec_id, state, concurrency, create_time, ordinal, start_time` @@ -245,6 +245,7 @@ func (mgr *TaskManager) GetTopUnfinishedTasks(ctx context.Context) ([]*proto.Tas proto.TaskStateCancelling, proto.TaskStatePausing, proto.TaskStateResuming, + proto.TaskStateModifying, proto.MaxConcurrentTask*2, ) if err != nil { From 35942680926194ef7d215bed6d7d6f99fdd377a0 Mon Sep 17 00:00:00 2001 From: D3Hunter Date: Wed, 20 Nov 2024 16:19:52 +0800 Subject: [PATCH 3/9] change --- pkg/disttask/framework/scheduler/nodes.go | 1 + pkg/disttask/framework/storage/converter.go | 3 +- pkg/disttask/framework/storage/table_test.go | 36 ++++++++++++--- pkg/disttask/framework/storage/task_table.go | 47 +++++++++++++++++++- 4 files changed, 78 insertions(+), 9 deletions(-) diff --git a/pkg/disttask/framework/scheduler/nodes.go b/pkg/disttask/framework/scheduler/nodes.go index 7a8dc8cb401c4..46bcbce61eae9 100644 --- a/pkg/disttask/framework/scheduler/nodes.go +++ b/pkg/disttask/framework/scheduler/nodes.go @@ -145,6 +145,7 @@ func (nm *NodeManager) refreshNodes(ctx context.Context, taskMgr TaskManager, sl for _, node := range newNodes { if node.CPUCount > 0 { cpuCount = node.CPUCount + break } } slotMgr.updateCapacity(cpuCount) diff --git a/pkg/disttask/framework/storage/converter.go b/pkg/disttask/framework/storage/converter.go index b469a69c880b6..e4c140bbecbc7 100644 --- a/pkg/disttask/framework/storage/converter.go +++ b/pkg/disttask/framework/storage/converter.go @@ -68,7 +68,8 @@ func Row2Task(r chunk.Row) *proto.Task { } } if !r.IsNull(14) { - if err := json.Unmarshal(r.GetJSON(14).GetString(), &task.ModifyParam); err != nil { + str := r.GetJSON(14).String() + if err := json.Unmarshal([]byte(str), &task.ModifyParam); err != nil { logutil.BgLogger().Error("unmarshal task modify param", zap.Error(err)) } } diff --git a/pkg/disttask/framework/storage/table_test.go b/pkg/disttask/framework/storage/table_test.go index e047d53a30e54..445d7d984eb40 100644 --- a/pkg/disttask/framework/storage/table_test.go +++ b/pkg/disttask/framework/storage/table_test.go @@ -196,6 +196,20 @@ func TestTaskTable(t *testing.T) { task, err = gm.GetTaskByID(ctx, id) require.NoError(t, err) require.Equal(t, proto.TaskStatePaused, task.State) + // check modifying param + require.ErrorIs(t, gm.ModifyTaskByID(ctx, id, &proto.ModifyParam{ + PrevState: proto.TaskStateReverting, + }), storage.ErrTaskStateNotAllow) + require.ErrorIs(t, gm.ModifyTaskByID(ctx, 123123123, &proto.ModifyParam{ + PrevState: proto.TaskStatePaused, + }), storage.ErrTaskNotFound) + require.NoError(t, gm.ModifyTaskByID(ctx, id, &proto.ModifyParam{ + PrevState: proto.TaskStatePaused, + })) + task, err = gm.GetTaskByID(ctx, id) + require.NoError(t, err) + require.Equal(t, proto.TaskStateModifying, task.State) + require.Equal(t, proto.TaskStatePaused, task.ModifyParam.PrevState) } func checkAfterSwitchStep(t *testing.T, startTime time.Time, task *proto.Task, subtasks []*proto.Subtask, step proto.Step) { @@ -369,6 +383,7 @@ func TestGetTopUnfinishedTasks(t *testing.T) { proto.TaskStatePending, proto.TaskStatePending, proto.TaskStatePending, + proto.TaskStateModifying, } for i, state := range taskStates { taskKey := fmt.Sprintf("key/%d", i) @@ -403,17 +418,26 @@ func TestGetTopUnfinishedTasks(t *testing.T) { rs, err := sqlexec.ExecSQL(ctx, se.GetSQLExecutor(), ` select count(1) from mysql.tidb_global_task`) require.Len(t, rs, 1) - require.Equal(t, int64(12), rs[0].GetInt64(0)) + require.Equal(t, int64(13), rs[0].GetInt64(0)) return err })) + getTaskKeys := func(tasks []*proto.TaskBase) []string { + taskKeys := make([]string, 0, len(tasks)) + for _, task := range tasks { + taskKeys = append(taskKeys, task.Key) + } + return taskKeys + } tasks, err := gm.GetTopUnfinishedTasks(ctx) require.NoError(t, err) require.Len(t, tasks, 8) - taskKeys := make([]string, 0, len(tasks)) - for _, task := range tasks { - taskKeys = append(taskKeys, task.Key) - } - require.Equal(t, []string{"key/6", "key/5", "key/1", "key/2", "key/3", "key/4", "key/8", "key/9"}, taskKeys) + require.Equal(t, []string{"key/6", "key/5", "key/1", "key/2", "key/3", "key/4", "key/8", "key/9"}, getTaskKeys(tasks)) + + proto.MaxConcurrentTask = 6 + tasks, err = gm.GetTopUnfinishedTasks(ctx) + require.NoError(t, err) + require.Len(t, tasks, 11) + require.Equal(t, []string{"key/6", "key/5", "key/1", "key/2", "key/3", "key/4", "key/8", "key/9", "key/10", "key/11", "key/12"}, getTaskKeys(tasks)) } func TestGetUsedSlotsOnNodes(t *testing.T) { diff --git a/pkg/disttask/framework/storage/task_table.go b/pkg/disttask/framework/storage/task_table.go index 29064e98a7f43..ed83d58acfad1 100644 --- a/pkg/disttask/framework/storage/task_table.go +++ b/pkg/disttask/framework/storage/task_table.go @@ -16,6 +16,7 @@ package storage import ( "context" + "encoding/json" "strconv" "strings" "sync/atomic" @@ -65,6 +66,9 @@ var ( // i.e. SubmitTask in handle may submit a task twice. ErrTaskAlreadyExists = errors.New("task already exists") + // ErrTaskStateNotAllow is the error when the task state is not allowed to do the operation. + ErrTaskStateNotAllow = errors.New("task state not allow to do the operation") + // ErrSubtaskNotFound is the error when can't find subtask by subtask_id and execId, // i.e. scheduler change the subtask's execId when subtask need to balance to other nodes. ErrSubtaskNotFound = errors.New("subtask not found") @@ -236,7 +240,7 @@ func (mgr *TaskManager) CreateTaskWithSession( func (mgr *TaskManager) GetTopUnfinishedTasks(ctx context.Context) ([]*proto.TaskBase, error) { rs, err := mgr.ExecuteSQLWithNewSession(ctx, `select `+basicTaskColumns+` from mysql.tidb_global_task t - where state in (%?, %?, %?, %?, %?, %?) + where state in (%?, %?, %?, %?, %?, %?, %?) order by priority asc, create_time asc, id asc limit %?`, proto.TaskStatePending, @@ -319,7 +323,16 @@ func (mgr *TaskManager) GetTaskByID(ctx context.Context, taskID int64) (task *pr // GetTaskBaseByID implements the TaskManager.GetTaskBaseByID interface. func (mgr *TaskManager) GetTaskBaseByID(ctx context.Context, taskID int64) (task *proto.TaskBase, err error) { - rs, err := mgr.ExecuteSQLWithNewSession(ctx, "select "+basicTaskColumns+" from mysql.tidb_global_task t where id = %?", taskID) + err = mgr.WithNewSession(func(se sessionctx.Context) error { + var err2 error + task, err2 = mgr.getTaskBaseByID(ctx, se.GetSQLExecutor(), taskID) + return err2 + }) + return +} + +func (mgr *TaskManager) getTaskBaseByID(ctx context.Context, exec sqlexec.SQLExecutor, taskID int64) (task *proto.TaskBase, err error) { + rs, err := sqlexec.ExecSQL(ctx, exec, "select "+basicTaskColumns+" from mysql.tidb_global_task t where id = %?", taskID) if err != nil { return task, err } @@ -812,3 +825,33 @@ func (mgr *TaskManager) AdjustTaskOverflowConcurrency(ctx context.Context, se se _, err = sqlexec.ExecSQL(ctx, se.GetSQLExecutor(), sql, cpuCount, cpuCount) return err } + +// ModifyTaskByID modifies the task by the task ID. +func (mgr *TaskManager) ModifyTaskByID(ctx context.Context, taskID int64, param *proto.ModifyParam) error { + if param.PrevState != proto.TaskStatePending && + param.PrevState != proto.TaskStateRunning && + param.PrevState != proto.TaskStatePaused { + return ErrTaskStateNotAllow + } + bytes, err := json.Marshal(param) + if err != nil { + return errors.Trace(err) + } + return mgr.WithNewSession(func(se sessionctx.Context) error { + _, err = mgr.getTaskBaseByID(ctx, se.GetSQLExecutor(), taskID) + if err != nil { + return err + } + _, err = sqlexec.ExecSQL(ctx, se.GetSQLExecutor(), + `update mysql.tidb_global_task set state = %?, modify_params = %? where id = %? and state = %?`, + proto.TaskStateModifying, json.RawMessage(bytes), taskID, param.PrevState, + ) + if err != nil { + return err + } + if se.GetSessionVars().StmtCtx.AffectedRows() == 0 { + return ErrTaskStateNotAllow + } + return nil + }) +} From a5cb08f787d1472263e27e539dedb3d8867acd96 Mon Sep 17 00:00:00 2001 From: D3Hunter Date: Wed, 20 Nov 2024 17:52:05 +0800 Subject: [PATCH 4/9] lint --- pkg/disttask/framework/storage/task_table.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pkg/disttask/framework/storage/task_table.go b/pkg/disttask/framework/storage/task_table.go index ed83d58acfad1..8e4d54673abc3 100644 --- a/pkg/disttask/framework/storage/task_table.go +++ b/pkg/disttask/framework/storage/task_table.go @@ -331,7 +331,7 @@ func (mgr *TaskManager) GetTaskBaseByID(ctx context.Context, taskID int64) (task return } -func (mgr *TaskManager) getTaskBaseByID(ctx context.Context, exec sqlexec.SQLExecutor, taskID int64) (task *proto.TaskBase, err error) { +func (*TaskManager) getTaskBaseByID(ctx context.Context, exec sqlexec.SQLExecutor, taskID int64) (task *proto.TaskBase, err error) { rs, err := sqlexec.ExecSQL(ctx, exec, "select "+basicTaskColumns+" from mysql.tidb_global_task t where id = %?", taskID) if err != nil { return task, err From dbcef0592bdf3cd86a553d65fe168eb957a424c0 Mon Sep 17 00:00:00 2001 From: D3Hunter Date: Wed, 20 Nov 2024 18:06:30 +0800 Subject: [PATCH 5/9] lint --- pkg/disttask/framework/scheduler/scheduler.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pkg/disttask/framework/scheduler/scheduler.go b/pkg/disttask/framework/scheduler/scheduler.go index 4b181b01c7f7c..d8bfaf8868b9e 100644 --- a/pkg/disttask/framework/scheduler/scheduler.go +++ b/pkg/disttask/framework/scheduler/scheduler.go @@ -423,7 +423,7 @@ func (s *BaseScheduler) onRunning() error { // onModifying is called when task is in modifying state. // the first return value indicates whether the scheduler should be recreated. -func (s *BaseScheduler) onModifying() (bool, error) { +func (*BaseScheduler) onModifying() (bool, error) { // TODO: implement me panic("implement me") } From 6c3d73dbd8e4ad173645b74c6dcd9803f404c8bc Mon Sep 17 00:00:00 2001 From: D3Hunter Date: Thu, 21 Nov 2024 12:23:58 +0800 Subject: [PATCH 6/9] change --- pkg/disttask/framework/storage/task_table.go | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/pkg/disttask/framework/storage/task_table.go b/pkg/disttask/framework/storage/task_table.go index 8e4d54673abc3..e3579121d8794 100644 --- a/pkg/disttask/framework/storage/task_table.go +++ b/pkg/disttask/framework/storage/task_table.go @@ -842,8 +842,10 @@ func (mgr *TaskManager) ModifyTaskByID(ctx context.Context, taskID int64, param if err != nil { return err } - _, err = sqlexec.ExecSQL(ctx, se.GetSQLExecutor(), - `update mysql.tidb_global_task set state = %?, modify_params = %? where id = %? and state = %?`, + _, err = sqlexec.ExecSQL(ctx, se.GetSQLExecutor(), ` + update mysql.tidb_global_task + set state = %?, modify_params = %?, state_update_time = CURRENT_TIMESTAMP() + where id = %? and state = %?`, proto.TaskStateModifying, json.RawMessage(bytes), taskID, param.PrevState, ) if err != nil { From 6d0e2350f3f3871610bc46efac95bb8f9adef1f9 Mon Sep 17 00:00:00 2001 From: D3Hunter Date: Thu, 21 Nov 2024 13:47:40 +0800 Subject: [PATCH 7/9] move --- pkg/disttask/framework/storage/task_state.go | 34 ++++++++++++++++++++ pkg/disttask/framework/storage/task_table.go | 33 ------------------- 2 files changed, 34 insertions(+), 33 deletions(-) diff --git a/pkg/disttask/framework/storage/task_state.go b/pkg/disttask/framework/storage/task_state.go index dbca740eb7b23..6155e5c9cf228 100644 --- a/pkg/disttask/framework/storage/task_state.go +++ b/pkg/disttask/framework/storage/task_state.go @@ -16,7 +16,9 @@ package storage import ( "context" + "encoding/json" + "github.com/pingcap/errors" "github.com/pingcap/tidb/pkg/disttask/framework/proto" "github.com/pingcap/tidb/pkg/sessionctx" "github.com/pingcap/tidb/pkg/util/sqlexec" @@ -159,6 +161,38 @@ func (mgr *TaskManager) ResumedTask(ctx context.Context, taskID int64) error { return err } +// ModifyTaskByID modifies the task by the task ID. +func (mgr *TaskManager) ModifyTaskByID(ctx context.Context, taskID int64, param *proto.ModifyParam) error { + if param.PrevState != proto.TaskStatePending && + param.PrevState != proto.TaskStateRunning && + param.PrevState != proto.TaskStatePaused { + return ErrTaskStateNotAllow + } + bytes, err := json.Marshal(param) + if err != nil { + return errors.Trace(err) + } + return mgr.WithNewSession(func(se sessionctx.Context) error { + _, err = mgr.getTaskBaseByID(ctx, se.GetSQLExecutor(), taskID) + if err != nil { + return err + } + _, err = sqlexec.ExecSQL(ctx, se.GetSQLExecutor(), ` + update mysql.tidb_global_task + set state = %?, modify_params = %?, state_update_time = CURRENT_TIMESTAMP() + where id = %? and state = %?`, + proto.TaskStateModifying, json.RawMessage(bytes), taskID, param.PrevState, + ) + if err != nil { + return err + } + if se.GetSessionVars().StmtCtx.AffectedRows() == 0 { + return ErrTaskStateNotAllow + } + return nil + }) +} + // SucceedTask update task state from running to succeed. func (mgr *TaskManager) SucceedTask(ctx context.Context, taskID int64) error { return mgr.WithNewSession(func(se sessionctx.Context) error { diff --git a/pkg/disttask/framework/storage/task_table.go b/pkg/disttask/framework/storage/task_table.go index e3579121d8794..3d1f7ffde0129 100644 --- a/pkg/disttask/framework/storage/task_table.go +++ b/pkg/disttask/framework/storage/task_table.go @@ -16,7 +16,6 @@ package storage import ( "context" - "encoding/json" "strconv" "strings" "sync/atomic" @@ -825,35 +824,3 @@ func (mgr *TaskManager) AdjustTaskOverflowConcurrency(ctx context.Context, se se _, err = sqlexec.ExecSQL(ctx, se.GetSQLExecutor(), sql, cpuCount, cpuCount) return err } - -// ModifyTaskByID modifies the task by the task ID. -func (mgr *TaskManager) ModifyTaskByID(ctx context.Context, taskID int64, param *proto.ModifyParam) error { - if param.PrevState != proto.TaskStatePending && - param.PrevState != proto.TaskStateRunning && - param.PrevState != proto.TaskStatePaused { - return ErrTaskStateNotAllow - } - bytes, err := json.Marshal(param) - if err != nil { - return errors.Trace(err) - } - return mgr.WithNewSession(func(se sessionctx.Context) error { - _, err = mgr.getTaskBaseByID(ctx, se.GetSQLExecutor(), taskID) - if err != nil { - return err - } - _, err = sqlexec.ExecSQL(ctx, se.GetSQLExecutor(), ` - update mysql.tidb_global_task - set state = %?, modify_params = %?, state_update_time = CURRENT_TIMESTAMP() - where id = %? and state = %?`, - proto.TaskStateModifying, json.RawMessage(bytes), taskID, param.PrevState, - ) - if err != nil { - return err - } - if se.GetSessionVars().StmtCtx.AffectedRows() == 0 { - return ErrTaskStateNotAllow - } - return nil - }) -} From c6fa65fa00f307b9938f93c580ca86b3caf37e45 Mon Sep 17 00:00:00 2001 From: D3Hunter Date: Thu, 21 Nov 2024 15:10:07 +0800 Subject: [PATCH 8/9] change --- pkg/disttask/framework/proto/task.go | 5 ++ pkg/disttask/framework/storage/table_test.go | 14 ----- pkg/disttask/framework/storage/task_state.go | 22 ++++--- .../framework/storage/task_state_test.go | 57 +++++++++++++++++++ pkg/disttask/framework/storage/task_table.go | 3 + pkg/disttask/framework/testutil/table_util.go | 4 +- 6 files changed, 81 insertions(+), 24 deletions(-) diff --git a/pkg/disttask/framework/proto/task.go b/pkg/disttask/framework/proto/task.go index 2c6c26ae41d1e..9a65e4c52b983 100644 --- a/pkg/disttask/framework/proto/task.go +++ b/pkg/disttask/framework/proto/task.go @@ -94,6 +94,11 @@ func (s TaskState) String() string { return string(s) } +// CanMoveToModifying checks if current state can move to 'modifying' state. +func (s TaskState) CanMoveToModifying() bool { + return s == TaskStatePending || s == TaskStateRunning || s == TaskStatePaused +} + const ( // TaskIDLabelName is the label name of task id. TaskIDLabelName = "task_id" diff --git a/pkg/disttask/framework/storage/table_test.go b/pkg/disttask/framework/storage/table_test.go index 445d7d984eb40..bb05c79780b25 100644 --- a/pkg/disttask/framework/storage/table_test.go +++ b/pkg/disttask/framework/storage/table_test.go @@ -196,20 +196,6 @@ func TestTaskTable(t *testing.T) { task, err = gm.GetTaskByID(ctx, id) require.NoError(t, err) require.Equal(t, proto.TaskStatePaused, task.State) - // check modifying param - require.ErrorIs(t, gm.ModifyTaskByID(ctx, id, &proto.ModifyParam{ - PrevState: proto.TaskStateReverting, - }), storage.ErrTaskStateNotAllow) - require.ErrorIs(t, gm.ModifyTaskByID(ctx, 123123123, &proto.ModifyParam{ - PrevState: proto.TaskStatePaused, - }), storage.ErrTaskNotFound) - require.NoError(t, gm.ModifyTaskByID(ctx, id, &proto.ModifyParam{ - PrevState: proto.TaskStatePaused, - })) - task, err = gm.GetTaskByID(ctx, id) - require.NoError(t, err) - require.Equal(t, proto.TaskStateModifying, task.State) - require.Equal(t, proto.TaskStatePaused, task.ModifyParam.PrevState) } func checkAfterSwitchStep(t *testing.T, startTime time.Time, task *proto.Task, subtasks []*proto.Subtask, step proto.Step) { diff --git a/pkg/disttask/framework/storage/task_state.go b/pkg/disttask/framework/storage/task_state.go index 6155e5c9cf228..00723b6c01d19 100644 --- a/pkg/disttask/framework/storage/task_state.go +++ b/pkg/disttask/framework/storage/task_state.go @@ -19,6 +19,7 @@ import ( "encoding/json" "github.com/pingcap/errors" + "github.com/pingcap/failpoint" "github.com/pingcap/tidb/pkg/disttask/framework/proto" "github.com/pingcap/tidb/pkg/sessionctx" "github.com/pingcap/tidb/pkg/util/sqlexec" @@ -163,20 +164,22 @@ func (mgr *TaskManager) ResumedTask(ctx context.Context, taskID int64) error { // ModifyTaskByID modifies the task by the task ID. func (mgr *TaskManager) ModifyTaskByID(ctx context.Context, taskID int64, param *proto.ModifyParam) error { - if param.PrevState != proto.TaskStatePending && - param.PrevState != proto.TaskStateRunning && - param.PrevState != proto.TaskStatePaused { + if !param.PrevState.CanMoveToModifying() { return ErrTaskStateNotAllow } bytes, err := json.Marshal(param) if err != nil { return errors.Trace(err) } - return mgr.WithNewSession(func(se sessionctx.Context) error { - _, err = mgr.getTaskBaseByID(ctx, se.GetSQLExecutor(), taskID) - if err != nil { - return err + return mgr.WithNewTxn(ctx, func(se sessionctx.Context) error { + task, err2 := mgr.getTaskBaseByID(ctx, se.GetSQLExecutor(), taskID) + if err2 != nil { + return err2 + } + if task.State != param.PrevState { + return ErrTaskChanged } + failpoint.InjectCall("beforeMoveToModifying") _, err = sqlexec.ExecSQL(ctx, se.GetSQLExecutor(), ` update mysql.tidb_global_task set state = %?, modify_params = %?, state_update_time = CURRENT_TIMESTAMP() @@ -187,7 +190,10 @@ func (mgr *TaskManager) ModifyTaskByID(ctx context.Context, taskID int64, param return err } if se.GetSessionVars().StmtCtx.AffectedRows() == 0 { - return ErrTaskStateNotAllow + // the txn is pessimistic, it's possible that another txn has + // changed the task state before this txn commits and there is no + // write-conflict. + return ErrTaskChanged } return nil }) diff --git a/pkg/disttask/framework/storage/task_state_test.go b/pkg/disttask/framework/storage/task_state_test.go index 460a197fda910..ba0ad9e9f4b6f 100644 --- a/pkg/disttask/framework/storage/task_state_test.go +++ b/pkg/disttask/framework/storage/task_state_test.go @@ -16,12 +16,16 @@ package storage_test import ( "errors" + "sync" "testing" "github.com/pingcap/tidb/pkg/disttask/framework/proto" + "github.com/pingcap/tidb/pkg/disttask/framework/storage" "github.com/pingcap/tidb/pkg/disttask/framework/testutil" "github.com/pingcap/tidb/pkg/kv" "github.com/pingcap/tidb/pkg/sessionctx" + "github.com/pingcap/tidb/pkg/testkit/testfailpoint" + tidbutil "github.com/pingcap/tidb/pkg/util" "github.com/stretchr/testify/require" "github.com/tikv/client-go/v2/util" ) @@ -126,3 +130,56 @@ func TestTaskState(t *testing.T) { require.NoError(t, err) checkTaskStateStep(t, task, proto.TaskStateSucceed, proto.StepDone) } + +func TestModifyTask(t *testing.T) { + _, gm, ctx := testutil.InitTableTest(t) + require.NoError(t, gm.InitMeta(ctx, ":4000", "")) + + id, err := gm.CreateTask(ctx, "key1", "test", 4, "", []byte("test")) + require.NoError(t, err) + + require.ErrorIs(t, gm.ModifyTaskByID(ctx, id, &proto.ModifyParam{ + PrevState: proto.TaskStateReverting, + }), storage.ErrTaskStateNotAllow) + require.ErrorIs(t, gm.ModifyTaskByID(ctx, 123123123, &proto.ModifyParam{ + PrevState: proto.TaskStatePaused, + }), storage.ErrTaskNotFound) + require.ErrorIs(t, gm.ModifyTaskByID(ctx, id, &proto.ModifyParam{ + PrevState: proto.TaskStatePaused, + }), storage.ErrTaskChanged) + + // task changed in middle of modifying + ch := make(chan struct{}) + var wg tidbutil.WaitGroupWrapper + var once sync.Once + testfailpoint.EnableCall(t, "github.com/pingcap/tidb/pkg/disttask/framework/storage/beforeMoveToModifying", func() { + once.Do(func() { + <-ch + <-ch + }) + }) + task, err := gm.GetTaskByID(ctx, id) + require.NoError(t, err) + require.Equal(t, proto.TaskStatePending, task.State) + wg.Run(func() { + ch <- struct{}{} + require.NoError(t, gm.SwitchTaskStep(ctx, task, proto.TaskStateRunning, proto.StepOne, nil)) + ch <- struct{}{} + }) + require.ErrorIs(t, gm.ModifyTaskByID(ctx, id, &proto.ModifyParam{ + PrevState: proto.TaskStatePending, + }), storage.ErrTaskChanged) + wg.Wait() + + // move to 'modifying' success + task, err = gm.GetTaskByID(ctx, id) + require.NoError(t, err) + require.Equal(t, proto.TaskStateRunning, task.State) + require.NoError(t, gm.ModifyTaskByID(ctx, id, &proto.ModifyParam{ + PrevState: proto.TaskStateRunning, + })) + task, err = gm.GetTaskByID(ctx, id) + require.NoError(t, err) + require.Equal(t, proto.TaskStateModifying, task.State) + require.Equal(t, proto.TaskStateRunning, task.ModifyParam.PrevState) +} diff --git a/pkg/disttask/framework/storage/task_table.go b/pkg/disttask/framework/storage/task_table.go index 3d1f7ffde0129..f3e3dd86ee969 100644 --- a/pkg/disttask/framework/storage/task_table.go +++ b/pkg/disttask/framework/storage/task_table.go @@ -68,6 +68,9 @@ var ( // ErrTaskStateNotAllow is the error when the task state is not allowed to do the operation. ErrTaskStateNotAllow = errors.New("task state not allow to do the operation") + // ErrTaskChanged is the error when task changed by other operation. + ErrTaskChanged = errors.New("task changed by other operation") + // ErrSubtaskNotFound is the error when can't find subtask by subtask_id and execId, // i.e. scheduler change the subtask's execId when subtask need to balance to other nodes. ErrSubtaskNotFound = errors.New("subtask not found") diff --git a/pkg/disttask/framework/testutil/table_util.go b/pkg/disttask/framework/testutil/table_util.go index 74969e5a9963c..205104d1a919a 100644 --- a/pkg/disttask/framework/testutil/table_util.go +++ b/pkg/disttask/framework/testutil/table_util.go @@ -55,10 +55,10 @@ func InitTableTestWithCancel(t *testing.T) (*storage.TaskManager, context.Contex func getResourcePool(t *testing.T) (kv.Storage, *pools.ResourcePool) { testfailpoint.Enable(t, "github.com/pingcap/tidb/pkg/domain/MockDisableDistTask", "return(true)") store := testkit.CreateMockStore(t, mockstore.WithStoreType(mockstore.EmbedUnistore)) - tk := testkit.NewTestKit(t, store) pool := pools.NewResourcePool(func() (pools.Resource, error) { + tk := testkit.NewTestKit(t, store) return tk.Session(), nil - }, 1, 1, time.Second) + }, 10, 10, time.Second) t.Cleanup(func() { pool.Close() From 662094f161d9252bde4ad7ffa01e19db53991235 Mon Sep 17 00:00:00 2001 From: D3Hunter Date: Thu, 21 Nov 2024 16:33:50 +0800 Subject: [PATCH 9/9] lint --- pkg/disttask/framework/storage/BUILD.bazel | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/pkg/disttask/framework/storage/BUILD.bazel b/pkg/disttask/framework/storage/BUILD.bazel index d1dc457934de7..04481c4993c58 100644 --- a/pkg/disttask/framework/storage/BUILD.bazel +++ b/pkg/disttask/framework/storage/BUILD.bazel @@ -42,7 +42,7 @@ go_test( embed = [":storage"], flaky = True, race = "on", - shard_count = 22, + shard_count = 23, deps = [ "//pkg/config", "//pkg/disttask/framework/proto", @@ -53,6 +53,7 @@ go_test( "//pkg/testkit", "//pkg/testkit/testfailpoint", "//pkg/testkit/testsetup", + "//pkg/util", "//pkg/util/sqlexec", "@com_github_pingcap_errors//:errors", "@com_github_stretchr_testify//require",