Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

sql: use memo metadata to add routines and UDTs to statement bundles #132147

Merged
merged 3 commits into from
Oct 29, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
110 changes: 42 additions & 68 deletions pkg/sql/explain_bundle.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,17 +25,20 @@ import (
"github.com/cockroachdb/cockroach/pkg/sql/opt/exec/explain"
"github.com/cockroachdb/cockroach/pkg/sql/opt/memo"
"github.com/cockroachdb/cockroach/pkg/sql/sem/catconstants"
"github.com/cockroachdb/cockroach/pkg/sql/sem/catid"
"github.com/cockroachdb/cockroach/pkg/sql/sem/tree"
"github.com/cockroachdb/cockroach/pkg/sql/sessiondata"
"github.com/cockroachdb/cockroach/pkg/sql/stmtdiagnostics"
"github.com/cockroachdb/cockroach/pkg/util/buildutil"
"github.com/cockroachdb/cockroach/pkg/util/intsets"
"github.com/cockroachdb/cockroach/pkg/util/log"
"github.com/cockroachdb/cockroach/pkg/util/memzipper"
"github.com/cockroachdb/cockroach/pkg/util/pretty"
"github.com/cockroachdb/cockroach/pkg/util/tracing"
"github.com/cockroachdb/cockroach/pkg/util/tracing/tracingpb"
"github.com/cockroachdb/errors"
"github.com/cockroachdb/redact"
"github.com/lib/pq/oid"
)

