Skip to content

Commit

Permalink
roachtest: fix TestVMPreemptionPolling data race
Browse files Browse the repository at this point in the history
This change switches to pollPreemptionInterval to be a
mutex protected struct instead, as multiple unit tests
modify it and can lead to a data race without.

Fixes: cockroachdb#135267
Epic: none
Release note: none
  • Loading branch information
DarrylWong committed Nov 19, 2024
1 parent 321e919 commit 422785c
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 6 deletions.
21 changes: 18 additions & 3 deletions pkg/cmd/roachtest/test_runner.go
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,12 @@ import (
"github.com/petermattis/goid"
)

func init() {
pollPreemptionInterval.Lock()
defer pollPreemptionInterval.Unlock()
pollPreemptionInterval.interval = 5 * time.Minute
}

var (
errTestsFailed = fmt.Errorf("some tests failed")

Expand Down Expand Up @@ -1992,20 +1998,29 @@ var getPreemptedVMsHook = func(c cluster.Cluster, ctx context.Context, l *logger
return c.GetPreemptedVMs(ctx, l)
}

// pollPreemptionInterval is how often to poll for preempted VMs.
var pollPreemptionInterval = 5 * time.Minute
// pollPreemptionInterval is how often to poll for preempted VMs. We use a
// mutex protected struct to allow for unit tests to safely modify it.
// Interval defaults to 5 minutes if not set.
var pollPreemptionInterval struct {
syncutil.Mutex
interval time.Duration
}

func monitorForPreemptedVMs(ctx context.Context, t test.Test, c cluster.Cluster, l *logger.Logger) {
if c.IsLocal() || !c.Spec().UseSpotVMs {
return
}

pollPreemptionInterval.Lock()
defer pollPreemptionInterval.Unlock()
interval := pollPreemptionInterval.interval

go func() {
for {
select {
case <-ctx.Done():
return
case <-time.After(pollPreemptionInterval):
case <-time.After(interval):
preemptedVMs, err := getPreemptedVMsHook(c, ctx, l)
if err != nil {
l.Printf("WARN: monitorForPreemptedVMs: failed to check preempted VMs:\n%+v", err)
Expand Down
12 changes: 9 additions & 3 deletions pkg/cmd/roachtest/test_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -606,6 +606,12 @@ func TestVMPreemptionPolling(t *testing.T) {
},
}

setPollPreemptionInterval := func(interval time.Duration) {
pollPreemptionInterval.Lock()
defer pollPreemptionInterval.Unlock()
pollPreemptionInterval.interval = interval
}

getPreemptedVMsHook = func(c cluster.Cluster, ctx context.Context, l *logger.Logger) ([]vm.PreemptedVM, error) {
preemptedVMs := []vm.PreemptedVM{{
Name: "test_node",
Expand All @@ -618,13 +624,13 @@ func TestVMPreemptionPolling(t *testing.T) {
getPreemptedVMsHook = func(c cluster.Cluster, ctx context.Context, l *logger.Logger) ([]vm.PreemptedVM, error) {
return c.GetPreemptedVMs(ctx, l)
}
pollPreemptionInterval = 5 * time.Minute
setPollPreemptionInterval(5 * time.Minute)
}()

// Test that if a VM is preempted, the VM preemption monitor will catch
// it and cancel the test before it times out.
t.Run("polling cancels test", func(t *testing.T) {
pollPreemptionInterval = 50 * time.Millisecond
setPollPreemptionInterval(50 * time.Millisecond)

err := runner.Run(ctx, []registry.TestSpec{mockTest}, 1, /* count */
defaultParallelism, copt, testOpts{}, lopt)
Expand All @@ -637,7 +643,7 @@ func TestVMPreemptionPolling(t *testing.T) {
// test finished first, the post failure checks will check again and mark it as a flake.
t.Run("polling doesn't catch preemption", func(t *testing.T) {
// Set the interval very high so we don't poll for preemptions.
pollPreemptionInterval = 1 * time.Hour
setPollPreemptionInterval(1 * time.Hour)

mockTest.Run = func(ctx context.Context, t test.Test, c cluster.Cluster) {
t.Error("Should be ignored")
Expand Down

0 comments on commit 422785c

Please sign in to comment.