diff --git a/pkg/cmd/roachtest/roachtestutil/task/BUILD.bazel b/pkg/cmd/roachtest/roachtestutil/task/BUILD.bazel index 82a5d94c5017..5cb630643424 100644 --- a/pkg/cmd/roachtest/roachtestutil/task/BUILD.bazel +++ b/pkg/cmd/roachtest/roachtestutil/task/BUILD.bazel @@ -3,9 +3,10 @@ load("@io_bazel_rules_go//go:def.bzl", "go_library", "go_test") go_library( name = "task", srcs = [ + "group.go", "manager.go", "options.go", - "tasker.go", + "task.go", "utils.go", ], importpath = "github.com/cockroachdb/cockroach/pkg/cmd/roachtest/roachtestutil/task", diff --git a/pkg/cmd/roachtest/roachtestutil/task/group.go b/pkg/cmd/roachtest/roachtestutil/task/group.go new file mode 100644 index 000000000000..0d530c3bf4d3 --- /dev/null +++ b/pkg/cmd/roachtest/roachtestutil/task/group.go @@ -0,0 +1,24 @@ +// Copyright 2024 The Cockroach Authors. +// +// Use of this software is governed by the CockroachDB Software License +// included in the /LICENSE file. + +package task + +// Group is an interface for managing a group of tasks. It is intended for use in +// tests, for creating a group and waiting for all tasks in the group to complete. +type Group interface { + Task + // Wait waits for all tasks in the group to complete. Errors from tasks are reported to the + // test framework automatically and will cause the test to fail, which also + // cancels the context passed to the group. + Wait() +} + +// GroupProvider is an interface for creating new Group(s). Generally, the test +// framework will supply a GroupProvider to tests. +type GroupProvider interface { + // NewGroup creates a new Group to manage tasks. Any options passed to this + // function will be applied to all tasks started by the group. + NewGroup(opts ...Option) Group +} diff --git a/pkg/cmd/roachtest/roachtestutil/task/manager.go b/pkg/cmd/roachtest/roachtestutil/task/manager.go index d15d783d5e37..8dfcecc08f50 100644 --- a/pkg/cmd/roachtest/roachtestutil/task/manager.go +++ b/pkg/cmd/roachtest/roachtestutil/task/manager.go @@ -19,9 +19,10 @@ type ( // Manager is responsible for managing a group of tasks initiated during // tests. The interface is designed for the test framework to control tasks. // Typically, tests will only interact, and be provided with the smaller - // Tasker interface to start tasks. + // Group interface to start tasks. Manager interface { - Tasker + Task + GroupProvider Terminate(*logger.Logger) CompletedEvents() <-chan Event } @@ -34,26 +35,39 @@ type ( } manager struct { - group ctxgroup.Group ctx context.Context logger *logger.Logger events chan Event id atomic.Uint32 - mu struct { + group *group + } + + group struct { + manager *manager + options []Option + ctxGroup ctxgroup.Group + mu struct { syncutil.Mutex cancelFns []context.CancelFunc + groups []*group } } ) +// NewManager creates a new Manager. The context passed to the manager is used +// to control the lifetime of all tasks started by the manager. The logger is +// the default logger used by all tasks started by the manager. func NewManager(ctx context.Context, l *logger.Logger) Manager { - g := ctxgroup.WithContext(ctx) - return &manager{ - group: g, + m := &manager{ ctx: ctx, logger: l, events: make(chan Event), } + m.group = &group{ + manager: m, + ctxGroup: ctxgroup.WithContext(ctx), + } + return m } func (m *manager) defaultOptions() []Option { @@ -73,9 +87,61 @@ func (m *manager) defaultOptions() []Option { } } +// Terminate will call the stop functions for every task, including tasks +// created by subgroups started during the test. Returns when all task functions +// have returned. +func (m *manager) Terminate(l *logger.Logger) { + m.group.cancelAll() + + doneCh := make(chan error) + go func() { + defer close(doneCh) + m.group.Wait() + }() + + WaitForChannel(doneCh, "tasks", l) +} + +// CompletedEvents returns a channel that will receive events for all tasks +// started by the manager. +func (m *manager) CompletedEvents() <-chan Event { + return m.events +} + +// NewGroup creates a new group of tasks as a subgroup under the manager's +// default group. +func (m *manager) NewGroup(opts ...Option) Group { + return m.group.NewGroup(opts...) +} + +// GoWithCancel runs GoWithCancel on the manager's default group. func (m *manager) GoWithCancel(fn Func, opts ...Option) context.CancelFunc { - opt := CombineOptions(OptionList(m.defaultOptions()...), OptionList(opts...)) - groupCtx, cancel := context.WithCancel(m.ctx) + return m.group.GoWithCancel(fn, opts...) +} + +// Go runs Go on the manager's default group. +func (m *manager) Go(fn Func, opts ...Option) { + _ = m.group.GoWithCancel(fn, opts...) +} + +func (t *group) NewGroup(opts ...Option) Group { + subgroup := &group{ + manager: t.manager, + options: opts, + ctxGroup: ctxgroup.WithContext(t.manager.ctx), + } + return subgroup +} + +func (t *group) GoWithCancel(fn Func, opts ...Option) context.CancelFunc { + // Combine options in order of precedence: default options, task options, and + // options passed to GoWithCancel. + opt := CombineOptions( + OptionList(t.manager.defaultOptions()...), + OptionList(t.options...), + OptionList(opts...), + ) + groupCtx, cancel := context.WithCancel(t.manager.ctx) var expectedContextCancellation atomic.Bool // internalFunc is a wrapper around the user-provided function that @@ -91,7 +157,7 @@ func (m *manager) GoWithCancel(fn Func, opts ...Option) context.CancelFunc { return retErr } - m.group.Go(func() error { + t.ctxGroup.Go(func() error { l, err := opt.L(opt.Name) if err != nil { return err @@ -103,9 +169,9 @@ func (m *manager) GoWithCancel(fn Func, opts ...Option) context.CancelFunc { ExpectedCancel: err != nil && IsContextCanceled(groupCtx) && expectedContextCancellation.Load(), } select { - case m.events <- event: + case t.manager.events <- event: // exit goroutine - case <-m.ctx.Done(): + case <-t.manager.ctx.Done(): // Parent context already finished, exit goroutine. return nil } @@ -118,36 +184,32 @@ func (m *manager) GoWithCancel(fn Func, opts ...Option) context.CancelFunc { } // Collect all taskCancelFn(s) so that we can explicitly stop all tasks when // the tasker is terminated. - m.mu.Lock() - defer m.mu.Unlock() - m.mu.cancelFns = append(m.mu.cancelFns, taskCancelFn) + t.mu.Lock() + defer t.mu.Unlock() + t.mu.cancelFns = append(t.mu.cancelFns, taskCancelFn) return taskCancelFn } -func (m *manager) Go(fn Func, opts ...Option) { - _ = m.GoWithCancel(fn, opts...) +func (t *group) Go(fn Func, opts ...Option) { + _ = t.GoWithCancel(fn, opts...) } -// Terminate will call the stop functions for every task started during the -// test. Returns when all task functions have returned. -func (m *manager) Terminate(l *logger.Logger) { - func() { - m.mu.Lock() - defer m.mu.Unlock() - for _, cancel := range m.mu.cancelFns { - cancel() - } - }() - - doneCh := make(chan error) - go func() { - defer close(doneCh) - _ = m.group.Wait() - }() - - WaitForChannel(doneCh, "tasks", l) +func (t *group) cancelAll() { + t.mu.Lock() + defer t.mu.Unlock() + for _, cancel := range t.mu.cancelFns { + cancel() + } + for _, g := range t.mu.groups { + g.cancelAll() + } } -func (m *manager) CompletedEvents() <-chan Event { - return m.events +func (t *group) Wait() { + t.mu.Lock() + defer t.mu.Unlock() + _ = t.ctxGroup.Wait() + for _, g := range t.mu.groups { + g.Wait() + } } diff --git a/pkg/cmd/roachtest/roachtestutil/task/manager_test.go b/pkg/cmd/roachtest/roachtestutil/task/manager_test.go index bacb13f5c122..0f09e5bdc854 100644 --- a/pkg/cmd/roachtest/roachtestutil/task/manager_test.go +++ b/pkg/cmd/roachtest/roachtestutil/task/manager_test.go @@ -135,6 +135,53 @@ func TestTerminate(t *testing.T) { require.Equal(t, uint32(numTasks), counter.Load()) } +func TestGroups(t *testing.T) { + defer leaktest.AfterTest(t)() + ctx := context.Background() + m := NewManager(ctx, nilLogger()) + + numTasks := 10 + g := m.NewGroup() + channels := make([]chan struct{}, 10) + + // Start tasks. + for i := 0; i < numTasks; i++ { + channels[i] = make(chan struct{}) + g.Go(func(ctx context.Context, l *logger.Logger) error { + <-channels[i] + return nil + }) + } + + // Start a goroutine that waits for all tasks in the group to complete. + done := make(chan struct{}) + go func() { + g.Wait() + close(done) + }() + + // Close channel one by one to complete all tasks, and ensure the group is not + // done yet. + for i := 0; i < numTasks; i++ { + select { + case <-done: + t.Fatal("group should not be done yet") + default: + } + close(channels[i]) + <-m.CompletedEvents() + } + + select { + case <-done: + case <-time.After(30 * time.Second): + t.Fatal("group should be done") + } + + m.Terminate(nilLogger()) + <-done +} + func nilLogger() *logger.Logger { lcfg := logger.Config{ Stdout: io.Discard, diff --git a/pkg/cmd/roachtest/roachtestutil/task/tasker.go b/pkg/cmd/roachtest/roachtestutil/task/task.go similarity index 86% rename from pkg/cmd/roachtest/roachtestutil/task/tasker.go rename to pkg/cmd/roachtest/roachtestutil/task/task.go index e5cf6c67311e..d4da364ba2ad 100644 --- a/pkg/cmd/roachtest/roachtestutil/task/tasker.go +++ b/pkg/cmd/roachtest/roachtestutil/task/task.go @@ -13,9 +13,9 @@ import ( type Func func(context.Context, *logger.Logger) error -// Tasker is an interface for executing tasks (goroutines). It is intended for +// Task is an interface for executing tasks (goroutines). It is intended for // use in tests, enabling the test framework to manage panics and errors. -type Tasker interface { +type Task interface { // Go runs the given function in a goroutine. Go(fn Func, opts ...Option) // GoWithCancel runs the given function in a goroutine and returns a diff --git a/pkg/cmd/roachtest/tests/mixed_version_backup.go b/pkg/cmd/roachtest/tests/mixed_version_backup.go index a0d6e9f31666..bd86e8895042 100644 --- a/pkg/cmd/roachtest/tests/mixed_version_backup.go +++ b/pkg/cmd/roachtest/tests/mixed_version_backup.go @@ -1826,7 +1826,7 @@ func (d *BackupRestoreTestDriver) saveContents( func (d *BackupRestoreTestDriver) runBackup( ctx context.Context, l *logger.Logger, - tasker task.Tasker, + tasker task.Task, rng *rand.Rand, nodes option.NodeListOption, pauseProbability float64, @@ -2016,7 +2016,7 @@ func (mvb *mixedVersionBackup) createBackupCollection( func (d *BackupRestoreTestDriver) createBackupCollection( ctx context.Context, l *logger.Logger, - tasker task.Tasker, + task task.Task, rng *rand.Rand, fullBackupSpec backupSpec, incBackupSpec backupSpec, @@ -2032,7 +2032,7 @@ func (d *BackupRestoreTestDriver) createBackupCollection( if err := d.testUtils.runJobOnOneOf(ctx, l, fullBackupSpec.Execute.Nodes, func() error { var err error collection, fullBackupEndTime, err = d.runBackup( - ctx, l, tasker, rng, fullBackupSpec.Plan.Nodes, fullBackupSpec.PauseProbability, + ctx, l, task, rng, fullBackupSpec.Plan.Nodes, fullBackupSpec.PauseProbability, fullBackup{backupNamePrefix}, internalSystemJobs, isMultitenant, ) return err @@ -2054,7 +2054,7 @@ func (d *BackupRestoreTestDriver) createBackupCollection( if err := d.testUtils.runJobOnOneOf(ctx, l, incBackupSpec.Execute.Nodes, func() error { var err error collection, latestIncBackupEndTime, err = d.runBackup( - ctx, l, tasker, rng, incBackupSpec.Plan.Nodes, incBackupSpec.PauseProbability, + ctx, l, task, rng, incBackupSpec.Plan.Nodes, incBackupSpec.PauseProbability, incrementalBackup{collection: collection, incNum: i + 1}, internalSystemJobs, isMultitenant, ) return err