From b6d81dbbdddef00548ede451e75319ec4af3aab4 Mon Sep 17 00:00:00 2001 From: "Eric.Yang" Date: Sun, 7 May 2023 15:20:49 -0700 Subject: [PATCH] Statement bundles now include the CREATE statements for related udfs in schema.sql. Epic: None Fixes: #102044 Release note: None --- pkg/sql/explain_bundle.go | 39 ++++++++++++++++++++++++++++++++++ pkg/sql/explain_bundle_test.go | 24 ++++++++++++++++++++- pkg/sql/opt/metadata.go | 5 +++++ 3 files changed, 67 insertions(+), 1 deletion(-) diff --git a/pkg/sql/explain_bundle.go b/pkg/sql/explain_bundle.go index ad7bfed67b6d..b4bd979abb9f 100644 --- a/pkg/sql/explain_bundle.go +++ b/pkg/sql/explain_bundle.go @@ -459,6 +459,13 @@ func (b *stmtBundleBuilder) addEnv(ctx context.Context) { fmt.Fprintf(&buf, "-- error getting schema for sequence %s: %v\n", sequences[i].String(), err) } } + if len(mem.Metadata().AllUserDefinedFunctions()) != 0 { + // Get all relevant user-defined functions. + blankLine() + if err := c.PrintRelevantCreateUdf(&buf, strings.ToLower(b.stmt), b.flags.RedactValues); err != nil { + fmt.Fprintf(&buf, "-- error getting schema for udfs: %v\n", err) + } + } for i := range tables { blankLine() if err := c.PrintCreateTable(&buf, &tables[i], b.flags.RedactValues); err != nil { @@ -796,6 +803,38 @@ func (c *stmtEnvCollector) PrintCreateSequence(w io.Writer, tn *tree.TableName) return nil } +func (c *stmtEnvCollector) PrintRelevantCreateUdf( + w io.Writer, stmt string, redactValues bool, +) error { + // The select function_name returns a DOidWrapper, + // we need to cast it to string for queryRows function to process. + // TODO: consider getting the udf sql body statements from the memo metadata. + functionNameQuery := "SELECT function_name::STRING as function_name_str FROM [SHOW FUNCTIONS]" + udfNames, err := c.queryRows(functionNameQuery) + if err != nil { + return err + } + for _, name := range udfNames { + if strings.Contains(stmt, name) { + createFunctionQuery := fmt.Sprintf( + "SELECT create_statement FROM [ SHOW CREATE FUNCTION \"%s\" ]", name, + ) + if redactValues { + createFunctionQuery = fmt.Sprintf( + "SELECT crdb_internal.redact(crdb_internal.redactable_sql_constants(create_statement)) FROM [ SHOW CREATE FUNCTION \"%s\" ]", name, + ) + } + createStatement, err := c.query(createFunctionQuery) + if err != nil { + fmt.Fprintf(w, "-- error getting user defined function %s: %s\n", name, err) + continue + } + fmt.Fprintf(w, "%s\n", createStatement) + } + } + return nil +} + func (c *stmtEnvCollector) PrintCreateView( w io.Writer, tn *tree.TableName, redactValues bool, ) error { diff --git a/pkg/sql/explain_bundle_test.go b/pkg/sql/explain_bundle_test.go index 41b1cf76d107..7932e0cee83d 100644 --- a/pkg/sql/explain_bundle_test.go +++ b/pkg/sql/explain_bundle_test.go @@ -287,8 +287,9 @@ CREATE TABLE users(id UUID DEFAULT gen_random_uuid() PRIMARY KEY, promo_id INT R r.Exec(t, "CREATE TABLE pterosaur (cardholder STRING PRIMARY KEY, cardno INT, INDEX (cardno));") r.Exec(t, "INSERT INTO pterosaur VALUES ('pterodactyl', 5555555555554444);") r.Exec(t, "CREATE STATISTICS jurassic FROM pterosaur;") + r.Exec(t, "CREATE FUNCTION test_redact() RETURNS STRING AS $body$ SELECT 'pterodactyl' $body$ LANGUAGE sql;") rows := r.QueryStr(t, - "EXPLAIN ANALYZE (DEBUG, REDACT) SELECT max(cardno) FROM pterosaur WHERE cardholder = 'pterodactyl'", + "EXPLAIN ANALYZE (DEBUG, REDACT) SELECT max(cardno), test_redact() FROM pterosaur WHERE cardholder = 'pterodactyl'", ) verboten := []string{"pterodactyl", "5555555555554444", fmt.Sprintf("%x", 5555555555554444)} checkBundle( @@ -304,6 +305,27 @@ CREATE TABLE users(id UUID DEFAULT gen_random_uuid() PRIMARY KEY, promo_id INT R plans, "statement.sql stats-defaultdb.public.pterosaur.sql env.sql vec.txt vec-v.txt", ) }) + + t.Run("udfs", func(t *testing.T) { + r.Exec(t, "CREATE FUNCTION add(a INT, b INT) RETURNS INT IMMUTABLE LEAKPROOF LANGUAGE SQL AS 'SELECT a + b';") + r.Exec(t, "CREATE FUNCTION subtract(a INT, b INT) RETURNS INT IMMUTABLE LEAKPROOF LANGUAGE SQL AS 'SELECT a - b';") + rows := r.QueryStr(t, "EXPLAIN ANALYZE (DEBUG) SELECT add(3, 4);") + checkBundle( + t, fmt.Sprint(rows), "add", func(name, contents string) error { + if name == "schema.sql" { + reg := regexp.MustCompile("add") + if reg.FindString(contents) == "" { + return errors.Errorf("could not find definition for 'add' function in schema.sql") + } + reg = regexp.MustCompile("subtract") + if reg.FindString(contents) != "" { + return errors.Errorf("Found irrelevant user defined function 'substract' in schema.sql") + } + } + return nil + }, base, plans, + "distsql-1-subquery.html distsql-2-main-query.html vec-1-subquery-v.txt vec-1-subquery.txt vec-2-main-query-v.txt vec-2-main-query.txt") + }) } // checkBundle searches text strings for a bundle URL and then verifies that the diff --git a/pkg/sql/opt/metadata.go b/pkg/sql/opt/metadata.go index 15ee20cf01dc..40974ed7cdfa 100644 --- a/pkg/sql/opt/metadata.go +++ b/pkg/sql/opt/metadata.go @@ -540,6 +540,11 @@ func (md *Metadata) AllUserDefinedTypes() []*types.T { return md.userDefinedTypesSlice } +// AllUserDefinedFunctions returns all user defined functions used in this query. +func (md *Metadata) AllUserDefinedFunctions() map[cat.StableID]*tree.Overload { + return md.udfDeps +} + // AddUserDefinedFunction adds a user-defined function to the metadata for this // query. If the function was resolved by name, the name will also be tracked. func (md *Metadata) AddUserDefinedFunction(