Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

optbuilder: relax a check for SQL routines calling SPs with OUT params #122936

Merged
merged 1 commit into from
Apr 24, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
65 changes: 65 additions & 0 deletions pkg/ccl/logictestccl/testdata/logic_test/procedure_plpgsql
Original file line number Diff line number Diff line change
Expand Up @@ -631,3 +631,68 @@ SELECT * FROM temp;

statement ok
DROP PROCEDURE p(INOUT int);

subtest nested_call

# It's ok for a SQL routine to call a PLpgSQL routine that calls an SP with OUT
# parameters.

statement ok
CREATE PROCEDURE p_inner_o(OUT param INTEGER) AS $$ SELECT 1; $$ LANGUAGE SQL;
CREATE PROCEDURE p_inner_io(INOUT param INTEGER) AS $$ SELECT 1; $$ LANGUAGE SQL;

skipif config local-mixed-23.2
statement ok
CREATE PROCEDURE p_nested() AS $$
DECLARE
a INT;
BEGIN
CALL p_inner_o(a);
RAISE NOTICE 'a: %', a;
CALL p_inner_io(a);
RAISE NOTICE 'a: %', a;
END
$$ LANGUAGE PLpgSQL;

skipif config local-mixed-23.2
statement ok
CREATE FUNCTION f() RETURNS VOID AS $$ CALL p_nested(); $$ LANGUAGE SQL;

skipif config local-mixed-23.2
query T noticetrace
SELECT f();
----
NOTICE: a: 1
NOTICE: a: 1

skipif config local-mixed-23.2
statement ok
DROP FUNCTION f;

statement ok
DROP PROCEDURE IF EXISTS p;

skipif config local-mixed-23.2
statement ok
CREATE PROCEDURE p() AS $$ CALL p_nested(); $$ LANGUAGE SQL;

skipif config local-mixed-23.2
query T noticetrace
CALL p();
----
NOTICE: a: 1
NOTICE: a: 1

skipif config local-mixed-23.2
statement ok
DROP PROCEDURE p;

skipif config local-mixed-23.2
statement ok
DROP PROCEDURE p_nested;

statement ok
DROP PROCEDURE p_inner_o;
DROP PROCEDURE p_inner_io;

