Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

sql: add descriptor validation on write #60552

Merged
merged 3 commits into from
Feb 18, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion pkg/ccl/backupccl/restore_job.go
Original file line number Diff line number Diff line change
Expand Up @@ -426,7 +426,7 @@ func WriteDescriptors(
}

for _, db := range databases {
if err := db.Validate(); err != nil {
if err := db.Validate(ctx, dg); err != nil {
return errors.Wrapf(err,
"validate database %d", errors.Safe(db.GetID()))
}
Expand Down
2 changes: 1 addition & 1 deletion pkg/ccl/changefeedccl/avro_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ func parseTableDesc(createTableStmt string) (catalog.TableDescriptor, error) {
if err != nil {
return nil, err
}
return mutDesc, mutDesc.ValidateTable(ctx)
return mutDesc, mutDesc.ValidateSelf(ctx)
}

func parseValues(tableDesc catalog.TableDescriptor, values string) ([]rowenc.EncDatumRow, error) {
Expand Down
2 changes: 1 addition & 1 deletion pkg/ccl/partitionccl/partition_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,7 @@ func (pt *partitioningTest) parse() error {
return err
}
pt.parsed.tableDesc = mutDesc
if err := pt.parsed.tableDesc.ValidateTable(ctx); err != nil {
if err := pt.parsed.tableDesc.ValidateSelf(ctx); err != nil {
return err
}
}
Expand Down
26 changes: 6 additions & 20 deletions pkg/sql/catalog/catalogkv/catalogkv.go
Original file line number Diff line number Diff line change
Expand Up @@ -280,21 +280,6 @@ func (t *oneLevelUncachedDescGetter) GetDescs(

var _ catalog.DescGetter = (*oneLevelUncachedDescGetter)(nil)

func validateDescriptor(ctx context.Context, dg catalog.DescGetter, desc catalog.Descriptor) error {
switch desc := desc.(type) {
case catalog.TableDescriptor:
return desc.Validate(ctx, dg)
case catalog.DatabaseDescriptor:
return desc.Validate()
case catalog.TypeDescriptor:
return desc.Validate(ctx, dg)
case catalog.SchemaDescriptor:
return nil
default:
return errors.AssertionFailedf("unknown descriptor type %T", desc)
}
}

// unwrapDescriptor takes a descriptor retrieved using a transaction and unwraps
// it into an immutable implementation of Descriptor. It ensures that
// the ModificationTime is set properly and will validate the descriptor if
Expand Down Expand Up @@ -327,7 +312,7 @@ func unwrapDescriptor(
return nil, nil
}
if validate {
if err := validateDescriptor(ctx, dg, unwrapped); err != nil {
if err := unwrapped.Validate(ctx, dg); err != nil {
return nil, err
}
}
Expand All @@ -350,13 +335,13 @@ func unwrapDescriptorMutable(
if err != nil {
return nil, err
}
if err := mutTable.ValidateTable(ctx); err != nil {
if err := mutTable.ValidateSelf(ctx); err != nil {
return nil, err
}
return mutTable, nil
case database != nil:
dbDesc := dbdesc.NewExistingMutable(*database)
if err := dbDesc.Validate(); err != nil {
if err := dbDesc.Validate(ctx, dg); err != nil {
return nil, err
}
return dbDesc, nil
Expand Down Expand Up @@ -432,7 +417,7 @@ func GetAllDescriptors(
dg[desc.GetID()] = desc
}
for _, desc := range descs {
if err := validateDescriptor(ctx, dg, desc); err != nil {
if err := desc.Validate(ctx, dg); err != nil {
return nil, err
}
}
Expand Down Expand Up @@ -603,6 +588,7 @@ func getDescriptorsFromIDs(
if err := txn.Run(ctx, b); err != nil {
return nil, err
}
dg := NewOneLevelUncachedDescGetter(txn, codec)
results := make([]catalog.Descriptor, 0, len(ids))
for i := range b.Results {
result := &b.Results[i]
Expand All @@ -624,7 +610,7 @@ func getDescriptorsFromIDs(
var catalogDesc catalog.Descriptor
if desc.Union != nil {
var err error
catalogDesc, err = unwrapDescriptor(ctx, nil /* descGetter */, result.Rows[0].Value.Timestamp, desc, true)
catalogDesc, err = unwrapDescriptor(ctx, dg, result.Rows[0].Value.Timestamp, desc, true)
if err != nil {
return nil, err
}
Expand Down
14 changes: 0 additions & 14 deletions pkg/sql/catalog/catalogkv/unwrap_validation_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -87,20 +87,6 @@ func (o oneLevelMapDescGetter) GetDesc(
return unwrapDescriptorMutable(ctx, nil, mt, &desc)
}

func (o oneLevelMapDescGetter) GetDescs(
ctx context.Context, reqs []descpb.ID,
) ([]catalog.Descriptor, error) {
resps := make([]catalog.Descriptor, len(reqs))
for i, r := range reqs {
var err error
resps[i], err = o.GetDesc(ctx, r)
if err != nil {
return nil, err
}
}
return resps, nil
}

func decodeDescriptorDSV(t *testing.T, descriptorCSVPath string) oneLevelMapDescGetter {
f, err := os.Open(descriptorCSVPath)
require.NoError(t, err)
Expand Down
1 change: 1 addition & 0 deletions pkg/sql/catalog/dbdesc/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ go_library(
"//pkg/sql/catalog/descpb",
"//pkg/sql/privilege",
"//pkg/util/hlc",
"//pkg/util/iterutil",
"//pkg/util/protoutil",
"@com_github_cockroachdb_errors//:errors",
"@com_github_cockroachdb_redact//:redact",
Expand Down
32 changes: 30 additions & 2 deletions pkg/sql/catalog/dbdesc/database_desc.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
package dbdesc

import (
"context"
"fmt"

"github.com/cockroachdb/cockroach/pkg/keys"
Expand All @@ -21,6 +22,7 @@ import (
"github.com/cockroachdb/cockroach/pkg/sql/catalog/descpb"
"github.com/cockroachdb/cockroach/pkg/sql/privilege"
"github.com/cockroachdb/cockroach/pkg/util/hlc"
"github.com/cockroachdb/cockroach/pkg/util/iterutil"
"github.com/cockroachdb/cockroach/pkg/util/protoutil"
"github.com/cockroachdb/errors"
"github.com/cockroachdb/redact"
Expand Down Expand Up @@ -250,10 +252,26 @@ func (desc *Mutable) SetName(name string) {
desc.Name = name
}

// Validate validates that the database descriptor is well formed.
// ForEachSchemaInfo iterates f over each schema info mapping in the descriptor.
// iterutil.StopIteration is supported.
func (desc *Immutable) ForEachSchemaInfo(
f func(id descpb.ID, name string, isDropped bool) error,
) error {
for name, info := range desc.Schemas {
if err := f(info.ID, name, info.Dropped); err != nil {
if iterutil.Done(err) {
return nil
}
return err
}
}
return nil
}

// ValidateSelf validates that the database descriptor is well formed.
// Checks include validate the database name, and verifying that there
// is at least one read and write user.
func (desc *Immutable) Validate() error {
func (desc *Immutable) ValidateSelf(_ context.Context) error {
if err := catalog.ValidateName(desc.GetName(), "descriptor"); err != nil {
return err
}
Expand Down Expand Up @@ -294,6 +312,16 @@ func (desc *Immutable) Validate() error {
return desc.Privileges.Validate(desc.GetID(), privilege.Database)
}

// Validate punts to ValidateSelf.
func (desc *Immutable) Validate(ctx context.Context, _ catalog.DescGetter) error {
return desc.ValidateSelf(ctx)
}

// ValidateTxnCommit punts to Validate.
func (desc *Immutable) ValidateTxnCommit(ctx context.Context, descGetter catalog.DescGetter) error {
return desc.Validate(ctx, descGetter)
}

// SchemaMeta implements the tree.SchemaMeta interface.
// TODO (rohany): I don't want to keep this here, but it seems to be used
// by backup only for the fake resolution that occurs in backup. Is it possible
Expand Down
32 changes: 23 additions & 9 deletions pkg/sql/catalog/desc_getter.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,32 @@ import (
// is used to look up other descriptors during validation.
type DescGetter interface {
GetDesc(ctx context.Context, id descpb.ID) (Descriptor, error)
}

// BatchDescGetter is like DescGetter but retrieves batches of descriptors,
// which for some implementation may make more sense performance-wise.
type BatchDescGetter interface {
GetDescs(ctx context.Context, reqs []descpb.ID) ([]Descriptor, error)
}

// GetDescs retrieves multiple descriptors using a DescGetter.
// If the latter is also a BatchDescGetter, it will delegate to its GetDescs
// method.
func GetDescs(ctx context.Context, descGetter DescGetter, reqs []descpb.ID) ([]Descriptor, error) {
if bdg, ok := descGetter.(BatchDescGetter); ok {
return bdg.GetDescs(ctx, reqs)
}
ret := make([]Descriptor, len(reqs))
for i, id := range reqs {
desc, err := descGetter.GetDesc(ctx, id)
if err != nil {
return nil, err
}
ret[i] = desc
}
return ret, nil
}

// GetTypeDescFromID retrieves the type descriptor for the type ID passed
// in using an existing descGetter. It returns an error if the descriptor
// doesn't exist or if it exists and is not a type descriptor.
Expand Down Expand Up @@ -62,12 +85,3 @@ func (m MapDescGetter) GetDesc(ctx context.Context, id descpb.ID) (Descriptor, e
desc := m[id]
return desc, nil
}

// GetDescs implements the catalog.DescGetter interface.
func (m MapDescGetter) GetDescs(ctx context.Context, ids []descpb.ID) ([]Descriptor, error) {
ret := make([]Descriptor, len(ids))
for i, id := range ids {
ret[i], _ = m.GetDesc(ctx, id)
}
return ret, nil
}
14 changes: 10 additions & 4 deletions pkg/sql/catalog/descriptor.go
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,15 @@ type Descriptor interface {

// DescriptorProto prepares this descriptor for serialization.
DescriptorProto() *descpb.Descriptor

// ValidateSelf checks the internal consistency of the descriptor.
ValidateSelf(ctx context.Context) error

// Validate is like ValidateSelf but with additional cross-reference checks.
Validate(ctx context.Context, descGetter DescGetter) error

// ValidateTxnCommit is like Validate but with additional pre-commit checks.
ValidateTxnCommit(ctx context.Context, descGetter DescGetter) error
}

// DatabaseDescriptor will eventually be called dbdesc.Descriptor.
Expand All @@ -90,8 +99,8 @@ type DatabaseDescriptor interface {
RegionNames() (descpb.RegionNames, error)
IsMultiRegion() bool
PrimaryRegionName() (descpb.RegionName, error)
Validate() error
MultiRegionEnumID() (descpb.ID, error)
ForEachSchemaInfo(func(id descpb.ID, name string, isDropped bool) error) error
}

// SchemaDescriptor will eventually be called schemadesc.Descriptor.
Expand Down Expand Up @@ -245,8 +254,6 @@ type TableDescriptor interface {
databaseDesc DatabaseDescriptor, getType func(descpb.ID) (TypeDescriptor, error),
) (descpb.IDs, error)

Validate(ctx context.Context, txn DescGetter) error

ForeachDependedOnBy(f func(dep *descpb.TableDescriptor_Reference) error) error
GetDependsOn() []descpb.ID
GetConstraintInfoWithLookup(fn TableLookupFn) (map[string]descpb.ConstraintDetail, error)
Expand Down Expand Up @@ -493,7 +500,6 @@ type TypeDescriptor interface {

PrimaryRegionName() (descpb.RegionName, error)
RegionNames() (descpb.RegionNames, error)
Validate(ctx context.Context, dg DescGetter) error
}

// TypeDescriptorResolver is an interface used during hydration of type
Expand Down
2 changes: 2 additions & 0 deletions pkg/sql/catalog/descs/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ go_library(
"//pkg/base",
"//pkg/keys",
"//pkg/kv",
"//pkg/settings",
"//pkg/settings/cluster",
"//pkg/sql/catalog",
"//pkg/sql/catalog/bootstrap",
Expand Down Expand Up @@ -65,6 +66,7 @@ go_test(
"//pkg/sql/catalog/tabledesc",
"//pkg/sql/sem/tree",
"//pkg/sql/sqlutil",
"//pkg/sql/types",
"//pkg/testutils/serverutils",
"//pkg/testutils/sqlutils",
"//pkg/testutils/testcluster",
Expand Down
52 changes: 51 additions & 1 deletion pkg/sql/catalog/descs/collection.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ import (

"github.com/cockroachdb/cockroach/pkg/keys"
"github.com/cockroachdb/cockroach/pkg/kv"
"github.com/cockroachdb/cockroach/pkg/settings"
"github.com/cockroachdb/cockroach/pkg/settings/cluster"
"github.com/cockroachdb/cockroach/pkg/sql/catalog"
"github.com/cockroachdb/cockroach/pkg/sql/catalog/bootstrap"
Expand Down Expand Up @@ -1342,13 +1343,25 @@ func (tc *Collection) addUncommittedDescriptor(
return ud, nil
}

// validateOnWriteEnabled is the cluster setting used to enable or disable
// validating descriptors prior to writing.
var validateOnWriteEnabled = settings.RegisterBoolSetting(
"sql.catalog.descs.validate_on_write.enabled",
"set to true to validate descriptors prior to writing, false to disable; default is true",
true, /* defaultValue */
)

// WriteDescToBatch calls MaybeIncrementVersion, adds the descriptor to the
// collection as an uncommitted descriptor, and writes it into b.
func (tc *Collection) WriteDescToBatch(
ctx context.Context, kvTrace bool, desc catalog.MutableDescriptor, b *kv.Batch,
) error {
desc.MaybeIncrementVersion()
// TODO(ajwerner): Add validation here.
if validateOnWriteEnabled.Get(&tc.settings.SV) {
if err := desc.ValidateSelf(ctx); err != nil {
return err
}
}
if err := tc.AddUncommittedDescriptor(desc); err != nil {
return err
}
Expand Down Expand Up @@ -1392,6 +1405,43 @@ func (tc *Collection) GetUncommittedTables() (tables []catalog.TableDescriptor)
return tables
}

type collectionDescGetter struct {
tc *Collection
txn *kv.Txn
}

var _ catalog.DescGetter = collectionDescGetter{}

func (cdg collectionDescGetter) GetDesc(
ctx context.Context, id descpb.ID,
) (catalog.Descriptor, error) {
flags := tree.CommonLookupFlags{
Required: true,
// Include everything, we want to cast the net as wide as we can.
IncludeOffline: true,
IncludeDropped: true,
// Avoid leased descriptors, if we're leasing the previous version then this
// older version may be returned and this may cause validation to fail.
AvoidCached: true,
}
return cdg.tc.getDescriptorByID(ctx, cdg.txn, id, flags, false /* mutable */)
}

// ValidateUncommittedDescriptors validates all uncommitted descriptors
func (tc *Collection) ValidateUncommittedDescriptors(ctx context.Context, txn *kv.Txn) error {
if !validateOnWriteEnabled.Get(&tc.settings.SV) {
return nil
}
cdg := collectionDescGetter{tc: tc, txn: txn}
for i, n := 0, len(tc.uncommittedDescriptors); i < n; i++ {
desc := tc.uncommittedDescriptors[i].immutable
if err := desc.ValidateTxnCommit(ctx, cdg); err != nil {
return err
}
}
return nil
}

// User defined type accessors.

// GetMutableTypeVersionByID is the equivalent of GetMutableTableDescriptorByID
Expand Down
Loading