Skip to content

Commit

Permalink
Browse files Browse the repository at this point in the history
122936: optbuilder: relax a check for SQL routines calling SPs with OUT params r=yuzefovich a=yuzefovich

This commit relaxes the check that we have for prohibiting SQL routines to call SPs with OUT parameters. In particular, it's ok for a SQL routine to call a PLpgSQL routine that calls an SP with OUT parameters.

Fixes: cockroachdb#122268.

Release note: None

Co-authored-by: Yahor Yuzefovich <[email protected]>
  • Loading branch information
craig[bot] and yuzefovich committed Apr 24, 2024
2 parents 18b3ab4 + 1cc9910 commit 0fbef53
Show file tree
Hide file tree
Showing 3 changed files with 109 additions and 49 deletions.
65 changes: 65 additions & 0 deletions pkg/ccl/logictestccl/testdata/logic_test/procedure_plpgsql
Original file line number Diff line number Diff line change
Expand Up @@ -631,3 +631,68 @@ SELECT * FROM temp;

statement ok
DROP PROCEDURE p(INOUT int);

subtest nested_call

# It's ok for a SQL routine to call a PLpgSQL routine that calls an SP with OUT
# parameters.

statement ok
CREATE PROCEDURE p_inner_o(OUT param INTEGER) AS $$ SELECT 1; $$ LANGUAGE SQL;
CREATE PROCEDURE p_inner_io(INOUT param INTEGER) AS $$ SELECT 1; $$ LANGUAGE SQL;

skipif config local-mixed-23.2
statement ok
CREATE PROCEDURE p_nested() AS $$
DECLARE
a INT;
BEGIN
CALL p_inner_o(a);
RAISE NOTICE 'a: %', a;
CALL p_inner_io(a);
RAISE NOTICE 'a: %', a;
END
$$ LANGUAGE PLpgSQL;

skipif config local-mixed-23.2
statement ok
CREATE FUNCTION f() RETURNS VOID AS $$ CALL p_nested(); $$ LANGUAGE SQL;

skipif config local-mixed-23.2
query T noticetrace
SELECT f();
----
NOTICE: a: 1
NOTICE: a: 1

skipif config local-mixed-23.2
statement ok
DROP FUNCTION f;

statement ok
DROP PROCEDURE IF EXISTS p;

skipif config local-mixed-23.2
statement ok
CREATE PROCEDURE p() AS $$ CALL p_nested(); $$ LANGUAGE SQL;

skipif config local-mixed-23.2
query T noticetrace
CALL p();
----
NOTICE: a: 1
NOTICE: a: 1

skipif config local-mixed-23.2
statement ok
DROP PROCEDURE p;

skipif config local-mixed-23.2
statement ok
DROP PROCEDURE p_nested;

statement ok
DROP PROCEDURE p_inner_o;
DROP PROCEDURE p_inner_io;

