diff --git a/pkg/cmd/roachtest/cluster.go b/pkg/cmd/roachtest/cluster.go index acc480ce719c..4ef56ba60aa7 100644 --- a/pkg/cmd/roachtest/cluster.go +++ b/pkg/cmd/roachtest/cluster.go @@ -869,6 +869,7 @@ func (f *clusterFactory) clusterMock(cfg clusterConfig) *clusterImpl { name: f.genName(cfg), expiration: timeutil.Now().Add(24 * time.Hour), r: f.r, + spec: cfg.spec, } } diff --git a/pkg/cmd/roachtest/test_runner.go b/pkg/cmd/roachtest/test_runner.go index 861dada4e6d5..4c02009ed1ec 100644 --- a/pkg/cmd/roachtest/test_runner.go +++ b/pkg/cmd/roachtest/test_runner.go @@ -50,6 +50,12 @@ import ( "github.com/petermattis/goid" ) +func init() { + pollPreemptionInterval.Lock() + defer pollPreemptionInterval.Unlock() + pollPreemptionInterval.interval = 5 * time.Minute +} + var ( errTestsFailed = fmt.Errorf("some tests failed") @@ -1288,6 +1294,9 @@ func (r *testRunner) runTest( }() grafanaAnnotateTestStart(runCtx, t, c) + // Actively poll for VM preemptions, so we can bail out of tests early and + // avoid situations where a test times out and the flake assignment logic fails. + monitorForPreemptedVMs(runCtx, t, c, l) // This is the call to actually run the test. s.Run(runCtx, t, c) }() @@ -1394,7 +1403,7 @@ func getVMNames(fullVMNames []string) string { // getPreemptedVMNames returns a comma separated list of preempted VM // names, or an empty string if no VM was preempted or an error was found. func getPreemptedVMNames(ctx context.Context, c *clusterImpl, l *logger.Logger) string { - preemptedVMs, err := c.GetPreemptedVMs(ctx, l) + preemptedVMs, err := getPreemptedVMsHook(c, ctx, l) if err != nil { l.Printf("failed to check preempted VMs:\n%+v", err) return "" @@ -2076,3 +2085,49 @@ func getTestParameters(t *testImpl, c *clusterImpl, createOpts *vm.CreateOpts) m return clusterParams } + +// getPreemptedVMsHook is a hook for unit tests to inject their own c.GetPreemptedVMs +// implementation. +var getPreemptedVMsHook = func(c cluster.Cluster, ctx context.Context, l *logger.Logger) ([]vm.PreemptedVM, error) { + return c.GetPreemptedVMs(ctx, l) +} + +// pollPreemptionInterval is how often to poll for preempted VMs. We use a +// mutex protected struct to allow for unit tests to safely modify it. +// Interval defaults to 5 minutes if not set. +var pollPreemptionInterval struct { + syncutil.Mutex + interval time.Duration +} + +func monitorForPreemptedVMs(ctx context.Context, t test.Test, c cluster.Cluster, l *logger.Logger) { + if c.IsLocal() || !c.Spec().UseSpotVMs { + return + } + + pollPreemptionInterval.Lock() + defer pollPreemptionInterval.Unlock() + interval := pollPreemptionInterval.interval + + go func() { + for { + select { + case <-ctx.Done(): + return + case <-time.After(interval): + preemptedVMs, err := getPreemptedVMsHook(c, ctx, l) + if err != nil { + l.Printf("WARN: monitorForPreemptedVMs: failed to check preempted VMs:\n%+v", err) + continue + } + + // If we find any preemptions, fail the test. Note that we will recheck for + // preemptions in post failure processing, which will correctly assign this + // failure as an infra flake. + if len(preemptedVMs) != 0 { + t.Errorf("monitorForPreemptedVMs: Preempted VMs detected: %s", preemptedVMs) + } + } + } + }() +} diff --git a/pkg/cmd/roachtest/test_test.go b/pkg/cmd/roachtest/test_test.go index c384550d08ff..fc8914f3a45a 100644 --- a/pkg/cmd/roachtest/test_test.go +++ b/pkg/cmd/roachtest/test_test.go @@ -58,6 +58,25 @@ func nilLogger() *logger.Logger { return l } +func defaultClusterOpt() clustersOpt { + return clustersOpt{ + typ: roachprodCluster, + user: "test_user", + cpuQuota: 1000, + debugMode: NoDebug, + } +} + +func defaultLoggingOpt(buf *syncedBuffer) loggingOpt { + return loggingOpt{ + l: nilLogger(), + tee: logger.NoTee, + stdout: buf, + stderr: buf, + artifactsDir: "", + } +} + func TestRunnerRun(t *testing.T) { ctx := context.Background() @@ -234,12 +253,7 @@ func setupRunnerTest(t *testing.T, r testRegistryImpl, testFilters []string) *ru stderr: &stderr, artifactsDir: "", } - copt := clustersOpt{ - typ: roachprodCluster, - user: "test_user", - cpuQuota: 1000, - debugMode: NoDebug, - } + copt := defaultClusterOpt() return &runnerTest{ stdout: &stdout, stderr: &stderr, @@ -301,19 +315,8 @@ func TestRunnerTestTimeout(t *testing.T) { runner := newUnitTestRunner(cr, stopper) var buf syncedBuffer - lopt := loggingOpt{ - l: nilLogger(), - tee: logger.NoTee, - stdout: &buf, - stderr: &buf, - artifactsDir: "", - } - copt := clustersOpt{ - typ: roachprodCluster, - user: "test_user", - cpuQuota: 1000, - debugMode: NoDebug, - } + copt := defaultClusterOpt() + lopt := defaultLoggingOpt(&buf) test := registry.TestSpec{ Name: `timeout`, Owner: OwnerUnitTest, @@ -418,13 +421,8 @@ func runExitCodeTest(t *testing.T, injectedError error) error { require.NoError(t, err) tests, _ := testsToRun(r, tf, false, 1.0, true) - lopt := loggingOpt{ - l: nilLogger(), - tee: logger.NoTee, - stdout: io.Discard, - stderr: io.Discard, - artifactsDir: "", - } + var buf syncedBuffer + lopt := defaultLoggingOpt(&buf) return runner.Run(ctx, tests, 1, 1, clustersOpt{}, testOpts{}, lopt) } @@ -537,19 +535,8 @@ func TestTransientErrorFallback(t *testing.T) { runner := newUnitTestRunner(cr, stopper) var buf syncedBuffer - lopt := loggingOpt{ - l: nilLogger(), - tee: logger.NoTee, - stdout: &buf, - stderr: &buf, - artifactsDir: "", - } - copt := clustersOpt{ - typ: roachprodCluster, - user: "test_user", - cpuQuota: 1000, - debugMode: NoDebug, - } + copt := defaultClusterOpt() + lopt := defaultLoggingOpt(&buf) // Test that if a test fails with a transient error handled by the `require` package, // the test runner will correctly still identify it as a flake and the run will have @@ -594,3 +581,77 @@ func TestTransientErrorFallback(t *testing.T) { } }) } + +func TestVMPreemptionPolling(t *testing.T) { + ctx := context.Background() + stopper := stop.NewStopper() + defer stopper.Stop(ctx) + cr := newClusterRegistry() + runner := newUnitTestRunner(cr, stopper) + + var buf syncedBuffer + copt := defaultClusterOpt() + lopt := defaultLoggingOpt(&buf) + + mockTest := registry.TestSpec{ + Name: `preemption`, + Owner: OwnerUnitTest, + Cluster: spec.MakeClusterSpec(0, spec.UseSpotVMs()), + CompatibleClouds: registry.AllExceptAWS, + Suites: registry.Suites(registry.Nightly), + CockroachBinary: registry.StandardCockroach, + Timeout: 10 * time.Second, + Run: func(ctx context.Context, t test.Test, c cluster.Cluster) { + <-ctx.Done() + }, + } + + setPollPreemptionInterval := func(interval time.Duration) { + pollPreemptionInterval.Lock() + defer pollPreemptionInterval.Unlock() + pollPreemptionInterval.interval = interval + } + + getPreemptedVMsHook = func(c cluster.Cluster, ctx context.Context, l *logger.Logger) ([]vm.PreemptedVM, error) { + preemptedVMs := []vm.PreemptedVM{{ + Name: "test_node", + PreemptedAt: time.Now(), + }} + return preemptedVMs, nil + } + + defer func() { + getPreemptedVMsHook = func(c cluster.Cluster, ctx context.Context, l *logger.Logger) ([]vm.PreemptedVM, error) { + return c.GetPreemptedVMs(ctx, l) + } + setPollPreemptionInterval(5 * time.Minute) + }() + + // Test that if a VM is preempted, the VM preemption monitor will catch + // it and cancel the test before it times out. + t.Run("polling cancels test", func(t *testing.T) { + setPollPreemptionInterval(50 * time.Millisecond) + + err := runner.Run(ctx, []registry.TestSpec{mockTest}, 1, /* count */ + defaultParallelism, copt, testOpts{}, lopt) + // The preemption monitor should mark a VM as preempted and the test should + // be treated as a flake instead of a failed test. + require.NoError(t, err) + }) + + // Test that if a VM is preempted but the polling doesn't catch it because the + // test finished first, the post failure checks will check again and mark it as a flake. + t.Run("polling doesn't catch preemption", func(t *testing.T) { + // Set the interval very high so we don't poll for preemptions. + setPollPreemptionInterval(1 * time.Hour) + + mockTest.Run = func(ctx context.Context, t test.Test, c cluster.Cluster) { + t.Error("Should be ignored") + } + err := runner.Run(ctx, []registry.TestSpec{mockTest}, 1, /* count */ + defaultParallelism, copt, testOpts{}, lopt) + // The post test failure check should mark a VM as preempted and the test should + // be treated as a flake instead of a failed test. + require.NoError(t, err) + }) +}