Skip to content

Commit

Permalink
roachtest: task manager
Browse files Browse the repository at this point in the history
This change provides a task manager implementation that will drive the logic
running behind the Tasker interface, as well as provide the implementation of
it. The manager is meant to be used by the test framework(s), while the Tasker
interface is supplied to tests. The framework will do additional logic such as
terminating tasks at the end of a test, logging uncaught errors or issues and
providing the appropriate panic and error handlers.

Informs: cockroachdb#118214

Epic: None
Release note: None
  • Loading branch information
herkolategan committed Oct 23, 2024
1 parent d18a104 commit 0f181d8
Show file tree
Hide file tree
Showing 3 changed files with 331 additions and 0 deletions.
147 changes: 147 additions & 0 deletions pkg/cmd/roachtest/roachtestutil/task/manager.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,147 @@
// Copyright 2024 The Cockroach Authors.
//
// Use of this software is governed by the CockroachDB Software License
// included in the /LICENSE file.

package task

import (
"context"
"fmt"
"sync/atomic"

"github.com/cockroachdb/cockroach/pkg/roachprod/logger"
"github.com/cockroachdb/cockroach/pkg/util/ctxgroup"
)

type (
// Manager is responsible for managing a group of tasks initiated during
// tests. The interface is designed for the test framework to control tasks.
// Typically, tests will only interact, and be provided with the smaller
// Tasker interface to start tasks.
Manager interface {
Tasker
Terminate()
CompletedEvents() <-chan Event
}

// Event represents the result of a task execution.
Event struct {
Name string
Err error
ExpectedCancel bool
}

manager struct {
group ctxgroup.Group
ctx context.Context
logger *logger.Logger
events chan Event
id atomic.Uint32
cancelFns []context.CancelFunc
}
)

func NewManager(ctx context.Context, l *logger.Logger) Manager {
g := ctxgroup.WithContext(ctx)
return &manager{
group: g,
ctx: ctx,
logger: l,
events: make(chan Event),
}
}

func (m *manager) defaultOptions() []Option {
// The default panic handler simply returns the panic as an error.
defaultPanicHandlerFn := func(_ context.Context, l *logger.Logger, r interface{}) error {
return r.(error)
}
// The default error handler simply returns the error as is.
defaultErrorHandlerFn := func(_ context.Context, l *logger.Logger, err error) error {
return err
}
return []Option{
Name(fmt.Sprintf("task-%d", m.id.Add(1))),
Logger(m.logger),
PanicHandler(defaultPanicHandlerFn),
ErrorHandler(defaultErrorHandlerFn),
}
}

func (m *manager) GoWithCancel(fn Func, opts ...Option) context.CancelFunc {
opt := CombineOptions(OptionList(m.defaultOptions()...), OptionList(opts...))
groupCtx, cancel := context.WithCancel(m.ctx)
var expectedContextCancellation bool

// internalFunc is a wrapper around the user-provided function that
// handles panics and errors.
internalFunc := func(l *logger.Logger) (retErr error) {
defer func() {
if r := recover(); r != nil {
retErr = opt.PanicHandler(groupCtx, l, r)
}
retErr = opt.ErrorHandler(groupCtx, l, retErr)
}()
retErr = fn(groupCtx, l)
return retErr
}

m.group.Go(func() error {
l, err := opt.L()
if err != nil {
return err
}
err = internalFunc(l)
event := Event{
Name: opt.Name,
Err: err,
ExpectedCancel: err != nil && IsContextCanceled(groupCtx) && expectedContextCancellation,
}
select {
case m.events <- event:
// exit goroutine
case <-m.ctx.Done():
// Parent context already finished, exit goroutine.
return nil
}
return err
})

taskCancelFn := func() {
expectedContextCancellation = true
cancel()
}
// Collect all taskCancelFns so that we can explicitly stop all
// tasks when the tasker is terminated.
m.cancelFns = append(m.cancelFns, taskCancelFn)
return taskCancelFn
}

func (m *manager) Go(fn Func, opts ...Option) {
_ = m.GoWithCancel(fn, opts...)
}

// Terminate will call the stop functions for every background function
// started during the test. This includes background functions created
// during test runtime (using `helper.Background()`), as well as
// background steps declared in the test setup (using
// `BackgroundFunc`, `Workload`, et al.). Returns when all background
// functions have returned.
func (m *manager) Terminate() {
for _, cancel := range m.cancelFns {
cancel()
}

doneCh := make(chan error)
go func() {
defer close(doneCh)
_ = m.group.Wait()
}()

WaitForChannel(doneCh, "tasks", m.logger)
}

func (m *manager) CompletedEvents() <-chan Event {
return m.events
}
141 changes: 141 additions & 0 deletions pkg/cmd/roachtest/roachtestutil/task/manager_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,141 @@
// Copyright 2024 The Cockroach Authors.
//
// Use of this software is governed by the CockroachDB Software License
// included in the /LICENSE file.

