From 962b06c6d380c9ad00fa2472ad60aa3cd7d9401a Mon Sep 17 00:00:00 2001 From: MiguelNovelo Date: Mon, 1 Feb 2021 17:35:57 -0600 Subject: [PATCH] sql: code generation for missing pg catalog tables Previously, programmer had to add missing table manually, This was inadequate because postgress added a lot of tables and manual process can lead to coding mistakes or gaps To address this, diff tool now have the hability to generate the missing code Release note: None Fixes #58001 --- pkg/sql/BUILD.bazel | 1 + pkg/sql/pg_catalog_diff.go | 27 +++ pkg/sql/pg_catalog_test.go | 447 ++++++++++++++++++++++++++++++++++++- pkg/sql/virtual_schema.go | 2 + 4 files changed, 473 insertions(+), 4 deletions(-) diff --git a/pkg/sql/BUILD.bazel b/pkg/sql/BUILD.bazel index 4ce850cd01bf..dabd43b7002f 100644 --- a/pkg/sql/BUILD.bazel +++ b/pkg/sql/BUILD.bazel @@ -598,6 +598,7 @@ go_test( "@com_github_jackc_pgx//pgtype", "@com_github_jackc_pgx_v4//:pgx", "@com_github_lib_pq//:pq", + "@com_github_lib_pq//oid", "@com_github_pmezard_go_difflib//difflib", "@com_github_stretchr_testify//assert", "@com_github_stretchr_testify//require", diff --git a/pkg/sql/pg_catalog_diff.go b/pkg/sql/pg_catalog_diff.go index f8c5a0702850..7a38bff167ba 100644 --- a/pkg/sql/pg_catalog_diff.go +++ b/pkg/sql/pg_catalog_diff.go @@ -17,6 +17,9 @@ import ( "encoding/json" "io" "os" + + "github.com/cockroachdb/cockroach/pkg/sql/types" + "github.com/lib/pq/oid" ) // GetPGCatalogSQL is a query uses udt_name::regtype instead of data_type column because @@ -185,6 +188,30 @@ func (p PGCatalogTables) rewriteDiffs(diffFile string) error { return nil } +// getNotImplementedTables retrieves tables that are not yet part of crdb +func (p PGCatalogTables) getNotImplementedTables(source PGCatalogTables) PGCatalogTables { + notImplemented := make(PGCatalogTables) + for tableName := range p { + if len(p[tableName]) == 0 && len(source[tableName].getNotImplementedTypes()) == 0 { + notImplemented[tableName] = source[tableName] + } + } + return notImplemented +} + +//AreAllTypesImplemented verifies that all the types are implemented in cockroach db +func (c PGCatalogColumns) getNotImplementedTypes() map[oid.Oid]string { + notImplemented := make(map[oid.Oid]string) + for _, column := range c { + typeOid := oid.Oid(column.Oid) + if _, ok := types.OidToType[typeOid]; !ok || typeOid == oid.T_anyarray { + notImplemented[typeOid] = column.DataType + } + } + + return notImplemented +} + // Save have the purpose of storing all the data retrieved from postgres and useful information as postgres version func (f *PGCatalogFile) Save(writer io.Writer) { byteArray, err := json.MarshalIndent(f, "", " ") diff --git a/pkg/sql/pg_catalog_test.go b/pkg/sql/pg_catalog_test.go index 53cd4e57a735..81a2cff15e5b 100644 --- a/pkg/sql/pg_catalog_test.go +++ b/pkg/sql/pg_catalog_test.go @@ -28,32 +28,77 @@ package sql import ( + "bufio" "context" "encoding/json" "flag" "fmt" + "io" "io/ioutil" "os" "path/filepath" + "regexp" + "sort" + "strings" "testing" "github.com/cockroachdb/cockroach/pkg/base" + "github.com/cockroachdb/cockroach/pkg/sql/types" "github.com/cockroachdb/cockroach/pkg/testutils/serverutils" "github.com/cockroachdb/cockroach/pkg/testutils/sqlutils" "github.com/cockroachdb/cockroach/pkg/util/leaktest" "github.com/cockroachdb/errors/oserror" + "github.com/lib/pq/oid" ) // Test data files const ( - pgCatalogDump = "pg_catalog_tables.json" // PostgreSQL pg_catalog schema - expectedDiffs = "pg_catalog_test_expected_diffs.json" // Contains expected difference between postgres and cockroach - testdata = "testdata" // testdata directory + pgCatalogDump = "pg_catalog_tables.json" // PostgreSQL pg_catalog schema + expectedDiffs = "pg_catalog_test_expected_diffs.json" // Contains expected difference between postgres and cockroach + testdata = "testdata" // testdata directory + catalogPkg = "catalog" + catconstantsPkg = "catconstants" + constantsGo = "constants.go" + vtablePkg = "vtable" + pgCatalogGo = "pg_catalog.go" +) + +// strings used on constants creations and text manipulation +const ( + pgCatalogPrefix = "PgCatalog" + pgCatalogIDConstant = "PgCatalogID" + tableIDSuffix = "TableID" + tableDefsDeclaration = `tableDefs: map[descpb.ID]virtualSchemaDef{` + tableDefsClosure = `},` + allTableNamesDeclaration = `allTableNames: buildStringSet(` + allTableNamesClose = `),` + virtualTablePosition = `// typOid is the only OID generation approach that does not use oidHasher, because` + virtualTableTempl = `var %s = virtualSchemaTable{ + comment: "%s is not implemented yet", + schema: vtable.%s, + populate: func(ctx context.Context, p *planner, _ *dbdesc.Immutable, addRow func(...tree.Datum) error) error { + return nil + }, +} + +` ) // When running test with -rewrite-diffs test will pass and re-create pg_catalog_test-diffs.json var rewriteFlag = flag.Bool("rewrite-diffs", false, "This will re-create the expected diffs file") +var addMissingTables = flag.Bool( + "add-missing-tables", + false, + "add-missing-tables will complete pg_catalog tables in the go code", +) + +var ( + tableFinderRE = regexp.MustCompile(`(?i)CREATE TABLE pg_catalog\.([^\s]+)\s`) +) + +var none = struct{}{} + // summary will keep accountability for any unexpected difference and report it in the log type summary struct { missingTables int @@ -141,7 +186,7 @@ func loadExpectedDiffs(t *testing.T) (diffs PGCatalogTables) { if err != nil { t.Fatal(err) } - defer f.Close() + defer dClose(f) bytes, err := ioutil.ReadAll(f) if err != nil { t.Fatal(err) @@ -170,9 +215,396 @@ func rewriteDiffs(t *testing.T, diffs PGCatalogTables, diffsFile string) { } } +// fixConstants helps to update catconstants that are needed for pgCatalog +func fixConstants(t *testing.T, notImplemented PGCatalogTables) { + constantsFileName := filepath.Join(".", catalogPkg, catconstantsPkg, constantsGo) + // pgConstants will contains all the pgCatalog tableID constant adding the new tables and preventing duplicates + pgConstants := getPgCatalogConstants(t, constantsFileName, notImplemented) + sort.Strings(pgConstants) + + // Rewrite will place all the pgConstants in alphabetical order after PgCatalogID + rewriteFile(constantsFileName, func(input *os.File, output outputFile) { + reader := bufio.NewScanner(input) + for reader.Scan() { + text := reader.Text() + trimText := strings.TrimSpace(text) + + // Skips PgCatalog constants (except PgCatalogID) as these will be written from pgConstants slice + if strings.HasPrefix(trimText, pgCatalogPrefix) && trimText != pgCatalogIDConstant { + continue + } + + output.appendString(text) + output.appendString("\n") + + if trimText == pgCatalogIDConstant { + for _, pgConstant := range pgConstants { + output.appendString("\t") + output.appendString(pgConstant) + output.appendString("\n") + } + } + } + }) +} + +// fixVtable adds missing table's create table constants +func fixVtable(t *testing.T, notImplemented PGCatalogTables) { + fileName := filepath.Join(vtablePkg, pgCatalogGo) + + // rewriteFile first will check existing create table constants to avoid duplicates + rewriteFile(fileName, func(input *os.File, output outputFile) { + existingTables := make(map[string]struct{}) + reader := bufio.NewScanner(input) + for reader.Scan() { + text := reader.Text() + output.appendString(text) + output.appendString("\n") + createTable := tableFinderRE.FindStringSubmatch(text) + if createTable != nil { + tableName := createTable[1] + existingTables[tableName] = none + } + } + + for tableName, columns := range notImplemented { + if _, ok := existingTables[tableName]; ok { + // Table already implemented + continue + } + createTable, err := createTableConstant(tableName, columns) + if err != nil { + // We can not implement this table as this uses types not implemented + t.Log(err) + continue + } + output.appendString(createTable) + } + }) +} + +// fixPgCatalogGo will update pgCatalog.allTableNames, pgCatalog.tableDefs and will add needed virtualSchemas +func fixPgCatalogGo(notImplemented PGCatalogTables) { + allTableNamesText := getAllTableNamesText(notImplemented) + tableDefinitionText := getTableDefinitionsText(pgCatalogGo, notImplemented) + + rewriteFile(pgCatalogGo, func(input *os.File, output outputFile) { + reader := bufio.NewScanner(input) + for reader.Scan() { + text := reader.Text() + trimText := strings.TrimSpace(text) + if trimText == virtualTablePosition { + //VirtualSchemas doesn't have a particular place to start we just print it before virtualTablePosition + output.appendString(printVirtualSchemas(notImplemented)) + } + output.appendString(text) + output.appendString("\n") + + switch trimText { + case tableDefsDeclaration: + printBeforeClosure(reader, output, tableDefsClosure, tableDefinitionText) + case allTableNamesDeclaration: + printBeforeClosure(reader, output, allTableNamesClose, allTableNamesText) + } + } + }) +} + +// printBeforeClosure will skip all the lines and print `s` text when finds the closure text +func printBeforeClosure(reader *bufio.Scanner, output outputFile, closure string, s string) { + for reader.Scan() { + text := reader.Text() + trimText := strings.TrimSpace(text) + + if strings.HasPrefix(trimText, "//") { + output.appendString(text) + output.appendString("\n") + continue + } + if trimText != closure { + continue + } + output.appendString(s) + output.appendString(text) + output.appendString("\n") + break + } +} + +// getPgCatalogConstants reads catconstant and retrieves all the constant with `PgCatalog` prefix +func getPgCatalogConstants( + t *testing.T, inputFileName string, notImplemented PGCatalogTables, +) []string { + pgConstantSet := make(map[string]struct{}) + f, err := os.Open(inputFileName) + if err != nil { + t.Logf("Problem getting pgCatalogConstants: %v", err) + t.Fatal(err) + } + defer dClose(f) + reader := bufio.NewScanner(f) + for reader.Scan() { + text := strings.TrimSpace(reader.Text()) + if strings.HasPrefix(text, pgCatalogPrefix) { + if text == pgCatalogIDConstant { + continue + } + pgConstantSet[text] = none + } + } + for tableName := range notImplemented { + pgConstantSet[constantName(tableName, tableIDSuffix)] = none + } + pgConstants := make([]string, 0, len(pgConstantSet)) + for pgConstantName := range pgConstantSet { + pgConstants = append(pgConstants, pgConstantName) + } + return pgConstants +} + +// outputFile wraps an *os.file to avoid explicit error checks on every WriteString +type outputFile struct { + f *os.File +} + +// appendString calls WriteString and panics on error +func (o outputFile) appendString(s string) { + if _, err := o.f.WriteString(s); err != nil { + panic(fmt.Errorf("error while writing string: %s: %v", s, err)) + } +} + +// rewriteFile recreate a file by using the f func, this creates a temporary file to place all the output first +// then it replaces the original file +func rewriteFile(fileName string, f func(*os.File, outputFile)) { + tmpName := fileName + ".tmp" + updateFile(fileName, tmpName, f) + defer func() { + if err := os.Remove(tmpName); err != nil { + panic(fmt.Errorf("problem removing temp file %s: %e", tmpName, err)) + } + }() + + updateFile(tmpName, fileName, func(input *os.File, output outputFile) { + if _, err := io.Copy(output.f, input); err != nil { + panic(fmt.Errorf("problem at rewriting file %s into %s: %v", tmpName, fileName, err)) + } + }) +} + +func updateFile(inputFileName, outputFileName string, f func(input *os.File, output outputFile)) { + input, err := os.Open(inputFileName) + if err != nil { + panic(fmt.Errorf("error opening file %s: %v", inputFileName, err)) + } + defer dClose(input) + + output, err := os.OpenFile(outputFileName, os.O_CREATE|os.O_WRONLY, 0644) + if err != nil { + panic(fmt.Errorf("error opening file %s: %v", outputFileName, err)) + } + defer dClose(output) + + f(input, outputFile{output}) +} + +// dClose is a helper that eliminates the need of error checking and defer the io.Closer Close() and pass lint checks +func dClose(f io.Closer) { + err := f.Close() + if err != nil { + panic(err) + } +} + +var acronyms = map[string]struct{}{ + "acl": none, + "id": none, +} + +// constantName create constant names for pg_catalog fixableTables following constant names standards +func constantName(tableName string, suffix string) string { + var sb strings.Builder + snakeWords := strings.Split(tableName, "_")[1:] + sb.WriteString("PgCatalog") + + for _, word := range snakeWords { + if _, ok := acronyms[word]; ok { + sb.WriteString(strings.ToUpper(word)) + } else { + sb.WriteString(strings.ToUpper(word[:1])) + sb.WriteString(word[1:]) + } + } + + sb.WriteString(suffix) + return sb.String() +} + +// createTableConstant formats the text for vtable constants +func createTableConstant(tableName string, columns PGCatalogColumns) (string, error) { + var sb strings.Builder + constName := constantName(tableName, "") + if notImplementedTypes := columns.getNotImplementedTypes(); len(notImplementedTypes) > 0 { + return "", fmt.Errorf("not all types are implemented %s: %v", tableName, notImplementedTypes) + } + + sb.WriteString("\n//") + sb.WriteString(constName) + sb.WriteString(" is an empty table in the pg_catalog that is not implemented yet\n") + sb.WriteString("const ") + sb.WriteString(constName) + sb.WriteString(" = `\n") + sb.WriteString("CREATE TABLE pg_catalog.") + sb.WriteString(tableName) + sb.WriteString(" (\n") + prefix := "" + for columnName, columnType := range columns { + formatColumn(&sb, prefix, columnName, columnType) + prefix = ",\n" + } + sb.WriteString("\n)`\n") + return sb.String(), nil +} + +func formatColumn(sb *strings.Builder, prefix, columnName string, columnType *PGCatalogColumn) { + typeOid := oid.Oid(columnType.Oid) + typeName := types.OidToType[typeOid].Name() + if !strings.HasPrefix(typeName, `"char"`) { + typeName = strings.ToUpper(typeName) + } + sb.WriteString(prefix) + sb.WriteString("\t") + sb.WriteString(columnName) + sb.WriteString(" ") + sb.WriteString(typeName) +} + +// printVirtualSchemas formats the golang code to create the virtualSchema structure +func printVirtualSchemas(newTableNameList PGCatalogTables) string { + var sb strings.Builder + for tableName := range newTableNameList { + variableName := "p" + constantName(tableName, "Table")[1:] + vTableName := constantName(tableName, "") + sb.WriteString(fmt.Sprintf(virtualTableTempl, variableName, tableName, vTableName)) + } + return sb.String() +} + +// getAllTableNamesText retrieves pgCatalog.allTableNames, then it merges the new table names and formats the +// replacement text +func getAllTableNamesText(notImplemented PGCatalogTables) string { + newTableNameSet := make(map[string]struct{}) + for tableName := range pgCatalog.allTableNames { + newTableNameSet[tableName] = none + } + for tableName := range notImplemented { + newTableNameSet[tableName] = none + } + newTableList := make([]string, 0, len(newTableNameSet)) + for tableName := range newTableNameSet { + newTableList = append(newTableList, tableName) + } + sort.Strings(newTableList) + return formatAllTableNamesText(newTableList) +} + +func formatAllTableNamesText(newTableNameList []string) string { + var sb strings.Builder + for _, tableName := range newTableNameList { + sb.WriteString("\t\t\"") + sb.WriteString(tableName) + sb.WriteString("\",\n") + } + return sb.String() +} + +// getTableDefinitionsText retrieves pgCatalog.tableDefs, then it merges the missing table in that object and formats +// the replacement text +func getTableDefinitionsText(fileName string, notImplemented PGCatalogTables) string { + tableDefs := make(map[string]string) + maxLength := 0 + f, err := os.Open(fileName) + if err != nil { + panic(fmt.Errorf("could not open file %s: %v", fileName, err)) + } + defer dClose(f) + reader := bufio.NewScanner(f) + for reader.Scan() { + text := strings.TrimSpace(reader.Text()) + if text == tableDefsDeclaration { + break + } + } + for reader.Scan() { + text := strings.TrimSpace(reader.Text()) + if text == tableDefsClosure { + break + } + def := strings.Split(text, ":") + defName := strings.TrimSpace(def[0]) + defValue := strings.TrimRight(strings.TrimSpace(def[1]), ",") + tableDefs[defName] = defValue + length := len(defName) + if length > maxLength { + maxLength = length + } + } + + for tableName := range notImplemented { + defName := "catconstants." + constantName(tableName, tableIDSuffix) + if _, ok := tableDefs[defName]; ok { + // Not overriding existing tableDefinitions + delete(notImplemented, tableName) + continue + } + defValue := "p" + constantName(tableName, "Table")[1:] + tableDefs[defName] = defValue + length := len(defName) + if length > maxLength { + maxLength = length + } + } + + return formatTableDefinitionText(tableDefs, maxLength) +} + +func formatTableDefinitionText(tableDefs map[string]string, maxLength int) string { + var sbAll strings.Builder + sortedDefKeys := getSortedDefKeys(tableDefs) + for _, defKey := range sortedDefKeys { + var sb strings.Builder + sb.WriteString("\t\t") + sb.WriteString(defKey) + sb.WriteString(":") + for sb.Len() < maxLength+4 { + sb.WriteString(" ") + } + sb.WriteString(tableDefs[defKey]) + sb.WriteString(",\n") + sbAll.WriteString(sb.String()) + } + return sbAll.String() +} + +func getSortedDefKeys(tableDefs map[string]string) []string { + keys := make([]string, 0, len(tableDefs)) + for constName := range tableDefs { + keys = append(keys, constName) + } + sort.Strings(keys) + return keys +} + // TestPGCatalog is the pg_catalog diff tool test which compares pg_catalog with postgres and cockroach func TestPGCatalog(t *testing.T) { defer leaktest.AfterTest(t)() + defer func() { + r := recover() + if err, ok := r.(error); ok { + t.Fatal(err) + } + }() + pgTables := loadTestData(t) crdbTables := loadCockroachPgCatalog(t) diffs := loadExpectedDiffs(t) @@ -218,4 +650,11 @@ func TestPGCatalog(t *testing.T) { sum.report(t) rewriteDiffs(t, diffs, filepath.Join(testdata, expectedDiffs)) + + if *addMissingTables { + notImplemented := diffs.getNotImplementedTables(pgTables) + fixConstants(t, notImplemented) + fixVtable(t, notImplemented) + fixPgCatalogGo(notImplemented) + } } diff --git a/pkg/sql/virtual_schema.go b/pkg/sql/virtual_schema.go index e5c94a12ee8d..3d6f6cbc78c5 100644 --- a/pkg/sql/virtual_schema.go +++ b/pkg/sql/virtual_schema.go @@ -35,6 +35,7 @@ import ( "github.com/cockroachdb/cockroach/pkg/sql/sqlerrors" "github.com/cockroachdb/cockroach/pkg/util/errorutil/unimplemented" "github.com/cockroachdb/cockroach/pkg/util/hlc" + "github.com/cockroachdb/cockroach/pkg/util/log" "github.com/cockroachdb/errors" ) @@ -175,6 +176,7 @@ func (t virtualSchemaTable) initVirtualTableDesc( tree.PersistencePermanent, ) if err != nil { + log.Errorf(ctx, "initVirtualDesc problem: %v\n%s", err, t.schema) return mutDesc.TableDescriptor, err } for _, index := range mutDesc.PublicNonPrimaryIndexes() {