diff --git a/pkg/cmd/roachtest/cluster_test.go b/pkg/cmd/roachtest/cluster_test.go index b0e31bac009e..401463fedd23 100644 --- a/pkg/cmd/roachtest/cluster_test.go +++ b/pkg/cmd/roachtest/cluster_test.go @@ -144,6 +144,10 @@ func (t testWrapper) Go(_ task.Func, _ ...task.Option) { panic("implement me") } +func (t testWrapper) NewGroup() task.Group { + panic("implement me") +} + var _ test2.Test = testWrapper{} // ArtifactsDir is part of the test.Test interface. diff --git a/pkg/cmd/roachtest/clusterstats/mock_test_generated_test.go b/pkg/cmd/roachtest/clusterstats/mock_test_generated_test.go index 35ee7ca580fe..f78dee987155 100644 --- a/pkg/cmd/roachtest/clusterstats/mock_test_generated_test.go +++ b/pkg/cmd/roachtest/clusterstats/mock_test_generated_test.go @@ -343,6 +343,20 @@ func (mr *MockTestMockRecorder) Name() *gomock.Call { return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Name", reflect.TypeOf((*MockTest)(nil).Name)) } +// NewGroup mocks base method. +func (m *MockTest) NewGroup() task.Group { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "NewGroup") + ret0, _ := ret[0].(task.Group) + return ret0 +} + +// NewGroup indicates an expected call of NewGroup. +func (mr *MockTestMockRecorder) NewGroup() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "NewGroup", reflect.TypeOf((*MockTest)(nil).NewGroup)) +} + // PerfArtifactsDir mocks base method. func (m *MockTest) PerfArtifactsDir() string { m.ctrl.T.Helper() diff --git a/pkg/cmd/roachtest/roachtestutil/mixedversion/helper.go b/pkg/cmd/roachtest/roachtestutil/mixedversion/helper.go index b52c5dce46f7..04fb8583fb71 100644 --- a/pkg/cmd/roachtest/roachtestutil/mixedversion/helper.go +++ b/pkg/cmd/roachtest/roachtestutil/mixedversion/helper.go @@ -224,8 +224,9 @@ func (h *Helper) ExecWithGateway( return h.DefaultService().ExecWithGateway(rng, nodes, query, args...) } -// GoWithCancel implements the Tasker interface. -func (h *Helper) GoWithCancel(fn task.Func, opts ...task.Option) context.CancelFunc { +// defaultTaskOptions returns the default options that are passed to all tasks +// started by the helper. +func (h *Helper) defaultTaskOptions() []task.Option { loggerFuncOpt := task.LoggerFunc(func(name string) (*logger.Logger, error) { bgLogger, err := h.loggerFor(name) if err != nil { @@ -246,8 +247,13 @@ func (h *Helper) GoWithCancel(fn task.Func, opts ...task.Option) context.CancelF } return nil }) + return []task.Option{loggerFuncOpt, panicOpt, errHandlerOpt} +} + +// GoWithCancel implements the Tasker interface. +func (h *Helper) GoWithCancel(fn task.Func, opts ...task.Option) context.CancelFunc { return h.runner.background.GoWithCancel( - fn, task.OptionList(opts...), loggerFuncOpt, panicOpt, errHandlerOpt, + fn, task.OptionList(h.defaultTaskOptions()...), task.OptionList(opts...), ) } @@ -256,6 +262,11 @@ func (h *Helper) Go(fn task.Func, opts ...task.Option) { h.GoWithCancel(fn, opts...) } +// NewGroup implements the Group interface. +func (h *Helper) NewGroup() task.Group { + return h.runner.background.NewGroup(h.defaultTaskOptions()...) +} + // GoCommand has the same semantics of `GoWithCancel()`; the command passed will // run and the test will fail if the command is not successful. The task name is // derived from the command passed. diff --git a/pkg/cmd/roachtest/roachtestutil/task/BUILD.bazel b/pkg/cmd/roachtest/roachtestutil/task/BUILD.bazel index 82a5d94c5017..693aa588c776 100644 --- a/pkg/cmd/roachtest/roachtestutil/task/BUILD.bazel +++ b/pkg/cmd/roachtest/roachtestutil/task/BUILD.bazel @@ -3,6 +3,7 @@ load("@io_bazel_rules_go//go:def.bzl", "go_library", "go_test") go_library( name = "task", srcs = [ + "group.go", "manager.go", "options.go", "tasker.go", diff --git a/pkg/cmd/roachtest/roachtestutil/task/group.go b/pkg/cmd/roachtest/roachtestutil/task/group.go new file mode 100644 index 000000000000..e13abf2c8261 --- /dev/null +++ b/pkg/cmd/roachtest/roachtestutil/task/group.go @@ -0,0 +1,25 @@ +// Copyright 2024 The Cockroach Authors. +// +// Use of this software is governed by the CockroachDB Software License +// included in the /LICENSE file. + +package task + +// Group is an interface for managing a group of tasks. It is intended for use +// in roachtests, for creating a group and waiting for all tasks in the group to +// complete. +type Group interface { + Tasker + // Wait waits for all tasks in the group to complete. Errors from tasks are reported to the + // test framework automatically and will cause the test to fail, which also + // cancels the context passed to the group. + Wait() +} + +// GroupProvider is an interface for creating new Group(s). Generally, the test +// framework will supply a GroupProvider to tests. +type GroupProvider interface { + // NewGroup creates a new Group to manage tasks. Any options passed to this + // function will be applied to all tasks started by the group. + NewGroup(opts ...Option) Group +} diff --git a/pkg/cmd/roachtest/roachtestutil/task/manager.go b/pkg/cmd/roachtest/roachtestutil/task/manager.go index de4bbd0e3c37..234f98859ce4 100644 --- a/pkg/cmd/roachtest/roachtestutil/task/manager.go +++ b/pkg/cmd/roachtest/roachtestutil/task/manager.go @@ -18,10 +18,11 @@ import ( type ( // Manager is responsible for managing a group of tasks initiated during // tests. The interface is designed for the test framework to control tasks. - // Typically, tests will only interact, and be provided with the smaller - // Tasker interface to start tasks. + // Typically, tests will only interact, and be provided with the smaller Group + // and Tasker interfaces to start tasks or wait on groups of tasks. Manager interface { Tasker + GroupProvider Terminate(*logger.Logger) CompletedEvents() <-chan Event } @@ -34,26 +35,42 @@ type ( } manager struct { - group ctxgroup.Group ctx context.Context logger *logger.Logger events chan Event id atomic.Uint32 - mu struct { + group *group + } + + group struct { + manager *manager + options []Option + ctxGroup ctxgroup.Group + groupMu struct { + syncutil.Mutex + groups []*group + } + cancelMu struct { syncutil.Mutex cancelFns []context.CancelFunc } } ) +// NewManager creates a new Manager. The context passed to the manager is used +// to control the lifetime of all tasks started by the manager. The logger is +// the default logger used by all tasks started by the manager. func NewManager(ctx context.Context, l *logger.Logger) Manager { - g := ctxgroup.WithContext(ctx) - return &manager{ - group: g, + m := &manager{ ctx: ctx, logger: l, events: make(chan Event), } + m.group = &group{ + manager: m, + ctxGroup: ctxgroup.WithContext(ctx), + } + return m } func (m *manager) defaultOptions() []Option { @@ -73,9 +90,65 @@ func (m *manager) defaultOptions() []Option { } } +// Terminate will call the stop functions for every task started during the +// test. Returns when all task functions have returned, or after a 5-minute +// timeout, whichever comes first. If the timeout is reached, the function logs +// a warning message and returns. +func (m *manager) Terminate(l *logger.Logger) { + m.group.cancelAll() + + doneCh := make(chan error) + go func() { + defer close(doneCh) + m.group.Wait() + }() + + WaitForChannel(doneCh, "tasks", l) +} + +// CompletedEvents returns a channel that will receive events for all tasks +// started by the manager. +func (m *manager) CompletedEvents() <-chan Event { + return m.events +} + +// NewGroup creates a new group of tasks as a subgroup under the manager's +// default group. +func (m *manager) NewGroup(opts ...Option) Group { + return m.group.NewGroup(opts...) +} + +// GoWithCancel runs GoWithCancel on the manager's default group. func (m *manager) GoWithCancel(fn Func, opts ...Option) context.CancelFunc { - opt := CombineOptions(OptionList(m.defaultOptions()...), OptionList(opts...)) - groupCtx, cancel := context.WithCancel(m.ctx) + return m.group.GoWithCancel(fn, opts...) +} + +// Go runs Go on the manager's default group. +func (m *manager) Go(fn Func, opts ...Option) { + _ = m.group.GoWithCancel(fn, opts...) +} + +func (t *group) NewGroup(opts ...Option) Group { + subgroup := &group{ + manager: t.manager, + options: opts, + ctxGroup: ctxgroup.WithContext(t.manager.ctx), + } + t.groupMu.Lock() + defer t.groupMu.Unlock() + t.groupMu.groups = append(t.groupMu.groups, subgroup) + return subgroup +} + +func (t *group) GoWithCancel(fn Func, opts ...Option) context.CancelFunc { + // Combine options in order of precedence: default options, task options, and + // options passed to GoWithCancel. + opt := CombineOptions( + OptionList(t.manager.defaultOptions()...), + OptionList(t.options...), + OptionList(opts...), + ) + groupCtx, cancel := context.WithCancel(t.manager.ctx) var expectedContextCancellation atomic.Bool // internalFunc is a wrapper around the user-provided function that @@ -91,7 +164,7 @@ func (m *manager) GoWithCancel(fn Func, opts ...Option) context.CancelFunc { return retErr } - m.group.Go(func() error { + t.ctxGroup.Go(func() error { l, err := opt.L(opt.Name) if err != nil { return err @@ -114,10 +187,10 @@ func (m *manager) GoWithCancel(fn Func, opts ...Option) context.CancelFunc { // already aware of the cancelation and sending an event would be redundant. // For instance, a call to test.Fatal would already have captured the error // and canceled the context. - if IsContextCanceled(m.ctx) { + if IsContextCanceled(t.manager.ctx) { return nil } - m.events <- event + t.manager.events <- event return err }) @@ -127,38 +200,36 @@ func (m *manager) GoWithCancel(fn Func, opts ...Option) context.CancelFunc { } // Collect all taskCancelFn(s) so that we can explicitly stop all tasks when // the tasker is terminated. - m.mu.Lock() - defer m.mu.Unlock() - m.mu.cancelFns = append(m.mu.cancelFns, taskCancelFn) + t.cancelMu.Lock() + defer t.cancelMu.Unlock() + t.cancelMu.cancelFns = append(t.cancelMu.cancelFns, taskCancelFn) return taskCancelFn } -func (m *manager) Go(fn Func, opts ...Option) { - _ = m.GoWithCancel(fn, opts...) +func (t *group) Go(fn Func, opts ...Option) { + _ = t.GoWithCancel(fn, opts...) } -// Terminate will call the stop functions for every task started during the -// test. Returns when all task functions have returned, or after a 5-minute -// timeout, whichever comes first. If the timeout is reached, the function logs -// a warning message and returns. -func (m *manager) Terminate(l *logger.Logger) { +func (t *group) cancelAll() { func() { - m.mu.Lock() - defer m.mu.Unlock() - for _, cancel := range m.mu.cancelFns { + t.cancelMu.Lock() + defer t.cancelMu.Unlock() + for _, cancel := range t.cancelMu.cancelFns { cancel() } }() - - doneCh := make(chan error) - go func() { - defer close(doneCh) - _ = m.group.Wait() - }() - - WaitForChannel(doneCh, "tasks", l) + t.groupMu.Lock() + defer t.groupMu.Unlock() + for _, g := range t.groupMu.groups { + g.cancelAll() + } } -func (m *manager) CompletedEvents() <-chan Event { - return m.events +func (t *group) Wait() { + t.groupMu.Lock() + defer t.groupMu.Unlock() + _ = t.ctxGroup.Wait() + for _, g := range t.groupMu.groups { + g.Wait() + } } diff --git a/pkg/cmd/roachtest/roachtestutil/task/manager_test.go b/pkg/cmd/roachtest/roachtestutil/task/manager_test.go index bacb13f5c122..b452dbc91cc4 100644 --- a/pkg/cmd/roachtest/roachtestutil/task/manager_test.go +++ b/pkg/cmd/roachtest/roachtestutil/task/manager_test.go @@ -11,7 +11,6 @@ import ( "sync" "sync/atomic" "testing" - "time" "github.com/cockroachdb/cockroach/pkg/roachprod/logger" "github.com/cockroachdb/cockroach/pkg/util/leaktest" @@ -70,13 +69,8 @@ func TestContextCancel(t *testing.T) { wg.Add(1) m.Go(func(ctx context.Context, l *logger.Logger) error { defer wg.Done() - select { - case <-ctx.Done(): - return ctx.Err() - case <-time.After(30 * time.Second): - t.Fatal("expected context to be canceled") - } - return nil + <-ctx.Done() + return ctx.Err() }) cancel() wg.Wait() @@ -92,13 +86,8 @@ func TestContextCancel(t *testing.T) { wg.Add(1) cancel := m.GoWithCancel(func(ctx context.Context, l *logger.Logger) error { defer wg.Done() - select { - case <-ctx.Done(): - return ctx.Err() - case <-time.After(30 * time.Second): - t.Fatal("expected context to be canceled") - } - return nil + <-ctx.Done() + return ctx.Err() }) cancel() wg.Wait() @@ -128,13 +117,90 @@ func TestTerminate(t *testing.T) { for i := 0; i < numTasks; i++ { e := <-m.CompletedEvents() require.NoError(t, e.Err) - } }() m.Terminate(nilLogger()) require.Equal(t, uint32(numTasks), counter.Load()) } +func TestGroups(t *testing.T) { + defer leaktest.AfterTest(t)() + ctx := context.Background() + m := NewManager(ctx, nilLogger()) + + numTasks := 10 + g := m.NewGroup() + channels := make([]chan struct{}, numTasks) + + // Start tasks. + for i := 0; i < numTasks; i++ { + channels[i] = make(chan struct{}) + g.Go(func(ctx context.Context, l *logger.Logger) error { + <-channels[i] + return nil + }) + } + + // Start a goroutine that waits for all tasks in the group to complete. + done := make(chan struct{}) + go func() { + g.Wait() + close(done) + }() + + // Close channels one by one to complete all tasks, and ensure the group is + // not done yet. + for i := 0; i < numTasks; i++ { + select { + case <-done: + t.Fatal("group should not be done yet") + default: + } + // Close the channel and wait for the completed event. + close(channels[i]) + <-m.CompletedEvents() + } + + // Ensure the group is done. + <-done + m.Terminate(nilLogger()) +} + +func TestTerminateGroups(t *testing.T) { + defer leaktest.AfterTest(t)() + ctx := context.Background() + m := NewManager(ctx, nilLogger()) + + numTasks := 3 + g := m.NewGroup() + + // Start tasks. + for i := 0; i < numTasks; i++ { + g.Go(func(ctx context.Context, l *logger.Logger) error { + <-ctx.Done() + return nil + }) + } + + // Start a goroutine that waits for all tasks in the group to complete. + done := make(chan struct{}) + go func() { + g.Wait() + close(done) + }() + + // Consume all completed events. + go func() { + for i := 0; i < numTasks; i++ { + e := <-m.CompletedEvents() + require.NoError(t, e.Err) + } + }() + + m.Terminate(nilLogger()) + <-done +} + func nilLogger() *logger.Logger { lcfg := logger.Config{ Stdout: io.Discard, diff --git a/pkg/cmd/roachtest/test/test_interface.go b/pkg/cmd/roachtest/test/test_interface.go index e5b3c9836b7d..302851aab810 100644 --- a/pkg/cmd/roachtest/test/test_interface.go +++ b/pkg/cmd/roachtest/test/test_interface.go @@ -83,6 +83,7 @@ type Test interface { Go(task.Func, ...task.Option) GoWithCancel(task.Func, ...task.Option) context.CancelFunc + NewGroup() task.Group // DeprecatedWorkload returns the path to the workload binary. // Don't use this, invoke `./cockroach workload` instead. diff --git a/pkg/cmd/roachtest/test_impl.go b/pkg/cmd/roachtest/test_impl.go index 9c99a275e470..0c6104bfa6c7 100644 --- a/pkg/cmd/roachtest/test_impl.go +++ b/pkg/cmd/roachtest/test_impl.go @@ -688,14 +688,21 @@ func (t *testImpl) IsBuildVersion(minVersion string) bool { return t.BuildVersion().AtLeast(vers) } -func panicHandler(_ context.Context, name string, l *logger.Logger, r interface{}) error { - return fmt.Errorf("test task %s panicked: %v", name, r) +// defaultTaskOptions returns the default options for a task started by the test. +func defaultTaskOptions() []task.Option { + return []task.Option{ + task.PanicHandler(func(_ context.Context, name string, l *logger.Logger, r interface{}) error { + return fmt.Errorf("test task %s panicked: %v", name, r) + }), + } } // GoWithCancel runs the given function in a goroutine and returns a // CancelFunc that can be used to cancel the function. func (t *testImpl) GoWithCancel(fn task.Func, opts ...task.Option) context.CancelFunc { - return t.taskManager.GoWithCancel(fn, task.PanicHandler(panicHandler), task.OptionList(opts...)) + return t.taskManager.GoWithCancel( + fn, task.OptionList(defaultTaskOptions()...), task.OptionList(opts...), + ) } // Go is like GoWithCancel but without a cancel function. @@ -703,6 +710,11 @@ func (t *testImpl) Go(fn task.Func, opts ...task.Option) { _ = t.GoWithCancel(fn, task.OptionList(opts...)) } +// NewGroup starts a new task group. +func (t *testImpl) NewGroup() task.Group { + return t.taskManager.NewGroup(defaultTaskOptions()...) +} + // TeamCityEscape escapes a string for use as in a key='' attribute // in TeamCity build output marker. // See https://www.jetbrains.com/help/teamcity/2023.05/service-messages.html#Escaped+Values diff --git a/pkg/cmd/roachtest/tests/multitenant_upgrade.go b/pkg/cmd/roachtest/tests/multitenant_upgrade.go index 8715ba417384..34344b98278c 100644 --- a/pkg/cmd/roachtest/tests/multitenant_upgrade.go +++ b/pkg/cmd/roachtest/tests/multitenant_upgrade.go @@ -278,23 +278,19 @@ func runMultitenantUpgrade(ctx context.Context, t test.Test, c cluster.Cluster) }, ) - var wg sync.WaitGroup - wg.Add(2) // tpcc worklaod and upgrade finalization - - h.Go(func(_ context.Context, l *logger.Logger) error { - defer wg.Done() + g := h.NewGroup() + g.Go(func(_ context.Context, l *logger.Logger) error { <-tpccFinished l.Printf("tpcc workload finished running on tenants") return nil }) - h.Go(func(_ context.Context, l *logger.Logger) error { - defer wg.Done() + g.Go(func(_ context.Context, l *logger.Logger) error { <-upgradeFinished l.Printf("tenant upgrades finished") return nil }) + g.Wait() - wg.Wait() return nil }, )