From 422785c6d7ccb879a6b7d8d54713576fa4efc647 Mon Sep 17 00:00:00 2001 From: DarrylWong Date: Fri, 15 Nov 2024 13:51:25 -0500 Subject: [PATCH] roachtest: fix TestVMPreemptionPolling data race This change switches to pollPreemptionInterval to be a mutex protected struct instead, as multiple unit tests modify it and can lead to a data race without. Fixes: #135267 Epic: none Release note: none --- pkg/cmd/roachtest/test_runner.go | 21 ++++++++++++++++++--- pkg/cmd/roachtest/test_test.go | 12 +++++++++--- 2 files changed, 27 insertions(+), 6 deletions(-) diff --git a/pkg/cmd/roachtest/test_runner.go b/pkg/cmd/roachtest/test_runner.go index bb2abe6d4ecd..971fe8110aad 100644 --- a/pkg/cmd/roachtest/test_runner.go +++ b/pkg/cmd/roachtest/test_runner.go @@ -50,6 +50,12 @@ import ( "github.com/petermattis/goid" ) +func init() { + pollPreemptionInterval.Lock() + defer pollPreemptionInterval.Unlock() + pollPreemptionInterval.interval = 5 * time.Minute +} + var ( errTestsFailed = fmt.Errorf("some tests failed") @@ -1992,20 +1998,29 @@ var getPreemptedVMsHook = func(c cluster.Cluster, ctx context.Context, l *logger return c.GetPreemptedVMs(ctx, l) } -// pollPreemptionInterval is how often to poll for preempted VMs. -var pollPreemptionInterval = 5 * time.Minute +// pollPreemptionInterval is how often to poll for preempted VMs. We use a +// mutex protected struct to allow for unit tests to safely modify it. +// Interval defaults to 5 minutes if not set. +var pollPreemptionInterval struct { + syncutil.Mutex + interval time.Duration +} func monitorForPreemptedVMs(ctx context.Context, t test.Test, c cluster.Cluster, l *logger.Logger) { if c.IsLocal() || !c.Spec().UseSpotVMs { return } + pollPreemptionInterval.Lock() + defer pollPreemptionInterval.Unlock() + interval := pollPreemptionInterval.interval + go func() { for { select { case <-ctx.Done(): return - case <-time.After(pollPreemptionInterval): + case <-time.After(interval): preemptedVMs, err := getPreemptedVMsHook(c, ctx, l) if err != nil { l.Printf("WARN: monitorForPreemptedVMs: failed to check preempted VMs:\n%+v", err) diff --git a/pkg/cmd/roachtest/test_test.go b/pkg/cmd/roachtest/test_test.go index 91f7f636575a..fc8914f3a45a 100644 --- a/pkg/cmd/roachtest/test_test.go +++ b/pkg/cmd/roachtest/test_test.go @@ -606,6 +606,12 @@ func TestVMPreemptionPolling(t *testing.T) { }, } + setPollPreemptionInterval := func(interval time.Duration) { + pollPreemptionInterval.Lock() + defer pollPreemptionInterval.Unlock() + pollPreemptionInterval.interval = interval + } + getPreemptedVMsHook = func(c cluster.Cluster, ctx context.Context, l *logger.Logger) ([]vm.PreemptedVM, error) { preemptedVMs := []vm.PreemptedVM{{ Name: "test_node", @@ -618,13 +624,13 @@ func TestVMPreemptionPolling(t *testing.T) { getPreemptedVMsHook = func(c cluster.Cluster, ctx context.Context, l *logger.Logger) ([]vm.PreemptedVM, error) { return c.GetPreemptedVMs(ctx, l) } - pollPreemptionInterval = 5 * time.Minute + setPollPreemptionInterval(5 * time.Minute) }() // Test that if a VM is preempted, the VM preemption monitor will catch // it and cancel the test before it times out. t.Run("polling cancels test", func(t *testing.T) { - pollPreemptionInterval = 50 * time.Millisecond + setPollPreemptionInterval(50 * time.Millisecond) err := runner.Run(ctx, []registry.TestSpec{mockTest}, 1, /* count */ defaultParallelism, copt, testOpts{}, lopt) @@ -637,7 +643,7 @@ func TestVMPreemptionPolling(t *testing.T) { // test finished first, the post failure checks will check again and mark it as a flake. t.Run("polling doesn't catch preemption", func(t *testing.T) { // Set the interval very high so we don't poll for preemptions. - pollPreemptionInterval = 1 * time.Hour + setPollPreemptionInterval(1 * time.Hour) mockTest.Run = func(ctx context.Context, t test.Test, c cluster.Cluster) { t.Error("Should be ignored")