diff --git a/pkg/ccl/importccl/import_stmt.go b/pkg/ccl/importccl/import_stmt.go index b298314e6dad..19af30a81e03 100644 --- a/pkg/ccl/importccl/import_stmt.go +++ b/pkg/ccl/importccl/import_stmt.go @@ -2304,9 +2304,10 @@ func (r *importResumer) publishSchemas(ctx context.Context, execCfg *sql.Executo } // checkForUDTModification checks whether any of the types referenced by the -// table being imported into have been modified since they were read during -// import planning. If they have, it may be unsafe to continue with the import -// since we could be ingesting data that is no longer valid for the type. +// table being imported into have been modified incompatibly since they were +// read during import planning. If they have, it may be unsafe to continue +// with the import since we could be ingesting data that is no longer valid +// for the type. // // Egs: Renaming an enum value mid import could result in the import ingesting a // value that is no longer valid. @@ -2314,7 +2315,11 @@ func (r *importResumer) publishSchemas(ctx context.Context, execCfg *sql.Executo // TODO(SQL Schema): This method might be unnecessarily aggressive in failing // the import. The semantics of what concurrent type changes are/are not safe // during an IMPORT still need to be ironed out. Once they are, we can make this -// method more conservative in what it uses to deem a type change dangerous. +// method more conservative in what it uses to deem a type change dangerous. At +// the time of writing, changes to privileges and back-references are supported. +// Additions of new values could be supported but are not. Renaming of logical +// enum values or removal of enum values will need to forever remain +// incompatible. func (r *importResumer) checkForUDTModification( ctx context.Context, execCfg *sql.ExecutorConfig, ) error { @@ -2322,23 +2327,76 @@ func (r *importResumer) checkForUDTModification( if details.Types == nil { return nil } - return sql.DescsTxn(ctx, execCfg, func(ctx context.Context, txn *kv.Txn, - col *descs.Collection) error { + // typeDescsAreEquivalent returns true if a and b are the same types save + // for the version, modification time, privileges, or the set of referencing + // descriptors. + typeDescsAreEquivalent := func(a, b *descpb.TypeDescriptor) (bool, error) { + clearIgnoredFields := func(d *descpb.TypeDescriptor) *descpb.TypeDescriptor { + d = protoutil.Clone(d).(*descpb.TypeDescriptor) + d.ModificationTime = hlc.Timestamp{} + d.Privileges = nil + d.Version = 0 + d.ReferencingDescriptorIDs = nil + return d + } + aData, err := protoutil.Marshal(clearIgnoredFields(a)) + if err != nil { + return false, err + } + bData, err := protoutil.Marshal(clearIgnoredFields(b)) + if err != nil { + return false, err + } + return bytes.Equal(aData, bData), nil + } + // checkTypeIsEquivalent checks that the current version of the type as + // retrieved from the collection is equivalent to the previously saved + // type descriptor used by the import. + checkTypeIsEquivalent := func( + ctx context.Context, txn *kv.Txn, col *descs.Collection, + savedTypeDesc *descpb.TypeDescriptor, + ) error { + typeDesc, err := catalogkv.MustGetTypeDescByID( + ctx, txn, execCfg.Codec, savedTypeDesc.GetID(), + ) + if err != nil { + return errors.Wrap(err, "resolving type descriptor when checking version mismatch") + } + if typeDesc.GetModificationTime() == savedTypeDesc.GetModificationTime() { + return nil + } + equivalent, err := typeDescsAreEquivalent(typeDesc.TypeDesc(), savedTypeDesc) + if err != nil { + return errors.NewAssertionErrorWithWrappedErrf( + err, "failed to check for type descriptor equivalence for type %q (%d)", + typeDesc.GetName(), typeDesc.GetID()) + } + if equivalent { + return nil + } + return errors.WithHint( + errors.Newf( + "type descriptor %q (%d) has been modified, potentially incompatibly,"+ + " since import planning; aborting to avoid possible corruption", + typeDesc.GetName(), typeDesc.GetID(), + ), + "retrying the IMPORT operation may succeed if the operation concurrently"+ + " modifying the descriptor does not reoccur during the retry attempt", + ) + } + checkTypesAreEquivalent := func( + ctx context.Context, txn *kv.Txn, col *descs.Collection, + ) error { for _, savedTypeDesc := range details.Types { - typeDesc, err := catalogkv.MustGetTypeDescByID(ctx, txn, execCfg.Codec, - savedTypeDesc.Desc.GetID()) - if err != nil { - return errors.Wrap(err, "resolving type descriptor when checking version mismatch") - } - if typeDesc.GetModificationTime() != savedTypeDesc.Desc.GetModificationTime() { - return errors.Newf("type descriptor %d has a different modification time than what"+ - " was saved during import planning; unsafe to import since the type"+ - " has changed during the course of the import", - typeDesc.GetID()) + if err := checkTypeIsEquivalent( + ctx, txn, col, savedTypeDesc.Desc, + ); err != nil { + return err } } return nil - }) + } + return sql.DescsTxn(ctx, execCfg, checkTypesAreEquivalent) } // publishTables updates the status of imported tables from OFFLINE to PUBLIC. diff --git a/pkg/ccl/importccl/import_stmt_test.go b/pkg/ccl/importccl/import_stmt_test.go index c73b8360b9ba..3126585830ec 100644 --- a/pkg/ccl/importccl/import_stmt_test.go +++ b/pkg/ccl/importccl/import_stmt_test.go @@ -24,6 +24,7 @@ import ( "path/filepath" "regexp" "strings" + "sync/atomic" "testing" "time" @@ -6412,20 +6413,34 @@ func TestImportMultiRegion(t *testing.T) { defer log.Scope(t).Close(t) baseDir := filepath.Join("testdata") - _, sqlDB, cleanup := multiregionccltestutils.TestingCreateMultiRegionCluster( - t, 1 /* numServers */, base.TestingKnobs{}, multiregionccltestutils.WithBaseDirectory(baseDir), + tc, sqlDB, cleanup := multiregionccltestutils.TestingCreateMultiRegionCluster( + t, 2 /* numServers */, base.TestingKnobs{}, multiregionccltestutils.WithBaseDirectory(baseDir), ) defer cleanup() - _, err := sqlDB.Exec(`SET CLUSTER SETTING kv.bulk_ingest.batch_size = '10KB'`) - require.NoError(t, err) + // Set up a hook which we can set to run during the import. + // Importantly this happens before the final descriptors have been published. + var duringImportFunc atomic.Value + noopDuringImportFunc := func() error { return nil } + duringImportFunc.Store(noopDuringImportFunc) + for i := 0; i < tc.NumServers(); i++ { + tc.Server(i).JobRegistry().(*jobs.Registry). + TestingResumerCreationKnobs = map[jobspb.Type]func(jobs.Resumer) jobs.Resumer{ + jobspb.TypeImport: func(resumer jobs.Resumer) jobs.Resumer { + resumer.(*importResumer).testingKnobs.afterImport = func(summary backupccl.RowCount) error { + return duringImportFunc.Load().(func() error)() + } + return resumer + }, + } + } - // Create the databases - _, err = sqlDB.Exec(`CREATE DATABASE foo`) - require.NoError(t, err) + tdb := sqlutils.MakeSQLRunner(sqlDB) + tdb.Exec(t, `SET CLUSTER SETTING kv.bulk_ingest.batch_size = '10KB'`) - _, err = sqlDB.Exec(`CREATE DATABASE multi_region PRIMARY REGION "us-east1"`) - require.NoError(t, err) + // Create the databases + tdb.Exec(t, `CREATE DATABASE foo`) + tdb.Exec(t, `CREATE DATABASE multi_region PRIMARY REGION "us-east1"`) simpleOcf := fmt.Sprintf("nodelocal://0/avro/%s", "simple.ocf") @@ -6469,30 +6484,23 @@ func TestImportMultiRegion(t *testing.T) { for _, tc := range viewsAndSequencesTestCases { t.Run(tc.desc, func(t *testing.T) { - _, err = sqlDB.Exec(`USE multi_region`) - require.NoError(t, err) - defer func() { - _, err := sqlDB.Exec(` + tdb.Exec(t, `USE multi_region`) + defer tdb.Exec(t, ` DROP TABLE IF EXISTS tbl; DROP SEQUENCE IF EXISTS s; DROP SEQUENCE IF EXISTS table_auto_inc; DROP VIEW IF EXISTS v`, - ) - require.NoError(t, err) - }() - - _, err = sqlDB.Exec(tc.importSQL) - require.NoError(t, err) - rows, err := sqlDB.Query("SELECT table_name, locality FROM [SHOW TABLES] ORDER BY table_name") - require.NoError(t, err) + ) + tdb.Exec(t, tc.importSQL) + rows := tdb.Query(t, "SELECT table_name, locality FROM [SHOW TABLES] ORDER BY table_name") results := make(map[string]string) for rows.Next() { - require.NoError(t, rows.Err()) var tableName, locality string require.NoError(t, rows.Scan(&tableName, &locality)) results[tableName] = locality } + require.NoError(t, rows.Err()) require.Equal(t, tc.expected, results) }) } @@ -6507,6 +6515,7 @@ DROP VIEW IF EXISTS v`, args []interface{} errString string data string + during string }{ { name: "import-create-using-multi-region-to-non-multi-region-database", @@ -6548,6 +6557,45 @@ DROP VIEW IF EXISTS v`, args: []interface{}{srv.URL}, data: "1,\"foo\",NULL,us-east1\n", }, + { + name: "import-into-multi-region-regional-by-row-to-multi-region-database-concurrent-table-add", + db: "multi_region", + table: "mr_regional_by_row", + create: "CREATE TABLE mr_regional_by_row (i INT8 PRIMARY KEY, s text, b bytea) LOCALITY REGIONAL BY ROW", + during: "CREATE TABLE mr_regional_by_row2 (i INT8 PRIMARY KEY) LOCALITY REGIONAL BY ROW", + sql: "IMPORT INTO mr_regional_by_row (i, s, b, crdb_region) CSV DATA ($1)", + args: []interface{}{srv.URL}, + data: "1,\"foo\",NULL,us-east1\n", + }, + { + name: "import-into-multi-region-regional-by-row-to-multi-region-database-concurrent-add-region", + db: "multi_region", + table: "mr_regional_by_row", + create: "CREATE TABLE mr_regional_by_row (i INT8 PRIMARY KEY, s text, b bytea) LOCALITY REGIONAL BY ROW", + sql: "IMPORT INTO mr_regional_by_row (i, s, b, crdb_region) CSV DATA ($1)", + during: `ALTER DATABASE multi_region ADD REGION "us-east2"`, + errString: `type descriptor "crdb_internal_region" \(54\) has been ` + + `modified, potentially incompatibly, since import planning; ` + + `aborting to avoid possible corruption`, + args: []interface{}{srv.URL}, + data: "1,\"foo\",NULL,us-east1\n", + }, + { + name: "import-into-multi-region-regional-by-row-with-enum-which-has-been-modified", + db: "multi_region", + table: "mr_regional_by_row", + create: ` +CREATE TYPE typ AS ENUM ('a'); +CREATE TABLE mr_regional_by_row (i INT8 PRIMARY KEY, s typ, b bytea) LOCALITY REGIONAL BY ROW; +`, + sql: "IMPORT INTO mr_regional_by_row (i, s, b, crdb_region) CSV DATA ($1)", + during: `ALTER TYPE typ ADD VALUE 'b'`, + errString: `type descriptor "typ" \(67\) has been ` + + `modified, potentially incompatibly, since import planning; ` + + `aborting to avoid possible corruption`, + args: []interface{}{srv.URL}, + data: "1,\"a\",NULL,us-east1\n", + }, { name: "import-into-multi-region-regional-by-row-to-multi-region-database-wrong-value", db: "multi_region", @@ -6570,33 +6618,34 @@ DROP VIEW IF EXISTS v`, for _, test := range tests { t.Run(test.name, func(t *testing.T) { - _, err = sqlDB.Exec(fmt.Sprintf(`SET DATABASE = %q`, test.db)) - require.NoError(t, err) - - _, err = sqlDB.Exec(fmt.Sprintf("DROP TABLE IF EXISTS %q CASCADE", test.table)) - require.NoError(t, err) + defer duringImportFunc.Store(noopDuringImportFunc) + if test.during != "" { + duringImportFunc.Store(func() error { + q := fmt.Sprintf(`SET DATABASE = %q; %s`, test.db, test.during) + _, err := sqlDB.Exec(q) + return err + }) + } + tdb.Exec(t, fmt.Sprintf(`SET DATABASE = %q`, test.db)) + tdb.Exec(t, fmt.Sprintf("DROP TABLE IF EXISTS %q CASCADE", test.table)) if test.data != "" { data = test.data } if test.create != "" { - _, err = sqlDB.Exec(test.create) - require.NoError(t, err) + tdb.Exec(t, test.create) } - _, err = sqlDB.ExecContext(context.Background(), test.sql, test.args...) + _, err := sqlDB.ExecContext(context.Background(), test.sql, test.args...) if test.errString != "" { - require.True(t, testutils.IsError(err, test.errString)) + require.Regexp(t, test.errString, err) } else { require.NoError(t, err) - res := sqlDB.QueryRow(fmt.Sprintf("SELECT count(*) FROM %q", test.table)) - require.NoError(t, res.Err()) - var numRows int - err = res.Scan(&numRows) - require.NoError(t, err) - + tdb.QueryRow( + t, fmt.Sprintf("SELECT count(*) FROM %q", test.table), + ).Scan(&numRows) if numRows == 0 { t.Error("expected some rows after import") } @@ -7240,6 +7289,18 @@ func TestUDTChangeDuringImport(t *testing.T) { "cannot drop type \"greeting\"", false, }, + { + "use-in-table", + "CREATE TABLE d.foo (i INT PRIMARY KEY, j d.greeting)", + "", + false, + }, + { + "grant", + "CREATE USER u; GRANT USAGE ON TYPE d.greeting TO u;", + "", + false, + }, } for _, test := range testCases {