diff --git a/pkg/sql/opt/BUILD.bazel b/pkg/sql/opt/BUILD.bazel index 90d46f5b66d6..1e417a36c73c 100644 --- a/pkg/sql/opt/BUILD.bazel +++ b/pkg/sql/opt/BUILD.bazel @@ -61,6 +61,7 @@ go_test( embed = [":opt"], deps = [ "//pkg/settings/cluster", + "//pkg/sql/catalog/descpb", "//pkg/sql/opt/cat", "//pkg/sql/opt/memo", "//pkg/sql/opt/norm", @@ -70,6 +71,7 @@ go_test( "//pkg/sql/sem/tree", "//pkg/sql/types", "//pkg/util", + "@com_github_stretchr_testify//require", ], ) diff --git a/pkg/sql/opt/metadata_test.go b/pkg/sql/opt/metadata_test.go index dacc6649575e..57495fadb5a8 100644 --- a/pkg/sql/opt/metadata_test.go +++ b/pkg/sql/opt/metadata_test.go @@ -17,6 +17,7 @@ import ( "testing" "github.com/cockroachdb/cockroach/pkg/settings/cluster" + "github.com/cockroachdb/cockroach/pkg/sql/catalog/descpb" "github.com/cockroachdb/cockroach/pkg/sql/opt" "github.com/cockroachdb/cockroach/pkg/sql/opt/cat" "github.com/cockroachdb/cockroach/pkg/sql/opt/memo" @@ -26,6 +27,7 @@ import ( "github.com/cockroachdb/cockroach/pkg/sql/sem/eval" "github.com/cockroachdb/cockroach/pkg/sql/sem/tree" "github.com/cockroachdb/cockroach/pkg/sql/types" + "github.com/stretchr/testify/require" ) func TestMetadata(t *testing.T) { @@ -409,3 +411,47 @@ func TestDuplicateTable(t *testing.T) { t.Errorf("expected partial index predicate to reference new column ID %d, got %d", dupB, col) } } + +// TestTableMeta_GetRegionsInDatabases exercises the multiregion.RegionConfig +// annotation. +func TestTableMeta_GetRegionsInDatabase(t *testing.T) { + cat := testcat.New() + _, err := cat.ExecuteDDL("CREATE TABLE a (b BOOL, b2 BOOL, INDEX (b2) WHERE b)") + if err != nil { + t.Fatal(err) + } + + var md opt.Metadata + tn := tree.NewUnqualifiedTableName("a") + tab := cat.Table(tn) + tab.DatabaseID = 1 // must be non-zero to trigger the region lookup + a := md.AddTable(tab, tn) + tabMeta := md.TableMeta(a) + + p := &fakeGetMultiregionConfigPlanner{} + // Call the function once, make sure our planner method gets invoked. + { + _, exists := tabMeta.GetRegionsInDatabase(p) + require.False(t, exists) + require.Equal(t, 1, p.getMultiregionConfigCalled) + } + // Call the function again, make sure that our planner method doesn't + // get invoked again. + { + _, exists := tabMeta.GetRegionsInDatabase(p) + require.False(t, exists) + require.Equal(t, 1, p.getMultiregionConfigCalled) + } +} + +type fakeGetMultiregionConfigPlanner struct { + eval.Planner + getMultiregionConfigCalled int +} + +func (f *fakeGetMultiregionConfigPlanner) GetMultiregionConfig( + databaseID descpb.ID, +) (interface{}, bool) { + f.getMultiregionConfigCalled++ + return nil, false +} diff --git a/pkg/sql/opt/table_meta.go b/pkg/sql/opt/table_meta.go index ddee17f5de87..70552d5a4d10 100644 --- a/pkg/sql/opt/table_meta.go +++ b/pkg/sql/opt/table_meta.go @@ -445,11 +445,11 @@ func (tm *TableMeta) VirtualComputedColumns() ColSet { } // GetRegionsInDatabase finds the full set of regions in the multiregion -// database owning the table described by `tm`, or returns ok=false if not -// multiregion. The result is cached in TableMeta. +// database owning the table described by `tm`, or returns hasRegionName=false +// if not multiregion. The result is cached in TableMeta. func (tm *TableMeta) GetRegionsInDatabase( planner eval.Planner, -) (regionNames catpb.RegionNames, ok bool) { +) (regionNames catpb.RegionNames, hasRegionNames bool) { multiregionConfig, ok := tm.TableAnnotation(regionConfigAnnID).(*multiregion.RegionConfig) if ok { if multiregionConfig == nil { @@ -458,14 +458,22 @@ func (tm *TableMeta) GetRegionsInDatabase( return multiregionConfig.Regions(), true } dbID := tm.Table.GetDatabaseID() + defer func() { + if !hasRegionNames { + tm.SetTableAnnotation( + regionConfigAnnID, + // Use a nil pointer to a RegionConfig, which is distinct from the + // untyped nil and will be detected in the type assertion above. + (*multiregion.RegionConfig)(nil), + ) + } + }() + if dbID == 0 { - tm.SetTableAnnotation(regionConfigAnnID, nil) return nil /* regionNames */, false } - regionConfig, ok := planner.GetMultiregionConfig(dbID) if !ok { - tm.SetTableAnnotation(regionConfigAnnID, nil) return nil /* regionNames */, false } multiregionConfig, _ = regionConfig.(*multiregion.RegionConfig) @@ -494,7 +502,9 @@ func (tm *TableMeta) GetDatabaseSurvivalGoal( dbID := tm.Table.GetDatabaseID() regionConfig, ok := planner.GetMultiregionConfig(dbID) if !ok { - tm.SetTableAnnotation(regionConfigAnnID, nil) + // Use a nil pointer to a RegionConfig, which is distinct from the + // untyped nil and will be detected in the type assertion above. + tm.SetTableAnnotation(regionConfigAnnID, (*multiregion.RegionConfig)(nil)) return descpb.SurvivalGoal_ZONE_FAILURE /* survivalGoal */, false } multiregionConfig, _ = regionConfig.(*multiregion.RegionConfig) diff --git a/pkg/sql/opt/testutils/testcat/test_catalog.go b/pkg/sql/opt/testutils/testcat/test_catalog.go index 6c673fa54ab6..8bac83967009 100644 --- a/pkg/sql/opt/testutils/testcat/test_catalog.go +++ b/pkg/sql/opt/testutils/testcat/test_catalog.go @@ -667,6 +667,7 @@ func (tv *View) CollectTypes(ord int) (descpb.IDs, error) { // Table implements the cat.Table interface for testing purposes. type Table struct { TabID cat.StableID + DatabaseID descpb.ID TabVersion int TabName tree.TableName Columns []cat.Column @@ -869,7 +870,7 @@ func (tt *Table) HomeRegionColName() (colName string, ok bool) { // GetDatabaseID is part of the cat.Table interface. func (tt *Table) GetDatabaseID() descpb.ID { - return 0 + return tt.DatabaseID } // FindOrdinal returns the ordinal of the column with the given name.