diff --git a/pkg/sql/BUILD.bazel b/pkg/sql/BUILD.bazel index 58fd6b9d9110..a3e9e6d6670d 100644 --- a/pkg/sql/BUILD.bazel +++ b/pkg/sql/BUILD.bazel @@ -459,6 +459,7 @@ go_library( "//pkg/sql/pgwire/pgwirecancel", "//pkg/sql/physicalplan", "//pkg/sql/physicalplan/replicaoracle", + "//pkg/sql/plpgsql/parser:plpgparser", "//pkg/sql/privilege", "//pkg/sql/protoreflect", "//pkg/sql/querycache", diff --git a/pkg/sql/crdb_internal.go b/pkg/sql/crdb_internal.go index bcaa71ad9783..60362cb869f9 100644 --- a/pkg/sql/crdb_internal.go +++ b/pkg/sql/crdb_internal.go @@ -70,6 +70,7 @@ import ( "github.com/cockroachdb/cockroach/pkg/sql/pgwire/pgcode" "github.com/cockroachdb/cockroach/pkg/sql/pgwire/pgerror" "github.com/cockroachdb/cockroach/pkg/sql/pgwire/pgnotice" + plpgsql "github.com/cockroachdb/cockroach/pkg/sql/plpgsql/parser" "github.com/cockroachdb/cockroach/pkg/sql/privilege" "github.com/cockroachdb/cockroach/pkg/sql/protoreflect" "github.com/cockroachdb/cockroach/pkg/sql/roleoption" @@ -3583,21 +3584,39 @@ func createRoutinePopulate( } for i := range treeNode.Options { if body, ok := treeNode.Options[i].(tree.RoutineBodyStr); ok { - typeReplacedBody, err := formatFunctionQueryTypesForDisplay(ctx, &p.semaCtx, p.SessionData(), string(body)) - if err != nil { - return err - } - seqReplacedBody, err := formatQuerySequencesForDisplay(ctx, &p.semaCtx, typeReplacedBody, true /* multiStmt */) - if err != nil { - return err + bodyStr := string(body) + switch fnDesc.GetLanguage() { + case catpb.Function_SQL: + bodyStr, err = formatFunctionQueryTypesForDisplay(ctx, &p.semaCtx, p.SessionData(), bodyStr) + if err != nil { + return err + } + bodyStr, err = formatQuerySequencesForDisplay(ctx, &p.semaCtx, bodyStr, true /* multiStmt */) + if err != nil { + return err + } + bodyStr = "\n" + bodyStr + "\n" + case catpb.Function_PLPGSQL: + // TODO(drewk): integrate this with the SQL case above. + plpgsqlStmt, err := plpgsql.Parse(bodyStr) + if err != nil { + return err + } + fmtCtx := tree.NewFmtCtx(tree.FmtParsable) + fmtCtx.FormatNode(plpgsqlStmt.AST) + bodyStr = "\n" + fmtCtx.CloseAndGetString() + default: + return errors.AssertionFailedf("unexpected function language: %s", fnDesc.GetLanguage()) } - stmtStrs := strings.Split(seqReplacedBody, "\n") + stmtStrs := strings.Split(bodyStr, "\n") for i := range stmtStrs { - stmtStrs[i] = "\t" + stmtStrs[i] + if stmtStrs[i] != "" { + stmtStrs[i] = "\t" + stmtStrs[i] + } } p := &treeNode.Options[i] // Add two new lines just for better formatting. - *p = "\n" + tree.RoutineBodyStr(strings.Join(stmtStrs, "\n")) + "\n" + *p = tree.RoutineBodyStr(strings.Join(stmtStrs, "\n")) } } diff --git a/pkg/sql/logictest/testdata/logic_test/show_create b/pkg/sql/logictest/testdata/logic_test/show_create index 4526471c12e2..1ade695574e3 100644 --- a/pkg/sql/logictest/testdata/logic_test/show_create +++ b/pkg/sql/logictest/testdata/logic_test/show_create @@ -292,3 +292,43 @@ r2 CREATE PROCEDURE sc.r2(IN s STRING) AS $$ SELECT 1; $$ + +# Regression test for #112134 - correctly parse and display PLpgSQL. +skipif config local-mixed-23.1 +statement ok +CREATE FUNCTION f112134() RETURNS INT AS $$ + DECLARE + x INT := 0; + i INT := 0; + BEGIN + WHILE i < 3 LOOP + x := x + i; + i := i + 1; + END LOOP; + RETURN x; + END +$$ LANGUAGE PLpgSQL; + +# TODO(112136): Fix the formatting. +skipif config local-mixed-23.1 +query TT +SHOW CREATE FUNCTION f112134; +---- +f112134 CREATE FUNCTION sc.f112134() + RETURNS INT8 + VOLATILE + NOT LEAKPROOF + CALLED ON NULL INPUT + LANGUAGE plpgsql + AS $$ + DECLARE + x INT8 := 0; + i INT8 := 0; + BEGIN + WHILE i < 3 LOOP + x := x + i; + i := i + 1; + END LOOP; + RETURN x; + END + $$