Skip to content

Commit

Permalink
roachtest: add groups to task manager
Browse files Browse the repository at this point in the history
Previously, a new API for managing tasks were introduced (see #133263). However,
this did not address roachtests that want to manage groups. In an effort to replace
`errgroup`, and possibly using `monitor.Go` for task management this change
introduces a group provider in the task manager.

A group adds the ability to wait on a subset of tasks to finish before
proceeding. The task handler will still report returned errors or panics to the
test framework.

Informs: #118214

Epic: None
Release note: None
  • Loading branch information
herkolategan committed Nov 21, 2024
1 parent 62c52f4 commit 73c9af6
Show file tree
Hide file tree
Showing 6 changed files with 178 additions and 44 deletions.
3 changes: 2 additions & 1 deletion pkg/cmd/roachtest/roachtestutil/task/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,10 @@ load("@io_bazel_rules_go//go:def.bzl", "go_library", "go_test")
go_library(
name = "task",
srcs = [
"group.go",
"manager.go",
"options.go",
"tasker.go",
"task.go",
"utils.go",
],
importpath = "github.com/cockroachdb/cockroach/pkg/cmd/roachtest/roachtestutil/task",
Expand Down
24 changes: 24 additions & 0 deletions pkg/cmd/roachtest/roachtestutil/task/group.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
// Copyright 2024 The Cockroach Authors.
//
// Use of this software is governed by the CockroachDB Software License
// included in the /LICENSE file.

package task

// Group is an interface for managing a group of tasks. It is intended for use in
// tests, for creating a group and waiting for all tasks in the group to complete.
type Group interface {
Task
// Wait waits for all tasks in the group to complete. Errors from tasks are reported to the
// test framework automatically and will cause the test to fail, which also
// cancels the context passed to the group.
Wait()
}

// GroupProvider is an interface for creating new Group(s). Generally, the test
// framework will supply a GroupProvider to tests.
type GroupProvider interface {
// NewGroup creates a new Group to manage tasks. Any options passed to this
// function will be applied to all tasks started by the group.
NewGroup(opts ...Option) Group
}
136 changes: 99 additions & 37 deletions pkg/cmd/roachtest/roachtestutil/task/manager.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,10 @@ type (
// Manager is responsible for managing a group of tasks initiated during
// tests. The interface is designed for the test framework to control tasks.
// Typically, tests will only interact, and be provided with the smaller
// Tasker interface to start tasks.
// Group interface to start tasks.
Manager interface {
Tasker
Task
GroupProvider
Terminate(*logger.Logger)
CompletedEvents() <-chan Event
}
Expand All @@ -34,26 +35,39 @@ type (
}

manager struct {
group ctxgroup.Group
ctx context.Context
logger *logger.Logger
events chan Event
id atomic.Uint32
mu struct {
group *group
}

group struct {
manager *manager
options []Option
ctxGroup ctxgroup.Group
mu struct {
syncutil.Mutex
cancelFns []context.CancelFunc
groups []*group
}
}
)

// NewManager creates a new Manager. The context passed to the manager is used
// to control the lifetime of all tasks started by the manager. The logger is
// the default logger used by all tasks started by the manager.
func NewManager(ctx context.Context, l *logger.Logger) Manager {
g := ctxgroup.WithContext(ctx)
return &manager{
group: g,
m := &manager{
ctx: ctx,
logger: l,
events: make(chan Event),
}
m.group = &group{
manager: m,
ctxGroup: ctxgroup.WithContext(ctx),
}
return m
}

func (m *manager) defaultOptions() []Option {
Expand All @@ -73,9 +87,61 @@ func (m *manager) defaultOptions() []Option {
}
}

// Terminate will call the stop functions for every task, including tasks
// created by subgroups started during the test. Returns when all task functions
// have returned.
func (m *manager) Terminate(l *logger.Logger) {
m.group.cancelAll()

doneCh := make(chan error)
go func() {
defer close(doneCh)
m.group.Wait()
}()

WaitForChannel(doneCh, "tasks", l)
}

// CompletedEvents returns a channel that will receive events for all tasks
// started by the manager.
func (m *manager) CompletedEvents() <-chan Event {
return m.events
}

// NewGroup creates a new group of tasks as a subgroup under the manager's
// default group.
func (m *manager) NewGroup(opts ...Option) Group {
return m.group.NewGroup(opts...)
}

// GoWithCancel runs GoWithCancel on the manager's default group.
func (m *manager) GoWithCancel(fn Func, opts ...Option) context.CancelFunc {
opt := CombineOptions(OptionList(m.defaultOptions()...), OptionList(opts...))
groupCtx, cancel := context.WithCancel(m.ctx)
return m.group.GoWithCancel(fn, opts...)
}

// Go runs Go on the manager's default group.
func (m *manager) Go(fn Func, opts ...Option) {
_ = m.group.GoWithCancel(fn, opts...)
}

func (t *group) NewGroup(opts ...Option) Group {
subgroup := &group{
manager: t.manager,
options: opts,
ctxGroup: ctxgroup.WithContext(t.manager.ctx),
}
return subgroup
}

func (t *group) GoWithCancel(fn Func, opts ...Option) context.CancelFunc {
// Combine options in order of precedence: default options, task options, and
// options passed to GoWithCancel.
opt := CombineOptions(
OptionList(t.manager.defaultOptions()...),
OptionList(t.options...),
OptionList(opts...),
)
groupCtx, cancel := context.WithCancel(t.manager.ctx)
var expectedContextCancellation atomic.Bool

// internalFunc is a wrapper around the user-provided function that
Expand All @@ -91,7 +157,7 @@ func (m *manager) GoWithCancel(fn Func, opts ...Option) context.CancelFunc {
return retErr
}

m.group.Go(func() error {
t.ctxGroup.Go(func() error {
l, err := opt.L(opt.Name)
if err != nil {
return err
Expand All @@ -103,9 +169,9 @@ func (m *manager) GoWithCancel(fn Func, opts ...Option) context.CancelFunc {
ExpectedCancel: err != nil && IsContextCanceled(groupCtx) && expectedContextCancellation.Load(),
}
select {
case m.events <- event:
case t.manager.events <- event:
// exit goroutine
case <-m.ctx.Done():
case <-t.manager.ctx.Done():
// Parent context already finished, exit goroutine.
return nil
}
Expand All @@ -118,36 +184,32 @@ func (m *manager) GoWithCancel(fn Func, opts ...Option) context.CancelFunc {
}
// Collect all taskCancelFn(s) so that we can explicitly stop all tasks when
// the tasker is terminated.
m.mu.Lock()
defer m.mu.Unlock()
m.mu.cancelFns = append(m.mu.cancelFns, taskCancelFn)
t.mu.Lock()
defer t.mu.Unlock()
t.mu.cancelFns = append(t.mu.cancelFns, taskCancelFn)
return taskCancelFn
}

func (m *manager) Go(fn Func, opts ...Option) {
_ = m.GoWithCancel(fn, opts...)
func (t *group) Go(fn Func, opts ...Option) {
_ = t.GoWithCancel(fn, opts...)
}

// Terminate will call the stop functions for every task started during the
// test. Returns when all task functions have returned.
func (m *manager) Terminate(l *logger.Logger) {
func() {
m.mu.Lock()
defer m.mu.Unlock()
for _, cancel := range m.mu.cancelFns {
cancel()
}
}()

doneCh := make(chan error)
go func() {
defer close(doneCh)
_ = m.group.Wait()
}()

WaitForChannel(doneCh, "tasks", l)
func (t *group) cancelAll() {
t.mu.Lock()
defer t.mu.Unlock()
for _, cancel := range t.mu.cancelFns {
cancel()
}
for _, g := range t.mu.groups {
g.cancelAll()
}
}

func (m *manager) CompletedEvents() <-chan Event {
return m.events
func (t *group) Wait() {
t.mu.Lock()
defer t.mu.Unlock()
_ = t.ctxGroup.Wait()
for _, g := range t.mu.groups {
g.Wait()
}
}
47 changes: 47 additions & 0 deletions pkg/cmd/roachtest/roachtestutil/task/manager_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,53 @@ func TestTerminate(t *testing.T) {
require.Equal(t, uint32(numTasks), counter.Load())
}

func TestGroups(t *testing.T) {
defer leaktest.AfterTest(t)()
ctx := context.Background()
m := NewManager(ctx, nilLogger())

numTasks := 10
g := m.NewGroup()
channels := make([]chan struct{}, 10)

// Start tasks.
for i := 0; i < numTasks; i++ {
channels[i] = make(chan struct{})
g.Go(func(ctx context.Context, l *logger.Logger) error {
<-channels[i]
return nil
})
}

// Start a goroutine that waits for all tasks in the group to complete.
done := make(chan struct{})
go func() {
g.Wait()
close(done)
}()

// Close channel one by one to complete all tasks, and ensure the group is not
// done yet.
for i := 0; i < numTasks; i++ {
select {
case <-done:
t.Fatal("group should not be done yet")
default:
}
close(channels[i])
<-m.CompletedEvents()
}

select {
case <-done:
case <-time.After(30 * time.Second):
t.Fatal("group should be done")
}

m.Terminate(nilLogger())
<-done
}

func nilLogger() *logger.Logger {
lcfg := logger.Config{
Stdout: io.Discard,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,9 @@ import (

type Func func(context.Context, *logger.Logger) error

// Tasker is an interface for executing tasks (goroutines). It is intended for
// Task is an interface for executing tasks (goroutines). It is intended for
// use in tests, enabling the test framework to manage panics and errors.
type Tasker interface {
type Task interface {
// Go runs the given function in a goroutine.
Go(fn Func, opts ...Option)
// GoWithCancel runs the given function in a goroutine and returns a
Expand Down
8 changes: 4 additions & 4 deletions pkg/cmd/roachtest/tests/mixed_version_backup.go
Original file line number Diff line number Diff line change
Expand Up @@ -1826,7 +1826,7 @@ func (d *BackupRestoreTestDriver) saveContents(
func (d *BackupRestoreTestDriver) runBackup(
ctx context.Context,
l *logger.Logger,
tasker task.Tasker,
tasker task.Task,
rng *rand.Rand,
nodes option.NodeListOption,
pauseProbability float64,
Expand Down Expand Up @@ -2016,7 +2016,7 @@ func (mvb *mixedVersionBackup) createBackupCollection(
func (d *BackupRestoreTestDriver) createBackupCollection(
ctx context.Context,
l *logger.Logger,
tasker task.Tasker,
task task.Task,
rng *rand.Rand,
fullBackupSpec backupSpec,
incBackupSpec backupSpec,
Expand All @@ -2032,7 +2032,7 @@ func (d *BackupRestoreTestDriver) createBackupCollection(
if err := d.testUtils.runJobOnOneOf(ctx, l, fullBackupSpec.Execute.Nodes, func() error {
var err error
collection, fullBackupEndTime, err = d.runBackup(
ctx, l, tasker, rng, fullBackupSpec.Plan.Nodes, fullBackupSpec.PauseProbability,
ctx, l, task, rng, fullBackupSpec.Plan.Nodes, fullBackupSpec.PauseProbability,
fullBackup{backupNamePrefix}, internalSystemJobs, isMultitenant,
)
return err
Expand All @@ -2054,7 +2054,7 @@ func (d *BackupRestoreTestDriver) createBackupCollection(
if err := d.testUtils.runJobOnOneOf(ctx, l, incBackupSpec.Execute.Nodes, func() error {
var err error
collection, latestIncBackupEndTime, err = d.runBackup(
ctx, l, tasker, rng, incBackupSpec.Plan.Nodes, incBackupSpec.PauseProbability,
ctx, l, task, rng, incBackupSpec.Plan.Nodes, incBackupSpec.PauseProbability,
incrementalBackup{collection: collection, incNum: i + 1}, internalSystemJobs, isMultitenant,
)
return err
Expand Down

0 comments on commit 73c9af6

Please sign in to comment.