diff --git a/pkg/sql/catalog/BUILD.bazel b/pkg/sql/catalog/BUILD.bazel index d7ec8859ab3a..8f5015cd0a73 100644 --- a/pkg/sql/catalog/BUILD.bazel +++ b/pkg/sql/catalog/BUILD.bazel @@ -43,10 +43,12 @@ go_test( srcs = [ "dep_test.go", "descriptor_test.go", + "table_col_map_test.go", "table_col_set_test.go", ], embed = [":catalog"], deps = [ + "//pkg/sql/catalog/colinfo", "//pkg/sql/catalog/dbdesc", "//pkg/sql/catalog/descpb", "//pkg/sql/catalog/schemadesc", @@ -54,6 +56,7 @@ go_test( "//pkg/testutils/buildutil", "//pkg/util", "//pkg/util/leaktest", + "//pkg/util/randutil", "@com_github_cockroachdb_redact//:redact", "@com_github_stretchr_testify//require", "@in_gopkg_yaml_v2//:yaml_v2", diff --git a/pkg/sql/catalog/catalog.go b/pkg/sql/catalog/catalog.go index 89b544b02ed4..ad8b213bfd1a 100644 --- a/pkg/sql/catalog/catalog.go +++ b/pkg/sql/catalog/catalog.go @@ -11,6 +11,8 @@ package catalog import ( + "math" + "github.com/cockroachdb/cockroach/pkg/sql/catalog/descpb" "github.com/cockroachdb/cockroach/pkg/sql/sem/tree" ) @@ -103,3 +105,13 @@ func (p ResolvedObjectPrefix) NamePrefix() tree.ObjectNamePrefix { } return n } + +// NumSystemColumns defines the number of supported system columns and must be +// equal to len(colinfo.AllSystemColumnDescs) (enforced in colinfo package to +// avoid an import cycle). +const NumSystemColumns = 2 + +// SmallestSystemColumnColumnID is a descpb.ColumnID with the smallest value +// among all system columns (enforced in colinfo package to avoid an import +// cycle). +const SmallestSystemColumnColumnID = math.MaxUint32 - 1 diff --git a/pkg/sql/catalog/colinfo/system_columns.go b/pkg/sql/catalog/colinfo/system_columns.go index 5dc904f7cac7..5fb2ec979386 100644 --- a/pkg/sql/catalog/colinfo/system_columns.go +++ b/pkg/sql/catalog/colinfo/system_columns.go @@ -13,6 +13,7 @@ package colinfo import ( "math" + "github.com/cockroachdb/cockroach/pkg/sql/catalog" "github.com/cockroachdb/cockroach/pkg/sql/catalog/descpb" "github.com/cockroachdb/cockroach/pkg/sql/types" ) @@ -44,6 +45,17 @@ const MVCCTimestampColumnID = math.MaxUint32 // TableOIDColumnID is the ID of the tableoid system column. const TableOIDColumnID = MVCCTimestampColumnID - 1 +func init() { + if len(AllSystemColumnDescs) != catalog.NumSystemColumns { + panic("need to update catalog.NumSystemColumns") + } + for _, desc := range AllSystemColumnDescs { + if desc.ID < catalog.SmallestSystemColumnColumnID { + panic("need to update catalog.SmallestSystemColumnColumnID") + } + } +} + // MVCCTimestampColumnDesc is a column descriptor for the MVCC system column. var MVCCTimestampColumnDesc = descpb.ColumnDescriptor{ Name: MVCCTimestampColumnName, diff --git a/pkg/sql/catalog/table_col_map.go b/pkg/sql/catalog/table_col_map.go index 5ce8299ba89a..29e723bc3ff8 100644 --- a/pkg/sql/catalog/table_col_map.go +++ b/pkg/sql/catalog/table_col_map.go @@ -11,6 +11,9 @@ package catalog import ( + "bytes" + "fmt" + "github.com/cockroachdb/cockroach/pkg/sql/catalog/descpb" "github.com/cockroachdb/cockroach/pkg/util" ) @@ -18,23 +21,64 @@ import ( // TableColMap is a map from descpb.ColumnID to int. It is typically used to // store a mapping from column id to ordinal position within a row, but can be // used for any similar purpose. +// +// It stores the mapping for ColumnIDs of the system columns separately since +// those IDs are very large and incur an allocation in util.FastIntMap all the +// time. type TableColMap struct { m util.FastIntMap + // systemColMap maps all system columns to their values. Columns here are + // in increasing order of their IDs (in other words, since we started giving + // out IDs from math.MaxUint32 and are going down, the newer system columns + // appear here earlier). + systemColMap [NumSystemColumns]int + // systemColIsSet indicates whether a value has been set for the + // corresponding system column in systemColMap (needed in order to + // differentiate between unset 0 and set 0). + systemColIsSet [NumSystemColumns]bool } // Set maps a key to the given value. -func (s *TableColMap) Set(col descpb.ColumnID, val int) { s.m.Set(int(col), val) } +func (s *TableColMap) Set(col descpb.ColumnID, val int) { + if col < SmallestSystemColumnColumnID { + s.m.Set(int(col), val) + } else { + pos := col - SmallestSystemColumnColumnID + s.systemColMap[pos] = val + s.systemColIsSet[pos] = true + } +} // Get returns the current value mapped to key, or ok=false if the // key is unmapped. -func (s *TableColMap) Get(col descpb.ColumnID) (val int, ok bool) { return s.m.Get(int(col)) } +func (s *TableColMap) Get(col descpb.ColumnID) (val int, ok bool) { + if col < SmallestSystemColumnColumnID { + return s.m.Get(int(col)) + } + pos := col - SmallestSystemColumnColumnID + return s.systemColMap[pos], s.systemColIsSet[pos] +} // GetDefault returns the current value mapped to key, or 0 if the key is // unmapped. -func (s *TableColMap) GetDefault(col descpb.ColumnID) (val int) { return s.m.GetDefault(int(col)) } +func (s *TableColMap) GetDefault(col descpb.ColumnID) (val int) { + if col < SmallestSystemColumnColumnID { + return s.m.GetDefault(int(col)) + } + pos := col - SmallestSystemColumnColumnID + return s.systemColMap[pos] +} // Len returns the number of keys in the map. -func (s *TableColMap) Len() (val int) { return s.m.Len() } +func (s *TableColMap) Len() (val int) { + l := s.m.Len() + for _, isSet := range s.systemColIsSet { + if isSet { + l++ + } + } + return l +} // ForEach calls the given function for each key/value pair in the map (in // arbitrary order). @@ -42,11 +86,32 @@ func (s *TableColMap) ForEach(f func(colID descpb.ColumnID, returnIndex int)) { s.m.ForEach(func(k, v int) { f(descpb.ColumnID(k), v) }) + for pos, isSet := range s.systemColIsSet { + if isSet { + id := SmallestSystemColumnColumnID + pos + f(descpb.ColumnID(id), s.systemColMap[pos]) + } + } } // String prints out the contents of the map in the following format: // map[key1:val1 key2:val2 ...] // The keys are in ascending order. func (s *TableColMap) String() string { - return s.m.String() + var buf bytes.Buffer + buf.WriteString("map[") + s.m.ContentsIntoBuffer(&buf) + first := buf.Len() == len("map[") + for pos, isSet := range s.systemColIsSet { + if isSet { + if !first { + buf.WriteByte(' ') + } + first = false + id := SmallestSystemColumnColumnID + pos + fmt.Fprintf(&buf, "%d:%d", id, s.systemColMap[pos]) + } + } + buf.WriteByte(']') + return buf.String() } diff --git a/pkg/sql/catalog/table_col_map_test.go b/pkg/sql/catalog/table_col_map_test.go new file mode 100644 index 000000000000..11337a3e87e0 --- /dev/null +++ b/pkg/sql/catalog/table_col_map_test.go @@ -0,0 +1,82 @@ +// Copyright 2021 The Cockroach Authors. +// +// Use of this software is governed by the Business Source License +// included in the file licenses/BSL.txt. +// +// As of the Change Date specified in that file, in accordance with +// the Business Source License, use of this software will be governed +// by the Apache License, Version 2.0, included in the file +// licenses/APL.txt. + +package catalog_test + +import ( + "math/rand" + "sort" + "testing" + + "github.com/cockroachdb/cockroach/pkg/sql/catalog" + "github.com/cockroachdb/cockroach/pkg/sql/catalog/colinfo" + "github.com/cockroachdb/cockroach/pkg/sql/catalog/descpb" + "github.com/cockroachdb/cockroach/pkg/util" + "github.com/cockroachdb/cockroach/pkg/util/randutil" + "github.com/stretchr/testify/require" +) + +func TestTableColMap(t *testing.T) { + var m catalog.TableColMap + var oracle util.FastIntMap + rng, _ := randutil.NewTestRand() + + var columnIDs []descpb.ColumnID + for i := 0; i < 5; i++ { + columnIDs = append(columnIDs, descpb.ColumnID(i)) + } + for _, systemColumnDesc := range colinfo.AllSystemColumnDescs { + columnIDs = append(columnIDs, systemColumnDesc.ID) + } + rand.Shuffle(len(columnIDs), func(i, j int) { + columnIDs[i], columnIDs[j] = columnIDs[j], columnIDs[i] + }) + + // Use each column ID with 50% probability. + for i, columnID := range columnIDs { + if rng.Float64() < 0.5 { + m.Set(columnID, i) + oracle.Set(int(columnID), i) + } + } + + // First, check the length. + require.Equal(t, oracle.Len(), m.Len()) + + // Check that Get and GetDefault return the same results. + for _, columnID := range columnIDs { + actual, actualOk := m.Get(columnID) + expected, expectedOk := oracle.Get(int(columnID)) + require.Equal(t, expectedOk, actualOk) + if actualOk { + require.Equal(t, expected, actual) + } + actual = m.GetDefault(columnID) + expected = oracle.GetDefault(int(columnID)) + require.Equal(t, expected, actual) + } + + // Verify ForEach. We don't bother storing the column IDs here since sorting + // them below would be mildly annoying. + var actualValues, expectedValues []int + m.ForEach(func(_ descpb.ColumnID, returnIndex int) { + actualValues = append(actualValues, returnIndex) + }) + oracle.ForEach(func(_ int, val int) { + expectedValues = append(expectedValues, val) + }) + // Since the order of iteration is not defined, we have to sort all slices. + sort.Ints(actualValues) + sort.Ints(expectedValues) + require.Equal(t, expectedValues, actualValues) + + // Check that stringification matches too. + require.Equal(t, oracle.String(), m.String()) +} diff --git a/pkg/util/fast_int_map.go b/pkg/util/fast_int_map.go index 7f915b57d153..290f9a0fea90 100644 --- a/pkg/util/fast_int_map.go +++ b/pkg/util/fast_int_map.go @@ -196,14 +196,12 @@ func (m FastIntMap) ForEach(fn func(key, val int)) { } } -// String prints out the contents of the map in the following format: -// map[key1:val1 key2:val2 ...] +// ContentsIntoBuffer writes the contents of the map into the provided buffer in +// the following format: +// key1:val1 key2:val2 ... // The keys are in ascending order. -func (m FastIntMap) String() string { - var buf bytes.Buffer - buf.WriteString("map[") +func (m FastIntMap) ContentsIntoBuffer(buf *bytes.Buffer) { first := true - if m.large != nil { keys := make([]int, 0, len(m.large)) for k := range m.large { @@ -215,7 +213,7 @@ func (m FastIntMap) String() string { buf.WriteByte(' ') } first = false - fmt.Fprintf(&buf, "%d:%d", k, m.large[k]) + fmt.Fprintf(buf, "%d:%d", k, m.large[k]) } } else { for i := 0; i < numVals; i++ { @@ -224,10 +222,19 @@ func (m FastIntMap) String() string { buf.WriteByte(' ') } first = false - fmt.Fprintf(&buf, "%d:%d", i, val) + fmt.Fprintf(buf, "%d:%d", i, val) } } } +} + +// String prints out the contents of the map in the following format: +// map[key1:val1 key2:val2 ...] +// The keys are in ascending order. +func (m FastIntMap) String() string { + var buf bytes.Buffer + buf.WriteString("map[") + m.ContentsIntoBuffer(&buf) buf.WriteByte(']') return buf.String() }