From 1cc991014c9893d19e269d1374f6ea8282071ebd Mon Sep 17 00:00:00 2001 From: Yahor Yuzefovich Date: Tue, 23 Apr 2024 14:40:52 -0700 Subject: [PATCH] optbuilder: relax a check for SQL routines calling SPs with OUT params This commit relaxes the check that we have for prohibiting SQL routines to call SPs with OUT parameters. In particular, it's ok for a SQL routine to call a PLpgSQL routine that calls an SP with OUT parameters. Release note: None --- .../testdata/logic_test/procedure_plpgsql | 65 +++++++++++++++++++ pkg/sql/opt/optbuilder/create_function.go | 53 ++++++++------- pkg/sql/opt/optbuilder/routine.go | 40 +++++------- 3 files changed, 109 insertions(+), 49 deletions(-) diff --git a/pkg/ccl/logictestccl/testdata/logic_test/procedure_plpgsql b/pkg/ccl/logictestccl/testdata/logic_test/procedure_plpgsql index e063c19cc4ff..030eeb22f32e 100644 --- a/pkg/ccl/logictestccl/testdata/logic_test/procedure_plpgsql +++ b/pkg/ccl/logictestccl/testdata/logic_test/procedure_plpgsql @@ -631,3 +631,68 @@ SELECT * FROM temp; statement ok DROP PROCEDURE p(INOUT int); + +subtest nested_call + +# It's ok for a SQL routine to call a PLpgSQL routine that calls an SP with OUT +# parameters. + +statement ok +CREATE PROCEDURE p_inner_o(OUT param INTEGER) AS $$ SELECT 1; $$ LANGUAGE SQL; +CREATE PROCEDURE p_inner_io(INOUT param INTEGER) AS $$ SELECT 1; $$ LANGUAGE SQL; + +skipif config local-mixed-23.2 +statement ok +CREATE PROCEDURE p_nested() AS $$ + DECLARE + a INT; + BEGIN + CALL p_inner_o(a); + RAISE NOTICE 'a: %', a; + CALL p_inner_io(a); + RAISE NOTICE 'a: %', a; + END +$$ LANGUAGE PLpgSQL; + +skipif config local-mixed-23.2 +statement ok +CREATE FUNCTION f() RETURNS VOID AS $$ CALL p_nested(); $$ LANGUAGE SQL; + +skipif config local-mixed-23.2 +query T noticetrace +SELECT f(); +---- +NOTICE: a: 1 +NOTICE: a: 1 + +skipif config local-mixed-23.2 +statement ok +DROP FUNCTION f; + +statement ok +DROP PROCEDURE IF EXISTS p; + +skipif config local-mixed-23.2 +statement ok +CREATE PROCEDURE p() AS $$ CALL p_nested(); $$ LANGUAGE SQL; + +skipif config local-mixed-23.2 +query T noticetrace +CALL p(); +---- +NOTICE: a: 1 +NOTICE: a: 1 + +skipif config local-mixed-23.2 +statement ok +DROP PROCEDURE p; + +skipif config local-mixed-23.2 +statement ok +DROP PROCEDURE p_nested; + +statement ok +DROP PROCEDURE p_inner_o; +DROP PROCEDURE p_inner_io; + +subtest end diff --git a/pkg/sql/opt/optbuilder/create_function.go b/pkg/sql/opt/optbuilder/create_function.go index 308fd9a19404..437f09edd54f 100644 --- a/pkg/sql/opt/optbuilder/create_function.go +++ b/pkg/sql/opt/optbuilder/create_function.go @@ -325,6 +325,11 @@ func (b *Builder) buildCreateFunction(cf *tree.CreateRoutine, inScope *scope) (o targetVolatility := tree.GetRoutineVolatility(cf.Options) fmtCtx := tree.NewFmtCtx(tree.FmtSerializable) + defer func(origValue bool) { + b.insideSQLRoutine = origValue + }(b.insideSQLRoutine) + b.insideSQLRoutine = language == tree.RoutineLangSQL + // Validate each statement and collect the dependencies. var stmtScope *scope switch language { @@ -334,31 +339,29 @@ func (b *Builder) buildCreateFunction(cf *tree.CreateRoutine, inScope *scope) (o if err != nil { panic(err) } - b.withinSQLRoutine(func() { - for i, stmt := range stmts { - // Add statement ast into CreateRoutine node for logging purpose, and set - // the annotations for this statement so names can be resolved. - cf.BodyStatements = append(cf.BodyStatements, stmt.AST) - ann := tree.MakeAnnotations(stmt.NumAnnotations) - cf.BodyAnnotations = append(cf.BodyAnnotations, &ann) - - // The defer logic will reset the annotations to the old value. - b.semaCtx.Annotations = ann - b.evalCtx.Annotations = &ann - - // We need to disable stable function folding because we want to catch the - // volatility of stable functions. If folded, we only get a scalar and - // lose the volatility. - b.factory.FoldingControl().TemporarilyDisallowStableFolds(func() { - stmtScope = b.buildStmtAtRootWithScope(stmts[i].AST, nil /* desiredTypes */, bodyScope) - }) - checkStmtVolatility(targetVolatility, stmtScope, stmt.AST) - - // Format the statements with qualified datasource names. - formatFuncBodyStmt(fmtCtx, stmt.AST, language, i > 0 /* newLine */) - afterBuildStmt() - } - }) + for i, stmt := range stmts { + // Add statement ast into CreateRoutine node for logging purpose, and set + // the annotations for this statement so names can be resolved. + cf.BodyStatements = append(cf.BodyStatements, stmt.AST) + ann := tree.MakeAnnotations(stmt.NumAnnotations) + cf.BodyAnnotations = append(cf.BodyAnnotations, &ann) + + // The defer logic will reset the annotations to the old value. + b.semaCtx.Annotations = ann + b.evalCtx.Annotations = &ann + + // We need to disable stable function folding because we want to catch the + // volatility of stable functions. If folded, we only get a scalar and + // lose the volatility. + b.factory.FoldingControl().TemporarilyDisallowStableFolds(func() { + stmtScope = b.buildStmtAtRootWithScope(stmts[i].AST, nil /* desiredTypes */, bodyScope) + }) + checkStmtVolatility(targetVolatility, stmtScope, stmt.AST) + + // Format the statements with qualified datasource names. + formatFuncBodyStmt(fmtCtx, stmt.AST, language, i > 0 /* newLine */) + afterBuildStmt() + } case tree.RoutineLangPLpgSQL: if cf.ReturnType != nil && cf.ReturnType.SetOf { panic(unimplemented.NewWithIssueDetail(105240, diff --git a/pkg/sql/opt/optbuilder/routine.go b/pkg/sql/opt/optbuilder/routine.go index 447a83256fbd..d2afc542f9a2 100644 --- a/pkg/sql/opt/optbuilder/routine.go +++ b/pkg/sql/opt/optbuilder/routine.go @@ -352,15 +352,17 @@ func (b *Builder) buildRoutine( // for the schema changer we only need depth 1. Also keep track of when // we have are executing inside a UDF, and whether the routine is used as a // data source (this could be nested, so we need to track the previous state). - defer func(trackSchemaDeps, insideUDF, insideDataSource bool) { + defer func(trackSchemaDeps, insideUDF, insideDataSource, insideSQLRoutine bool) { b.trackSchemaDeps = trackSchemaDeps b.insideUDF = insideUDF b.insideDataSource = insideDataSource - }(b.trackSchemaDeps, b.insideUDF, b.insideDataSource) + b.insideSQLRoutine = insideSQLRoutine + }(b.trackSchemaDeps, b.insideUDF, b.insideDataSource, b.insideSQLRoutine) oldInsideDataSource := b.insideDataSource b.insideDataSource = false b.trackSchemaDeps = false b.insideUDF = true + b.insideSQLRoutine = o.Language == tree.RoutineLangSQL isSetReturning := o.Class == tree.GeneratorClass // Build an expression for each statement in the function body. @@ -392,21 +394,19 @@ func (b *Builder) buildRoutine( body = make([]memo.RelExpr, len(stmts)) bodyProps = make([]*physical.Required, len(stmts)) - b.withinSQLRoutine(func() { - for i := range stmts { - stmtScope := b.buildStmtAtRootWithScope(stmts[i].AST, nil /* desiredTypes */, bodyScope) - expr, physProps := stmtScope.expr, stmtScope.makePhysicalProps() - - // The last statement produces the output of the UDF. - if i == len(stmts)-1 { - expr, physProps = b.finishBuildLastStmt( - stmtScope, bodyScope, inScope, isSetReturning, oldInsideDataSource, f, - ) - } - body[i] = expr - bodyProps[i] = physProps + for i := range stmts { + stmtScope := b.buildStmtAtRootWithScope(stmts[i].AST, nil /* desiredTypes */, bodyScope) + expr, physProps := stmtScope.expr, stmtScope.makePhysicalProps() + + // The last statement produces the output of the UDF. + if i == len(stmts)-1 { + expr, physProps = b.finishBuildLastStmt( + stmtScope, bodyScope, inScope, isSetReturning, oldInsideDataSource, f, + ) } - }) + body[i] = expr + bodyProps[i] = physProps + } if b.verboseTracing { bodyStmts = make([]string, len(stmts)) @@ -677,11 +677,3 @@ func (b *Builder) maybeAddRoutineAssignmentCasts( } return b.constructProject(expr, stmtScope.cols), stmtScope.makePhysicalProps() } - -func (b *Builder) withinSQLRoutine(fn func()) { - defer func(origValue bool) { - b.insideSQLRoutine = origValue - }(b.insideSQLRoutine) - b.insideSQLRoutine = true - fn() -}