diff --git a/pkg/cmd/roachtest/roachtestutil/task/manager.go b/pkg/cmd/roachtest/roachtestutil/task/manager.go new file mode 100644 index 000000000000..531f48385271 --- /dev/null +++ b/pkg/cmd/roachtest/roachtestutil/task/manager.go @@ -0,0 +1,147 @@ +// Copyright 2024 The Cockroach Authors. +// +// Use of this software is governed by the CockroachDB Software License +// included in the /LICENSE file. + +package task + +import ( + "context" + "fmt" + "sync/atomic" + + "github.com/cockroachdb/cockroach/pkg/roachprod/logger" + "github.com/cockroachdb/cockroach/pkg/util/ctxgroup" +) + +type ( + // Manager is responsible for managing a group of tasks initiated during + // tests. The interface is designed for the test framework to control tasks. + // Typically, tests will only interact, and be provided with the smaller + // Tasker interface to start tasks. + Manager interface { + Tasker + Terminate() + CompletedEvents() <-chan Event + } + + // Event represents the result of a task execution. + Event struct { + Name string + Err error + ExpectedCancel bool + } + + manager struct { + group ctxgroup.Group + ctx context.Context + logger *logger.Logger + events chan Event + id atomic.Uint32 + cancelFns []context.CancelFunc + } +) + +func NewManager(ctx context.Context, l *logger.Logger) Manager { + g := ctxgroup.WithContext(ctx) + return &manager{ + group: g, + ctx: ctx, + logger: l, + events: make(chan Event), + } +} + +func (m *manager) defaultOptions() []Option { + // The default panic handler simply returns the panic as an error. + defaultPanicHandlerFn := func(_ context.Context, l *logger.Logger, r interface{}) error { + return r.(error) + } + // The default error handler simply returns the error as is. + defaultErrorHandlerFn := func(_ context.Context, l *logger.Logger, err error) error { + return err + } + return []Option{ + Name(fmt.Sprintf("task-%d", m.id.Add(1))), + Logger(m.logger), + PanicHandler(defaultPanicHandlerFn), + ErrorHandler(defaultErrorHandlerFn), + } +} + +func (m *manager) GoWithCancel(fn Func, opts ...Option) context.CancelFunc { + opt := CombineOptions(OptionList(m.defaultOptions()...), OptionList(opts...)) + groupCtx, cancel := context.WithCancel(m.ctx) + var expectedContextCancellation bool + + // internalFunc is a wrapper around the user-provided function that + // handles panics and errors. + internalFunc := func(l *logger.Logger) (retErr error) { + defer func() { + if r := recover(); r != nil { + retErr = opt.PanicHandler(groupCtx, l, r) + } + retErr = opt.ErrorHandler(groupCtx, l, retErr) + }() + retErr = fn(groupCtx, l) + return retErr + } + + m.group.Go(func() error { + l, err := opt.L() + if err != nil { + return err + } + err = internalFunc(l) + event := Event{ + Name: opt.Name, + Err: err, + ExpectedCancel: err != nil && IsContextCanceled(groupCtx) && expectedContextCancellation, + } + select { + case m.events <- event: + // exit goroutine + case <-m.ctx.Done(): + // Parent context already finished, exit goroutine. + return nil + } + return err + }) + + taskCancelFn := func() { + expectedContextCancellation = true + cancel() + } + // Collect all taskCancelFns so that we can explicitly stop all + // tasks when the tasker is terminated. + m.cancelFns = append(m.cancelFns, taskCancelFn) + return taskCancelFn +} + +func (m *manager) Go(fn Func, opts ...Option) { + _ = m.GoWithCancel(fn, opts...) +} + +// Terminate will call the stop functions for every background function +// started during the test. This includes background functions created +// during test runtime (using `helper.Background()`), as well as +// background steps declared in the test setup (using +// `BackgroundFunc`, `Workload`, et al.). Returns when all background +// functions have returned. +func (m *manager) Terminate() { + for _, cancel := range m.cancelFns { + cancel() + } + + doneCh := make(chan error) + go func() { + defer close(doneCh) + _ = m.group.Wait() + }() + + WaitForChannel(doneCh, "tasks", m.logger) +} + +func (m *manager) CompletedEvents() <-chan Event { + return m.events +} diff --git a/pkg/cmd/roachtest/roachtestutil/task/manager_test.go b/pkg/cmd/roachtest/roachtestutil/task/manager_test.go new file mode 100644 index 000000000000..7459a7f7dcfa --- /dev/null +++ b/pkg/cmd/roachtest/roachtestutil/task/manager_test.go @@ -0,0 +1,141 @@ +// Copyright 2024 The Cockroach Authors. +// +// Use of this software is governed by the CockroachDB Software License +// included in the /LICENSE file. + +package task + +import ( + "context" + "sync" + "sync/atomic" + "testing" + "time" + + "github.com/cockroachdb/cockroach/pkg/roachprod/logger" + "github.com/cockroachdb/cockroach/pkg/util/leaktest" + "github.com/cockroachdb/errors" + "github.com/stretchr/testify/require" +) + +func TestPanicHandler(t *testing.T) { + defer leaktest.AfterTest(t)() + ctx := context.Background() + m := NewManager(ctx, nil) + + panicErr := errors.New("panic") + panicHandlerFn := func(_ context.Context, l *logger.Logger, r interface{}) (err error) { + return r.(error) + } + m.Go(func(ctx context.Context, l *logger.Logger) error { + panic(panicErr) + return nil + }, PanicHandler(panicHandlerFn), Name("abc")) + + select { + case e := <-m.CompletedEvents(): + require.ErrorIs(t, e.Err, panicErr) + require.Equal(t, "abc", e.Name) + } + m.Terminate() +} + +func TestErrorHandler(t *testing.T) { + defer leaktest.AfterTest(t)() + ctx := context.Background() + m := NewManager(ctx, nil) + + var wrapErr error + errorHandlerFn := func(_ context.Context, l *logger.Logger, err error) error { + wrapErr = errors.Wrapf(err, "wrapped") + return wrapErr + } + + m.Go(func(ctx context.Context, l *logger.Logger) error { + return errors.New("error") + }, ErrorHandler(errorHandlerFn), Name("def")) + + select { + case e := <-m.CompletedEvents(): + require.ErrorIs(t, e.Err, wrapErr) + require.Equal(t, "def", e.Name) + } + m.Terminate() +} + +func TestContextCancel(t *testing.T) { + t.Run("cancel main context", func(t *testing.T) { + defer leaktest.AfterTest(t)() + ctx, cancel := context.WithCancel(context.Background()) + m := NewManager(ctx, nil) + + wg := sync.WaitGroup{} + wg.Add(1) + m.Go(func(ctx context.Context, l *logger.Logger) error { + defer wg.Done() + select { + case <-ctx.Done(): + return ctx.Err() + case <-time.After(30 * time.Second): + t.Fatal("expected context to be canceled") + } + return nil + }) + cancel() + wg.Wait() + m.Terminate() + }) + + t.Run("cancel task context", func(t *testing.T) { + defer leaktest.AfterTest(t)() + ctx := context.Background() + m := NewManager(ctx, nil) + + wg := sync.WaitGroup{} + wg.Add(1) + cancel := m.GoWithCancel(func(ctx context.Context, l *logger.Logger) error { + defer wg.Done() + select { + case <-ctx.Done(): + return ctx.Err() + case <-time.After(30 * time.Second): + t.Fatal("expected context to be canceled") + } + return nil + }) + cancel() + wg.Wait() + select { + case e := <-m.CompletedEvents(): + require.ErrorIs(t, e.Err, context.Canceled) + } + m.Terminate() + }) +} + +func TestTerminate(t *testing.T) { + defer leaktest.AfterTest(t)() + ctx := context.Background() + m := NewManager(ctx, nil) + numTasks := 10 + var counter atomic.Uint32 + for i := 0; i < numTasks; i++ { + m.Go(func(ctx context.Context, l *logger.Logger) error { + defer func() { + counter.Add(1) + }() + <-ctx.Done() + return nil + }) + } + go func() { + for i := 0; i < numTasks; i++ { + select { + case e := <-m.CompletedEvents(): + require.NoError(t, e.Err) + } + } + }() + m.Terminate() + require.Equal(t, uint32(numTasks), counter.Load()) +} diff --git a/pkg/cmd/roachtest/roachtestutil/task/utils.go b/pkg/cmd/roachtest/roachtestutil/task/utils.go new file mode 100644 index 000000000000..4902f79efb98 --- /dev/null +++ b/pkg/cmd/roachtest/roachtestutil/task/utils.go @@ -0,0 +1,43 @@ +// Copyright 2024 The Cockroach Authors. +// +// Use of this software is governed by the CockroachDB Software License +// included in the /LICENSE file. + +package task + +import ( + "context" + "time" + + "github.com/cockroachdb/cockroach/pkg/roachprod/logger" +) + +// IsContextCanceled returns a boolean indicating whether the context +// passed is canceled. +func IsContextCanceled(ctx context.Context) bool { + select { + case <-ctx.Done(): + return true + default: + return false + } +} + +// WaitForChannel waits for the given channel `ch` to close; returns +// when that happens. If the channel does not close within 5 minutes, +// the function logs a message and returns. +// +// The main use-case for this function is waiting for user-provided +// hooks to return after the context passed to them is canceled. We +// want to allow some time for them to finish, but we also don't want +// to block indefinitely if a function inadvertently ignores context +// cancellation. +func WaitForChannel(ch chan error, desc string, l *logger.Logger) { + maxWait := 5 * time.Minute + select { + case <-ch: + // return + case <-time.After(maxWait): + l.Printf("waited for %s for %s to finish, giving up", maxWait, desc) + } +}