subtest end
53 changes: 28 additions & 25 deletions pkg/sql/opt/optbuilder/create_function.go
Original file line number Diff line number Diff line change
Expand Up @@ -325,6 +325,11 @@ func (b *Builder) buildCreateFunction(cf *tree.CreateRoutine, inScope *scope) (o
targetVolatility := tree.GetRoutineVolatility(cf.Options)
fmtCtx := tree.NewFmtCtx(tree.FmtSerializable)

defer func(origValue bool) {
b.insideSQLRoutine = origValue
}(b.insideSQLRoutine)
b.insideSQLRoutine = language == tree.RoutineLangSQL

// Validate each statement and collect the dependencies.
var stmtScope *scope
switch language {
Expand All @@ -334,31 +339,29 @@ func (b *Builder) buildCreateFunction(cf *tree.CreateRoutine, inScope *scope) (o
if err != nil {
panic(err)
}
b.withinSQLRoutine(func() {
for i, stmt := range stmts {
// Add statement ast into CreateRoutine node for logging purpose, and set
// the annotations for this statement so names can be resolved.
cf.BodyStatements = append(cf.BodyStatements, stmt.AST)
ann := tree.MakeAnnotations(stmt.NumAnnotations)
cf.BodyAnnotations = append(cf.BodyAnnotations, &ann)

// The defer logic will reset the annotations to the old value.
b.semaCtx.Annotations = ann
b.evalCtx.Annotations = &ann

// We need to disable stable function folding because we want to catch the
// volatility of stable functions. If folded, we only get a scalar and
// lose the volatility.
b.factory.FoldingControl().TemporarilyDisallowStableFolds(func() {
stmtScope = b.buildStmtAtRootWithScope(stmts[i].AST, nil /* desiredTypes */, bodyScope)
})
checkStmtVolatility(targetVolatility, stmtScope, stmt.AST)

// Format the statements with qualified datasource names.
formatFuncBodyStmt(fmtCtx, stmt.AST, language, i > 0 /* newLine */)
afterBuildStmt()
}
})
for i, stmt := range stmts {
// Add statement ast into CreateRoutine node for logging purpose, and set
// the annotations for this statement so names can be resolved.
cf.BodyStatements = append(cf.BodyStatements, stmt.AST)
ann := tree.MakeAnnotations(stmt.NumAnnotations)
cf.BodyAnnotations = append(cf.BodyAnnotations, &ann)

// The defer logic will reset the annotations to the old value.
b.semaCtx.Annotations = ann
b.evalCtx.Annotations = &ann

// We need to disable stable function folding because we want to catch the
// volatility of stable functions. If folded, we only get a scalar and
// lose the volatility.
b.factory.FoldingControl().TemporarilyDisallowStableFolds(func() {
stmtScope = b.buildStmtAtRootWithScope(stmts[i].AST, nil /* desiredTypes */, bodyScope)
})
checkStmtVolatility(targetVolatility, stmtScope, stmt.AST)

// Format the statements with qualified datasource names.
formatFuncBodyStmt(fmtCtx, stmt.AST, language, i > 0 /* newLine */)
afterBuildStmt()
}
case tree.RoutineLangPLpgSQL:
if cf.ReturnType != nil && cf.ReturnType.SetOf {
panic(unimplemented.NewWithIssueDetail(105240,
Expand Down
40 changes: 16 additions & 24 deletions pkg/sql/opt/optbuilder/routine.go
Original file line number Diff line number Diff line change
Expand Up @@ -352,15 +352,17 @@ func (b *Builder) buildRoutine(
// for the schema changer we only need depth 1. Also keep track of when
// we have are executing inside a UDF, and whether the routine is used as a
// data source (this could be nested, so we need to track the previous state).
defer func(trackSchemaDeps, insideUDF, insideDataSource bool) {
defer func(trackSchemaDeps, insideUDF, insideDataSource, insideSQLRoutine bool) {
b.trackSchemaDeps = trackSchemaDeps
b.insideUDF = insideUDF
b.insideDataSource = insideDataSource
}(b.trackSchemaDeps, b.insideUDF, b.insideDataSource)
b.insideSQLRoutine = insideSQLRoutine
}(b.trackSchemaDeps, b.insideUDF, b.insideDataSource, b.insideSQLRoutine)
oldInsideDataSource := b.insideDataSource
b.insideDataSource = false
b.trackSchemaDeps = false
b.insideUDF = true
b.insideSQLRoutine = o.Language == tree.RoutineLangSQL
isSetReturning := o.Class == tree.GeneratorClass

// Build an expression for each statement in the function body.
Expand Down Expand Up @@ -392,21 +394,19 @@ func (b *Builder) buildRoutine(
body = make([]memo.RelExpr, len(stmts))
bodyProps = make([]*physical.Required, len(stmts))

b.withinSQLRoutine(func() {
for i := range stmts {
stmtScope := b.buildStmtAtRootWithScope(stmts[i].AST, nil /* desiredTypes */, bodyScope)
expr, physProps := stmtScope.expr, stmtScope.makePhysicalProps()

// The last statement produces the output of the UDF.
if i == len(stmts)-1 {
expr, physProps = b.finishBuildLastStmt(
stmtScope, bodyScope, inScope, isSetReturning, oldInsideDataSource, f,
)
}
body[i] = expr
bodyProps[i] = physProps
for i := range stmts {
stmtScope := b.buildStmtAtRootWithScope(stmts[i].AST, nil /* desiredTypes */, bodyScope)
expr, physProps := stmtScope.expr, stmtScope.makePhysicalProps()

// The last statement produces the output of the UDF.
if i == len(stmts)-1 {
expr, physProps = b.finishBuildLastStmt(
stmtScope, bodyScope, inScope, isSetReturning, oldInsideDataSource, f,
)
}
})
body[i] = expr
bodyProps[i] = physProps
}

if b.verboseTracing {
bodyStmts = make([]string, len(stmts))
Expand Down Expand Up @@ -677,11 +677,3 @@ func (b *Builder) maybeAddRoutineAssignmentCasts(
}
return b.constructProject(expr, stmtScope.cols), stmtScope.makePhysicalProps()
}

func (b *Builder) withinSQLRoutine(fn func()) {
defer func(origValue bool) {
b.insideSQLRoutine = origValue
}(b.insideSQLRoutine)
b.insideSQLRoutine = true
fn()
}

0 comments on commit 0fbef53

Please sign in to comment.