diff --git a/pkg/sql/schemachanger/sctest/cumulative.go b/pkg/sql/schemachanger/sctest/cumulative.go index c1b4d21a9c88..788abf9a2258 100644 --- a/pkg/sql/schemachanger/sctest/cumulative.go +++ b/pkg/sql/schemachanger/sctest/cumulative.go @@ -64,8 +64,8 @@ const ( stageExecuteStmt stageExecType = 2 ) -// stageExecStmt a statement that will be executed during a given stage, -// including any expected errors from this statement or any schema change +// stageExecStmt represents statements that will be executed during a given +// stage, including any expected errors from this statement or any schema change // running concurrently. type stageExecStmt struct { execType stageExecType @@ -153,13 +153,12 @@ func makeStageExecStmtMap() *stageExecStmtMap { } } -// getExecStmts gets the statements to be used for a given phase and range -// of stages. +// getExecStmts gets the statements to be used for a particular phase and a +// particular stage. func (m *stageExecStmtMap) getExecStmts(targetKey stageKey) []*stageExecStmt { var stmts []*stageExecStmt if targetKey.minOrdinal != targetKey.maxOrdinal { - panic(fmt.Sprintf("only a single ordinal key can be looked up %v ", - targetKey)) + panic(fmt.Sprintf("only a single ordinal key can be looked up %v ", targetKey)) } for _, key := range m.entries { if key.stageKey.phase == targetKey.phase && @@ -172,8 +171,8 @@ func (m *stageExecStmtMap) getExecStmts(targetKey stageKey) []*stageExecStmt { return stmts } -// AssertMapIsUsed asserts that all injected DML statements injected -// at various stages. +// AssertMapIsUsed asserts that all DML statements are injected at various +// stages. func (m *stageExecStmtMap) AssertMapIsUsed(t *testing.T) { // If there is any rollback error, then not all stages will be used. for _, e := range m.entries { @@ -191,10 +190,16 @@ func (m *stageExecStmtMap) AssertMapIsUsed(t *testing.T) { require.Equal(t, len(m.usedMap), len(m.entries), "All declared entries was not used") } -// GetInjectionRanges gets a set of ranges that should have DML statements injected. -// This function will return the set of ranges where we will generate spans of -// statements that can be executed without any errors, until the first -// schema change error is hit. +// GetInjectionRuns returns a set of stage keys, where each key corresponds to a +// separate run of the schema change statement with the DML statements injected +// "appropriately". That is, for each stageKey (phase:startStage:endStage) in +// the return, we run the schema change statement and during each stage `s` within +// [startStage, endStage], we run DML statements that are requested to be injected +// in phase:`s`. +// +// The rule to generate the return is to aggregate all stages where injected DML +// statements does not cause the schema change to error, until the first schema change +// error is hit. // For example if we have the following DML statements concurrently with // schema changes: // 1) Statement A from phases 1:14 that will fail. @@ -206,7 +211,7 @@ func (m *stageExecStmtMap) AssertMapIsUsed(t *testing.T) { // // For each run the original DDL (schema change will be executed) with the // DML statements from the given ranges. -func (m *stageExecStmtMap) GetInjectionRanges( +func (m *stageExecStmtMap) GetInjectionRuns( totalPostCommit int, totalPostCommitNonRevertible int, ) []stageKey { var start stageKey @@ -214,8 +219,8 @@ func (m *stageExecStmtMap) GetInjectionRanges( var result []stageKey // First split any ranges that have schema change errors to have their own // entries. i.e. If stages 1 to N will generate schema change errors due to - // some statement, we need to have one entry for each one. Additionally, convert - // any latest ordinal values, to actual values. + // some statement, we need to have one entry for each one. Additionally, + // convert any latest ordinal values, to actual values. var forcedSplitEntries []stageKeyEntry for _, key := range m.entries { if key.stmt.execType != stageExecuteStmt { @@ -264,21 +269,16 @@ func (m *stageExecStmtMap) GetInjectionRanges( } result = append(result, keyRange) } - if !key.stmt.HasSchemaChangeError() && - start.IsEmpty() { + if !key.stmt.HasSchemaChangeError() && start.IsEmpty() { // If we see a schema change error, and no other statements were executed // earlier, then this is an entry on its own. start = key.stageKey end = key.stageKey - } else if !start.IsEmpty() && - (key.stmt.HasSchemaChangeError() || - (key.phase != start.phase)) { + } else if !start.IsEmpty() && (key.stmt.HasSchemaChangeError() || key.phase != start.phase) { // If we already have a start, and we either hit a schema change error // or separate phase, then we need to emit a new entry. setStart := true - if (key.phase == start.phase && - !key.stmt.HasSchemaChangeError()) || - start.IsEmpty() { + if (key.phase == start.phase && !key.stmt.HasSchemaChangeError()) || start.IsEmpty() { setStart = false end = key.stageKey if start.IsEmpty() { @@ -327,19 +327,26 @@ func (m *stageExecStmtMap) ParseStageExec(t *testing.T, d *datadriven.TestData) m.parseStageCommon(t, d, stageExecuteStmt) } -// parseStageCommon common fields between stage-exec and stage-query, which -// support the following keys: -// - phase - The phase in which this statement/query should be injected, of the -// string scop.Phase. Note: PreCommitPhase with stage 1 can be used to -// inject failures that will only happen for DML injection testing. -// - stage / stageStart / stageEnd - The ordinal for the stage where this -// statement should be injected. stageEnd accepts the special value -// latest which will map to the highest observed stage. -// - schemaChangeExecError a schema change execution error will be encountered -// by injecting at this stage. -// - schemaChangeExecErrorForRollback a schema change execution error that will -// be encountered at a future stage, leading to a rollback. -// - statements can refer to builtin variable names with a $ +// parseStageCommon processes common arguments for "stage-exec"-directive and +// "stage-query" directive, which include: +// - "phase": The phase in which this statement/query should be injected, of the +// string scop.Phase. +// - "stage": A range of stage ordinals, of the form "stage=x:y", in the +// specified phase where this statement should be injected. +// Note: There is a few shorthands for the notation. If we want to inject +// only at one particular stage, we can use "stage=x" If we want to inject +// the statement in all stages, we can use "stage=:". If we want to inject +// the statement in all stages starting from the x-th stage, we can use +// "stage=x:". +// Note: PreCommitPhase with stage 1 can be used to inject failures that will +// only happen for DML injection testing. +// - "schemaChangeExecError": assert that the DML injection will cause the +// schema change to fail with a particular error at the same stage. +// - "schemaChangeExecErrorForRollback": assert that the DML injection will cause +// the schema change to fail with a particular error in a future stage. +// - "rollback": mark that the injection happens during rollback. +// +// Note: statements can refer to builtin variable names with a dollar sign ($): // - $stageKey - A unique identifier for stages and phases // - $successfulStageCount - Number of stages of the that have been successfully // executed with injections @@ -355,25 +362,25 @@ func (m *stageExecStmtMap) parseStageCommon( if stmts[len(stmts)-1] == "" { stmts = stmts[0 : len(stmts)-1] } - for _, cmdArgs := range d.CmdArgs { - switch cmdArgs.Key { + for _, cmdArg := range d.CmdArgs { + switch cmdArg.Key { case "phase": found := false for i := scop.EarliestPhase; i <= scop.LatestPhase; i++ { - if cmdArgs.Vals[0] == i.String() { + if cmdArg.Vals[0] == i.String() { key.phase = i found = true break } } - require.Truef(t, found, "invalid phase name %s", cmdArgs.Key) + require.Truef(t, found, "invalid phase name %s", cmdArg.Key) if !found { panic("phase not mapped") } case "stage": // Detect ranges, otherwise we are looking at single value. - if strings.Contains(cmdArgs.Vals[0], ":") { - rangeVals := strings.Split(cmdArgs.Vals[0], ":") + if strings.Contains(cmdArg.Vals[0], ":") { + rangeVals := strings.Split(cmdArg.Vals[0], ":") key.minOrdinal = 1 key.maxOrdinal = stageKeyOrdinalLatest if len(rangeVals) >= 1 && len(rangeVals[0]) > 0 { @@ -390,24 +397,24 @@ func (m *stageExecStmtMap) parseStageCommon( key.maxOrdinal = ordinal } } else { - ordinal, err := strconv.Atoi(cmdArgs.Vals[0]) + ordinal, err := strconv.Atoi(cmdArg.Vals[0]) require.Greater(t, ordinal, 0, "minimum ordinal is zero") require.NoError(t, err) key.minOrdinal = ordinal key.maxOrdinal = ordinal } case "schemaChangeExecError": - schemaChangeErrorRegex = regexp.MustCompile(strings.Join(cmdArgs.Vals, " ")) + schemaChangeErrorRegex = regexp.MustCompile(strings.Join(cmdArg.Vals, " ")) require.Nil(t, schemaChangeErrorRegexRollback, "future and current stage errors cannot be set concurrently") case "schemaChangeExecErrorForRollback": - schemaChangeErrorRegexRollback = regexp.MustCompile(strings.Join(cmdArgs.Vals, " ")) + schemaChangeErrorRegexRollback = regexp.MustCompile(strings.Join(cmdArg.Vals, " ")) require.Nil(t, schemaChangeErrorRegex, "rollback and current stage errors cannot be set concurrently") case "rollback": - rollback, err := strconv.ParseBool(cmdArgs.Vals[0]) + rollback, err := strconv.ParseBool(cmdArg.Vals[0]) require.NoError(t, err) key.rollback = rollback default: - require.Failf(t, "unknown key encountered", "key was %s", cmdArgs.Key) + require.Failf(t, "unknown key encountered", "key was %s", cmdArg.Key) } } entry := stageKeyEntry{ @@ -459,32 +466,28 @@ func (e *stageExecStmt) Exec( } return "" }) - if e.execType == stageExecuteStmt { + switch e.execType { + case stageExecuteStmt: _, err := runner.DB.ExecContext(context.Background(), boundSQL) - if (e.expectedOutput == "" || - idx != len(e.stmts)-1) && err != nil { - if !rewrite { - t.Fatalf("unexpected error executing query %v", err) - } - e.expectedOutput = err.Error() - } else if err != nil { - errorMatches := testutils.IsError(err, strings.TrimSuffix(e.expectedOutput, "\n")) - if !errorMatches { - if !rewrite { - require.Truef(t, - errorMatches, - "unexpected error got: %v expected %v", - err, - e.expectedOutput) + if err != nil { + if idx != len(e.stmts)-1 { + // We require that only the last statement in a stage-exec block can cause an error. + t.Fatalf("unexpected error encountered; only the last statement can cause an error") + } else { + // Fail the test unless the error is expected (from e.expectedOutput), or + // "rewrite" is set, in which case we record the error and proceed. + errorMatches := testutils.IsError(err, strings.TrimSuffix(e.expectedOutput, "\n")) + if !errorMatches { + if !rewrite { + t.Fatalf("unexpected error: got: %v, expected: %v", err, e.expectedOutput) + } + e.expectedOutput = err.Error() } - e.expectedOutput = err.Error() } } - } else { + case stageExecuteQuery: var expectedQueryResult [][]string - for _, expectedRow := range strings.Split( - strings.TrimSuffix(e.expectedOutput, "\n"), - "\n") { + for _, expectedRow := range strings.Split(strings.TrimSuffix(e.expectedOutput, "\n"), "\n") { expectRowArray := strings.Split(expectedRow, ",") expectedQueryResult = append(expectedQueryResult, expectRowArray) } @@ -496,15 +499,15 @@ func (e *stageExecStmt) Exec( } e.expectedOutput = sqlutils.MatrixToStr(results) } + default: + t.Fatal("unknown execType") } } } // GetInjectionCallback gets call back that will inject statements based on a // given stage. -func (m *stageExecStmtMap) GetInjectionCallback( - t *testing.T, rewrite bool, -) (execInjectionCallback, error) { +func (m *stageExecStmtMap) GetInjectionCallback(t *testing.T, rewrite bool) execInjectionCallback { return func(stage stageKey, runner *sqlutils.SQLRunner, successfulStageCount int) []*stageExecStmt { execStmts := m.getExecStmts(stage) for _, execStmt := range execStmts { @@ -515,7 +518,7 @@ func (m *stageExecStmtMap) GetInjectionCallback( }, rewrite) } return execStmts - }, nil + } } // cumulativeTest is a foundational helper for building tests over the @@ -550,20 +553,25 @@ func cumulativeTest( numTestStmts := 0 // First pass collect stage-exec/stage-query/setup commands and execute them - // test once the test command is encountered. Only a single test command is - // allowed via an assertion which guarantees all others appear first. + // once the test command is encountered. Only a single test command is allowed + // via an assertion which guarantees all others appear first. + // This pass does no "checking" in the sense that "it does not compare any + // actual to any expected" (look: it always returns d.Expected!). Its only + // purpose is to run the "test"-ed statement, inject DMLs as specified, and + // collect output of those DML injections, so they can be used to rewrite the + // expected output of those DML injected in the second pass. datadriven.RunTest(t, path, func(t *testing.T, d *datadriven.TestData) string { - // Assert that only one test statement shows up and nothing can follow it - // afterwards. + // Assert that only one "test"-directive statement shows up and nothing can + // follow it afterwards. require.Zero(t, numTestStmts, "only one test command per-test, "+ "and it must be the last one.") switch d.Cmd { case "setup": + // Store setup stmts into `setup` slice (without executing them). stmts, err := parser.Parse(d.Input) setup = append(setup, stmts...) require.NoError(t, err) require.NotEmpty(t, stmts) - // no-op case "stage-exec": // DML injected statements will only be executed on cumalative tests, // for end-to-end tests these are fully ignored. @@ -586,21 +594,17 @@ func cumulativeTest( } return d.Expected }) - // Run through and recover the observed output for statements. + + // Second pass is reserved to rewrite expected output of DML injections. For + // all other directives, this pass effectively ignores them by returning + // d.Expected. if rewrite { datadriven.RunTest(t, path, func(t *testing.T, d *datadriven.TestData) string { - switch d.Cmd { - // Retrieve the generated output from the previous execution in the - // rewrite mode of DML injection. We are going to the store the - // observed output based on the line number and file names for the - // stage-exec and stage-query commands. - case "stage-exec": - fallthrough - case "stage-query": + if d.Cmd == "stage-exec" || d.Cmd == "stage-query" { + // Retrieve the actual output of each DML injection block (from first + // pass), indexed by file:line. return stageExecMap.GetExpectedOutputForPos(d.Pos) } - // cumlativeTest will rewrite only the stage-exec and stage-query commands, - // all others are rewritten by the end-to-end tests. return d.Expected }) } @@ -830,6 +834,8 @@ func ExecuteWithDMLInjection(t *testing.T, relPath string, newCluster NewCluster ) var injectionFunc execInjectionCallback testFunc := func(t *testing.T, _ string, rewrite bool, setup, stmts []statements.Statement[tree.Statement], execMap *stageExecStmtMap) { + // Count number of stages in PostCommit and PostCommitNonRevertible phase + // for running `stmts` after properly running `setup`. var postCommit, nonRevertible int processPlanInPhase(t, newCluster, setup, stmts, scop.PostCommitPhase, func( p scplan.Plan, @@ -838,13 +844,11 @@ func ExecuteWithDMLInjection(t *testing.T, relPath string, newCluster NewCluster nonRevertible = len(p.Stages) - postCommit }, nil) - injectionFunc, _ = execMap.GetInjectionCallback(t, rewrite) - injectionRanges := execMap.GetInjectionRanges(postCommit, nonRevertible) + injectionFunc = execMap.GetInjectionCallback(t, rewrite) + injectionRanges := execMap.GetInjectionRuns(postCommit, nonRevertible) defer execMap.AssertMapIsUsed(t) - injectPreCommits := []bool{true} - if execMap.getExecStmts(makeStageKey(scop.PreCommitPhase, - 1, - false)) != nil { + injectPreCommits := []bool{false} + if execMap.getExecStmts(makeStageKey(scop.PreCommitPhase, 1, false)) != nil { injectPreCommits = []bool{false, true} } // Test both happy and unhappy paths with pre-commit injection, this @@ -862,20 +866,27 @@ func ExecuteWithDMLInjection(t *testing.T, relPath string, newCluster NewCluster } } testDMLInjectionCase = func(t *testing.T, setup, stmts []statements.Statement[tree.Statement], injection stageKey, injectPreCommit bool) { + // Create a new cluster with the `BeforeStage` knob properly set for the DML injection framework. var schemaChangeErrorRegex *regexp.Regexp var lastRollbackStageKey *stageKey usedStages := make(map[int]struct{}) successfulStages := 0 + var clusterCreated atomic.Bool var tdb *sqlutils.SQLRunner _, db, cleanup := newCluster(t, &scexec.TestingKnobs{ BeforeStage: func(p scplan.Plan, stageIdx int) error { + if !clusterCreated.Load() { + // Do nothing if cluster creation isn't finished. Certain schema + // changes are run during cluster creation (e.g. `CREATE DATABASE + // defaultdb`) and we don't want those to hijack this knob. + return nil + } s := p.Stages[stageIdx] if (injection.phase == p.Stages[stageIdx].Phase && p.Stages[stageIdx].Ordinal >= injection.minOrdinal && p.Stages[stageIdx].Ordinal <= injection.maxOrdinal) || (p.InRollback || p.CurrentState.InRollback) || /* Rollbacks are always injected */ - (p.Stages[stageIdx].Phase == scop.PreCommitPhase && - injectPreCommit) { + (p.Stages[stageIdx].Phase == scop.PreCommitPhase && injectPreCommit) { jobErrorMutex.Lock() defer jobErrorMutex.Unlock() key := makeStageKey(s.Phase, s.Ordinal, p.InRollback || p.CurrentState.InRollback) @@ -890,8 +901,7 @@ func ExecuteWithDMLInjection(t *testing.T, relPath string, newCluster NewCluster injectStmts := injectionFunc(key, tdb, successfulStages) regexSetOnce := false for _, injectStmt := range injectStmts { - if injectStmt != nil && - injectStmt.HasAnySchemaChangeError() != nil { + if injectStmt != nil && injectStmt.HasAnySchemaChangeError() != nil { require.Falsef(t, regexSetOnce, "multiple statements are expecting errors in the same phase.") schemaChangeErrorRegex = injectStmt.HasAnySchemaChangeError() regexSetOnce = true @@ -909,11 +919,15 @@ func ExecuteWithDMLInjection(t *testing.T, relPath string, newCluster NewCluster }, }) defer cleanup() + clusterCreated.Store(true) tdb = sqlutils.MakeSQLRunner(db) + + // Now run the schema change and the `BeforeStage` knob will inject DMLs + // as specified in `injection`. errorDetected := false onError := func(err error) error { - if schemaChangeErrorRegex != nil && - schemaChangeErrorRegex.MatchString(err.Error()) { + // Mute the error if it matches what the DML injection specifies. + if schemaChangeErrorRegex != nil && schemaChangeErrorRegex.MatchString(err.Error()) { errorDetected = true return nil } @@ -1462,7 +1476,7 @@ SELECT name } // processPlanInPhase will call processFunc with the plan as of the first -// stage in the requested phase. The function will be called at most once. +// stage in the requested phase. processFunc will be called at most once. func processPlanInPhase( t *testing.T, newCluster NewClusterFunc, diff --git a/pkg/sql/schemachanger/testdata/end_to_end/add_column_with_stored b/pkg/sql/schemachanger/testdata/end_to_end/add_column_with_stored index c6808bda00a2..180bc0d791c9 100644 --- a/pkg/sql/schemachanger/testdata/end_to_end/add_column_with_stored +++ b/pkg/sql/schemachanger/testdata/end_to_end/add_column_with_stored @@ -14,12 +14,6 @@ stage-exec phase=PreCommitPhase stage=1 schemaChangeExecErrorForRollback=(.*vali UPDATE db.public.tbl SET k=NULL WHERE i = -7; ---- -# Each insert will be injected twice per stage, plus 1 extra. -stage-query phase=PostCommitPhase stage=: rollback=true -SELECT count(*)=($successfulStageCount*2)+4 FROM db.public.tbl; ----- -true - # Each insert will be injected twice per stage, plus 1 extra. stage-query phase=PostCommitNonRevertiblePhase stage=: rollback=true SELECT count(*)=($successfulStageCount*2)+4 FROM db.public.tbl; diff --git a/pkg/sql/schemachanger/testdata/end_to_end/add_column_with_stored_family b/pkg/sql/schemachanger/testdata/end_to_end/add_column_with_stored_family index 006b572e826f..d3dcf1f67e0e 100644 --- a/pkg/sql/schemachanger/testdata/end_to_end/add_column_with_stored_family +++ b/pkg/sql/schemachanger/testdata/end_to_end/add_column_with_stored_family @@ -14,12 +14,6 @@ stage-exec phase=PreCommitPhase stage=1 schemaChangeExecErrorForRollback=(.*vali UPDATE db.public.tbl SET k=NULL WHERE i = -7; ---- -# Each insert will be injected twice per stage, plus 1 extra. -stage-query phase=PostCommitPhase stage=: rollback=true -SELECT count(*)=($successfulStageCount*2)+4 FROM db.public.tbl; ----- -true - # Each insert will be injected twice per stage, plus 1 extra. stage-query phase=PostCommitNonRevertiblePhase stage=: rollback=true SELECT count(*)=($successfulStageCount*2)+4 FROM db.public.tbl;