subtest end
53 changes: 28 additions & 25 deletions pkg/sql/opt/optbuilder/create_function.go
Original file line number Diff line number Diff line change
Expand Up @@ -325,6 +325,11 @@ func (b *Builder) buildCreateFunction(cf *tree.CreateRoutine, inScope *scope) (o
targetVolatility := tree.GetRoutineVolatility(cf.Options)
fmtCtx := tree.NewFmtCtx(tree.FmtSerializable)

defer func(origValue bool) {
b.insideSQLRoutine = origValue
}(b.insideSQLRoutine)
b.insideSQLRoutine = language == tree.RoutineLangSQL

// Validate each statement and collect the dependencies.
var stmtScope *scope
switch language {
Expand All @@ -334,31 +339,29 @@ func (b *Builder) buildCreateFunction(cf *tree.CreateRoutine, inScope *scope) (o
if err != nil {
panic(err)
}
b.withinSQLRoutine(func() {
for i, stmt := range stmts {
// Add statement ast into CreateRoutine node for logging purpose, and set
// the annotations for this statement so names can be resolved.
cf.BodyStatements = append(cf.BodyStatements, stmt.AST)
ann := tree.MakeAnnotations(stmt.NumAnnotations)
cf.BodyAnnotations = append(cf.BodyAnnotations, &ann)

// The defer logic will reset the annotations to the old value.
b.semaCtx.Annotations = ann
b.evalCtx.Annotations = &ann

// We need to disable stable function folding because we want to catch the
// volatility of stable functions. If folded, we only get a scalar and
// lose the volatility.
b.factory.FoldingControl().TemporarilyDisallowStableFolds(func() {
stmtScope = b.buildStmtAtRootWithScope(stmts[i].AST, nil /* desiredTypes */, bodyScope)
})
checkStmtVolatility(targetVolatility, stmtScope, stmt.AST)

// Format the statements with qualified datasource names.
formatFuncBodyStmt(fmtCtx, stmt.AST, language, i > 0 /* newLine */)
afterBuildStmt()
}
})
for i, stmt := range stmts {
// Add statement ast into CreateRoutine node for logging purpose, and set
// the annotations for this statement so names can be resolved.
cf.BodyStatements = append(cf.BodyStatements, stmt.AST)
ann := tree.MakeAnnotations(stmt.NumAnnotations)
cf.BodyAnnotations = append(cf.BodyAnnotations, &ann)

// The defer logic will reset the annotations to the old value.
b.semaCtx.Annotations = ann
b.evalCtx.Annotations = &ann

// We need to disable stable function folding because we want to catch the
// volatility of stable functions. If folded, we only get a scalar and
// lose the volatility.
b.factory.FoldingControl().TemporarilyDisallowStableFolds(func() {
stmtScope = b.buildStmtAtRootWithScope(stmts[i].AST, nil /* desiredTypes */, bodyScope)
})
checkStmtVolatility(targetVolatility, stmtScope, stmt.AST)

// Format the statements with qualified datasource names.
formatFuncBodyStmt(fmtCtx, stmt.AST, language, i > 0 /* newLine */)
afterBuildStmt()
}
case tree.RoutineLangPLpgSQL:
if cf.ReturnType != nil && cf.ReturnType.SetOf {
panic(unimplemented.NewWithIssueDetail(105240,
Expand Down
40 changes: 16 additions & 24 deletions pkg/sql/opt/optbuilder/routine.go
Original file line number Diff line number Diff line change
Expand Up @@ -352,15 +352,17 @@ func (b *Builder) buildRoutine(
// for the schema changer we only need depth 1. Also keep track of when
// we have are executing inside a UDF, and whether the routine is used as a
// data source (this could be nested, so we need to track the previous state).
defer func(trackSchemaDeps, insideUDF, insideDataSource bool) {
defer func(trackSchemaDeps, insideUDF, insideDataSource, insideSQLRoutine bool) {
b.trackSchemaDeps = trackSchemaDeps
b.insideUDF = insideUDF
b.insideDataSource = insideDataSource
}(b.trackSchemaDeps, b.insideUDF, b.insideDataSource)
b.insideSQLRoutine = insideSQLRoutine
}(b.trackSchemaDeps, b.insideUDF, b.insideDataSource, b.insideSQLRoutine)
oldInsideDataSource := b.insideDataSource
b.insideDataSource = false
b.trackSchemaDeps = false
b.insideUDF = true
b.insideSQLRoutine = o.Language == tree.RoutineLangSQL
isSetReturning := o.Class == tree.GeneratorClass

// Build an expression for each statement in the function body.
Expand Down Expand Up @@ -392,21 +394,19 @@ func (b *Builder) buildRoutine(
body = make([]memo.RelExpr, len(stmts))
bodyProps = make([]*physical.Required, len(stmts))

b.withinSQLRoutine(func() {
for i := range stmts {
stmtScope := b.buildStmtAtRootWithScope(stmts[i].AST, nil /* desiredTypes */, bodyScope)
expr, physProps := stmtScope.expr, stmtScope.makePhysicalProps()

// The last statement produces the output of the UDF.
if i == len(stmts)-1 {
expr, physProps = b.finishBuildLastStmt(
stmtScope, bodyScope, inScope, isSetReturning, oldInsideDataSource, f,
)
}
body[i] = expr
bodyProps[i] = physProps
for i := range stmts {
stmtScope := b.buildStmtAtRootWithScope(stmts[i].AST, nil /* desiredTypes */, bodyScope)
expr, physProps := stmtScope.expr, stmtScope.makePhysicalProps()

// The last statement produces the output of the UDF.
if i == len(stmts)-1 {
expr, physProps = b.finishBuildLastStmt(
stmtScope, bodyScope, inScope, isSetReturning, oldInsideDataSource, f,
)
}
})
body[i] = expr
bodyProps[i] = physProps
}

if b.verboseTracing {
bodyStmts = make([]string, len(stmts))
Expand Down Expand Up @@ -677,11 +677,3 @@ func (b *Builder) maybeAddRoutineAssignmentCasts(
}
return b.constructProject(expr, stmtScope.cols), stmtScope.makePhysicalProps()
}

func (b *Builder) withinSQLRoutine(fn func()) {
defer func(origValue bool) {
b.insideSQLRoutine = origValue
}(b.insideSQLRoutine)
b.insideSQLRoutine = true
fn()
}
Loading