diff --git a/pkg/ccl/backupccl/backup_test.go b/pkg/ccl/backupccl/backup_test.go index 7d30c7d9a7fc..8572f872e3ff 100644 --- a/pkg/ccl/backupccl/backup_test.go +++ b/pkg/ccl/backupccl/backup_test.go @@ -676,7 +676,7 @@ func TestBackupRestoreSystemJobs(t *testing.T) { fullDir := sanitizedFullDir + "moarSecretsHere" backupDatabaseID := sqlutils.QueryDatabaseID(t, conn, "data") - backupTableID := sqlutils.QueryTableID(t, conn, "data", "bank") + backupTableID := sqlutils.QueryTableID(t, conn, "data", "public", "bank") sqlDB.Exec(t, `CREATE DATABASE restoredb`) restoreDatabaseID := sqlutils.QueryDatabaseID(t, conn, "restoredb") @@ -775,7 +775,7 @@ func checkInProgressBackupRestore( sqlDB.Exec(t, `CREATE DATABASE restoredb`) - backupTableID := sqlutils.QueryTableID(t, conn, "data", "bank") + backupTableID := sqlutils.QueryTableID(t, conn, "data", "public", "bank") do := func(query string, check inProgressChecker) { jobDone := make(chan error) diff --git a/pkg/ccl/changefeedccl/changefeed_test.go b/pkg/ccl/changefeedccl/changefeed_test.go index 012529fa0c59..9ab7a69600da 100644 --- a/pkg/ccl/changefeedccl/changefeed_test.go +++ b/pkg/ccl/changefeedccl/changefeed_test.go @@ -804,7 +804,7 @@ func fetchDescVersionModificationTime( Key: tblKey, EndKey: tblKey.PrefixEnd(), } - dropColTblID := sqlutils.QueryTableID(t, db, `d`, tableName) + dropColTblID := sqlutils.QueryTableID(t, db, `d`, "public", tableName) req := &roachpb.ExportRequest{ RequestHeader: header, MVCCFilter: roachpb.MVCCFilter_All, diff --git a/pkg/ccl/importccl/load.go b/pkg/ccl/importccl/load.go index 2b60f5048e25..6a14eb478670 100644 --- a/pkg/ccl/importccl/load.go +++ b/pkg/ccl/importccl/load.go @@ -54,12 +54,22 @@ func getDescriptorFromDB( // Due to the namespace migration, the row may not exist in system.namespace // so a fallback to system.namespace_deprecated is required. // TODO(sqlexec): In 20.2, this logic can be removed. - for _, tableName := range []string{"system.namespace", "system.namespace_deprecated"} { - if err := db.QueryRow(fmt.Sprintf(`SELECT + for _, t := range []struct { + tableName string + extraClause string + }{ + {"system.namespace", `AND n."parentSchemaID" = 0`}, + {"system.namespace_deprecated", ""}, + } { + if err := db.QueryRow( + fmt.Sprintf(`SELECT d.descriptor FROM %s n INNER JOIN system.descriptor d ON n.id = d.id - WHERE n."parentID" = $1 - AND n.name = $2`, tableName), + WHERE n."parentID" = $1 %s + AND n.name = $2`, + t.tableName, + t.extraClause, + ), keys.RootNamespaceID, dbName, ).Scan(&dbDescBytes); err != nil { diff --git a/pkg/ccl/partitionccl/zone_test.go b/pkg/ccl/partitionccl/zone_test.go index da85a1d3455e..028eddc0b93b 100644 --- a/pkg/ccl/partitionccl/zone_test.go +++ b/pkg/ccl/partitionccl/zone_test.go @@ -53,7 +53,7 @@ func TestValidIndexPartitionSetShowZones(t *testing.T) { partialZoneOverride.GC = &zonepb.GCPolicy{TTLSeconds: 42} dbID := sqlutils.QueryDatabaseID(t, db, "d") - tableID := sqlutils.QueryTableID(t, db, "d", "t") + tableID := sqlutils.QueryTableID(t, db, "d", "public", "t") defaultRow := sqlutils.ZoneRow{ ID: keys.RootNamespaceID, diff --git a/pkg/sql/ambiguous_commit_test.go b/pkg/sql/ambiguous_commit_test.go index 484f9113a523..765f974dbfc1 100644 --- a/pkg/sql/ambiguous_commit_test.go +++ b/pkg/sql/ambiguous_commit_test.go @@ -155,7 +155,7 @@ func TestAmbiguousCommit(t *testing.T) { t.Fatal(err) } - tableID := sqlutils.QueryTableID(t, sqlDB, "test", "t") + tableID := sqlutils.QueryTableID(t, sqlDB, "test", "public", "t") tableStartKey.Store(keys.MakeTablePrefix(tableID)) // Wait for new table to split & replication. diff --git a/pkg/sql/privileged_accessor.go b/pkg/sql/privileged_accessor.go index 135f0840048b..4895a5312e07 100644 --- a/pkg/sql/privileged_accessor.go +++ b/pkg/sql/privileged_accessor.go @@ -22,14 +22,23 @@ import ( ) // LookupNamespaceID implements tree.PrivilegedAccessor. +// TODO(sqlexec): make this work for any arbitrary schema. +// This currently only works for public schemas and databases. func (p *planner) LookupNamespaceID( ctx context.Context, parentID int64, name string, ) (tree.DInt, bool, error) { var r tree.Datums - for _, tableName := range []string{"system.namespace", "system.namespace_deprecated"} { + for _, t := range []struct { + tableName string + extraClause string + }{ + {"system.namespace", `AND "parentSchemaID" IN (0, 29)`}, + {"system.namespace_deprecated", ""}, + } { query := fmt.Sprintf( - `SELECT id FROM %s WHERE "parentID" = $1 AND name = $2`, - tableName, + `SELECT id FROM %s WHERE "parentID" = $1 AND name = $2 %s`, + t.tableName, + t.extraClause, ) var err error r, err = p.ExtendedEvalContext().ExecCfg.InternalExecutor.QueryRowEx( diff --git a/pkg/sql/zone_test.go b/pkg/sql/zone_test.go index 8d561989b96b..eb5795a2eff4 100644 --- a/pkg/sql/zone_test.go +++ b/pkg/sql/zone_test.go @@ -66,7 +66,7 @@ func TestValidSetShowZones(t *testing.T) { } dbID := sqlutils.QueryDatabaseID(t, db, "d") - tableID := sqlutils.QueryTableID(t, db, "d", "t") + tableID := sqlutils.QueryTableID(t, db, "d", "public", "t") dbRow := sqlutils.ZoneRow{ ID: dbID, @@ -246,7 +246,7 @@ func TestZoneInheritField(t *testing.T) { } newReplicationFactor := 10 - tableID := sqlutils.QueryTableID(t, db, "d", "t") + tableID := sqlutils.QueryTableID(t, db, "d", "public", "t") newDefCfg := s.(*server.TestServer).Cfg.DefaultZoneConfig newDefCfg.NumReplicas = proto.Int32(int32(newReplicationFactor)) diff --git a/pkg/testutils/sqlutils/table_id.go b/pkg/testutils/sqlutils/table_id.go index 8009d9f5de9a..50bc4cf7d0fd 100644 --- a/pkg/testutils/sqlutils/table_id.go +++ b/pkg/testutils/sqlutils/table_id.go @@ -18,7 +18,10 @@ import ( // QueryDatabaseID returns the database ID of the specified database using the // system.namespace table. func QueryDatabaseID(t testing.TB, sqlDB DBHandle, dbName string) uint32 { - dbIDQuery := `SELECT id FROM system.namespace WHERE name = $1 AND "parentID" = 0` + dbIDQuery := ` + SELECT id FROM system.namespace + WHERE name = $1 AND "parentSchemaID" = 0 AND "parentID" = 0 + ` var dbID uint32 result := sqlDB.QueryRowContext(context.Background(), dbIDQuery, dbName) if err := result.Scan(&dbID); err != nil { @@ -29,14 +32,22 @@ func QueryDatabaseID(t testing.TB, sqlDB DBHandle, dbName string) uint32 { // QueryTableID returns the table ID of the specified database.table // using the system.namespace table. -func QueryTableID(t testing.TB, sqlDB DBHandle, dbName, tableName string) uint32 { +func QueryTableID( + t testing.TB, sqlDB DBHandle, dbName, schemaName string, tableName string, +) uint32 { tableIDQuery := ` SELECT tables.id FROM system.namespace tables JOIN system.namespace dbs ON dbs.id = tables."parentID" - WHERE dbs.name = $1 AND tables.name = $2 + JOIN system.namespace schemas ON schemas.id = tables."parentSchemaID" + WHERE dbs.name = $1 AND schemas.name = $2 AND tables.name = $3 ` var tableID uint32 - result := sqlDB.QueryRowContext(context.Background(), tableIDQuery, dbName, tableName) + result := sqlDB.QueryRowContext( + context.Background(), + tableIDQuery, dbName, + schemaName, + tableName, + ) if err := result.Scan(&tableID); err != nil { t.Fatal(err) }