diff --git a/pkg/ccl/logictestccl/testdata/logic_test/procedure_plpgsql b/pkg/ccl/logictestccl/testdata/logic_test/procedure_plpgsql index e063c19cc4ff..030eeb22f32e 100644 --- a/pkg/ccl/logictestccl/testdata/logic_test/procedure_plpgsql +++ b/pkg/ccl/logictestccl/testdata/logic_test/procedure_plpgsql @@ -631,3 +631,68 @@ SELECT * FROM temp; statement ok DROP PROCEDURE p(INOUT int); + +subtest nested_call + +# It's ok for a SQL routine to call a PLpgSQL routine that calls an SP with OUT +# parameters. + +statement ok +CREATE PROCEDURE p_inner_o(OUT param INTEGER) AS $$ SELECT 1; $$ LANGUAGE SQL; +CREATE PROCEDURE p_inner_io(INOUT param INTEGER) AS $$ SELECT 1; $$ LANGUAGE SQL; + +skipif config local-mixed-23.2 +statement ok +CREATE PROCEDURE p_nested() AS $$ + DECLARE + a INT; + BEGIN + CALL p_inner_o(a); + RAISE NOTICE 'a: %', a; + CALL p_inner_io(a); + RAISE NOTICE 'a: %', a; + END +$$ LANGUAGE PLpgSQL; + +skipif config local-mixed-23.2 +statement ok +CREATE FUNCTION f() RETURNS VOID AS $$ CALL p_nested(); $$ LANGUAGE SQL; + +skipif config local-mixed-23.2 +query T noticetrace +SELECT f(); +---- +NOTICE: a: 1 +NOTICE: a: 1 + +skipif config local-mixed-23.2 +statement ok +DROP FUNCTION f; + +statement ok +DROP PROCEDURE IF EXISTS p; + +skipif config local-mixed-23.2 +statement ok +CREATE PROCEDURE p() AS $$ CALL p_nested(); $$ LANGUAGE SQL; + +skipif config local-mixed-23.2 +query T noticetrace +CALL p(); +---- +NOTICE: a: 1 +NOTICE: a: 1 + +skipif config local-mixed-23.2 +statement ok +DROP PROCEDURE p; + +skipif config local-mixed-23.2 +statement ok +DROP PROCEDURE p_nested; + +statement ok +DROP PROCEDURE p_inner_o; +DROP PROCEDURE p_inner_io; + +subtest end diff --git a/pkg/sql/opt/optbuilder/create_function.go b/pkg/sql/opt/optbuilder/create_function.go index 308fd9a19404..437f09edd54f 100644 --- a/pkg/sql/opt/optbuilder/create_function.go +++ b/pkg/sql/opt/optbuilder/create_function.go @@ -325,6 +325,11 @@ func (b *Builder) buildCreateFunction(cf *tree.CreateRoutine, inScope *scope) (o targetVolatility := tree.GetRoutineVolatility(cf.Options) fmtCtx := tree.NewFmtCtx(tree.FmtSerializable) + defer func(origValue bool) { + b.insideSQLRoutine = origValue + }(b.insideSQLRoutine) + b.insideSQLRoutine = language == tree.RoutineLangSQL + // Validate each statement and collect the dependencies. var stmtScope *scope switch language { @@ -334,31 +339,29 @@ func (b *Builder) buildCreateFunction(cf *tree.CreateRoutine, inScope *scope) (o if err != nil { panic(err) } - b.withinSQLRoutine(func() { - for i, stmt := range stmts { - // Add statement ast into CreateRoutine node for logging purpose, and set - // the annotations for this statement so names can be resolved. - cf.BodyStatements = append(cf.BodyStatements, stmt.AST) - ann := tree.MakeAnnotations(stmt.NumAnnotations) - cf.BodyAnnotations = append(cf.BodyAnnotations, &ann) - - // The defer logic will reset the annotations to the old value. - b.semaCtx.Annotations = ann - b.evalCtx.Annotations = &ann - - // We need to disable stable function folding because we want to catch the - // volatility of stable functions. If folded, we only get a scalar and - // lose the volatility. - b.factory.FoldingControl().TemporarilyDisallowStableFolds(func() { - stmtScope = b.buildStmtAtRootWithScope(stmts[i].AST, nil /* desiredTypes */, bodyScope) - }) - checkStmtVolatility(targetVolatility, stmtScope, stmt.AST) - - // Format the statements with qualified datasource names. - formatFuncBodyStmt(fmtCtx, stmt.AST, language, i > 0 /* newLine */) - afterBuildStmt() - } - }) + for i, stmt := range stmts { + // Add statement ast into CreateRoutine node for logging purpose, and set + // the annotations for this statement so names can be resolved. + cf.BodyStatements = append(cf.BodyStatements, stmt.AST) + ann := tree.MakeAnnotations(stmt.NumAnnotations) + cf.BodyAnnotations = append(cf.BodyAnnotations, &ann) + + // The defer logic will reset the annotations to the old value. + b.semaCtx.Annotations = ann + b.evalCtx.Annotations = &ann + + // We need to disable stable function folding because we want to catch the + // volatility of stable functions. If folded, we only get a scalar and + // lose the volatility. + b.factory.FoldingControl().TemporarilyDisallowStableFolds(func() { + stmtScope = b.buildStmtAtRootWithScope(stmts[i].AST, nil /* desiredTypes */, bodyScope) + }) + checkStmtVolatility(targetVolatility, stmtScope, stmt.AST) + + // Format the statements with qualified datasource names. + formatFuncBodyStmt(fmtCtx, stmt.AST, language, i > 0 /* newLine */) + afterBuildStmt() + } case tree.RoutineLangPLpgSQL: if cf.ReturnType != nil && cf.ReturnType.SetOf { panic(unimplemented.NewWithIssueDetail(105240, diff --git a/pkg/sql/opt/optbuilder/routine.go b/pkg/sql/opt/optbuilder/routine.go index 447a83256fbd..d2afc542f9a2 100644 --- a/pkg/sql/opt/optbuilder/routine.go +++ b/pkg/sql/opt/optbuilder/routine.go @@ -352,15 +352,17 @@ func (b *Builder) buildRoutine( // for the schema changer we only need depth 1. Also keep track of when // we have are executing inside a UDF, and whether the routine is used as a // data source (this could be nested, so we need to track the previous state). - defer func(trackSchemaDeps, insideUDF, insideDataSource bool) { + defer func(trackSchemaDeps, insideUDF, insideDataSource, insideSQLRoutine bool) { b.trackSchemaDeps = trackSchemaDeps b.insideUDF = insideUDF b.insideDataSource = insideDataSource - }(b.trackSchemaDeps, b.insideUDF, b.insideDataSource) + b.insideSQLRoutine = insideSQLRoutine + }(b.trackSchemaDeps, b.insideUDF, b.insideDataSource, b.insideSQLRoutine) oldInsideDataSource := b.insideDataSource b.insideDataSource = false b.trackSchemaDeps = false b.insideUDF = true + b.insideSQLRoutine = o.Language == tree.RoutineLangSQL isSetReturning := o.Class == tree.GeneratorClass // Build an expression for each statement in the function body. @@ -392,21 +394,19 @@ func (b *Builder) buildRoutine( body = make([]memo.RelExpr, len(stmts)) bodyProps = make([]*physical.Required, len(stmts)) - b.withinSQLRoutine(func() { - for i := range stmts { - stmtScope := b.buildStmtAtRootWithScope(stmts[i].AST, nil /* desiredTypes */, bodyScope) - expr, physProps := stmtScope.expr, stmtScope.makePhysicalProps() - - // The last statement produces the output of the UDF. - if i == len(stmts)-1 { - expr, physProps = b.finishBuildLastStmt( - stmtScope, bodyScope, inScope, isSetReturning, oldInsideDataSource, f, - ) - } - body[i] = expr - bodyProps[i] = physProps + for i := range stmts { + stmtScope := b.buildStmtAtRootWithScope(stmts[i].AST, nil /* desiredTypes */, bodyScope) + expr, physProps := stmtScope.expr, stmtScope.makePhysicalProps() + + // The last statement produces the output of the UDF. + if i == len(stmts)-1 { + expr, physProps = b.finishBuildLastStmt( + stmtScope, bodyScope, inScope, isSetReturning, oldInsideDataSource, f, + ) } - }) + body[i] = expr + bodyProps[i] = physProps + } if b.verboseTracing { bodyStmts = make([]string, len(stmts)) @@ -677,11 +677,3 @@ func (b *Builder) maybeAddRoutineAssignmentCasts( } return b.constructProject(expr, stmtScope.cols), stmtScope.makePhysicalProps() } - -func (b *Builder) withinSQLRoutine(fn func()) { - defer func(origValue bool) { - b.insideSQLRoutine = origValue - }(b.insideSQLRoutine) - b.insideSQLRoutine = true - fn() -}