package task

import (
"context"
"sync"
"sync/atomic"
"testing"
"time"

"github.com/cockroachdb/cockroach/pkg/roachprod/logger"
"github.com/cockroachdb/cockroach/pkg/util/leaktest"
"github.com/cockroachdb/errors"
"github.com/stretchr/testify/require"
)

func TestPanicHandler(t *testing.T) {
defer leaktest.AfterTest(t)()
ctx := context.Background()
m := NewManager(ctx, nil)

panicErr := errors.New("panic")
panicHandlerFn := func(_ context.Context, l *logger.Logger, r interface{}) (err error) {
return r.(error)
}
m.Go(func(ctx context.Context, l *logger.Logger) error {
panic(panicErr)
return nil
}, PanicHandler(panicHandlerFn), Name("abc"))

select {
case e := <-m.CompletedEvents():
require.ErrorIs(t, e.Err, panicErr)
require.Equal(t, "abc", e.Name)
}
m.Terminate()
}

func TestErrorHandler(t *testing.T) {
defer leaktest.AfterTest(t)()
ctx := context.Background()
m := NewManager(ctx, nil)

var wrapErr error
errorHandlerFn := func(_ context.Context, l *logger.Logger, err error) error {
wrapErr = errors.Wrapf(err, "wrapped")
return wrapErr
}

m.Go(func(ctx context.Context, l *logger.Logger) error {
return errors.New("error")
}, ErrorHandler(errorHandlerFn), Name("def"))

select {
case e := <-m.CompletedEvents():
require.ErrorIs(t, e.Err, wrapErr)
require.Equal(t, "def", e.Name)
}
m.Terminate()
}

func TestContextCancel(t *testing.T) {
t.Run("cancel main context", func(t *testing.T) {
defer leaktest.AfterTest(t)()
ctx, cancel := context.WithCancel(context.Background())
m := NewManager(ctx, nil)

wg := sync.WaitGroup{}
wg.Add(1)
m.Go(func(ctx context.Context, l *logger.Logger) error {
defer wg.Done()
select {
case <-ctx.Done():
return ctx.Err()
case <-time.After(30 * time.Second):
t.Fatal("expected context to be canceled")
}
return nil
})
cancel()
wg.Wait()
m.Terminate()
})

t.Run("cancel task context", func(t *testing.T) {
defer leaktest.AfterTest(t)()
ctx := context.Background()
m := NewManager(ctx, nil)

wg := sync.WaitGroup{}
wg.Add(1)
cancel := m.GoWithCancel(func(ctx context.Context, l *logger.Logger) error {
defer wg.Done()
select {
case <-ctx.Done():
return ctx.Err()
case <-time.After(30 * time.Second):
t.Fatal("expected context to be canceled")
}
return nil
})
cancel()
wg.Wait()
select {
case e := <-m.CompletedEvents():
require.ErrorIs(t, e.Err, context.Canceled)
}
m.Terminate()
})
}

func TestTerminate(t *testing.T) {
defer leaktest.AfterTest(t)()
ctx := context.Background()
m := NewManager(ctx, nil)
numTasks := 10
var counter atomic.Uint32
for i := 0; i < numTasks; i++ {
m.Go(func(ctx context.Context, l *logger.Logger) error {
defer func() {
counter.Add(1)
}()
<-ctx.Done()
return nil
})
}
go func() {
for i := 0; i < numTasks; i++ {
select {
case e := <-m.CompletedEvents():
require.NoError(t, e.Err)
}
}
}()
m.Terminate()
require.Equal(t, uint32(numTasks), counter.Load())
}
43 changes: 43 additions & 0 deletions pkg/cmd/roachtest/roachtestutil/task/utils.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
// Copyright 2024 The Cockroach Authors.
//
// Use of this software is governed by the CockroachDB Software License
// included in the /LICENSE file.

package task

import (
"context"
"time"

"github.com/cockroachdb/cockroach/pkg/roachprod/logger"
)

// IsContextCanceled returns a boolean indicating whether the context
// passed is canceled.
func IsContextCanceled(ctx context.Context) bool {
select {
case <-ctx.Done():
return true
default:
return false
}
}

// WaitForChannel waits for the given channel `ch` to close; returns
// when that happens. If the channel does not close within 5 minutes,
// the function logs a message and returns.
//
// The main use-case for this function is waiting for user-provided
// hooks to return after the context passed to them is canceled. We
// want to allow some time for them to finish, but we also don't want
// to block indefinitely if a function inadvertently ignores context
// cancellation.
func WaitForChannel(ch chan error, desc string, l *logger.Logger) {
maxWait := 5 * time.Minute
select {
case <-ch:
// return
case <-time.After(maxWait):
l.Printf("waited for %s for %s to finish, giving up", maxWait, desc)
}
}

0 comments on commit 0f181d8

Please sign in to comment.