diff --git a/pkg/ccl/backupccl/BUILD.bazel b/pkg/ccl/backupccl/BUILD.bazel index 44c5a8ce2e31..b85b2709bc4e 100644 --- a/pkg/ccl/backupccl/BUILD.bazel +++ b/pkg/ccl/backupccl/BUILD.bazel @@ -78,6 +78,7 @@ go_library( "//pkg/sql/catalog/descs", "//pkg/sql/catalog/desctestutils", "//pkg/sql/catalog/multiregion", + "//pkg/sql/catalog/nstree", "//pkg/sql/catalog/resolver", "//pkg/sql/catalog/schemadesc", "//pkg/sql/catalog/schemaexpr", @@ -96,6 +97,9 @@ go_library( "//pkg/sql/roleoption", "//pkg/sql/rowenc", "//pkg/sql/rowexec", + "//pkg/sql/schemachanger/scbackup", + "//pkg/sql/schemachanger/scpb", + "//pkg/sql/schemachanger/screl", "//pkg/sql/sem/builtins", "//pkg/sql/sem/tree", "//pkg/sql/sessiondata", diff --git a/pkg/ccl/backupccl/restore_job.go b/pkg/ccl/backupccl/restore_job.go index 28bf6cff574f..2fece94fc279 100644 --- a/pkg/ccl/backupccl/restore_job.go +++ b/pkg/ccl/backupccl/restore_job.go @@ -38,6 +38,7 @@ import ( "github.com/cockroachdb/cockroach/pkg/sql/catalog/descpb" "github.com/cockroachdb/cockroach/pkg/sql/catalog/descs" "github.com/cockroachdb/cockroach/pkg/sql/catalog/multiregion" + "github.com/cockroachdb/cockroach/pkg/sql/catalog/nstree" "github.com/cockroachdb/cockroach/pkg/sql/catalog/schemadesc" "github.com/cockroachdb/cockroach/pkg/sql/catalog/tabledesc" "github.com/cockroachdb/cockroach/pkg/sql/catalog/typedesc" @@ -45,6 +46,7 @@ import ( "github.com/cockroachdb/cockroach/pkg/sql/pgwire/pgcode" "github.com/cockroachdb/cockroach/pkg/sql/pgwire/pgerror" "github.com/cockroachdb/cockroach/pkg/sql/privilege" + "github.com/cockroachdb/cockroach/pkg/sql/schemachanger/scbackup" "github.com/cockroachdb/cockroach/pkg/sql/sem/tree" "github.com/cockroachdb/cockroach/pkg/sql/sessiondata" "github.com/cockroachdb/cockroach/pkg/sql/sqlutil" @@ -1870,30 +1872,44 @@ func (r *restoreResumer) publishDescriptors( // Write the new descriptors and flip state over to public so they can be // accessed. - allMutDescs := make([]catalog.MutableDescriptor, 0, - len(details.TableDescs)+len(details.TypeDescs)+len(details.SchemaDescs)+len(details.DatabaseDescs)) + + // Pre-fetch all the descriptors into the collection to avoid doing + // round-trips per descriptor. + all, err := prefetchDescriptors(ctx, txn, descsCol, details) + if err != nil { + return err + } + // Create slices of raw descriptors for the restore job details. newTables := make([]*descpb.TableDescriptor, 0, len(details.TableDescs)) newTypes := make([]*descpb.TypeDescriptor, 0, len(details.TypeDescs)) newSchemas := make([]*descpb.SchemaDescriptor, 0, len(details.SchemaDescs)) newDBs := make([]*descpb.DatabaseDescriptor, 0, len(details.DatabaseDescs)) - checkVersion := func(read catalog.Descriptor, exp descpb.DescriptorVersion) error { - if read.GetVersion() == exp { - return nil + + // Go through the descriptors and find any declarative schema change jobs + // affecting them. + // + // If we're restoring all the descriptors, it means we're also restoring the + // jobs. + if details.DescriptorCoverage != tree.AllDescriptors { + if err := scbackup.CreateDeclarativeSchemaChangeJobs( + ctx, r.execCfg.JobRegistry, txn, all, + ); err != nil { + return err } - return errors.Errorf("version mismatch for descriptor %d, expected version %d, got %v", - read.GetID(), read.GetVersion(), exp) } + // Write the new TableDescriptors and flip state over to public so they can be // accessed. - for _, tbl := range details.TableDescs { - mutTable, err := descsCol.GetMutableTableVersionByID(ctx, tbl.GetID(), txn) - if err != nil { - return err - } - if err := checkVersion(mutTable, tbl.Version); err != nil { - return err + for i := range details.TableDescs { + mutTable := all.LookupDescriptorEntry(details.TableDescs[i].GetID()).(*tabledesc.Mutable) + // Note that we don't need to worry about the re-validated indexes for descriptors + // with a declarative schema change job. + if mutTable.GetDeclarativeSchemaChangerState() != nil { + newTables = append(newTables, mutTable.TableDesc()) + continue } + badIndexes := devalidateIndexes[mutTable.ID] for _, badIdx := range badIndexes { found, err := mutTable.FindIndexWithID(badIdx) @@ -1910,7 +1926,6 @@ func (r *restoreResumer) publishDescriptors( if err := mutTable.AllocateIDs(ctx, version); err != nil { return err } - allMutDescs = append(allMutDescs, mutTable) newTables = append(newTables, mutTable.TableDesc()) // For cluster restores, all the jobs are restored directly from the jobs // table, so there is no need to re-create ongoing schema change jobs, @@ -1927,17 +1942,11 @@ func (r *restoreResumer) publishDescriptors( } // For all of the newly created types, make type schema change jobs for any // type descriptors that were backed up in the middle of a type schema change. - for _, typDesc := range details.TypeDescs { - typ, err := descsCol.GetMutableTypeVersionByID(ctx, txn, typDesc.GetID()) - if err != nil { - return err - } - if err := checkVersion(typ, typDesc.Version); err != nil { - return err - } - allMutDescs = append(allMutDescs, typ) + for i := range details.TypeDescs { + typ := all.LookupDescriptorEntry(details.TypeDescs[i].GetID()).(catalog.TypeDescriptor) newTypes = append(newTypes, typ.TypeDesc()) - if typ.HasPendingSchemaChanges() && details.DescriptorCoverage != tree.AllDescriptors { + if typ.GetDeclarativeSchemaChangerState() == nil && + typ.HasPendingSchemaChanges() && details.DescriptorCoverage != tree.AllDescriptors { if err := createTypeChangeJobFromDesc( ctx, r.execCfg.JobRegistry, r.execCfg.Codec, txn, r.job.Payload().UsernameProto.Decode(), typ, ); err != nil { @@ -1945,49 +1954,24 @@ func (r *restoreResumer) publishDescriptors( } } } - for _, sc := range details.SchemaDescs { - mutDesc, err := descsCol.GetMutableDescriptorByID(ctx, txn, sc.ID) - if err != nil { - return err - } - if err := checkVersion(mutDesc, sc.Version); err != nil { - return err - } - mutSchema := mutDesc.(*schemadesc.Mutable) - allMutDescs = append(allMutDescs, mutSchema) - newSchemas = append(newSchemas, mutSchema.SchemaDesc()) + for i := range details.SchemaDescs { + sc := all.LookupDescriptorEntry(details.SchemaDescs[i].GetID()).(catalog.SchemaDescriptor) + newSchemas = append(newSchemas, sc.SchemaDesc()) } - for _, dbDesc := range details.DatabaseDescs { - // Jobs started before 20.2 upgrade finalization don't put databases in - // an offline state. - // TODO(lucy): Should we make this more explicit with a format version - // field in the details? - mutDesc, err := descsCol.GetMutableDescriptorByID(ctx, txn, dbDesc.ID) - if err != nil { - return err - } - if err := checkVersion(mutDesc, dbDesc.Version); err != nil { - return err - } - mutDB := mutDesc.(*dbdesc.Mutable) - // TODO(lucy,ajwerner): Remove this in 21.1. - if !mutDB.Offline() { - newDBs = append(newDBs, dbDesc) - } else { - allMutDescs = append(allMutDescs, mutDB) - newDBs = append(newDBs, mutDB.DatabaseDesc()) - } + for i := range details.DatabaseDescs { + db := all.LookupDescriptorEntry(details.DatabaseDescs[i].GetID()).(catalog.DatabaseDescriptor) + newDBs = append(newDBs, db.DatabaseDesc()) } b := txn.NewBatch() - for _, desc := range allMutDescs { - desc.SetPublic() - if err := descsCol.WriteDescToBatch( - ctx, false /* kvTrace */, desc, b, - ); err != nil { - return err - } + if err := all.ForEachDescriptorEntry(func(desc catalog.Descriptor) error { + d := desc.(catalog.MutableDescriptor) + d.SetPublic() + return descsCol.WriteDescToBatch( + ctx, false /* kvTrace */, d, b, + ) + }); err != nil { + return err } - if err := txn.Run(ctx, b); err != nil { return errors.Wrap(err, "publishing tables") } @@ -2011,6 +1995,57 @@ func (r *restoreResumer) publishDescriptors( return nil } +// prefetchDescriptors calculates the set of descriptors needed by looking +// at the relevant fields of the job details. It then fetches all of those +// descriptors in a batch using the descsCol. It packages up that set of +// descriptors into an nstree.Catalog for easy use. +// +// This function also takes care of asserting that the retrieved version +// matches the expectation. +func prefetchDescriptors( + ctx context.Context, txn *kv.Txn, descsCol *descs.Collection, details jobspb.RestoreDetails, +) (_ nstree.Catalog, _ error) { + var all nstree.MutableCatalog + var allDescIDs catalog.DescriptorIDSet + expVersion := map[descpb.ID]descpb.DescriptorVersion{} + for i := range details.TableDescs { + expVersion[details.TableDescs[i].GetID()] = details.TableDescs[i].GetVersion() + allDescIDs.Add(details.TableDescs[i].GetID()) + } + for i := range details.TypeDescs { + expVersion[details.TypeDescs[i].GetID()] = details.TypeDescs[i].GetVersion() + allDescIDs.Add(details.TypeDescs[i].GetID()) + } + for i := range details.SchemaDescs { + expVersion[details.SchemaDescs[i].GetID()] = details.SchemaDescs[i].GetVersion() + allDescIDs.Add(details.SchemaDescs[i].GetID()) + } + for i := range details.DatabaseDescs { + expVersion[details.DatabaseDescs[i].GetID()] = details.DatabaseDescs[i].GetVersion() + allDescIDs.Add(details.DatabaseDescs[i].GetID()) + } + // Note that no maximum size is put on the batch here because, + // in general, we assume that we can fit all of the descriptors + // in RAM (we have them in RAM as part of the details object, + // and we're going to write them to KV very soon as part of a + // single batch). + ids := allDescIDs.Ordered() + got, err := descsCol.GetMutableDescriptorsByID(ctx, txn, ids...) + if err != nil { + return nstree.Catalog{}, errors.Wrap(err, "prefetch descriptors") + } + for i, id := range ids { + if got[i].GetVersion() != expVersion[id] { + return nstree.Catalog{}, errors.Errorf( + "version mismatch for descriptor %d, expected version %d, got %v", + got[i].GetID(), got[i].GetVersion(), expVersion[id], + ) + } + all.UpsertDescriptorEntry(got[i]) + } + return all.Catalog, nil +} + func emitRestoreJobEvent( ctx context.Context, p sql.JobExecContext, status jobs.Status, job *jobs.Job, ) { diff --git a/pkg/ccl/backupccl/restore_planning.go b/pkg/ccl/backupccl/restore_planning.go index 795585672ebe..9c24d4935362 100644 --- a/pkg/ccl/backupccl/restore_planning.go +++ b/pkg/ccl/backupccl/restore_planning.go @@ -53,6 +53,8 @@ import ( "github.com/cockroachdb/cockroach/pkg/sql/pgwire/pgnotice" "github.com/cockroachdb/cockroach/pkg/sql/privilege" "github.com/cockroachdb/cockroach/pkg/sql/roleoption" + "github.com/cockroachdb/cockroach/pkg/sql/schemachanger/scpb" + "github.com/cockroachdb/cockroach/pkg/sql/schemachanger/screl" "github.com/cockroachdb/cockroach/pkg/sql/sem/tree" "github.com/cockroachdb/cockroach/pkg/sql/sqlerrors" "github.com/cockroachdb/cockroach/pkg/sql/types" @@ -1065,6 +1067,10 @@ func rewriteDatabaseDescs(databases []*dbdesc.Mutable, descriptorRewrites DescRe db.Version = 1 db.ModificationTime = hlc.Timestamp{} + if err := rewriteSchemaChangerState(db, descriptorRewrites); err != nil { + return err + } + // Rewrite the name-to-ID mapping for the database's child schemas. newSchemas := make(map[string]descpb.DatabaseDescriptor_SchemaInfo) err := db.ForEachNonDroppedSchema(func(id descpb.ID, name string) error { @@ -1130,6 +1136,10 @@ func rewriteTypeDescs(types []*typedesc.Mutable, descriptorRewrites DescRewriteM typ.Version = 1 typ.ModificationTime = hlc.Timestamp{} + if err := rewriteSchemaChangerState(typ, descriptorRewrites); err != nil { + return err + } + typ.ID = rewrite.ID typ.ParentSchemaID = rewrite.ParentSchemaID typ.ParentID = rewrite.ParentID @@ -1170,6 +1180,85 @@ func rewriteSchemaDescs(schemas []*schemadesc.Mutable, descriptorRewrites DescRe sc.ID = rewrite.ID sc.ParentID = rewrite.ParentID + + if err := rewriteSchemaChangerState(sc, descriptorRewrites); err != nil { + return err + } + } + return nil +} + +// rewriteSchemaChangerState handles rewriting any references to IDs stored in +// the descriptor's declarative schema changer state. +func rewriteSchemaChangerState( + d catalog.MutableDescriptor, descriptorRewrites DescRewriteMap, +) (err error) { + state := d.GetDeclarativeSchemaChangerState() + if state == nil { + return nil + } + defer func() { + if err != nil { + err = errors.Wrap(err, "rewriting declarative schema changer state") + } + }() + for i := 0; i < len(state.Targets); i++ { + t := &state.Targets[i] + if err := screl.WalkDescIDs(t.Element(), func(id *descpb.ID) error { + rewrite, ok := descriptorRewrites[*id] + if !ok { + return errors.Errorf("missing rewrite for id %d in %T", *id, t) + } + *id = rewrite.ID + return nil + }); err != nil { + // We'll permit this in the special case of a schema descriptor + // database entry. + // + // TODO(ajwerner,postamar): it's not obvious that this should be its own + // element as opposed to just an extension of the namespace table that only + // ops know about. + switch el := t.Element().(type) { + case *scpb.DatabaseSchemaEntry: + _, scExists := descriptorRewrites[el.SchemaID] + if !scExists && state.CurrentStatuses[i] == scpb.Status_ABSENT { + state.Targets = append(state.Targets[:i], state.Targets[i+1:]...) + state.CurrentStatuses = append(state.CurrentStatuses[:i], state.CurrentStatuses[i+1:]...) + state.TargetRanks = append(state.TargetRanks[:i], state.TargetRanks[i+1:]...) + i-- + continue + } + } + return errors.Wrap(err, "rewriting descriptor ids") + } + + if err := screl.WalkExpressions(t.Element(), func(expr *catpb.Expression) error { + if *expr == "" { + return nil + } + newExpr, err := rewriteTypesInExpr(string(*expr), descriptorRewrites) + if err != nil { + return errors.Wrapf(err, "rewriting expression type references: %q", *expr) + } + newExpr, err = rewriteSequencesInExpr(newExpr, descriptorRewrites) + if err != nil { + return errors.Wrapf(err, "rewriting expression sequence references: %q", newExpr) + } + *expr = catpb.Expression(newExpr) + return nil + }); err != nil { + return err + } + if err := screl.WalkTypes(t.Element(), func(t *types.T) error { + return rewriteIDsInTypesT(t, descriptorRewrites) + }); err != nil { + return errors.Wrap(err, "rewriting user-defined type references") + } + // TODO(ajwerner): Remember to rewrite views when the time comes. Currently + // views are not handled by the declarative schema changer. + } + if len(state.Targets) == 0 { + d.SetDeclarativeSchemaChangerState(nil) } return nil } @@ -1200,6 +1289,9 @@ func RewriteTableDescs( return err } } + if err := rewriteSchemaChangerState(table, descriptorRewrites); err != nil { + return err + } table.ID = tableRewrite.ID table.UnexposedParentSchemaID = tableRewrite.ParentSchemaID @@ -1365,6 +1457,7 @@ func RewriteTableDescs( // lease is obviously bogus (plus the nodeID is relative to backup cluster). table.Lease = nil } + return nil } diff --git a/pkg/ccl/schemachangerccl/BUILD.bazel b/pkg/ccl/schemachangerccl/BUILD.bazel index 860662b75868..f9bf6c64ecff 100644 --- a/pkg/ccl/schemachangerccl/BUILD.bazel +++ b/pkg/ccl/schemachangerccl/BUILD.bazel @@ -6,18 +6,23 @@ go_test( "main_test.go", "schemachanger_end_to_end_test.go", ], - data = glob(["testdata/**"]), + data = glob(["testdata/**"]) + [ + "//pkg/sql/schemachanger:testdata", + ], embed = [":schemachangerccl"], deps = [ "//pkg/base", + "//pkg/build/bazel", "//pkg/ccl", "//pkg/ccl/multiregionccl/multiregionccltestutils", "//pkg/ccl/utilccl", + "//pkg/jobs", "//pkg/security", "//pkg/security/securitytest", "//pkg/server", "//pkg/sql/schemachanger/scrun", "//pkg/sql/schemachanger/sctest", + "//pkg/testutils", "//pkg/testutils/serverutils", "//pkg/testutils/testcluster", "//pkg/util/leaktest", diff --git a/pkg/ccl/schemachangerccl/schemachanger_end_to_end_test.go b/pkg/ccl/schemachangerccl/schemachanger_end_to_end_test.go index 01f2b47aeee4..ad79d4ac0960 100644 --- a/pkg/ccl/schemachangerccl/schemachanger_end_to_end_test.go +++ b/pkg/ccl/schemachangerccl/schemachanger_end_to_end_test.go @@ -13,29 +13,65 @@ import ( "testing" "github.com/cockroachdb/cockroach/pkg/base" + "github.com/cockroachdb/cockroach/pkg/build/bazel" "github.com/cockroachdb/cockroach/pkg/ccl/multiregionccl/multiregionccltestutils" + "github.com/cockroachdb/cockroach/pkg/jobs" "github.com/cockroachdb/cockroach/pkg/sql/schemachanger/scrun" "github.com/cockroachdb/cockroach/pkg/sql/schemachanger/sctest" + "github.com/cockroachdb/cockroach/pkg/testutils" "github.com/cockroachdb/cockroach/pkg/util/leaktest" "github.com/cockroachdb/cockroach/pkg/util/log" ) func newCluster(t *testing.T, knobs *scrun.TestingKnobs) (*gosql.DB, func()) { _, sqlDB, cleanup := multiregionccltestutils.TestingCreateMultiRegionCluster( - t, 3 /* numServers */, base.TestingKnobs{}, + t, 3 /* numServers */, base.TestingKnobs{ + SQLDeclarativeSchemaChanger: knobs, + JobsTestingKnobs: jobs.NewTestingKnobsWithShortIntervals(), + }, ) return sqlDB, cleanup } +func sharedTestdata(t *testing.T) string { + testdataDir := "../../sql/schemachanger/testdata/" + if bazel.BuiltWithBazel() { + runfile, err := bazel.Runfile("pkg/sql/schemachanger/testdata/") + if err != nil { + t.Fatal(err) + } + testdataDir = runfile + } + return testdataDir +} + func TestSchemaChangerSideEffects(t *testing.T) { defer leaktest.AfterTest(t)() defer log.Scope(t).Close(t) - sctest.EndToEndSideEffects(t, newCluster) + sctest.EndToEndSideEffects(t, testutils.TestDataPath(t), newCluster) +} + +func TestBackupRestore(t *testing.T) { + defer leaktest.AfterTest(t)() + defer log.Scope(t).Close(t) + t.Run("ccl", func(t *testing.T) { + sctest.Backup(t, testutils.TestDataPath(t), newCluster) + }) + t.Run("non-ccl", func(t *testing.T) { + sctest.Backup(t, sharedTestdata(t), sctest.SingleNodeCluster) + }) } func TestRollback(t *testing.T) { defer leaktest.AfterTest(t)() defer log.Scope(t).Close(t) - sctest.Rollback(t, newCluster) + sctest.Rollback(t, testutils.TestDataPath(t), newCluster) +} + +func TestPause(t *testing.T) { + defer leaktest.AfterTest(t)() + defer log.Scope(t).Close(t) + + sctest.Pause(t, testutils.TestDataPath(t), newCluster) } diff --git a/pkg/ccl/schemachangerccl/testdata/drop_multiregion_database b/pkg/ccl/schemachangerccl/testdata/drop_multiregion_database index 6646ab899f9d..318b044784f6 100644 --- a/pkg/ccl/schemachangerccl/testdata/drop_multiregion_database +++ b/pkg/ccl/schemachangerccl/testdata/drop_multiregion_database @@ -13,10 +13,12 @@ DROP DATABASE multi_region_test_db CASCADE checking for feature: DROP DATABASE begin transaction #1 # begin StatementPhase -## StatementPhase stage 1 of 1 with 7 MutationType ops +## StatementPhase stage 1 of 1 with 9 MutationType ops delete comment for descriptor #106 of type SchemaCommentType delete comment for descriptor #104 of type DatabaseCommentType delete role settings for database on #104 +delete database namespace entry {0 0 multi_region_test_db} -> 104 +delete schema namespace entry {104 0 public} -> 106 # end StatementPhase # begin PreCommitPhase ## PreCommitPhase stage 1 of 1 with 12 MutationType ops @@ -228,15 +230,13 @@ commit transaction #1 begin transaction #2 commit transaction #2 begin transaction #3 -## PostCommitNonRevertiblePhase stage 1 of 1 with 15 MutationType ops +## PostCommitNonRevertiblePhase stage 1 of 1 with 13 MutationType ops create job #2: "GC for dropping descriptors and parent database 104" descriptor IDs: [] write *eventpb.DropDatabase to event log for descriptor #104: DROP DATABASE ‹multi_region_test_db› CASCADE write *eventpb.DropType to event log for descriptor #107: DROP DATABASE ‹multi_region_test_db› CASCADE update progress of schema change job #1 set schema change job #1 to non-cancellable -delete database namespace entry {0 0 multi_region_test_db} -> 104 -delete schema namespace entry {104 0 public} -> 106 upsert descriptor #104 database: - declarativeSchemaChangerState: diff --git a/pkg/sql/schemachanger/BUILD.bazel b/pkg/sql/schemachanger/BUILD.bazel index 1d6386b08384..34d54fedeafc 100644 --- a/pkg/sql/schemachanger/BUILD.bazel +++ b/pkg/sql/schemachanger/BUILD.bazel @@ -1,5 +1,11 @@ load("@io_bazel_rules_go//go:def.bzl", "go_test") +filegroup( + name = "testdata", + srcs = glob(["testdata/**"]), + visibility = ["//visibility:public"], +) + go_test( name = "schemachanger_test", size = "medium", @@ -8,7 +14,7 @@ go_test( "main_test.go", "schemachanger_test.go", ], - data = glob(["testdata/**"]), + data = [":testdata"], deps = [ "//pkg/base", "//pkg/jobs", diff --git a/pkg/sql/schemachanger/end_to_end_test.go b/pkg/sql/schemachanger/end_to_end_test.go index 7e14c5044114..67e30af608e0 100644 --- a/pkg/sql/schemachanger/end_to_end_test.go +++ b/pkg/sql/schemachanger/end_to_end_test.go @@ -11,15 +11,10 @@ package schemachanger_test import ( - "context" - gosql "database/sql" "testing" - "github.com/cockroachdb/cockroach/pkg/base" - "github.com/cockroachdb/cockroach/pkg/jobs" - "github.com/cockroachdb/cockroach/pkg/sql/schemachanger/scrun" "github.com/cockroachdb/cockroach/pkg/sql/schemachanger/sctest" - "github.com/cockroachdb/cockroach/pkg/testutils/serverutils" + "github.com/cockroachdb/cockroach/pkg/testutils" "github.com/cockroachdb/cockroach/pkg/util/leaktest" "github.com/cockroachdb/cockroach/pkg/util/log" ) @@ -27,31 +22,19 @@ import ( func TestSchemaChangerSideEffects(t *testing.T) { defer leaktest.AfterTest(t)() defer log.Scope(t).Close(t) - sctest.EndToEndSideEffects(t, newCluster) + sctest.EndToEndSideEffects(t, testutils.TestDataPath(t), sctest.SingleNodeCluster) } func TestRollback(t *testing.T) { defer leaktest.AfterTest(t)() defer log.Scope(t).Close(t) - sctest.Rollback(t, newCluster) + sctest.Rollback(t, testutils.TestDataPath(t), sctest.SingleNodeCluster) } func TestPause(t *testing.T) { defer leaktest.AfterTest(t)() defer log.Scope(t).Close(t) - sctest.Pause(t, newCluster) -} - -func newCluster(t *testing.T, knobs *scrun.TestingKnobs) (*gosql.DB, func()) { - s, db, _ := serverutils.StartServer(t, base.TestServerArgs{ - Knobs: base.TestingKnobs{ - SQLDeclarativeSchemaChanger: knobs, - JobsTestingKnobs: jobs.NewTestingKnobsWithShortIntervals(), - }, - }) - return db, func() { - s.Stopper().Stop(context.Background()) - } + sctest.Pause(t, testutils.TestDataPath(t), sctest.SingleNodeCluster) } diff --git a/pkg/sql/schemachanger/scbackup/BUILD.bazel b/pkg/sql/schemachanger/scbackup/BUILD.bazel new file mode 100644 index 000000000000..b07134d20a6a --- /dev/null +++ b/pkg/sql/schemachanger/scbackup/BUILD.bazel @@ -0,0 +1,21 @@ +load("@io_bazel_rules_go//go:def.bzl", "go_library") + +go_library( + name = "scbackup", + srcs = [ + "doc.go", + "job.go", + ], + importpath = "github.com/cockroachdb/cockroach/pkg/sql/schemachanger/scbackup", + visibility = ["//visibility:public"], + deps = [ + "//pkg/jobs", + "//pkg/kv", + "//pkg/sql/catalog", + "//pkg/sql/catalog/catpb", + "//pkg/sql/catalog/nstree", + "//pkg/sql/schemachanger/scexec", + "//pkg/sql/schemachanger/scpb", + "//pkg/sql/schemachanger/screl", + ], +) diff --git a/pkg/sql/schemachanger/scbackup/doc.go b/pkg/sql/schemachanger/scbackup/doc.go new file mode 100644 index 000000000000..9bcb5573d83c --- /dev/null +++ b/pkg/sql/schemachanger/scbackup/doc.go @@ -0,0 +1,13 @@ +// Copyright 2022 The Cockroach Authors. +// +// Use of this software is governed by the Business Source License +// included in the file licenses/BSL.txt. +// +// As of the Change Date specified in that file, in accordance with +// the Business Source License, use of this software will be governed +// by the Apache License, Version 2.0, included in the file +// licenses/APL.txt. + +// Package scbackup contains logic for interacting with schema changer state +// during backup and restore. +package scbackup diff --git a/pkg/sql/schemachanger/scbackup/job.go b/pkg/sql/schemachanger/scbackup/job.go new file mode 100644 index 000000000000..1e3a3e13d14b --- /dev/null +++ b/pkg/sql/schemachanger/scbackup/job.go @@ -0,0 +1,73 @@ +// Copyright 2022 The Cockroach Authors. +// +// Use of this software is governed by the Business Source License +// included in the file licenses/BSL.txt. +// +// As of the Change Date specified in that file, in accordance with +// the Business Source License, use of this software will be governed +// by the Apache License, Version 2.0, included in the file +// licenses/APL.txt. + +package scbackup + +import ( + "context" + + "github.com/cockroachdb/cockroach/pkg/jobs" + "github.com/cockroachdb/cockroach/pkg/kv" + "github.com/cockroachdb/cockroach/pkg/sql/catalog" + "github.com/cockroachdb/cockroach/pkg/sql/catalog/catpb" + "github.com/cockroachdb/cockroach/pkg/sql/catalog/nstree" + "github.com/cockroachdb/cockroach/pkg/sql/schemachanger/scexec" + "github.com/cockroachdb/cockroach/pkg/sql/schemachanger/scpb" + "github.com/cockroachdb/cockroach/pkg/sql/schemachanger/screl" +) + +// CreateDeclarativeSchemaChangeJobs is called during the last phase of a +// restore. The provided catalog should contain all descriptors being restored. +// The code here will iterate those descriptors and synthesize the appropriate +// jobs. +// +// It should only be called for backups which do not restore the jobs table +// directly. +func CreateDeclarativeSchemaChangeJobs( + ctx context.Context, registry *jobs.Registry, txn *kv.Txn, allMut nstree.Catalog, +) error { + byJobID := make(map[catpb.JobID][]catalog.MutableDescriptor) + _ = allMut.ForEachDescriptorEntry(func(d catalog.Descriptor) error { + if s := d.GetDeclarativeSchemaChangerState(); s != nil { + byJobID[s.JobID] = append(byJobID[s.JobID], d.(catalog.MutableDescriptor)) + } + return nil + }) + var records []*jobs.Record + for _, descs := range byJobID { + // TODO(ajwerner): Consider the need to trim elements or update + // descriptors in the face of restoring only some constituent + // descriptors of a larger change. One example where this needs + // to happen urgently is sequences. Others shouldn't be possible + // at this point. + newID := registry.MakeJobID() + var descriptorStates []*scpb.DescriptorState + for _, d := range descs { + ds := d.GetDeclarativeSchemaChangerState() + ds.JobID = newID + descriptorStates = append(descriptorStates, ds) + } + // TODO(ajwerner): Deal with rollback and revertibility. + currentState, err := scpb.MakeCurrentStateFromDescriptors( + descriptorStates, + ) + if err != nil { + return err + } + records = append(records, scexec.MakeDeclarativeSchemaChangeJobRecord( + newID, + currentState.Statements, + currentState.Authorization, + screl.GetDescIDs(currentState.TargetState).Ordered(), + )) + } + _, err := registry.CreateJobsWithTxn(ctx, txn, records) + return err +} diff --git a/pkg/sql/schemachanger/scexec/exec_mutation.go b/pkg/sql/schemachanger/scexec/exec_mutation.go index 7b48f3eede60..856d1e397764 100644 --- a/pkg/sql/schemachanger/scexec/exec_mutation.go +++ b/pkg/sql/schemachanger/scexec/exec_mutation.go @@ -411,11 +411,7 @@ func (mvs *mutationVisitorState) DeleteSchedule(scheduleID int64) { } func (mvs *mutationVisitorState) AddDrainedName(id descpb.ID, nameInfo descpb.NameInfo) { - if _, ok := mvs.drainedNames[id]; !ok { - mvs.drainedNames[id] = []descpb.NameInfo{nameInfo} - } else { - mvs.drainedNames[id] = append(mvs.drainedNames[id], nameInfo) - } + mvs.drainedNames[id] = append(mvs.drainedNames[id], nameInfo) } func (mvs *mutationVisitorState) AddNewGCJobForTable(table catalog.TableDescriptor) { @@ -447,11 +443,27 @@ func (mvs *mutationVisitorState) AddNewSchemaChangerJob( if mvs.schemaChangerJob != nil { return errors.AssertionFailedf("cannot create more than one new schema change job") } + mvs.schemaChangerJob = MakeDeclarativeSchemaChangeJobRecord(jobID, stmts, auth, descriptorIDs) + return nil +} + +// MakeDeclarativeSchemaChangeJobRecord is used to construct a declarative +// schema change job. The state of the schema change is stored in the descriptors +// themselves rather than the job state. During execution, the only state which +// is stored in the job itself pertains to backfill progress. +// +// Note that there's no way to construct a job in the reverting state. If the +// state of the schema change according to the descriptors is InRollback, then +// at the outset of the job, an error will be returned to move the job into +// the reverting state. +func MakeDeclarativeSchemaChangeJobRecord( + jobID jobspb.JobID, stmts []scpb.Statement, auth scpb.Authorization, descriptorIDs descpb.IDs, +) *jobs.Record { stmtStrs := make([]string, len(stmts)) for i, stmt := range stmts { stmtStrs[i] = stmt.Statement } - mvs.schemaChangerJob = &jobs.Record{ + rec := &jobs.Record{ JobID: jobID, Description: "schema change job", // TODO(ajwerner): use const Statements: stmtStrs, @@ -468,9 +480,9 @@ func (mvs *mutationVisitorState) AddNewSchemaChangerJob( // TODO(ajwerner): It'd be good to populate the RunningStatus at all times. RunningStatus: "", - NonCancelable: false, + NonCancelable: false, // TODO(ajwerner): Set this appropriately } - return nil + return rec } // createGCJobRecord creates the job record for a GC job, setting some diff --git a/pkg/sql/schemachanger/scpb/scpb.proto b/pkg/sql/schemachanger/scpb/scpb.proto index 0ba8a680ab87..0b980696f7a9 100644 --- a/pkg/sql/schemachanger/scpb/scpb.proto +++ b/pkg/sql/schemachanger/scpb/scpb.proto @@ -88,6 +88,29 @@ message DescriptorState { (gogoproto.customname) = "JobID", (gogoproto.casttype) = "github.com/cockroachdb/cockroach/pkg/sql/catalog/catpb.JobID"]; + // Revertible captures whether the job is currently revertible. + // This is important to facilitate constructing the job in the appropriate + // way upon restore. + bool revertible = 7; + + // InRollback captures whether the job is currently rolling back. + // This is important to ensure that the job can be moved to the proper + // failed state upon restore. + // + // Note, if this value is true, the targets have had their directions + // flipped already. + // + // TODO(ajwerner): Should we capture the error with which the job + // failed for a higher-fidelity restore? + // + // TODO(ajwerner): Do we need to track this state? In principle the + // answer seems like yes, and the reason is that if we've rolled back + // and have had the direction of the targets flipped, then we won't + // be able to know that and we might try to revert the revert, which + // would be confusing. + // + bool in_rollback = 8; + // Targets is the set of targets in the schema change belonging to this // descriptor. repeated Target targets = 1 [(gogoproto.nullable) = false]; diff --git a/pkg/sql/schemachanger/scpb/state.go b/pkg/sql/schemachanger/scpb/state.go index 095f65452c60..343869ff8023 100644 --- a/pkg/sql/schemachanger/scpb/state.go +++ b/pkg/sql/schemachanger/scpb/state.go @@ -11,6 +11,8 @@ package scpb import ( + "sort" + "github.com/cockroachdb/cockroach/pkg/util/protoutil" "github.com/cockroachdb/errors" ) @@ -20,6 +22,15 @@ import ( type CurrentState struct { TargetState Current []Status + + // InRollback captures whether the job is currently rolling back. + // This is important to ensure that the job can be moved to the proper + // failed state upon restore. + // + // Note, if this value is true, the targets have had their directions + // flipped already. + // + InRollback bool } // DeepCopy returns a deep copy of the receiver. @@ -30,6 +41,25 @@ func (s CurrentState) DeepCopy() CurrentState { } } +// Rollback idempotently marks the current state as InRollback. If the +// CurrentState was not previously marked as InRollback, it reverses the +// directions of all the targets. +func (s *CurrentState) Rollback() { + if s.InRollback { + return + } + for i := range s.Targets { + t := &s.Targets[i] + switch t.TargetStatus { + case Status_PUBLIC: + t.TargetStatus = Status_ABSENT + case Status_ABSENT: + t.TargetStatus = Status_PUBLIC + } + } + s.InRollback = true +} + // NumStatus is the number of values which Status may take on. var NumStatus = len(Status_name) @@ -78,3 +108,81 @@ func (m *DescriptorState) Clone() *DescriptorState { } return protoutil.Clone(m).(*DescriptorState) } + +// MakeCurrentStateFromDescriptors constructs a CurrentState object from a +// slice of DescriptorState object from which the current state has been +// decomposed. +func MakeCurrentStateFromDescriptors(descriptorStates []*DescriptorState) (CurrentState, error) { + + var s CurrentState + var targetRanks []uint32 + var rollback bool + stmts := make(map[uint32]Statement) + for i, cs := range descriptorStates { + if i == 0 { + rollback = cs.InRollback + } else if rollback != cs.InRollback { + return CurrentState{}, errors.AssertionFailedf( + "job %d: conflicting rollback statuses between descriptors", + cs.JobID, + ) + } + s.Current = append(s.Current, cs.CurrentStatuses...) + s.Targets = append(s.Targets, cs.Targets...) + targetRanks = append(targetRanks, cs.TargetRanks...) + for _, stmt := range cs.RelevantStatements { + if existing, ok := stmts[stmt.StatementRank]; ok { + if existing.Statement != stmt.Statement.Statement { + return CurrentState{}, errors.AssertionFailedf( + "job %d: statement %q does not match %q for rank %d", + cs.JobID, + existing.Statement, + stmt.Statement, + stmt.StatementRank, + ) + } + } + stmts[stmt.StatementRank] = stmt.Statement + } + s.Authorization = cs.Authorization + } + sort.Sort(&stateAndRanks{CurrentState: &s, ranks: targetRanks}) + var sr stmtsAndRanks + for rank, stmt := range stmts { + sr.stmts = append(sr.stmts, stmt) + sr.ranks = append(sr.ranks, rank) + } + sort.Sort(&sr) + s.Statements = sr.stmts + s.InRollback = rollback + return s, nil +} + +type stateAndRanks struct { + *CurrentState + ranks []uint32 +} + +var _ sort.Interface = (*stateAndRanks)(nil) + +func (s *stateAndRanks) Len() int { return len(s.Targets) } +func (s *stateAndRanks) Less(i, j int) bool { return s.ranks[i] < s.ranks[j] } +func (s *stateAndRanks) Swap(i, j int) { + s.ranks[i], s.ranks[j] = s.ranks[j], s.ranks[i] + s.Targets[i], s.Targets[j] = s.Targets[j], s.Targets[i] + s.Current[i], s.Current[j] = s.Current[j], s.Current[i] +} + +type stmtsAndRanks struct { + stmts []Statement + ranks []uint32 +} + +func (s *stmtsAndRanks) Len() int { return len(s.stmts) } +func (s *stmtsAndRanks) Less(i, j int) bool { return s.ranks[i] < s.ranks[j] } +func (s stmtsAndRanks) Swap(i, j int) { + s.ranks[i], s.ranks[j] = s.ranks[j], s.ranks[i] + s.stmts[i], s.stmts[j] = s.stmts[j], s.stmts[i] +} + +var _ sort.Interface = (*stmtsAndRanks)(nil) diff --git a/pkg/sql/schemachanger/scplan/internal/scstage/build.go b/pkg/sql/schemachanger/scplan/internal/scstage/build.go index 376062013e2f..5e75863f0103 100644 --- a/pkg/sql/schemachanger/scplan/internal/scstage/build.go +++ b/pkg/sql/schemachanger/scplan/internal/scstage/build.go @@ -29,6 +29,7 @@ func BuildStages( init scpb.CurrentState, phase scop.Phase, g *scgraph.Graph, scJobIDSupplier func() jobspb.JobID, ) []Stage { c := buildContext{ + rollback: init.InRollback, g: g, scJobIDSupplier: scJobIDSupplier, isRevertibilityIgnored: true, @@ -50,6 +51,7 @@ func BuildStages( // buildContext contains the global constants for building the stages. // Only the BuildStages function mutates it, it's read-only everywhere else. type buildContext struct { + rollback bool g *scgraph.Graph scJobIDSupplier func() jobspb.JobID isRevertibilityIgnored bool @@ -472,7 +474,9 @@ func (bc buildContext) nodes(current []scpb.Status) []*screl.Node { } func (bc buildContext) setJobStateOnDescriptorOps(initialize bool, after []scpb.Status) []scop.Op { - descIDs, states := makeDescriptorStates(bc.scJobIDSupplier(), bc.targetState, after) + descIDs, states := makeDescriptorStates( + bc.scJobIDSupplier(), bc.rollback, bc.targetState, after, + ) ops := make([]scop.Op, 0, descIDs.Len()) descIDs.ForEach(func(descID descpb.ID) { ops = append(ops, &scop.SetJobStateOnDescriptor{ @@ -485,7 +489,7 @@ func (bc buildContext) setJobStateOnDescriptorOps(initialize bool, after []scpb. } func makeDescriptorStates( - jobID jobspb.JobID, ts scpb.TargetState, statuses []scpb.Status, + jobID jobspb.JobID, inRollback bool, ts scpb.TargetState, statuses []scpb.Status, ) (catalog.DescriptorIDSet, map[descpb.ID]*scpb.DescriptorState) { descIDs := screl.GetDescIDs(ts) states := make(map[descpb.ID]*scpb.DescriptorState, descIDs.Len()) @@ -517,6 +521,7 @@ func makeDescriptorStates( state.Targets = append(state.Targets, t) state.TargetRanks = append(state.TargetRanks, uint32(i)) state.CurrentStatuses = append(state.CurrentStatuses, statuses[i]) + state.InRollback = inRollback } return descIDs, states } diff --git a/pkg/sql/schemachanger/scplan/plan.go b/pkg/sql/schemachanger/scplan/plan.go index c1ba5f4efa44..19da4cacbe68 100644 --- a/pkg/sql/schemachanger/scplan/plan.go +++ b/pkg/sql/schemachanger/scplan/plan.go @@ -25,6 +25,13 @@ import ( // Params holds the arguments for planning. type Params struct { + // InRollback is used to indicate whether we've already been reverted. + // Note that when in rollback, there is no turning back and all work is + // non-revertible. Theory dictates that this is fine because of how we + // had carefully crafted stages to only allow entering rollback while it + // remains safe to do so. + InRollback bool + // ExecutionPhase indicates the phase that the plan should be constructed for. ExecutionPhase scop.Phase diff --git a/pkg/sql/schemachanger/scrun/scrun.go b/pkg/sql/schemachanger/scrun/scrun.go index 67ff4a0c454a..795e7884b48d 100644 --- a/pkg/sql/schemachanger/scrun/scrun.go +++ b/pkg/sql/schemachanger/scrun/scrun.go @@ -12,7 +12,6 @@ package scrun import ( "context" - "sort" "github.com/cockroachdb/cockroach/pkg/jobs/jobspb" "github.com/cockroachdb/cockroach/pkg/settings/cluster" @@ -86,7 +85,7 @@ func RunSchemaChangesInJob( descriptorIDs []descpb.ID, rollback bool, ) error { - state, err := makeState(ctx, settings, deps, descriptorIDs, rollback) + state, err := makeState(ctx, deps, descriptorIDs, rollback) if err != nil { return errors.Wrapf(err, "failed to construct state for job %d", jobID) } @@ -123,59 +122,22 @@ func executeStage( } } - log.Infof(ctx, "executing stage %d/%d in phase %v, %d ops of type %s", stage.Ordinal, stage.StagesInPhase, stage.Phase, len(stage.Ops()), stage.Ops()[0].Type()) + log.Infof(ctx, "executing stage %d/%d in phase %v, %d ops of type %s (rollback=%v)", + stage.Ordinal, stage.StagesInPhase, stage.Phase, len(stage.Ops()), stage.Ops()[0].Type(), p.InRollback) if err := scexec.ExecuteStage(ctx, deps, stage.Ops()); err != nil { return errors.Wrapf(p.DecorateErrorWithPlanDetails(err), "error executing %s", stage.String()) } return nil } -type stateAndRanks struct { - *scpb.CurrentState - ranks []uint32 -} - -var _ sort.Interface = (*stateAndRanks)(nil) - -func (s *stateAndRanks) Len() int { return len(s.Targets) } -func (s *stateAndRanks) Less(i, j int) bool { return s.ranks[i] < s.ranks[j] } -func (s *stateAndRanks) Swap(i, j int) { - s.ranks[i], s.ranks[j] = s.ranks[j], s.ranks[i] - s.Targets[i], s.Targets[j] = s.Targets[j], s.Targets[i] - s.Current[i], s.Current[j] = s.Current[j], s.Current[i] -} - -type stmtsAndRanks struct { - stmts []scpb.Statement - ranks []uint32 -} - -func (s *stmtsAndRanks) Len() int { return len(s.stmts) } -func (s *stmtsAndRanks) Less(i, j int) bool { return s.ranks[i] < s.ranks[j] } -func (s stmtsAndRanks) Swap(i, j int) { - s.ranks[i], s.ranks[j] = s.ranks[j], s.ranks[i] - s.stmts[i], s.stmts[j] = s.stmts[j], s.stmts[i] -} - -var _ sort.Interface = (*stmtsAndRanks)(nil) - func makeState( - ctx context.Context, - sv *cluster.Settings, - deps JobRunDependencies, - descriptorIDs []descpb.ID, - rollback bool, + ctx context.Context, deps JobRunDependencies, descriptorIDs []descpb.ID, rollback bool, ) (scpb.CurrentState, error) { - var s scpb.CurrentState - var targetRanks []uint32 - var stmts map[uint32]scpb.Statement - if err := deps.WithTxnInJob(ctx, func(ctx context.Context, txnDeps scexec.Dependencies) error { + var descriptorStates []*scpb.DescriptorState + if err := deps.WithTxnInJob(ctx, func(ctx context.Context, txnDeps scexec.Dependencies) error { + descriptorStates = nil // Reset for restarts. - s = scpb.CurrentState{} - targetRanks = nil - stmts = make(map[uint32]scpb.Statement) - descs, err := txnDeps.Catalog().MustReadImmutableDescriptors(ctx, descriptorIDs...) if err != nil { // TODO(ajwerner): It seems possible that a descriptor could be deleted @@ -193,50 +155,23 @@ func makeState( "descriptor %d does not contain schema changer state", desc.GetID(), ) } - s.Current = append(s.Current, cs.CurrentStatuses...) - s.Targets = append(s.Targets, cs.Targets...) - targetRanks = append(targetRanks, cs.TargetRanks...) - for _, stmt := range cs.RelevantStatements { - if existing, ok := stmts[stmt.StatementRank]; ok { - if existing.Statement != stmt.Statement.Statement { - return errors.AssertionFailedf( - "job %d: statement %q does not match %q for rank %d", - cs.JobID, - existing.Statement, - stmt.Statement, - stmt.StatementRank, - ) - } - } - stmts[stmt.StatementRank] = stmt.Statement - } - s.Authorization = cs.Authorization + descriptorStates = append(descriptorStates, cs) } return nil }); err != nil { return scpb.CurrentState{}, err } - sort.Sort(&stateAndRanks{ - CurrentState: &s, - ranks: targetRanks, - }) - var sr stmtsAndRanks - for rank, stmt := range stmts { - sr.stmts = append(sr.stmts, stmt) - sr.ranks = append(sr.ranks, rank) + state, err := scpb.MakeCurrentStateFromDescriptors(descriptorStates) + if err != nil { + return scpb.CurrentState{}, err } - sort.Sort(&sr) - s.Statements = sr.stmts - if rollback { - for i := range s.Targets { - t := &s.Targets[i] - switch t.TargetStatus { - case scpb.Status_PUBLIC: - t.TargetStatus = scpb.Status_ABSENT - case scpb.Status_ABSENT: - t.TargetStatus = scpb.Status_PUBLIC - } - } + if !rollback && state.InRollback { + return scpb.CurrentState{}, errors.Errorf( + "job in running state but schema change in rollback, " + + "returning an error to restart in the reverting state") + } + if rollback && !state.InRollback { + state.Rollback() } - return s, nil + return state, nil } diff --git a/pkg/sql/schemachanger/sctest/BUILD.bazel b/pkg/sql/schemachanger/sctest/BUILD.bazel index 797dba698d59..c2b2535657a4 100644 --- a/pkg/sql/schemachanger/sctest/BUILD.bazel +++ b/pkg/sql/schemachanger/sctest/BUILD.bazel @@ -9,9 +9,11 @@ go_library( importpath = "github.com/cockroachdb/cockroach/pkg/sql/schemachanger/sctest", visibility = ["//visibility:public"], deps = [ + "//pkg/base", "//pkg/jobs", "//pkg/jobs/jobspb", "//pkg/keys", + "//pkg/sql/catalog", "//pkg/sql/parser", "//pkg/sql/schemachanger/scbuild", "//pkg/sql/schemachanger/scdeps/sctestdeps", @@ -19,14 +21,19 @@ go_library( "//pkg/sql/schemachanger/scop", "//pkg/sql/schemachanger/scpb", "//pkg/sql/schemachanger/scplan", + "//pkg/sql/schemachanger/screl", "//pkg/sql/schemachanger/scrun", "//pkg/sql/sessiondata", "//pkg/sql/sessiondatapb", "//pkg/testutils", + "//pkg/testutils/serverutils", + "//pkg/testutils/skip", "//pkg/testutils/sqlutils", "@com_github_cockroachdb_cockroach_go_v2//crdb", "@com_github_cockroachdb_datadriven//:datadriven", - "@com_github_pkg_errors//:errors", + "@com_github_cockroachdb_errors//:errors", + "@com_github_lib_pq//:pq", "@com_github_stretchr_testify//require", + "@org_golang_x_sync//errgroup", ], ) diff --git a/pkg/sql/schemachanger/sctest/cumulative.go b/pkg/sql/schemachanger/sctest/cumulative.go index a88a829db0c7..732ae97dcdee 100644 --- a/pkg/sql/schemachanger/sctest/cumulative.go +++ b/pkg/sql/schemachanger/sctest/cumulative.go @@ -23,15 +23,20 @@ import ( "github.com/cockroachdb/cockroach-go/v2/crdb" "github.com/cockroachdb/cockroach/pkg/jobs" + "github.com/cockroachdb/cockroach/pkg/sql/catalog" "github.com/cockroachdb/cockroach/pkg/sql/parser" "github.com/cockroachdb/cockroach/pkg/sql/schemachanger/scop" "github.com/cockroachdb/cockroach/pkg/sql/schemachanger/scplan" + "github.com/cockroachdb/cockroach/pkg/sql/schemachanger/screl" "github.com/cockroachdb/cockroach/pkg/sql/schemachanger/scrun" "github.com/cockroachdb/cockroach/pkg/testutils" + "github.com/cockroachdb/cockroach/pkg/testutils/skip" "github.com/cockroachdb/cockroach/pkg/testutils/sqlutils" "github.com/cockroachdb/datadriven" - "github.com/pkg/errors" + "github.com/cockroachdb/errors" + "github.com/lib/pq" "github.com/stretchr/testify/require" + "golang.org/x/sync/errgroup" ) // cumulativeTest is a foundational helper for building tests over the @@ -39,8 +44,10 @@ import ( // the passed function for each test directive in the file. The setup // statements passed to the function will be all statements from all // previous test and setup blocks combined. -func cumulativeTest(t *testing.T, tf func(t *testing.T, setup, stmts []parser.Statement)) { - datadriven.Walk(t, testutils.TestDataPath(t), func(t *testing.T, path string) { +func cumulativeTest( + t *testing.T, dir string, tf func(t *testing.T, setup, stmts []parser.Statement), +) { + datadriven.Walk(t, dir, func(t *testing.T, path string) { var setup []parser.Statement datadriven.RunTest(t, path, func(t *testing.T, d *datadriven.TestData) string { @@ -68,18 +75,23 @@ func cumulativeTest(t *testing.T, tf func(t *testing.T, setup, stmts []parser.St }) } +// TODO(ajwerner): For all the non-rollback variants, we'd really actually +// like them to run over each of the rollback stages too. + // Rollback tests that the schema changer job rolls back properly. // This data-driven test uses the same input as EndToEndSideEffects // but ignores the expected output. -func Rollback(t *testing.T, newCluster NewClusterFunc) { +func Rollback(t *testing.T, dir string, newCluster NewClusterFunc) { countRevertiblePostCommitStages := func( t *testing.T, setup, stmts []parser.Statement, ) (n int) { - processPlanInPhase(t, newCluster, setup, stmts, scop.PostCommitPhase, func( - p scplan.Plan, - ) { - n = len(p.StagesForCurrentPhase()) - }) + processPlanInPhase( + t, newCluster, setup, stmts, scop.PostCommitPhase, + func(p scplan.Plan) { n = len(p.StagesForCurrentPhase()) }, + func(db *gosql.DB) { + + }, + ) return n } var testRollbackCase func( @@ -102,18 +114,6 @@ func Rollback(t *testing.T, newCluster NewClusterFunc) { } } - const fetchDescriptorStateQuery = ` -SELECT - create_statement -FROM - ( - SELECT descriptor_id, create_statement FROM crdb_internal.create_schema_statements - UNION ALL SELECT descriptor_id, create_statement FROM crdb_internal.create_statements - UNION ALL SELECT descriptor_id, create_statement FROM crdb_internal.create_type_statements - ) -ORDER BY - create_statement;` - testRollbackCase = func( t *testing.T, setup, stmts []parser.Statement, ord int, ) { @@ -140,10 +140,6 @@ ORDER BY beforeFunc := func() { before = tdb.QueryStr(t, fetchDescriptorStateQuery) } - resetTxnState := func() { - // Reset the counter in case the transaction is retried for whatever reason. - atomic.StoreUint32(&numInjectedFailures, 0) - } onError := func(err error) error { // If the statement execution failed, then we expect to end up in the same // state as when we started. @@ -151,7 +147,7 @@ ORDER BY return err } err := executeSchemaChangeTxn( - context.Background(), t, setup, stmts, db, beforeFunc, resetTxnState, onError, + context.Background(), t, setup, stmts, db, beforeFunc, nil, onError, ) if atomic.LoadUint32(&numInjectedFailures) == 0 { require.NoError(t, err) @@ -159,13 +155,27 @@ ORDER BY require.Regexp(t, fmt.Sprintf("boom %d", ord), err) } } - cumulativeTest(t, testFunc) + cumulativeTest(t, dir, testFunc) } +const fetchDescriptorStateQuery = ` +SELECT + create_statement +FROM + ( + SELECT descriptor_id, create_statement FROM crdb_internal.create_schema_statements + UNION ALL SELECT descriptor_id, create_statement FROM crdb_internal.create_statements + UNION ALL SELECT descriptor_id, create_statement FROM crdb_internal.create_type_statements + ) +WHERE descriptor_id IN (SELECT id FROM system.namespace) +ORDER BY + create_statement;` + // Pause tests that the schema changer can handle being paused and resumed // correctly. This data-driven test uses the same input as EndToEndSideEffects // but ignores the expected output. -func Pause(t *testing.T, newCluster NewClusterFunc) { +func Pause(t *testing.T, dir string, newCluster NewClusterFunc) { + skip.UnderRace(t) var postCommit, nonRevertible int countStages := func( t *testing.T, setup, stmts []parser.Statement, @@ -175,7 +185,7 @@ func Pause(t *testing.T, newCluster NewClusterFunc) { ) { postCommit = len(p.StagesForCurrentPhase()) nonRevertible = len(p.Stages) - postCommit - }) + }, nil) } var testPauseCase func( t *testing.T, setup, stmts []parser.Statement, ord int, @@ -190,7 +200,7 @@ func Pause(t *testing.T, newCluster NewClusterFunc) { t.Logf("test case has %d revertible post-commit stages", n) for i := 1; i <= n; i++ { if !t.Run( - fmt.Sprintf("rollback stage %d of %d", i, n), + fmt.Sprintf("pause stage %d of %d", i, n), func(t *testing.T) { testPauseCase(t, setup, stmts, i) }, ) { return @@ -199,10 +209,10 @@ func Pause(t *testing.T, newCluster NewClusterFunc) { } testPauseCase = func(t *testing.T, setup, stmts []parser.Statement, ord int) { var numInjectedFailures uint32 - resetTxnState := func() { - // Reset the counter in case the transaction is retried for whatever reason. - atomic.StoreUint32(&numInjectedFailures, 0) - } + // TODO(ajwerner): It'd be nice to assert something about the number of + // remaining stages before the pause and then after. It's not totally + // trivial, as we don't checkpoint during non-mutation stages, so we'd + // need to look back and find the last mutation phase. db, cleanup := newCluster(t, &scrun.TestingKnobs{ BeforeStage: func(p scplan.Plan, stageIdx int) error { if atomic.LoadUint32(&numInjectedFailures) > 0 { @@ -222,8 +232,9 @@ func Pause(t *testing.T, newCluster NewClusterFunc) { onError := func(err error) error { // Check that it's a pause error, with a job. // Resume the job and wait for the job. - //require.NoError(t, err) - re := regexp.MustCompile(`job (\d+) was paused before it completed with reason: boom (\d+)`) + re := regexp.MustCompile( + `job (\d+) was paused before it completed with reason: boom (\d+)`, + ) match := re.FindStringSubmatch(err.Error()) require.NotNil(t, match) idx, err := strconv.Atoi(match[2]) @@ -239,10 +250,247 @@ func Pause(t *testing.T, newCluster NewClusterFunc) { return nil } require.NoError(t, executeSchemaChangeTxn( - context.Background(), t, setup, stmts, db, nil, resetTxnState, onError, + context.Background(), t, setup, stmts, db, nil, nil, onError, )) + require.Equal(t, uint32(1), atomic.LoadUint32(&numInjectedFailures)) } - cumulativeTest(t, testFunc) + cumulativeTest(t, dir, testFunc) +} + +// Backup tests that the schema changer can handle being backed up and +// restored correctly. This data-driven test uses the same input as +// EndToEndSideEffects but ignores the expected output. Note that the +// cluster constructor needs to provide a cluster with CCL BACKUP/RESTORE +// functionality enabled. +func Backup(t *testing.T, dir string, newCluster NewClusterFunc) { + skip.UnderRace(t) + skip.UnderStress(t) + var after [][]string + var dbName string + countStages := func( + t *testing.T, setup, stmts []parser.Statement, + ) (postCommit, nonRevertible int) { + var pl scplan.Plan + processPlanInPhase(t, newCluster, setup, stmts, scop.PostCommitPhase, + func(p scplan.Plan) { + pl = p + postCommit = len(p.StagesForCurrentPhase()) + nonRevertible = len(p.Stages) - postCommit + }, func(db *gosql.DB) { + tdb := sqlutils.MakeSQLRunner(db) + var ok bool + dbName, ok = maybeGetDatabaseForIDs(t, tdb, screl.GetDescIDs(pl.TargetState)) + if ok { + tdb.Exec(t, fmt.Sprintf("USE %q", dbName)) + } + after = tdb.QueryStr(t, fetchDescriptorStateQuery) + }) + return postCommit, nonRevertible + } + var testBackupRestoreCase func( + t *testing.T, setup, stmts []parser.Statement, ord int, + ) + testFunc := func(t *testing.T, setup, stmts []parser.Statement) { + postCommit, nonRevertible := countStages(t, setup, stmts) + if nonRevertible > 0 { + postCommit++ + } + n := postCommit + t.Logf("test case has %d revertible post-commit stages", n) + for i := 1; i <= n; i++ { + if !t.Run( + fmt.Sprintf("backup/restore stage %d of %d", i, n), + func(t *testing.T) { testBackupRestoreCase(t, setup, stmts, i) }, + ) { + return + } + } + } + type stage struct { + p scplan.Plan + stageIdx int + resume chan error + } + mkStage := func(p scplan.Plan, stageIdx int) stage { + return stage{p: p, stageIdx: stageIdx, resume: make(chan error)} + } + testBackupRestoreCase = func( + t *testing.T, setup, stmts []parser.Statement, ord int, + ) { + stageChan := make(chan stage) + ctx, cancel := context.WithCancel(context.Background()) + db, cleanup := newCluster(t, &scrun.TestingKnobs{ + BeforeStage: func(p scplan.Plan, stageIdx int) error { + if p.Stages[stageIdx].Phase < scop.PostCommitPhase { + return nil + } + if stageChan != nil { + s := mkStage(p, stageIdx) + select { + case stageChan <- s: + case <-ctx.Done(): + return ctx.Err() + } + select { + case err := <-s.resume: + return err + case <-ctx.Done(): + return ctx.Err() + } + } + return nil + }, + }) + + // Start with full database backup/restore. + defer cleanup() + defer cancel() + + conn, err := db.Conn(ctx) + require.NoError(t, err) + tdb := sqlutils.MakeSQLRunner(conn) + tdb.Exec(t, "create database backups") + var g errgroup.Group + var before [][]string + beforeFunc := func() { + tdb.Exec(t, fmt.Sprintf("USE %q", dbName)) + before = tdb.QueryStr(t, fetchDescriptorStateQuery) + } + g.Go(func() error { + return executeSchemaChangeTxn( + context.Background(), t, setup, stmts, db, beforeFunc, nil, nil, + ) + }) + type backup struct { + name string + isRollback bool + url string + s stage + } + var backups []backup + var done bool + var rollbackStage int + for i := 0; !done; i++ { + // We want to let the stages up to ord continue unscathed. Then, we'll + // start taking backups at ord. If ord corresponds to a revertible + // stage, we'll inject an error, forcing the schema change to revert. + // At each subsequent stage, we also take a backup. At the very end, + // we'll have one backup where things should succeed and N backups + // where we're reverting. In each case, we want to have the end state + // of the restored set of descriptors match what we have in the original + // cluster. + // + // Lastly, we'll hit an ord corresponding to the first non-revertible + // stage. At this point, we'll take a backup for each non-revertible + // stage and confirm that restoring them and letting the jobs run + // leaves the database in the right state. + s := <-stageChan + shouldFail := ord == i && + s.p.Stages[s.stageIdx].Phase != scop.PostCommitNonRevertiblePhase && + !s.p.InRollback + done = len(s.p.Stages) == s.stageIdx+1 && !shouldFail + t.Logf("stage %d/%d in %v (rollback=%v) %d %q %v", + s.stageIdx+1, len(s.p.Stages), s.p.Stages[s.stageIdx].Phase, s.p.InRollback, ord, dbName, done) + + // If the database has been dropped, there is nothing for + // us to do here. + var exists bool + tdb.QueryRow(t, + `SELECT count(*) > 0 FROM system.namespace WHERE "parentID" = 0 AND name = $1`, + dbName).Scan(&exists) + if !exists || (i < ord && !done) { + close(s.resume) + continue + } + + // This test assumes that all the descriptors being modified in the + // transaction are in the same database. + // + // TODO(ajwerner): Deal with trying to restore just some of the tables. + backupURL := fmt.Sprintf("userfile://backups.public.userfiles_$user/data%d", i) + tdb.Exec(t, fmt.Sprintf( + "BACKUP DATABASE %s INTO '%s'", dbName, backupURL)) + backups = append(backups, backup{ + name: dbName, + isRollback: rollbackStage > 0, + url: backupURL, + s: s, + }) + + if s.p.InRollback { + rollbackStage++ + } + if shouldFail { + s.resume <- errors.Newf("boom %d", i) + } else { + close(s.resume) + } + } + if err := g.Wait(); rollbackStage > 0 { + require.Regexp(t, fmt.Sprintf("boom %d", ord), err) + } else { + require.NoError(t, err) + } + stageChan = nil // allow the restored jobs to proceed + t.Logf("finished") + + for i, b := range backups { + t.Run("", func(t *testing.T) { + t.Logf("testing backup %d %v", i, b.isRollback) + tdb.Exec(t, fmt.Sprintf("DROP DATABASE IF EXISTS %q CASCADE", dbName)) + tdb.Exec(t, "SET experimental_use_new_schema_changer = 'off'") + tdb.Exec(t, fmt.Sprintf("RESTORE DATABASE %s FROM LATEST IN '%s'", dbName, b.url)) + tdb.Exec(t, fmt.Sprintf("USE %q", dbName)) + waitForSchemaChangesToFinish(t, tdb) + afterRestore := tdb.QueryStr(t, fetchDescriptorStateQuery) + if b.isRollback { + require.Equal(t, before, afterRestore) + } else { + require.Equal(t, after, afterRestore) + } + // Hack to deal with corrupt userfiles tables due to #76764. + const validateQuery = ` +SELECT * FROM crdb_internal.invalid_objects WHERE database_name != 'backups' +` + tdb.CheckQueryResults(t, validateQuery, [][]string{}) + tdb.Exec(t, fmt.Sprintf("DROP DATABASE %q CASCADE", dbName)) + tdb.Exec(t, "USE backups") + tdb.CheckQueryResults(t, validateQuery, [][]string{}) + }) + } + } + cumulativeTest(t, dir, testFunc) +} + +func maybeGetDatabaseForIDs( + t *testing.T, tdb *sqlutils.SQLRunner, ids catalog.DescriptorIDSet, +) (dbName string, exists bool) { + err := tdb.DB.QueryRowContext(context.Background(), ` +SELECT name + FROM system.namespace + WHERE id + IN ( + SELECT DISTINCT + COALESCE( + d->'database'->>'id', + d->'schema'->>'parentId', + d->'type'->>'parentId', + d->'table'->>'parentId' + )::INT8 + FROM ( + SELECT crdb_internal.pb_to_json('desc', descriptor) AS d + FROM system.descriptor + WHERE id IN (SELECT * FROM ROWS FROM (unnest($1::INT8[]))) + ) + ) +`, pq.Array(ids.Ordered())). + Scan(&dbName) + if errors.Is(err, gosql.ErrNoRows) { + return "", false + } + + require.NoError(t, err) + return dbName, true } // processPlanInPhase will call processFunc with the plan as of the first @@ -253,6 +501,7 @@ func processPlanInPhase( setup, stmt []parser.Statement, phaseToProcess scop.Phase, processFunc func(p scplan.Plan), + after func(db *gosql.DB), ) { var processOnce sync.Once db, cleanup := newCluster(t, &scrun.TestingKnobs{ @@ -267,6 +516,9 @@ func processPlanInPhase( require.NoError(t, executeSchemaChangeTxn( context.Background(), t, setup, stmt, db, nil, nil, nil, )) + if after != nil { + after(db) + } } // executeSchemaChangeTxn spins up a test cluster, executes the setup @@ -291,7 +543,7 @@ func executeSchemaChangeTxn( for _, stmt := range setup { tdb.Exec(t, stmt.SQL) } - waitForSchemaChangesToComplete(t, tdb) + waitForSchemaChangesToSucceed(t, tdb) if before != nil { before() } @@ -299,11 +551,34 @@ func executeSchemaChangeTxn( // Execute the tested statements with the declarative schema changer and fail // the test if it all takes too long. This prevents the test suite from // hanging when a regression is introduced. - tdb.Exec(t, "SET experimental_use_new_schema_changer = 'unsafe_always'") { c := make(chan error, 1) go func() { - c <- crdb.ExecuteTx(ctx, db, nil, func(tx *gosql.Tx) error { + conn, err := db.Conn(ctx) + if err != nil { + c <- err + return + } + defer func() { _ = conn.Close() }() + c <- crdb.Execute(func() (err error) { + _, err = conn.ExecContext( + ctx, "SET experimental_use_new_schema_changer = 'unsafe_always'", + ) + if err != nil { + return err + } + var tx *gosql.Tx + tx, err = conn.BeginTx(ctx, nil) + if err != nil { + return err + } + defer func() { + if err != nil { + err = errors.WithSecondaryError(err, tx.Rollback()) + } else { + err = tx.Commit() + } + }() if txnStartCallback != nil { txnStartCallback() } @@ -334,6 +609,6 @@ func executeSchemaChangeTxn( } // Ensure we're really done here. - waitForSchemaChangesToComplete(t, tdb) + waitForSchemaChangesToSucceed(t, tdb) return nil } diff --git a/pkg/sql/schemachanger/sctest/end_to_end.go b/pkg/sql/schemachanger/sctest/end_to_end.go index a68267ae72be..70acc72b4917 100644 --- a/pkg/sql/schemachanger/sctest/end_to_end.go +++ b/pkg/sql/schemachanger/sctest/end_to_end.go @@ -20,6 +20,8 @@ import ( "strings" "testing" + "github.com/cockroachdb/cockroach/pkg/base" + "github.com/cockroachdb/cockroach/pkg/jobs" "github.com/cockroachdb/cockroach/pkg/jobs/jobspb" "github.com/cockroachdb/cockroach/pkg/keys" "github.com/cockroachdb/cockroach/pkg/sql/parser" @@ -31,7 +33,7 @@ import ( "github.com/cockroachdb/cockroach/pkg/sql/schemachanger/scrun" "github.com/cockroachdb/cockroach/pkg/sql/sessiondata" "github.com/cockroachdb/cockroach/pkg/sql/sessiondatapb" - "github.com/cockroachdb/cockroach/pkg/testutils" + "github.com/cockroachdb/cockroach/pkg/testutils/serverutils" "github.com/cockroachdb/cockroach/pkg/testutils/sqlutils" "github.com/cockroachdb/datadriven" "github.com/stretchr/testify/require" @@ -43,15 +45,28 @@ type NewClusterFunc func( t *testing.T, knobs *scrun.TestingKnobs, ) (_ *gosql.DB, cleanup func()) +// SingleNodeCluster is a NewClusterFunc. +func SingleNodeCluster(t *testing.T, knobs *scrun.TestingKnobs) (*gosql.DB, func()) { + s, db, _ := serverutils.StartServer(t, base.TestServerArgs{ + Knobs: base.TestingKnobs{ + SQLDeclarativeSchemaChanger: knobs, + JobsTestingKnobs: jobs.NewTestingKnobsWithShortIntervals(), + }, + }) + return db, func() { + s.Stopper().Stop(context.Background()) + } +} + // EndToEndSideEffects is a data-driven test runner that executes DDL statements in the // declarative schema changer injected with test dependencies and compares the // accumulated side effects logs with expected results from the data-driven // test file. // // It shares a data-driven format with Rollback. -func EndToEndSideEffects(t *testing.T, newCluster NewClusterFunc) { +func EndToEndSideEffects(t *testing.T, dir string, newCluster NewClusterFunc) { ctx := context.Background() - datadriven.Walk(t, testutils.TestDataPath(t), func(t *testing.T, path string) { + datadriven.Walk(t, dir, func(t *testing.T, path string) { // Create a test cluster. db, cleanup := newCluster(t, nil /* knobs */) tdb := sqlutils.MakeSQLRunner(db) @@ -64,7 +79,7 @@ func EndToEndSideEffects(t *testing.T, newCluster NewClusterFunc) { for _, stmt := range stmts { tdb.Exec(t, stmt.SQL) } - waitForSchemaChangesToComplete(t, tdb) + waitForSchemaChangesToSucceed(t, tdb) } switch d.Cmd { @@ -190,12 +205,25 @@ func prettyNamespaceDump(t *testing.T, tdb *sqlutils.SQLRunner) string { return strings.Join(lines, "\n") } -func waitForSchemaChangesToComplete(t *testing.T, tdb *sqlutils.SQLRunner) { +func waitForSchemaChangesToSucceed(t *testing.T, tdb *sqlutils.SQLRunner) { + tdb.CheckQueryResultsRetry( + t, schemaChangeWaitQuery(`('succeeded')`), [][]string{}, + ) +} + +func waitForSchemaChangesToFinish(t *testing.T, tdb *sqlutils.SQLRunner) { + tdb.CheckQueryResultsRetry( + t, schemaChangeWaitQuery(`('succeeded', 'failed')`), [][]string{}, + ) +} + +func schemaChangeWaitQuery(statusInString string) string { q := fmt.Sprintf( - `SELECT count(*) FROM [SHOW JOBS] WHERE job_type IN ('%s', '%s', '%s') AND status <> 'succeeded'`, + `SELECT status, job_type, description FROM [SHOW JOBS] WHERE job_type IN ('%s', '%s', '%s') AND status NOT IN %s`, jobspb.TypeSchemaChange, jobspb.TypeTypeSchemaChange, jobspb.TypeNewSchemaChange, + statusInString, ) - tdb.CheckQueryResultsRetry(t, q, [][]string{{"0"}}) + return q }