diff --git a/pkg/disttask/framework/taskexecutor/manager.go b/pkg/disttask/framework/taskexecutor/manager.go index 7f864e9c7b5e1..48fb9ad321a34 100644 --- a/pkg/disttask/framework/taskexecutor/manager.go +++ b/pkg/disttask/framework/taskexecutor/manager.go @@ -94,7 +94,7 @@ func (b *ManagerBuilder) BuildManager(ctx context.Context, id string, taskTable logCtx: logutil.WithFields(context.Background()), newPool: b.newPool, slotManager: &slotManager{ - executorSlotInfos: make(map[int64]*slotInfo, 0), + executorSlotInfos: make(map[int64]*slotInfo), available: cpu.GetCPUCount(), }, } @@ -222,7 +222,7 @@ func (m *Manager) onRunnableTasks(tasks []*proto.Task) { } logutil.Logger(m.logCtx).Info("detect new subtask", zap.Int64("task-id", task.ID)) - if !m.slotManager.canReserve(task) { + if !m.slotManager.canAlloc(task) { failpoint.Inject("taskTick", func() { <-onRunnableTasksTick }) @@ -231,14 +231,13 @@ func (m *Manager) onRunnableTasks(tasks []*proto.Task) { m.addHandlingTask(task.ID) t := task err = m.executorPool.Run(func() { - m.slotManager.reserve(t) - defer m.slotManager.unReserve(t.ID) + m.slotManager.alloc(t) + defer m.slotManager.free(t.ID) m.onRunnableTask(t) m.removeHandlingTask(t.ID) }) // pool closed. if err != nil { - m.slotManager.unReserve(t.ID) m.removeHandlingTask(task.ID) m.logErr(err) return diff --git a/pkg/disttask/framework/taskexecutor/slot.go b/pkg/disttask/framework/taskexecutor/slot.go index cb9a34ef606c2..e788312fbccc3 100644 --- a/pkg/disttask/framework/taskexecutor/slot.go +++ b/pkg/disttask/framework/taskexecutor/slot.go @@ -38,7 +38,7 @@ type slotInfo struct { slotCount int } -func (sm *slotManager) reserve(task *proto.Task) { +func (sm *slotManager) alloc(task *proto.Task) { sm.Lock() defer sm.Unlock() sm.executorSlotInfos[task.ID] = &slotInfo{ @@ -50,7 +50,7 @@ func (sm *slotManager) reserve(task *proto.Task) { sm.available -= task.Concurrency } -func (sm *slotManager) unReserve(taskID int64) { +func (sm *slotManager) free(taskID int64) { sm.Lock() defer sm.Unlock() @@ -62,9 +62,9 @@ func (sm *slotManager) unReserve(taskID int64) { } // canReserve is used to check whether the instance has enough slots to run the task. -func (sm *slotManager) canReserve(task *proto.Task) bool { - sm.Lock() - defer sm.Unlock() +func (sm *slotManager) canAlloc(task *proto.Task) bool { + sm.RLock() + defer sm.RUnlock() return sm.available >= task.Concurrency } diff --git a/pkg/disttask/framework/taskexecutor/slot_test.go b/pkg/disttask/framework/taskexecutor/slot_test.go index 8b18c8dce8558..65fff426b9f04 100644 --- a/pkg/disttask/framework/taskexecutor/slot_test.go +++ b/pkg/disttask/framework/taskexecutor/slot_test.go @@ -38,18 +38,18 @@ func TestSlotManager(t *testing.T) { Priority: 1, Concurrency: 1, } - require.True(t, sm.canReserve(task)) - sm.reserve(task) + require.True(t, sm.canAlloc(task)) + sm.alloc(task) require.Equal(t, 1, sm.executorSlotInfos[taskID].priority) require.Equal(t, 1, sm.executorSlotInfos[taskID].slotCount) require.Equal(t, 9, sm.available) - require.False(t, sm.canReserve(&proto.Task{ + require.False(t, sm.canAlloc(&proto.Task{ ID: taskID2, Priority: 2, Concurrency: 10, })) - sm.unReserve(taskID) + sm.free(taskID) require.Nil(t, sm.executorSlotInfos[taskID]) }