diff --git a/pkg/cmd/roachtest/roachtestutil/task/BUILD.bazel b/pkg/cmd/roachtest/roachtestutil/task/BUILD.bazel index 82a5d94c5017..5cb630643424 100644 --- a/pkg/cmd/roachtest/roachtestutil/task/BUILD.bazel +++ b/pkg/cmd/roachtest/roachtestutil/task/BUILD.bazel @@ -3,9 +3,10 @@ load("@io_bazel_rules_go//go:def.bzl", "go_library", "go_test") go_library( name = "task", srcs = [ + "group.go", "manager.go", "options.go", - "tasker.go", + "task.go", "utils.go", ], importpath = "github.com/cockroachdb/cockroach/pkg/cmd/roachtest/roachtestutil/task", diff --git a/pkg/cmd/roachtest/roachtestutil/task/group.go b/pkg/cmd/roachtest/roachtestutil/task/group.go new file mode 100644 index 000000000000..b3e6814feccb --- /dev/null +++ b/pkg/cmd/roachtest/roachtestutil/task/group.go @@ -0,0 +1,24 @@ +// Copyright 2024 The Cockroach Authors. +// +// Use of this software is governed by the CockroachDB Software License +// included in the /LICENSE file. + +package task + +// Group is an interface for managing a group of tasks. It is intended for use in +// tests, for creating a group and waiting for all tasks in the group to complete. +type Group interface { + Tasker + // Wait waits for all tasks in the group to complete. Errors from tasks are reported to the + // test framework automatically and will cause the test to fail, which also + // cancels the context passed to the group. + Wait() +} + +// GroupProvider is an interface for creating new Group(s). Generally, the test +// framework will supply a GroupProvider to tests. +type GroupProvider interface { + // NewGroup creates a new Group to manage tasks. Any options passed to this + // function will be applied to all tasks started by the group. + NewGroup(opts ...Option) Group +} diff --git a/pkg/cmd/roachtest/roachtestutil/task/manager.go b/pkg/cmd/roachtest/roachtestutil/task/manager.go index de4bbd0e3c37..89e692c70627 100644 --- a/pkg/cmd/roachtest/roachtestutil/task/manager.go +++ b/pkg/cmd/roachtest/roachtestutil/task/manager.go @@ -19,9 +19,10 @@ type ( // Manager is responsible for managing a group of tasks initiated during // tests. The interface is designed for the test framework to control tasks. // Typically, tests will only interact, and be provided with the smaller - // Tasker interface to start tasks. + // Group interface to start tasks. Manager interface { Tasker + GroupProvider Terminate(*logger.Logger) CompletedEvents() <-chan Event } @@ -34,26 +35,39 @@ type ( } manager struct { - group ctxgroup.Group ctx context.Context logger *logger.Logger events chan Event id atomic.Uint32 - mu struct { + group *group + } + + group struct { + manager *manager + options []Option + ctxGroup ctxgroup.Group + mu struct { syncutil.Mutex cancelFns []context.CancelFunc + groups []*group } } ) +// NewManager creates a new Manager. The context passed to the manager is used +// to control the lifetime of all tasks started by the manager. The logger is +// the default logger used by all tasks started by the manager. func NewManager(ctx context.Context, l *logger.Logger) Manager { - g := ctxgroup.WithContext(ctx) - return &manager{ - group: g, + m := &manager{ ctx: ctx, logger: l, events: make(chan Event), } + m.group = &group{ + manager: m, + ctxGroup: ctxgroup.WithContext(ctx), + } + return m } func (m *manager) defaultOptions() []Option { @@ -73,9 +87,62 @@ func (m *manager) defaultOptions() []Option { } } +// Terminate will call the stop functions for every task started during the +// test. Returns when all task functions have returned, or after a 5-minute +// timeout, whichever comes first. If the timeout is reached, the function logs +// a warning message and returns. +func (m *manager) Terminate(l *logger.Logger) { + m.group.cancelAll() + + doneCh := make(chan error) + go func() { + defer close(doneCh) + m.group.Wait() + }() + + WaitForChannel(doneCh, "tasks", l) +} + +// CompletedEvents returns a channel that will receive events for all tasks +// started by the manager. +func (m *manager) CompletedEvents() <-chan Event { + return m.events +} + +// NewGroup creates a new group of tasks as a subgroup under the manager's +// default group. +func (m *manager) NewGroup(opts ...Option) Group { + return m.group.NewGroup(opts...) +} + +// GoWithCancel runs GoWithCancel on the manager's default group. func (m *manager) GoWithCancel(fn Func, opts ...Option) context.CancelFunc { - opt := CombineOptions(OptionList(m.defaultOptions()...), OptionList(opts...)) - groupCtx, cancel := context.WithCancel(m.ctx) + return m.group.GoWithCancel(fn, opts...) +} + +// Go runs Go on the manager's default group. +func (m *manager) Go(fn Func, opts ...Option) { + _ = m.group.GoWithCancel(fn, opts...) +} + +func (t *group) NewGroup(opts ...Option) Group { + subgroup := &group{ + manager: t.manager, + options: opts, + ctxGroup: ctxgroup.WithContext(t.manager.ctx), + } + return subgroup +} + +func (t *group) GoWithCancel(fn Func, opts ...Option) context.CancelFunc { + // Combine options in order of precedence: default options, task options, and + // options passed to GoWithCancel. + opt := CombineOptions( + OptionList(t.manager.defaultOptions()...), + OptionList(t.options...), + OptionList(opts...), + ) + groupCtx, cancel := context.WithCancel(t.manager.ctx) var expectedContextCancellation atomic.Bool // internalFunc is a wrapper around the user-provided function that @@ -91,7 +158,7 @@ func (m *manager) GoWithCancel(fn Func, opts ...Option) context.CancelFunc { return retErr } - m.group.Go(func() error { + t.ctxGroup.Go(func() error { l, err := opt.L(opt.Name) if err != nil { return err @@ -114,10 +181,10 @@ func (m *manager) GoWithCancel(fn Func, opts ...Option) context.CancelFunc { // already aware of the cancelation and sending an event would be redundant. // For instance, a call to test.Fatal would already have captured the error // and canceled the context. - if IsContextCanceled(m.ctx) { + if IsContextCanceled(t.manager.ctx) { return nil } - m.events <- event + t.manager.events <- event return err }) @@ -127,38 +194,32 @@ func (m *manager) GoWithCancel(fn Func, opts ...Option) context.CancelFunc { } // Collect all taskCancelFn(s) so that we can explicitly stop all tasks when // the tasker is terminated. - m.mu.Lock() - defer m.mu.Unlock() - m.mu.cancelFns = append(m.mu.cancelFns, taskCancelFn) + t.mu.Lock() + defer t.mu.Unlock() + t.mu.cancelFns = append(t.mu.cancelFns, taskCancelFn) return taskCancelFn } -func (m *manager) Go(fn Func, opts ...Option) { - _ = m.GoWithCancel(fn, opts...) +func (t *group) Go(fn Func, opts ...Option) { + _ = t.GoWithCancel(fn, opts...) } -// Terminate will call the stop functions for every task started during the -// test. Returns when all task functions have returned, or after a 5-minute -// timeout, whichever comes first. If the timeout is reached, the function logs -// a warning message and returns. -func (m *manager) Terminate(l *logger.Logger) { - func() { - m.mu.Lock() - defer m.mu.Unlock() - for _, cancel := range m.mu.cancelFns { - cancel() - } - }() - - doneCh := make(chan error) - go func() { - defer close(doneCh) - _ = m.group.Wait() - }() - - WaitForChannel(doneCh, "tasks", l) +func (t *group) cancelAll() { + t.mu.Lock() + defer t.mu.Unlock() + for _, cancel := range t.mu.cancelFns { + cancel() + } + for _, g := range t.mu.groups { + g.cancelAll() + } } -func (m *manager) CompletedEvents() <-chan Event { - return m.events +func (t *group) Wait() { + t.mu.Lock() + defer t.mu.Unlock() + _ = t.ctxGroup.Wait() + for _, g := range t.mu.groups { + g.Wait() + } } diff --git a/pkg/cmd/roachtest/roachtestutil/task/manager_test.go b/pkg/cmd/roachtest/roachtestutil/task/manager_test.go index bacb13f5c122..0f09e5bdc854 100644 --- a/pkg/cmd/roachtest/roachtestutil/task/manager_test.go +++ b/pkg/cmd/roachtest/roachtestutil/task/manager_test.go @@ -135,6 +135,53 @@ func TestTerminate(t *testing.T) { require.Equal(t, uint32(numTasks), counter.Load()) } +func TestGroups(t *testing.T) { + defer leaktest.AfterTest(t)() + ctx := context.Background() + m := NewManager(ctx, nilLogger()) + + numTasks := 10 + g := m.NewGroup() + channels := make([]chan struct{}, 10) + + // Start tasks. + for i := 0; i < numTasks; i++ { + channels[i] = make(chan struct{}) + g.Go(func(ctx context.Context, l *logger.Logger) error { + <-channels[i] + return nil + }) + } + + // Start a goroutine that waits for all tasks in the group to complete. + done := make(chan struct{}) + go func() { + g.Wait() + close(done) + }() + + // Close channel one by one to complete all tasks, and ensure the group is not + // done yet. + for i := 0; i < numTasks; i++ { + select { + case <-done: + t.Fatal("group should not be done yet") + default: + } + close(channels[i]) + <-m.CompletedEvents() + } + + select { + case <-done: + case <-time.After(30 * time.Second): + t.Fatal("group should be done") + } + + m.Terminate(nilLogger()) + <-done +} + func nilLogger() *logger.Logger { lcfg := logger.Config{ Stdout: io.Discard, diff --git a/pkg/cmd/roachtest/tests/mixed_version_backup.go b/pkg/cmd/roachtest/tests/mixed_version_backup.go index a0d6e9f31666..af17041eaf73 100644 --- a/pkg/cmd/roachtest/tests/mixed_version_backup.go +++ b/pkg/cmd/roachtest/tests/mixed_version_backup.go @@ -2016,7 +2016,7 @@ func (mvb *mixedVersionBackup) createBackupCollection( func (d *BackupRestoreTestDriver) createBackupCollection( ctx context.Context, l *logger.Logger, - tasker task.Tasker, + task task.Tasker, rng *rand.Rand, fullBackupSpec backupSpec, incBackupSpec backupSpec, @@ -2032,7 +2032,7 @@ func (d *BackupRestoreTestDriver) createBackupCollection( if err := d.testUtils.runJobOnOneOf(ctx, l, fullBackupSpec.Execute.Nodes, func() error { var err error collection, fullBackupEndTime, err = d.runBackup( - ctx, l, tasker, rng, fullBackupSpec.Plan.Nodes, fullBackupSpec.PauseProbability, + ctx, l, task, rng, fullBackupSpec.Plan.Nodes, fullBackupSpec.PauseProbability, fullBackup{backupNamePrefix}, internalSystemJobs, isMultitenant, ) return err @@ -2054,7 +2054,7 @@ func (d *BackupRestoreTestDriver) createBackupCollection( if err := d.testUtils.runJobOnOneOf(ctx, l, incBackupSpec.Execute.Nodes, func() error { var err error collection, latestIncBackupEndTime, err = d.runBackup( - ctx, l, tasker, rng, incBackupSpec.Plan.Nodes, incBackupSpec.PauseProbability, + ctx, l, task, rng, incBackupSpec.Plan.Nodes, incBackupSpec.PauseProbability, incrementalBackup{collection: collection, incNum: i + 1}, internalSystemJobs, isMultitenant, ) return err