From 943a7fdf5cf18b06eacf75e190384d098f0be9f2 Mon Sep 17 00:00:00 2001 From: Xiang Gu Date: Wed, 2 Aug 2023 18:00:03 -0400 Subject: [PATCH] sctest: Parallelize declarative schema change tests Many of our declarative schema change tests follow the pattern of examining and testing certain behavior at each and every stage in the PostCommitPhase of the to-be-tested DDL(s). This includes our BACKUP/RESTORE, PAUSE/RESUME, ROLLBACK testing suites. Previously, each (sub)test for each stage is run sequentially, and this commit parallelize those (sub)tests. This can help with some of recent timeout failure we saw. Release note: None --- pkg/sql/schemachanger/sctest/backup.go | 27 ++++- pkg/sql/schemachanger/sctest/framework.go | 135 ++++++++++++---------- 2 files changed, 94 insertions(+), 68 deletions(-) diff --git a/pkg/sql/schemachanger/sctest/backup.go b/pkg/sql/schemachanger/sctest/backup.go index b6d99054f251..f5357256f004 100644 --- a/pkg/sql/schemachanger/sctest/backup.go +++ b/pkg/sql/schemachanger/sctest/backup.go @@ -13,7 +13,9 @@ package sctest import ( "context" gosql "database/sql" + "flag" "fmt" + "math/rand" "strings" "sync/atomic" "testing" @@ -70,10 +72,6 @@ func BackupSuccessMixedVersion(t *testing.T, path string, factory TestServerFact // and at least as expensive to run. skip.UnderShort(t) - if strings.Contains(path, "alter_table_add_primary_key_drop_rowid") { - skip.WithIssue(t, 107552, "flaky test") - } - factory = factory.WithMixedVersion() cumulativeTestForEachPostCommitStage(t, path, factory, func(t *testing.T, cs CumulativeTestCaseSpec) { backupSuccess(t, factory, cs) @@ -96,7 +94,27 @@ func BackupRollbacksMixedVersion(t *testing.T, path string, factory TestServerFa }) } +// runAllBackups runs all the backup tests, disabling the random skipping. +var runAllBackups = flag.Bool( + "run-all-backups", false, + "if true, run all backups instead of a random subset", +) + +// If the number of stages in the same phase exceeds skipThreshold, we enable +// skipping such that the backup test for each stage is skipped with probability +// skipRate. +// Set runAllBackups to true to disable skipping altogether. +const skipThreshold = 10 +const skipRate = .5 + +func maybeRandomlySkip(t *testing.T, stageCountInPhase int) { + if !*runAllBackups && stageCountInPhase > skipThreshold && rand.Float64() < skipRate { + skip.IgnoreLint(t, "skipping due to randomness") + } +} + func backupSuccess(t *testing.T, factory TestServerFactory, cs CumulativeTestCaseSpec) { + maybeRandomlySkip(t, cs.StagesCount) ctx := context.Background() url := fmt.Sprintf("userfile://backups.public.userfiles_$user/data_%s_%d", cs.Phase, cs.StageOrdinal) @@ -205,6 +223,7 @@ func backupRollbacks(t *testing.T, factory TestServerFactory, cs CumulativeTestC if cs.Phase != scop.PostCommitPhase { return } + maybeRandomlySkip(t, cs.StagesCount) ctx := context.Background() var urls atomic.Value var dbForBackup atomic.Pointer[gosql.DB] diff --git a/pkg/sql/schemachanger/sctest/framework.go b/pkg/sql/schemachanger/sctest/framework.go index 4f7e3c0a2c1a..ed565317ec65 100644 --- a/pkg/sql/schemachanger/sctest/framework.go +++ b/pkg/sql/schemachanger/sctest/framework.go @@ -644,83 +644,90 @@ func (cs CumulativeTestCaseSpec) run(t *testing.T, fn func(t *testing.T)) bool { return t.Run(fmt.Sprintf("%s_stage_%d_of_%d", prefix, cs.StageOrdinal, cs.StagesCount), fn) } +// cumulativeTestForEachPostCommitStage invokes `tf` once for each stage in the +// PostCommitPhase. These invocation are run in parallel. func cumulativeTestForEachPostCommitStage( t *testing.T, relTestCaseDir string, factory TestServerFactory, tf func(t *testing.T, spec CumulativeTestCaseSpec), ) { - testFunc := func(t *testing.T, spec CumulativeTestSpec) { - // Skip this test if any of the stmts is not fully supported. - if err := areStmtsFullySupportedAtClusterVersion(t, spec, factory); err != nil { - skip.IgnoreLint(t, "test is skipped because", err.Error()) - } - var postCommitCount, postCommitNonRevertibleCount int - var after [][]string - var dbName string - prepfn := func(db *gosql.DB, p scplan.Plan) { - for _, s := range p.Stages { - switch s.Phase { - case scop.PostCommitPhase: - postCommitCount++ - case scop.PostCommitNonRevertiblePhase: - postCommitNonRevertibleCount++ + // Grouping the parallel subtests into a non-parallel subtest allows any defer + // calls to work as expected. + t.Run("group", func(t *testing.T) { + testFunc := func(t *testing.T, spec CumulativeTestSpec) { + // Skip this test if any of the stmts is not fully supported. + if err := areStmtsFullySupportedAtClusterVersion(t, spec, factory); err != nil { + skip.IgnoreLint(t, "test is skipped because", err.Error()) + } + var postCommitCount, postCommitNonRevertibleCount int + var after [][]string + var dbName string + prepfn := func(db *gosql.DB, p scplan.Plan) { + for _, s := range p.Stages { + switch s.Phase { + case scop.PostCommitPhase: + postCommitCount++ + case scop.PostCommitNonRevertiblePhase: + postCommitNonRevertibleCount++ + } } + tdb := sqlutils.MakeSQLRunner(db) + var ok bool + dbName, ok = maybeGetDatabaseForIDs(t, tdb, screl.AllTargetStateDescIDs(p.TargetState)) + if ok { + tdb.Exec(t, fmt.Sprintf("USE %q", dbName)) + } + after = tdb.QueryStr(t, fetchDescriptorStateQuery) } - tdb := sqlutils.MakeSQLRunner(db) - var ok bool - dbName, ok = maybeGetDatabaseForIDs(t, tdb, screl.AllTargetStateDescIDs(p.TargetState)) - if ok { - tdb.Exec(t, fmt.Sprintf("USE %q", dbName)) + withPostCommitPlanAfterSchemaChange(t, spec, factory, prepfn) + if postCommitCount+postCommitNonRevertibleCount == 0 { + skip.IgnoreLint(t, "test case has no post-commit stages") + return } - after = tdb.QueryStr(t, fetchDescriptorStateQuery) - } - withPostCommitPlanAfterSchemaChange(t, spec, factory, prepfn) - if postCommitCount+postCommitNonRevertibleCount == 0 { - skip.IgnoreLint(t, "test case has no post-commit stages") - return - } - if dbName == "" { - skip.IgnoreLint(t, "test case has no usable database") - return - } - var testCases []CumulativeTestCaseSpec - for stageOrdinal := 1; stageOrdinal <= postCommitCount; stageOrdinal++ { - testCases = append(testCases, CumulativeTestCaseSpec{ - CumulativeTestSpec: spec, - Phase: scop.PostCommitPhase, - StageOrdinal: stageOrdinal, - StagesCount: postCommitCount, - After: after, - DatabaseName: dbName, - }) - } - for stageOrdinal := 1; stageOrdinal <= postCommitNonRevertibleCount; stageOrdinal++ { - testCases = append(testCases, CumulativeTestCaseSpec{ - CumulativeTestSpec: spec, - Phase: scop.PostCommitNonRevertiblePhase, - StageOrdinal: stageOrdinal, - StagesCount: postCommitNonRevertibleCount, - After: after, - DatabaseName: dbName, - }) - } - var hasFailed bool - for _, tc := range testCases { - fn := func(t *testing.T) { - tf(t, tc) + if dbName == "" { + skip.IgnoreLint(t, "test case has no usable database") + return } - if hasFailed { - fn = func(t *testing.T) { - skip.IgnoreLint(t, "skipping test cases subsequent to earlier failure") - } + var testCases []CumulativeTestCaseSpec + for stageOrdinal := 1; stageOrdinal <= postCommitCount; stageOrdinal++ { + testCases = append(testCases, CumulativeTestCaseSpec{ + CumulativeTestSpec: spec, + Phase: scop.PostCommitPhase, + StageOrdinal: stageOrdinal, + StagesCount: postCommitCount, + After: after, + DatabaseName: dbName, + }) } - if !tc.run(t, fn) { - hasFailed = true + for stageOrdinal := 1; stageOrdinal <= postCommitNonRevertibleCount; stageOrdinal++ { + testCases = append(testCases, CumulativeTestCaseSpec{ + CumulativeTestSpec: spec, + Phase: scop.PostCommitNonRevertiblePhase, + StageOrdinal: stageOrdinal, + StagesCount: postCommitNonRevertibleCount, + After: after, + DatabaseName: dbName, + }) + } + var hasFailed bool + for _, tc := range testCases { + fn := func(t *testing.T) { + t.Parallel() // SAFE FOR TESTING + tf(t, tc) + } + if hasFailed { + fn = func(t *testing.T) { + skip.IgnoreLint(t, "skipping test cases subsequent to earlier failure") + } + } + if !tc.run(t, fn) { + hasFailed = true + } } } - } - cumulativeTest(t, relTestCaseDir, testFunc) + cumulativeTest(t, relTestCaseDir, testFunc) + }) } // fetchDescriptorStateQuery returns the CREATE statements for all descriptors