Skip to content

Commit

Permalink
Browse files Browse the repository at this point in the history
32402: sql: refactor context and txn out of resolver flags r=vivekmenezes a=vivekmenezes

Because flags should just be flags!

Release note: None

Co-authored-by: Vivek Menezes <[email protected]>
  • Loading branch information
craig[bot] and vivekmenezes committed Nov 15, 2018
2 parents 3b256be + 5fd6141 commit eb8345b
Show file tree
Hide file tree
Showing 18 changed files with 96 additions and 89 deletions.
6 changes: 2 additions & 4 deletions pkg/ccl/importccl/import_stmt.go
Original file line number Diff line number Diff line change
Expand Up @@ -318,14 +318,12 @@ func (r fkResolver) CurrentSearchPath() sessiondata.SearchPath {
}

// Implements the sql.SchemaResolver interface.
func (r fkResolver) CommonLookupFlags(ctx context.Context, required bool) sql.CommonLookupFlags {
func (r fkResolver) CommonLookupFlags(required bool) sql.CommonLookupFlags {
return sql.CommonLookupFlags{}
}

// Implements the sql.SchemaResolver interface.
func (r fkResolver) ObjectLookupFlags(
ctx context.Context, required bool, requireMutable bool,
) sql.ObjectLookupFlags {
func (r fkResolver) ObjectLookupFlags(required bool, requireMutable bool) sql.ObjectLookupFlags {
return sql.ObjectLookupFlags{}
}

Expand Down
4 changes: 2 additions & 2 deletions pkg/sql/authorization.go
Original file line number Diff line number Diff line change
Expand Up @@ -173,8 +173,8 @@ func (p *planner) MemberOfWithAdminOption(
ctx context.Context, member string,
) (map[string]bool, error) {
// Lookup table version.
objDesc, _, err := p.PhysicalSchemaAccessor().GetObjectDesc(&roleMembersTableName,
p.ObjectLookupFlags(ctx, true /*required*/, false /*requireMutable*/))
objDesc, _, err := p.PhysicalSchemaAccessor().GetObjectDesc(ctx, p.txn, &roleMembersTableName,
p.ObjectLookupFlags(true /*required*/, false /*requireMutable*/))
if err != nil {
return nil, err
}
Expand Down
6 changes: 2 additions & 4 deletions pkg/sql/backfill.go
Original file line number Diff line number Diff line change
Expand Up @@ -193,8 +193,7 @@ func (sc *SchemaChanger) runBackfill(
func (sc *SchemaChanger) getTableVersion(
ctx context.Context, txn *client.Txn, tc *TableCollection, version sqlbase.DescriptorVersion,
) (*sqlbase.TableDescriptor, error) {
flags := ObjectLookupFlags{CommonLookupFlags{txn: txn}, false /*requireMutable*/}
tableDesc, err := tc.getTableVersionByID(ctx, sc.tableID, flags)
tableDesc, err := tc.getTableVersionByID(ctx, txn, sc.tableID, ObjectLookupFlags{})
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -417,9 +416,8 @@ func (sc *SchemaChanger) distBackfill(
return err
}

flags := ObjectLookupFlags{CommonLookupFlags{txn: txn}, false /*requireMutable*/}
for k := range fkTables {
table, err := tc.getTableVersionByID(ctx, k, flags)
table, err := tc.getTableVersionByID(ctx, txn, k, ObjectLookupFlags{})
if err != nil {
return err
}
Expand Down
3 changes: 1 addition & 2 deletions pkg/sql/data_source.go
Original file line number Diff line number Diff line change
Expand Up @@ -199,10 +199,9 @@ func (p *planner) getTableScanByRef(
scanVisibility scanVisibility,
) (planDataSource, error) {
flags := ObjectLookupFlags{CommonLookupFlags: CommonLookupFlags{
txn: p.txn,
avoidCached: p.avoidCachedDescriptors,
}}
desc, err := p.Tables().getTableVersionByID(ctx, sqlbase.ID(tref.TableID), flags)
desc, err := p.Tables().getTableVersionByID(ctx, p.txn, sqlbase.ID(tref.TableID), flags)
if err != nil {
return planDataSource{}, errors.Wrapf(err, "%s", tree.ErrString(tref))
}
Expand Down
4 changes: 2 additions & 2 deletions pkg/sql/database.go
Original file line number Diff line number Diff line change
Expand Up @@ -207,8 +207,8 @@ func (dc *databaseCache) getDatabaseDesc(
if desc == nil {
if err := txnRunner(ctx, func(ctx context.Context, txn *client.Txn) error {
a := UncachedPhysicalAccessor{}
desc, err = a.GetDatabaseDesc(name,
DatabaseLookupFlags{ctx: ctx, txn: txn, required: required})
desc, err = a.GetDatabaseDesc(ctx, txn, name,
DatabaseLookupFlags{required: required})
return err
}); err != nil {
return nil, err
Expand Down
2 changes: 1 addition & 1 deletion pkg/sql/drop_database.go
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ func (p *planner) DropDatabase(ctx context.Context, n *tree.DropDatabase) (planN
return nil, err
}

tbNames, err := GetObjectNames(ctx, p, dbDesc, tree.PublicSchema, true /*explicitPrefix*/)
tbNames, err := GetObjectNames(ctx, p.txn, p, dbDesc, tree.PublicSchema, true /*explicitPrefix*/)
if err != nil {
return nil, err
}
Expand Down
15 changes: 11 additions & 4 deletions pkg/sql/logical_schema_accessors.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,9 @@
package sql

import (
"context"

"github.com/cockroachdb/cockroach/pkg/internal/client"
"github.com/cockroachdb/cockroach/pkg/sql/pgwire/pgerror"
"github.com/cockroachdb/cockroach/pkg/sql/sem/tree"
"github.com/cockroachdb/cockroach/pkg/sql/sqlbase"
Expand Down Expand Up @@ -45,7 +48,11 @@ func (l *LogicalSchemaAccessor) IsValidSchema(dbDesc *DatabaseDescriptor, scName

// GetObjectNames implements the DatabaseLister interface.
func (l *LogicalSchemaAccessor) GetObjectNames(
dbDesc *DatabaseDescriptor, scName string, flags DatabaseListFlags,
ctx context.Context,
txn *client.Txn,
dbDesc *DatabaseDescriptor,
scName string,
flags DatabaseListFlags,
) (TableNames, error) {
if entry, ok := l.vt.getVirtualSchemaEntry(scName); ok {
names := make(TableNames, len(entry.orderedDefNames))
Expand All @@ -60,12 +67,12 @@ func (l *LogicalSchemaAccessor) GetObjectNames(
}

// Fallthrough.
return l.SchemaAccessor.GetObjectNames(dbDesc, scName, flags)
return l.SchemaAccessor.GetObjectNames(ctx, txn, dbDesc, scName, flags)
}

// GetObjectDesc implements the ObjectAccessor interface.
func (l *LogicalSchemaAccessor) GetObjectDesc(
name *ObjectName, flags ObjectLookupFlags,
ctx context.Context, txn *client.Txn, name *ObjectName, flags ObjectLookupFlags,
) (ObjectDescriptor, *DatabaseDescriptor, error) {
if scEntry, ok := l.vt.getVirtualSchemaEntry(name.Schema()); ok {
tableName := name.Table()
Expand All @@ -87,5 +94,5 @@ func (l *LogicalSchemaAccessor) GetObjectDesc(
}

// Fallthrough.
return l.SchemaAccessor.GetObjectDesc(name, flags)
return l.SchemaAccessor.GetObjectDesc(ctx, txn, name, flags)
}
36 changes: 21 additions & 15 deletions pkg/sql/physical_schema_accessors.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,9 @@ package sql

import (
"bytes"
"context"

"github.com/cockroachdb/cockroach/pkg/internal/client"
"github.com/cockroachdb/cockroach/pkg/sql/sem/tree"
"github.com/cockroachdb/cockroach/pkg/sql/sqlbase"
"github.com/cockroachdb/cockroach/pkg/util/encoding"
Expand Down Expand Up @@ -47,10 +49,10 @@ var _ SchemaAccessor = UncachedPhysicalAccessor{}

// GetDatabaseDesc implements the SchemaAccessor interface.
func (a UncachedPhysicalAccessor) GetDatabaseDesc(
name string, flags DatabaseLookupFlags,
ctx context.Context, txn *client.Txn, name string, flags DatabaseLookupFlags,
) (desc *DatabaseDescriptor, err error) {
desc = &sqlbase.DatabaseDescriptor{}
found, err := getDescriptor(flags.ctx, flags.txn, databaseKey{name}, desc)
found, err := getDescriptor(ctx, txn, databaseKey{name}, desc)
if err != nil {
return nil, err
}
Expand All @@ -74,7 +76,11 @@ func (a UncachedPhysicalAccessor) IsValidSchema(dbDesc *DatabaseDescriptor, scNa

// GetObjectNames implements the SchemaAccessor interface.
func (a UncachedPhysicalAccessor) GetObjectNames(
dbDesc *DatabaseDescriptor, scName string, flags DatabaseListFlags,
ctx context.Context,
txn *client.Txn,
dbDesc *DatabaseDescriptor,
scName string,
flags DatabaseListFlags,
) (TableNames, error) {
if ok := a.IsValidSchema(dbDesc, scName); !ok {
if flags.required {
Expand All @@ -84,9 +90,9 @@ func (a UncachedPhysicalAccessor) GetObjectNames(
return nil, nil
}

log.Eventf(flags.ctx, "fetching list of objects for %q", dbDesc.Name)
log.Eventf(ctx, "fetching list of objects for %q", dbDesc.Name)
prefix := sqlbase.MakeNameMetadataKey(dbDesc.ID, "")
sr, err := flags.txn.Scan(flags.ctx, prefix, prefix.PrefixEnd(), 0)
sr, err := txn.Scan(ctx, prefix, prefix.PrefixEnd(), 0)
if err != nil {
return nil, err
}
Expand All @@ -108,7 +114,7 @@ func (a UncachedPhysicalAccessor) GetObjectNames(

// GetObjectDesc implements the SchemaAccessor interface.
func (a UncachedPhysicalAccessor) GetObjectDesc(
name *ObjectName, flags ObjectLookupFlags,
ctx context.Context, txn *client.Txn, name *ObjectName, flags ObjectLookupFlags,
) (ObjectDescriptor, *DatabaseDescriptor, error) {
// At this point, only the public schema is recognized.
if name.Schema() != tree.PublicSchema {
Expand All @@ -119,7 +125,7 @@ func (a UncachedPhysicalAccessor) GetObjectDesc(
}

// Look up the database.
dbDesc, err := a.GetDatabaseDesc(name.Catalog(), flags.CommonLookupFlags)
dbDesc, err := a.GetDatabaseDesc(ctx, txn, name.Catalog(), flags.CommonLookupFlags)
if dbDesc == nil || err != nil {
// dbDesc can be nil if the object is not required and the
// database was not found.
Expand All @@ -128,7 +134,7 @@ func (a UncachedPhysicalAccessor) GetObjectDesc(

// Look up the table using the discovered database descriptor.
desc := &sqlbase.TableDescriptor{}
found, err := getDescriptor(flags.ctx, flags.txn,
found, err := getDescriptor(ctx, txn,
tableKey{parentID: dbDesc.ID, name: name.Table()}, desc)
if err != nil {
return nil, nil, err
Expand Down Expand Up @@ -166,7 +172,7 @@ var _ SchemaAccessor = &CachedPhysicalAccessor{}

// GetDatabaseDesc implements the SchemaAccessor interface.
func (a *CachedPhysicalAccessor) GetDatabaseDesc(
name string, flags DatabaseLookupFlags,
ctx context.Context, txn *client.Txn, name string, flags DatabaseLookupFlags,
) (desc *DatabaseDescriptor, err error) {
isSystemDB := name == sqlbase.SystemDB.Name
if !(flags.avoidCached || isSystemDB || testDisableTableLeases) {
Expand All @@ -178,7 +184,7 @@ func (a *CachedPhysicalAccessor) GetDatabaseDesc(
if dbID != 0 {
// Some database ID was found in the list of uncommitted DB changes.
// Use that to get the descriptor.
desc, err := a.tc.databaseCache.getDatabaseDescByID(flags.ctx, flags.txn, dbID)
desc, err := a.tc.databaseCache.getDatabaseDescByID(ctx, txn, dbID)
if desc == nil && flags.required {
return nil, sqlbase.NewUndefinedDatabaseError(name)
}
Expand All @@ -187,27 +193,27 @@ func (a *CachedPhysicalAccessor) GetDatabaseDesc(

// The database was not known in the uncommitted list. Have the db
// cache look it up by name for us.
return a.tc.databaseCache.getDatabaseDesc(flags.ctx,
return a.tc.databaseCache.getDatabaseDesc(ctx,
a.tc.leaseMgr.execCfg.DB.Txn, name, flags.required)
}

// We avoided the cache. Go lower.
return a.SchemaAccessor.GetDatabaseDesc(name, flags)
return a.SchemaAccessor.GetDatabaseDesc(ctx, txn, name, flags)
}

// GetObjectDesc implements the SchemaAccessor interface.
func (a *CachedPhysicalAccessor) GetObjectDesc(
name *ObjectName, flags ObjectLookupFlags,
ctx context.Context, txn *client.Txn, name *ObjectName, flags ObjectLookupFlags,
) (ObjectDescriptor, *DatabaseDescriptor, error) {
if flags.requireMutable {
table, db, err := a.tc.getMutableTableDescriptor(flags.ctx, name, flags)
table, db, err := a.tc.getMutableTableDescriptor(ctx, txn, name, flags)
if table == nil {
// return nil interface.
return nil, db, err
}
return table, db, err
}
table, db, err := a.tc.getTableVersion(flags.ctx, name, flags)
table, db, err := a.tc.getTableVersion(ctx, txn, name, flags)
if table == nil {
// return nil interface.
return nil, db, err
Expand Down
5 changes: 2 additions & 3 deletions pkg/sql/planner.go
Original file line number Diff line number Diff line change
Expand Up @@ -392,9 +392,8 @@ func (p *planner) ResolveTableName(ctx context.Context, tn *tree.TableName) erro
func (p *planner) LookupTableByID(
ctx context.Context, tableID sqlbase.ID,
) (row.TableLookup, error) {
flags := ObjectLookupFlags{
CommonLookupFlags{txn: p.txn, avoidCached: p.avoidCachedDescriptors}, false /*requireMutable*/}
table, err := p.Tables().getTableVersionByID(ctx, tableID, flags)
flags := ObjectLookupFlags{CommonLookupFlags: CommonLookupFlags{avoidCached: p.avoidCachedDescriptors}}
table, err := p.Tables().getTableVersionByID(ctx, p.txn, tableID, flags)
if err != nil {
if err == errTableAdding {
return row.TableLookup{IsAdding: true}, nil
Expand Down
6 changes: 3 additions & 3 deletions pkg/sql/rename_database.go
Original file line number Diff line number Diff line change
Expand Up @@ -77,11 +77,11 @@ func (n *renameDatabaseNode) startExec(params runParams) error {
// name. Rather than trying to rewrite them with the changed DB name, we
// simply disallow such renames for now.
phyAccessor := p.PhysicalSchemaAccessor()
lookupFlags := p.CommonLookupFlags(ctx, true /*required*/)
lookupFlags := p.CommonLookupFlags(true /*required*/)
// DDL statements bypass the cache.
lookupFlags.avoidCached = true
tbNames, err := phyAccessor.GetObjectNames(
dbDesc, tree.PublicSchema, DatabaseListFlags{
ctx, p.txn, dbDesc, tree.PublicSchema, DatabaseListFlags{
CommonLookupFlags: lookupFlags,
explicitPrefix: true,
})
Expand All @@ -90,7 +90,7 @@ func (n *renameDatabaseNode) startExec(params runParams) error {
}
lookupFlags.required = false
for i := range tbNames {
objDesc, _, err := phyAccessor.GetObjectDesc(&tbNames[i],
objDesc, _, err := phyAccessor.GetObjectDesc(ctx, p.txn, &tbNames[i],
ObjectLookupFlags{CommonLookupFlags: lookupFlags})
if err != nil {
return err
Expand Down
Loading

0 comments on commit eb8345b

Please sign in to comment.