const noPlan = "no plan"
Expand Down Expand Up @@ -600,12 +603,11 @@ func (b *stmtBundleBuilder) addEnv(ctx context.Context) {
// Note: we do not shortcut out of this function if there is no table/sequence/view to report:
// the bundle analysis tool require schema.sql to always be present, even if it's empty.

first := true
blankLine := func() {
if !first {
if buf.Len() > 0 {
// Don't add newlines to the beginning of the file.
buf.WriteByte('\n')
}
first = false
}
blankLine()
c.printCreateAllDatabases(&buf, dbNames)
Expand All @@ -618,35 +620,29 @@ func (b *stmtBundleBuilder) addEnv(ctx context.Context) {
b.printError(fmt.Sprintf("-- error getting schema for sequence %s: %v", sequences[i].FQString(), err), &buf)
}
}
// Get all user-defined types. If redaction is a
blankLine()
if err := c.PrintCreateEnum(&buf, b.flags.RedactValues); err != nil {
b.printError(fmt.Sprintf("-- error getting schema for enums: %v", err), &buf)
}
if mem.Metadata().HasUserDefinedFunctions() {
// Get all relevant user-defined functions.
// Get all relevant user-defined types.
for _, t := range mem.Metadata().AllUserDefinedTypes() {
blankLine()
err = c.PrintRelevantCreateRoutine(
&buf, strings.ToLower(b.stmt), b.flags.RedactValues, &b.errorStrings, false, /* procedure */
)
if err != nil {
b.printError(fmt.Sprintf("-- error getting schema for udfs: %v", err), &buf)
if err := c.PrintCreateUDT(&buf, t.Oid(), b.flags.RedactValues); err != nil {
b.printError(fmt.Sprintf("-- error getting schema for type %s: %v", t.SQLStringForError(), err), &buf)
}
}
if call, ok := mem.RootExpr().(*memo.CallExpr); ok {
// Currently, a stored procedure can only be called from a CALL statement,
// which can only be the root expression.
if proc, ok := call.Proc.(*memo.UDFCallExpr); ok {
if mem.Metadata().HasUserDefinedRoutines() {
// Get all relevant user-defined routines.
var ids intsets.Fast
isProcedure := make(map[oid.Oid]bool)
mem.Metadata().ForEachUserDefinedRoutine(func(ol *tree.Overload) {
ids.Add(int(ol.Oid))
isProcedure[ol.Oid] = ol.Type == tree.ProcedureRoutine
})
ids.ForEach(func(id int) {
blankLine()
err = c.PrintRelevantCreateRoutine(
&buf, strings.ToLower(proc.Def.Name), b.flags.RedactValues, &b.errorStrings, true, /* procedure */
)
routineOid := oid.Oid(id)
err = c.PrintCreateRoutine(&buf, routineOid, b.flags.RedactValues, isProcedure[routineOid])
if err != nil {
b.printError(fmt.Sprintf("-- error getting schema for procedure: %v", err), &buf)
b.printError(fmt.Sprintf("-- error getting schema for routine with ID %d: %v", id, err), &buf)
}
} else {
b.printError("-- unexpected input expression for CALL statement", &buf)
}
})
}
for i := range tables {
blankLine()
Expand Down Expand Up @@ -1014,13 +1010,13 @@ func (c *stmtEnvCollector) PrintCreateSequence(w io.Writer, tn *tree.TableName)
return nil
}

func (c *stmtEnvCollector) PrintCreateEnum(w io.Writer, redactValues bool) error {
qry := "SELECT create_statement FROM [SHOW CREATE ALL TYPES]"
func (c *stmtEnvCollector) PrintCreateUDT(w io.Writer, id oid.Oid, redactValues bool) error {
descID := catid.UserDefinedOIDToID(id)
query := fmt.Sprintf("SELECT create_statement FROM crdb_internal.create_type_statements WHERE descriptor_id = %d", descID)
if redactValues {
qry = "SELECT crdb_internal.redact(crdb_internal.redactable_sql_constants(create_statement)) FROM [SHOW CREATE ALL TYPES]"

query = fmt.Sprintf("SELECT crdb_internal.redact(crdb_internal.redactable_sql_constants(create_statement)) FROM (%s)", query)
}
createStatement, err := c.queryRows(qry)
createStatement, err := c.queryRows(query)
if err != nil {
return err
}
Expand All @@ -1030,50 +1026,28 @@ func (c *stmtEnvCollector) PrintCreateEnum(w io.Writer, redactValues bool) error
return nil
}

func (c *stmtEnvCollector) PrintRelevantCreateRoutine(
w io.Writer, stmt string, redactValues bool, errorStrings *[]string, procedure bool,
func (c *stmtEnvCollector) PrintCreateRoutine(
w io.Writer, id oid.Oid, redactValues bool, procedure bool,
) error {
// The select function_name returns a DOidWrapper,
// we need to cast it to string for queryRows function to process.
// TODO(#104976): consider getting the udf sql body statements from the memo metadata.
var routineTypeName, routineNameQuery string
var createRoutineQuery string
descID := catid.UserDefinedOIDToID(id)
queryTemplate := "SELECT create_statement FROM crdb_internal.create_%[1]s_statements WHERE %[1]s_id = %[2]d"
if procedure {
routineTypeName = "PROCEDURE"
routineNameQuery = "SELECT procedure_name::STRING as procedure_name_str FROM [SHOW PROCEDURES]"
createRoutineQuery = fmt.Sprintf(queryTemplate, "procedure", descID)
} else {
routineTypeName = "FUNCTION"
routineNameQuery = "SELECT function_name::STRING as function_name_str FROM [SHOW FUNCTIONS]"
createRoutineQuery = fmt.Sprintf(queryTemplate, "function", descID)
}
if redactValues {
createRoutineQuery = fmt.Sprintf(
"SELECT crdb_internal.redact(crdb_internal.redactable_sql_constants(create_statement)) FROM (%s)",
createRoutineQuery,
)
}
routineNames, err := c.queryRows(routineNameQuery)
createStatement, err := c.query(createRoutineQuery)
if err != nil {
return err
}
for _, name := range routineNames {
if strings.Contains(stmt, name) {
createRoutineQuery := fmt.Sprintf(
"SELECT create_statement FROM [ SHOW CREATE %s \"%s\" ]", routineTypeName, name,
)
if redactValues {
createRoutineQuery = fmt.Sprintf(
"SELECT crdb_internal.redact(crdb_internal.redactable_sql_constants(create_statement)) FROM [ SHOW CREATE %s \"%s\" ]",
routineTypeName, name,
)
}
createStatement, err := c.query(createRoutineQuery)
if err != nil {
var errString string
if procedure {
errString = fmt.Sprintf("-- error getting stored procedure %s: %s", name, err)
} else {
errString = fmt.Sprintf("-- error getting user defined function %s: %s", name, err)
}
fmt.Fprint(w, errString+"\n")
*errorStrings = append(*errorStrings, errString)
continue
}
fmt.Fprintf(w, "%s\n", createStatement)
}
}
fmt.Fprintf(w, "%s;\n", createStatement)
return nil
}

Expand Down
82 changes: 76 additions & 6 deletions pkg/sql/explain_bundle_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -367,13 +367,21 @@ CREATE TABLE users(id UUID DEFAULT gen_random_uuid() PRIMARY KEY, promo_id INT R
t.Run("types", func(t *testing.T) {
r.Exec(t, "CREATE TYPE test_type1 AS ENUM ('hello','world');")
r.Exec(t, "CREATE TYPE test_type2 AS ENUM ('goodbye','earth');")
rows := r.QueryStr(t, "EXPLAIN ANALYZE (DEBUG) SELECT 1;")
rows := r.QueryStr(t, "EXPLAIN ANALYZE (DEBUG) SELECT 'hello'::test_type1;")
checkBundle(
t, fmt.Sprint(rows), "test_type1", nil, false, /* expectErrors */
base, plans, "distsql.html vec.txt vec-v.txt",
)
checkBundle(
t, fmt.Sprint(rows), "test_type2", nil, false, /* expectErrors */
t, fmt.Sprint(rows), "test_type1", func(name, contents string) error {
if name == "schema.sql" {
reg := regexp.MustCompile("test_type1")
if reg.FindString(contents) == "" {
return errors.Errorf("could not find definition for 'test_type1' type in schema.sql")
}
reg = regexp.MustCompile("test_type2")
if reg.FindString(contents) != "" {
return errors.Errorf("Found irrelevant user defined type 'test_type2' in schema.sql")
}
}
return nil
}, false, /* expectErrors */
base, plans, "distsql.html vec.txt vec-v.txt",
)
})
Expand Down Expand Up @@ -420,6 +428,68 @@ CREATE TABLE users(id UUID DEFAULT gen_random_uuid() PRIMARY KEY, promo_id INT R
"distsql.html vec-v.txt vec.txt")
})

t.Run("different schema UDF", func(t *testing.T) {
r.Exec(t, "CREATE FUNCTION foo() RETURNS INT LANGUAGE SQL AS 'SELECT count(*) FROM abc, s.a';")
r.Exec(t, "CREATE FUNCTION s.foo() RETURNS INT LANGUAGE SQL AS 'SELECT count(*) FROM abc, s.a';")
rows := r.QueryStr(t, "EXPLAIN ANALYZE (DEBUG) SELECT s.foo();")
checkBundle(
t, fmt.Sprint(rows), "s.foo", func(name, contents string) error {
if name == "schema.sql" {
reg := regexp.MustCompile("s.foo")
if reg.FindString(contents) == "" {
return errors.Errorf("could not find definition for 's.foo' function in schema.sql")
}
reg = regexp.MustCompile("^foo")
if reg.FindString(contents) != "" {
return errors.Errorf("found irrelevant function 'foo' in schema.sql")
}
reg = regexp.MustCompile("s.a")
if reg.FindString(contents) == "" {
return errors.Errorf("could not find definition for relation 's.a' in schema.sql")
}
reg = regexp.MustCompile("abc")
if reg.FindString(contents) == "" {
return errors.Errorf("could not find definition for relation 'abc' in schema.sql")
}
}
return nil
},
false /* expectErrors */, base, plans,
"stats-defaultdb.public.abc.sql stats-defaultdb.s.a.sql distsql.html vec-v.txt vec.txt",
)
})

t.Run("different schema procedure", func(t *testing.T) {
r.Exec(t, "CREATE PROCEDURE bar() LANGUAGE SQL AS 'SELECT count(*) FROM abc, s.a';")
r.Exec(t, "CREATE PROCEDURE s.bar() LANGUAGE SQL AS 'SELECT count(*) FROM abc, s.a';")
rows := r.QueryStr(t, "EXPLAIN ANALYZE (DEBUG) CALL s.bar();")
checkBundle(
t, fmt.Sprint(rows), "s.bar", func(name, contents string) error {
if name == "schema.sql" {
reg := regexp.MustCompile("s.bar")
if reg.FindString(contents) == "" {
return errors.Errorf("could not find definition for 's.bar' procedure in schema.sql")
}
reg = regexp.MustCompile("^bar")
if reg.FindString(contents) != "" {
return errors.Errorf("Found irrelevant procedure 'bar' in schema.sql")
}
reg = regexp.MustCompile("s.a")
if reg.FindString(contents) == "" {
return errors.Errorf("could not find definition for relation 's.a' in schema.sql")
}
reg = regexp.MustCompile("abc")
if reg.FindString(contents) == "" {
return errors.Errorf("could not find definition for relation 'abc' in schema.sql")
}
}
return nil
},
false /* expectErrors */, base, plans,
"stats-defaultdb.public.abc.sql stats-defaultdb.s.a.sql distsql.html vec-v.txt vec.txt",
)
})

t.Run("permission error", func(t *testing.T) {
r.Exec(t, "CREATE USER test")
r.Exec(t, "SET ROLE test")
Expand Down
Loading
Loading