diff --git a/pkg/sql/schemachanger/sctest/backup.go b/pkg/sql/schemachanger/sctest/backup.go index b6d99054f251..f5357256f004 100644 --- a/pkg/sql/schemachanger/sctest/backup.go +++ b/pkg/sql/schemachanger/sctest/backup.go @@ -13,7 +13,9 @@ package sctest import ( "context" gosql "database/sql" + "flag" "fmt" + "math/rand" "strings" "sync/atomic" "testing" @@ -70,10 +72,6 @@ func BackupSuccessMixedVersion(t *testing.T, path string, factory TestServerFact // and at least as expensive to run. skip.UnderShort(t) - if strings.Contains(path, "alter_table_add_primary_key_drop_rowid") { - skip.WithIssue(t, 107552, "flaky test") - } - factory = factory.WithMixedVersion() cumulativeTestForEachPostCommitStage(t, path, factory, func(t *testing.T, cs CumulativeTestCaseSpec) { backupSuccess(t, factory, cs) @@ -96,7 +94,27 @@ func BackupRollbacksMixedVersion(t *testing.T, path string, factory TestServerFa }) } +// runAllBackups runs all the backup tests, disabling the random skipping. +var runAllBackups = flag.Bool( + "run-all-backups", false, + "if true, run all backups instead of a random subset", +) + +// If the number of stages in the same phase exceeds skipThreshold, we enable +// skipping such that the backup test for each stage is skipped with probability +// skipRate. +// Set runAllBackups to true to disable skipping altogether. +const skipThreshold = 10 +const skipRate = .5 + +func maybeRandomlySkip(t *testing.T, stageCountInPhase int) { + if !*runAllBackups && stageCountInPhase > skipThreshold && rand.Float64() < skipRate { + skip.IgnoreLint(t, "skipping due to randomness") + } +} + func backupSuccess(t *testing.T, factory TestServerFactory, cs CumulativeTestCaseSpec) { + maybeRandomlySkip(t, cs.StagesCount) ctx := context.Background() url := fmt.Sprintf("userfile://backups.public.userfiles_$user/data_%s_%d", cs.Phase, cs.StageOrdinal) @@ -205,6 +223,7 @@ func backupRollbacks(t *testing.T, factory TestServerFactory, cs CumulativeTestC if cs.Phase != scop.PostCommitPhase { return } + maybeRandomlySkip(t, cs.StagesCount) ctx := context.Background() var urls atomic.Value var dbForBackup atomic.Pointer[gosql.DB] diff --git a/pkg/sql/schemachanger/sctest/framework.go b/pkg/sql/schemachanger/sctest/framework.go index 4f7e3c0a2c1a..ed565317ec65 100644 --- a/pkg/sql/schemachanger/sctest/framework.go +++ b/pkg/sql/schemachanger/sctest/framework.go @@ -644,83 +644,90 @@ func (cs CumulativeTestCaseSpec) run(t *testing.T, fn func(t *testing.T)) bool { return t.Run(fmt.Sprintf("%s_stage_%d_of_%d", prefix, cs.StageOrdinal, cs.StagesCount), fn) } +// cumulativeTestForEachPostCommitStage invokes `tf` once for each stage in the +// PostCommitPhase. These invocation are run in parallel. func cumulativeTestForEachPostCommitStage( t *testing.T, relTestCaseDir string, factory TestServerFactory, tf func(t *testing.T, spec CumulativeTestCaseSpec), ) { - testFunc := func(t *testing.T, spec CumulativeTestSpec) { - // Skip this test if any of the stmts is not fully supported. - if err := areStmtsFullySupportedAtClusterVersion(t, spec, factory); err != nil { - skip.IgnoreLint(t, "test is skipped because", err.Error()) - } - var postCommitCount, postCommitNonRevertibleCount int - var after [][]string - var dbName string - prepfn := func(db *gosql.DB, p scplan.Plan) { - for _, s := range p.Stages { - switch s.Phase { - case scop.PostCommitPhase: - postCommitCount++ - case scop.PostCommitNonRevertiblePhase: - postCommitNonRevertibleCount++ + // Grouping the parallel subtests into a non-parallel subtest allows any defer + // calls to work as expected. + t.Run("group", func(t *testing.T) { + testFunc := func(t *testing.T, spec CumulativeTestSpec) { + // Skip this test if any of the stmts is not fully supported. + if err := areStmtsFullySupportedAtClusterVersion(t, spec, factory); err != nil { + skip.IgnoreLint(t, "test is skipped because", err.Error()) + } + var postCommitCount, postCommitNonRevertibleCount int + var after [][]string + var dbName string + prepfn := func(db *gosql.DB, p scplan.Plan) { + for _, s := range p.Stages { + switch s.Phase { + case scop.PostCommitPhase: + postCommitCount++ + case scop.PostCommitNonRevertiblePhase: + postCommitNonRevertibleCount++ + } } + tdb := sqlutils.MakeSQLRunner(db) + var ok bool + dbName, ok = maybeGetDatabaseForIDs(t, tdb, screl.AllTargetStateDescIDs(p.TargetState)) + if ok { + tdb.Exec(t, fmt.Sprintf("USE %q", dbName)) + } + after = tdb.QueryStr(t, fetchDescriptorStateQuery) } - tdb := sqlutils.MakeSQLRunner(db) - var ok bool - dbName, ok = maybeGetDatabaseForIDs(t, tdb, screl.AllTargetStateDescIDs(p.TargetState)) - if ok { - tdb.Exec(t, fmt.Sprintf("USE %q", dbName)) + withPostCommitPlanAfterSchemaChange(t, spec, factory, prepfn) + if postCommitCount+postCommitNonRevertibleCount == 0 { + skip.IgnoreLint(t, "test case has no post-commit stages") + return } - after = tdb.QueryStr(t, fetchDescriptorStateQuery) - } - withPostCommitPlanAfterSchemaChange(t, spec, factory, prepfn) - if postCommitCount+postCommitNonRevertibleCount == 0 { - skip.IgnoreLint(t, "test case has no post-commit stages") - return - } - if dbName == "" { - skip.IgnoreLint(t, "test case has no usable database") - return - } - var testCases []CumulativeTestCaseSpec - for stageOrdinal := 1; stageOrdinal <= postCommitCount; stageOrdinal++ { - testCases = append(testCases, CumulativeTestCaseSpec{ - CumulativeTestSpec: spec, - Phase: scop.PostCommitPhase, - StageOrdinal: stageOrdinal, - StagesCount: postCommitCount, - After: after, - DatabaseName: dbName, - }) - } - for stageOrdinal := 1; stageOrdinal <= postCommitNonRevertibleCount; stageOrdinal++ { - testCases = append(testCases, CumulativeTestCaseSpec{ - CumulativeTestSpec: spec, - Phase: scop.PostCommitNonRevertiblePhase, - StageOrdinal: stageOrdinal, - StagesCount: postCommitNonRevertibleCount, - After: after, - DatabaseName: dbName, - }) - } - var hasFailed bool - for _, tc := range testCases { - fn := func(t *testing.T) { - tf(t, tc) + if dbName == "" { + skip.IgnoreLint(t, "test case has no usable database") + return } - if hasFailed { - fn = func(t *testing.T) { - skip.IgnoreLint(t, "skipping test cases subsequent to earlier failure") - } + var testCases []CumulativeTestCaseSpec + for stageOrdinal := 1; stageOrdinal <= postCommitCount; stageOrdinal++ { + testCases = append(testCases, CumulativeTestCaseSpec{ + CumulativeTestSpec: spec, + Phase: scop.PostCommitPhase, + StageOrdinal: stageOrdinal, + StagesCount: postCommitCount, + After: after, + DatabaseName: dbName, + }) } - if !tc.run(t, fn) { - hasFailed = true + for stageOrdinal := 1; stageOrdinal <= postCommitNonRevertibleCount; stageOrdinal++ { + testCases = append(testCases, CumulativeTestCaseSpec{ + CumulativeTestSpec: spec, + Phase: scop.PostCommitNonRevertiblePhase, + StageOrdinal: stageOrdinal, + StagesCount: postCommitNonRevertibleCount, + After: after, + DatabaseName: dbName, + }) + } + var hasFailed bool + for _, tc := range testCases { + fn := func(t *testing.T) { + t.Parallel() // SAFE FOR TESTING + tf(t, tc) + } + if hasFailed { + fn = func(t *testing.T) { + skip.IgnoreLint(t, "skipping test cases subsequent to earlier failure") + } + } + if !tc.run(t, fn) { + hasFailed = true + } } } - } - cumulativeTest(t, relTestCaseDir, testFunc) + cumulativeTest(t, relTestCaseDir, testFunc) + }) } // fetchDescriptorStateQuery returns the CREATE statements for all descriptors