Skip to content

Commit

Permalink
resourcemanager: fix unavailable Stop
Browse files Browse the repository at this point in the history
Signed-off-by: Weizhen Wang <[email protected]>
  • Loading branch information
hawkingrei committed Feb 3, 2023
1 parent 4cf07cc commit e741e62
Show file tree
Hide file tree
Showing 5 changed files with 71 additions and 67 deletions.
38 changes: 20 additions & 18 deletions resourcemanager/pooltask/task.go
Original file line number Diff line number Diff line change
Expand Up @@ -129,41 +129,43 @@ type GPool[T any, U any, C any, CT any, TF Context[CT]] interface {

// TaskController is a controller that can control or watch the pool.
type TaskController[T any, U any, C any, CT any, TF Context[CT]] struct {
pool GPool[T, U, C, CT, TF]
productCloseCh chan struct{}
wg *sync.WaitGroup
prodWg *sync.WaitGroup
taskID uint64
resultCh chan U
inputCh chan Task[T]
pool GPool[T, U, C, CT, TF]
productExitCh chan struct{}
wg *sync.WaitGroup
taskID uint64
resultCh chan U
inputCh chan Task[T]
}

// NewTaskController create a controller to deal with pooltask's status.
func NewTaskController[T any, U any, C any, CT any, TF Context[CT]](p GPool[T, U, C, CT, TF], taskID uint64, productCloseCh chan struct{}, wg, prodWg *sync.WaitGroup, inputCh chan Task[T], resultCh chan U) TaskController[T, U, C, CT, TF] {
func NewTaskController[T any, U any, C any, CT any, TF Context[CT]](p GPool[T, U, C, CT, TF], taskID uint64, productExitCh chan struct{}, wg *sync.WaitGroup, inputCh chan Task[T], resultCh chan U) TaskController[T, U, C, CT, TF] {
return TaskController[T, U, C, CT, TF]{
pool: p,
taskID: taskID,
productCloseCh: productCloseCh,
wg: wg,
prodWg: prodWg,
resultCh: resultCh,
inputCh: inputCh,
pool: p,
taskID: taskID,
productExitCh: productExitCh,
wg: wg,
resultCh: resultCh,
inputCh: inputCh,
}
}

// Wait is to wait the pool task to stop.
func (t *TaskController[T, U, C, CT, TF]) Wait() {
t.prodWg.Wait()
t.wg.Wait()
close(t.resultCh)
t.pool.DeleteTask(t.taskID)
}

// Stop is to send stop command to the task. But you still need to wait the task to stop.
func (t *TaskController[T, U, C, CT, TF]) Stop() {
close(t.productCloseCh)
channel.Clear(t.inputCh)
close(t.productExitCh)
// Clear all the task in the task queue and mark all task complete.
// so that ```t.Wait``` is able to close resultCh
for range t.inputCh {
t.wg.Done()
}
t.pool.StopTask(t.TaskID())
// Clear the resultCh to avoid blocking the consumer put result into the channel and cannot exit.
channel.Clear(t.resultCh)
}

Expand Down
1 change: 0 additions & 1 deletion util/gpool/spmc/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,6 @@ go_test(
embed = [":spmc"],
flaky = True,
race = "on",
shard_count = 2,
deps = [
"//resourcemanager/pooltask",
"//resourcemanager/util",
Expand Down
4 changes: 3 additions & 1 deletion util/gpool/spmc/option.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@ import (
"time"
)

const defaultTaskChanLen = 1

// Option represents the optional function.
type Option func(opts *Options)

Expand Down Expand Up @@ -104,7 +106,7 @@ func loadTaskOptions(options ...TaskOption) *TaskOptions {
opts.ResultChanLen = uint64(opts.Concurrency)
}
if opts.TaskChanLen == 0 {
opts.TaskChanLen = uint64(opts.Concurrency)
opts.TaskChanLen = defaultTaskChanLen
}
return opts
}
Expand Down
76 changes: 35 additions & 41 deletions util/gpool/spmc/spmcpool.go
Original file line number Diff line number Diff line change
Expand Up @@ -219,7 +219,6 @@ func (p *Pool[T, U, C, CT, TF]) release() {
// There might be some callers waiting in retrieveWorker(), so we need to wake them up to prevent
// those callers blocking infinitely.
p.cond.Broadcast()
close(p.taskCh)
}

func isClose(exitCh chan struct{}) bool {
Expand Down Expand Up @@ -258,11 +257,11 @@ func (p *Pool[T, U, C, CT, TF]) SetConsumerFunc(consumerFunc func(T, C, CT) U) {
func (p *Pool[T, U, C, CT, TF]) AddProduceBySlice(producer func() ([]T, error), constArg C, contextFn TF, options ...TaskOption) (<-chan U, pooltask.TaskController[T, U, C, CT, TF]) {
opt := loadTaskOptions(options...)
taskID := p.NewTaskID()
var wg, prodWg sync.WaitGroup
var wg sync.WaitGroup
result := make(chan U, opt.ResultChanLen)
productCloseCh := make(chan struct{})
productExitCh := make(chan struct{})
inputCh := make(chan pooltask.Task[T], opt.TaskChanLen)
tc := pooltask.NewTaskController[T, U, C, CT, TF](p, taskID, productCloseCh, &wg, &prodWg, inputCh, result)
tc := pooltask.NewTaskController[T, U, C, CT, TF](p, taskID, productExitCh, &wg, inputCh, result)
p.taskManager.RegisterTask(taskID, int32(opt.Concurrency))
for i := 0; i < opt.Concurrency; i++ {
err := p.run()
Expand All @@ -274,37 +273,34 @@ func (p *Pool[T, U, C, CT, TF]) AddProduceBySlice(producer func() ([]T, error),
p.taskManager.AddSubTask(taskID, &taskBox)
p.taskCh <- &taskBox
}
prodWg.Add(1)
wg.Add(1)
go func() {
defer func() {
if r := recover(); r != nil {
logutil.BgLogger().Error("producer panic", zap.Any("recover", r), zap.Stack("stack"))
}
close(inputCh)
prodWg.Done()
wg.Done()
}()
for {
select {
case <-productCloseCh:
if isClose(productExitCh) {
return
default:
tasks, err := producer()
if err != nil {
if errors.Is(err, gpool.ErrProducerClosed) {
return
}
log.Error("producer error", zap.Error(err))
}
tasks, err := producer()
if err != nil {
if errors.Is(err, gpool.ErrProducerClosed) {
return
}
for _, task := range tasks {
wg.Add(1)
task := pooltask.Task[T]{
Task: task,
}
inputCh <- task
log.Error("producer error", zap.Error(err))
return
}
for _, task := range tasks {
wg.Add(1)
task := pooltask.Task[T]{
Task: task,
}
inputCh <- task
}

}
}()
return result, tc
Expand All @@ -315,12 +311,12 @@ func (p *Pool[T, U, C, CT, TF]) AddProduceBySlice(producer func() ([]T, error),
func (p *Pool[T, U, C, CT, TF]) AddProducer(producer func() (T, error), constArg C, contextFn TF, options ...TaskOption) (<-chan U, pooltask.TaskController[T, U, C, CT, TF]) {
opt := loadTaskOptions(options...)
taskID := p.NewTaskID()
var wg, prodWg sync.WaitGroup
var wg sync.WaitGroup
result := make(chan U, opt.ResultChanLen)
productCloseCh := make(chan struct{})
productExitCh := make(chan struct{})
inputCh := make(chan pooltask.Task[T], opt.TaskChanLen)
p.taskManager.RegisterTask(taskID, int32(opt.Concurrency))
tc := pooltask.NewTaskController[T, U, C, CT, TF](p, taskID, productCloseCh, &wg, &prodWg, inputCh, result)
tc := pooltask.NewTaskController[T, U, C, CT, TF](p, taskID, productExitCh, &wg, inputCh, result)
for i := 0; i < opt.Concurrency; i++ {
err := p.run()
if err == gpool.ErrPoolClosed {
Expand All @@ -331,34 +327,32 @@ func (p *Pool[T, U, C, CT, TF]) AddProducer(producer func() (T, error), constArg
p.taskManager.AddSubTask(taskID, &taskBox)
p.taskCh <- &taskBox
}
prodWg.Add(1)
wg.Add(1)
go func() {
defer func() {
if r := recover(); r != nil {
logutil.BgLogger().Error("producer panic", zap.Any("recover", r), zap.Stack("stack"))
}
close(inputCh)
prodWg.Done()
wg.Done()
}()
for {
select {
case <-productCloseCh:
if isClose(productExitCh) {
return
default:
task, err := producer()
if err != nil {
if errors.Is(err, gpool.ErrProducerClosed) {
return
}
log.Error("producer error", zap.Error(err))
}
task, err := producer()
if err != nil {
if errors.Is(err, gpool.ErrProducerClosed) {
return
}
wg.Add(1)
t := pooltask.Task[T]{
Task: task,
}
inputCh <- t
log.Error("producer error", zap.Error(err))
return
}
wg.Add(1)
t := pooltask.Task[T]{
Task: task,
}
inputCh <- t
}
}()
return result, tc
Expand Down
19 changes: 13 additions & 6 deletions util/gpool/spmc/spmcpool_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ func TestPool(t *testing.T) {
}
}
// add new task
resultCh, control := pool.AddProducer(pfunc, myArgs, pooltask.NilContext{}, WithConcurrency(4))
resultCh, control := pool.AddProducer(pfunc, myArgs, pooltask.NilContext{}, WithConcurrency(5))

var count atomic.Uint32
var wg sync.WaitGroup
Expand Down Expand Up @@ -112,8 +112,12 @@ func TestStopPool(t *testing.T) {
require.Greater(t, result, 10)
}
}()
wg.Add(1)
go func() {
defer wg.Done()
control.Stop()
}()
// Waiting task finishing
control.Stop()
control.Wait()
wg.Wait()
// close pool
Expand Down Expand Up @@ -152,10 +156,10 @@ func TestStopPoolWithSlice(t *testing.T) {
defer wg.Done()
for result := range resultCh {
require.Greater(t, result, 10)
control.Stop()
}
}()
// Waiting task finishing
control.Stop()
control.Wait()
wg.Wait()
// close pool
Expand Down Expand Up @@ -230,9 +234,12 @@ func testTunePool(t *testing.T, name string) {
for n := pool.Cap(); n > 1; n-- {
downclockPool(t, pool, tid)
}

// exit test
control.Stop()
wg.Add(1)
go func() {
// exit test
control.Stop()
wg.Done()
}()
control.Wait()
wg.Wait()
// close pool
Expand Down

0 comments on commit e741e62

Please sign in to comment.