From 2f7a944bfa45cc70e5f52a663e5e5c53285edccb Mon Sep 17 00:00:00 2001 From: Drew Kimball Date: Fri, 12 Apr 2024 15:54:27 -0600 Subject: [PATCH 1/3] plpgsql: prevent inlining for block-exit continuation This commit prevents inlining for the continuation that transitions out of a PL/pgSQL block with an exception handler. This is necessary to ensure that the statements following the nested block are considered part of the parent block, not the nested block. Otherwise, an error thrown after the nested block might still be caught by the nested block's exception handler, which is incorrect behavior. Informs #122278 Release note: None --- .../testdata/logic_test/plpgsql_block | 34 ++++++++++ pkg/sql/opt/memo/testdata/logprops/tail-calls | 2 +- pkg/sql/opt/optbuilder/plpgsql.go | 38 +++++++---- .../opt/optbuilder/testdata/procedure_plpgsql | 16 ++--- pkg/sql/opt/optbuilder/testdata/udf_plpgsql | 68 +++++++++---------- 5 files changed, 102 insertions(+), 56 deletions(-) diff --git a/pkg/ccl/logictestccl/testdata/logic_test/plpgsql_block b/pkg/ccl/logictestccl/testdata/logic_test/plpgsql_block index dfcaa96ecf85..2143df8e9869 100644 --- a/pkg/ccl/logictestccl/testdata/logic_test/plpgsql_block +++ b/pkg/ccl/logictestccl/testdata/logic_test/plpgsql_block @@ -434,6 +434,40 @@ NOTICE: inner block: 1 NOTICE: inner handler: 2 NOTICE: outer handler: 3 +# Regression test for #122278 - calling f(1) should result in the second +# exception handler being triggered. +statement ok +CREATE TABLE t122278(x INT); + +statement ok +DROP FUNCTION f; +CREATE FUNCTION f(n INT) RETURNS INT AS $$ + BEGIN + BEGIN + IF n = 0 THEN + RETURN 1 // 0; + END IF; + EXCEPTION + WHEN division_by_zero THEN + RETURN (SELECT 100 + count(*) FROM t122278); + END; + RETURN 1 // 0; + EXCEPTION + WHEN division_by_zero THEN + RETURN (SELECT 200 + count(*) FROM t122278); + END +$$ LANGUAGE PLpgSQL; + +query I +SELECT f(0); +---- +100 + +query I +SELECT f(1); +---- +200 + subtest error statement ok diff --git a/pkg/sql/opt/memo/testdata/logprops/tail-calls b/pkg/sql/opt/memo/testdata/logprops/tail-calls index 14adda5af6f2..dc4f5eafe7e3 100644 --- a/pkg/sql/opt/memo/testdata/logprops/tail-calls +++ b/pkg/sql/opt/memo/testdata/logprops/tail-calls @@ -902,7 +902,7 @@ values └── body └── values └── tuple - └── udf: exception_block_5 + └── udf: nested_block_5 ├── tail-call ├── body │ └── values diff --git a/pkg/sql/opt/optbuilder/plpgsql.go b/pkg/sql/opt/optbuilder/plpgsql.go index ec9749939847..3299366d70f3 100644 --- a/pkg/sql/opt/optbuilder/plpgsql.go +++ b/pkg/sql/opt/optbuilder/plpgsql.go @@ -372,17 +372,18 @@ func (b *plpgsqlBuilder) buildBlock(astBlock *ast.Block, s *scope) *scope { } // Build the exception handler. This has to happen after building the variable // declarations, since the exception handler can reference the block's vars. - if exceptions := b.buildExceptions(astBlock); exceptions != nil { + if len(astBlock.Exceptions) > 0 { + exceptionBlock := b.buildExceptions(astBlock) + block.hasExceptionHandler = true + // There is an implicit block around the body statements, with an optional // exception handler. Note that the variable declarations are not in block // scope, and exceptions thrown during variable declaration are not caught. // - // The routine is volatile to prevent inlining. Only the block and - // variable-assignment routines need to be volatile; see the buildExceptions + // The routine is volatile to prevent inlining; see the buildExceptions // comment for details. - block.hasExceptionHandler = true - blockCon := b.makeContinuation("exception_block") - blockCon.def.ExceptionBlock = exceptions + blockCon := b.makeContinuation("nested_block") + blockCon.def.ExceptionBlock = exceptionBlock blockCon.def.Volatility = volatility.Volatile b.appendPlpgSQLStmts(&blockCon, astBlock.Body) return b.callContinuation(&blockCon, s) @@ -407,7 +408,15 @@ func (b *plpgsqlBuilder) buildPLpgSQLStatements(stmts []ast.Statement, s *scope) // For a nested block, push a continuation with the remaining statements // before calling recursively into buildBlock. The continuation will be // called when the control flow within the nested block terminates. - blockCon := b.makeContinuationWithTyp("nested_block", t.Label, continuationBlockExit) + blockCon := b.makeContinuationWithTyp("post_nested_block", t.Label, continuationBlockExit) + if len(t.Exceptions) > 0 { + // If the block has an exception handler, mark the continuation as + // volatile to prevent inlining. This is necessary to ensure that + // transitions out of a PL/pgSQL block are correctly tracked during + // execution. The transition *into* the block is marked volatile for the + // same reason; see also buildBlock and buildExceptions. + blockCon.def.Volatility = volatility.Volatile + } b.appendPlpgSQLStmts(&blockCon, stmts[i+1:]) b.pushContinuation(blockCon) return b.buildBlock(t, s) @@ -1375,13 +1384,16 @@ func (b *plpgsqlBuilder) makeRaiseFormatMessage( // cannot throw an exception, and so the "i := 2" assignment will never become // visible. // -// The block and assignment continuations must be volatile to prevent inlining. -// The presence of an exception handler does not impose restrictions on inlining -// for other continuations. +// The block entry/exit and assignment continuations for a block with an +// exception handler must be volatile to prevent inlining. The presence of an +// exception handler does not impose restrictions on inlining for other types of +// continuations. +// +// Inlining is disabled for the block-exit continuation to ensure that the +// statements following the nested block are correctly handled as part of the +// parent block. Otherwise, an error thrown from the parent block could +// incorrectly be caught by the exception handler of the nested block. func (b *plpgsqlBuilder) buildExceptions(block *ast.Block) *memo.ExceptionBlock { - if len(block.Exceptions) == 0 { - return nil - } codes := make([]pgcode.Code, 0, len(block.Exceptions)) handlers := make([]*memo.UDFDefinition, 0, len(block.Exceptions)) addHandler := func(codeStr string, handler *memo.UDFDefinition) { diff --git a/pkg/sql/opt/optbuilder/testdata/procedure_plpgsql b/pkg/sql/opt/optbuilder/testdata/procedure_plpgsql index 21e37cde0c78..46fa3bf806a5 100644 --- a/pkg/sql/opt/optbuilder/testdata/procedure_plpgsql +++ b/pkg/sql/opt/optbuilder/testdata/procedure_plpgsql @@ -368,11 +368,11 @@ call │ │ ├── const: '' │ │ └── const: '00000' │ └── project - │ ├── columns: nested_block_3:25 + │ ├── columns: post_nested_block_3:25 │ ├── values │ │ └── tuple │ └── projections - │ └── udf: nested_block_3 [as=nested_block_3:25] + │ └── udf: post_nested_block_3 [as=post_nested_block_3:25] │ ├── tail-call │ ├── args │ │ ├── variable: x:21 @@ -502,11 +502,11 @@ call │ │ ├── const: '' │ │ └── const: '00000' │ └── project - │ ├── columns: exception_block_7:17 + │ ├── columns: nested_block_7:17 │ ├── values │ │ └── tuple │ └── projections - │ └── udf: exception_block_7 [as=exception_block_7:17] + │ └── udf: nested_block_7 [as=nested_block_7:17] │ ├── tail-call │ ├── args │ │ └── variable: x:2 @@ -532,11 +532,11 @@ call │ │ │ ├── const: 1 │ │ │ └── const: 0 │ │ └── project - │ │ ├── columns: nested_block_3:15 + │ │ ├── columns: post_nested_block_3:15 │ │ ├── values │ │ │ └── tuple │ │ └── projections - │ │ └── udf: nested_block_3 [as=nested_block_3:15] + │ │ └── udf: post_nested_block_3 [as=post_nested_block_3:15] │ │ ├── tail-call │ │ ├── args │ │ │ └── variable: x:13 @@ -581,7 +581,7 @@ call │ └── exception-handler │ └── SQLSTATE '22012' │ └── project - │ ├── columns: nested_block_3:11 + │ ├── columns: post_nested_block_3:11 │ ├── barrier │ │ ├── columns: x:10!null │ │ └── project @@ -591,7 +591,7 @@ call │ │ └── projections │ │ └── const: 100 [as=x:10] │ └── projections - │ └── udf: nested_block_3 [as=nested_block_3:11] + │ └── udf: post_nested_block_3 [as=post_nested_block_3:11] │ ├── args │ │ └── variable: x:10 │ ├── params: x:4 diff --git a/pkg/sql/opt/optbuilder/testdata/udf_plpgsql b/pkg/sql/opt/optbuilder/testdata/udf_plpgsql index 313bc9e28100..4d1e51ff1cbe 100644 --- a/pkg/sql/opt/optbuilder/testdata/udf_plpgsql +++ b/pkg/sql/opt/optbuilder/testdata/udf_plpgsql @@ -3717,13 +3717,13 @@ project └── udf: f [as=f:6] └── body └── limit - ├── columns: exception_block_7:5 + ├── columns: nested_block_7:5 ├── project - │ ├── columns: exception_block_7:5 + │ ├── columns: nested_block_7:5 │ ├── values │ │ └── tuple │ └── projections - │ └── udf: exception_block_7 [as=exception_block_7:5] + │ └── udf: nested_block_7 [as=nested_block_7:5] │ ├── body │ │ └── project │ │ ├── columns: stmt_return_8:4!null @@ -3786,13 +3786,13 @@ project ├── params: i:1 └── body └── limit - ├── columns: exception_block_3:13 + ├── columns: nested_block_3:13 ├── project - │ ├── columns: exception_block_3:13 + │ ├── columns: nested_block_3:13 │ ├── values │ │ └── tuple │ └── projections - │ └── udf: exception_block_3 [as=exception_block_3:13] + │ └── udf: nested_block_3 [as=nested_block_3:13] │ ├── args │ │ └── variable: i:1 │ ├── params: i:4 @@ -3875,13 +3875,13 @@ project ├── params: i:1 └── body └── limit - ├── columns: exception_block_5:15 + ├── columns: nested_block_5:15 ├── project - │ ├── columns: exception_block_5:15 + │ ├── columns: nested_block_5:15 │ ├── values │ │ └── tuple │ └── projections - │ └── udf: exception_block_5 [as=exception_block_5:15] + │ └── udf: nested_block_5 [as=nested_block_5:15] │ ├── args │ │ └── variable: i:1 │ ├── params: i:6 @@ -3960,13 +3960,13 @@ project └── udf: f [as=f:8] └── body └── limit - ├── columns: exception_block_7:7 + ├── columns: nested_block_7:7 ├── project - │ ├── columns: exception_block_7:7 + │ ├── columns: nested_block_7:7 │ ├── values │ │ └── tuple │ └── projections - │ └── udf: exception_block_7 [as=exception_block_7:7] + │ └── udf: nested_block_7 [as=nested_block_7:7] │ ├── body │ │ └── project │ │ ├── columns: "_stmt_raise_8":6 @@ -4045,9 +4045,9 @@ project ├── params: n:1 └── body └── limit - ├── columns: exception_block_3:9 + ├── columns: nested_block_3:9 ├── project - │ ├── columns: exception_block_3:9 + │ ├── columns: nested_block_3:9 │ ├── barrier │ │ ├── columns: i:2 │ │ └── project @@ -4059,7 +4059,7 @@ project │ │ ├── const: 100 │ │ └── variable: n:1 │ └── projections - │ └── udf: exception_block_3 [as=exception_block_3:9] + │ └── udf: nested_block_3 [as=nested_block_3:9] │ ├── args │ │ ├── variable: n:1 │ │ └── variable: i:2 @@ -4114,9 +4114,9 @@ project ├── params: i:1 j:2 k:3 └── body └── limit - ├── columns: exception_block_3:33 + ├── columns: nested_block_3:33 ├── project - │ ├── columns: exception_block_3:33 + │ ├── columns: nested_block_3:33 │ ├── barrier │ │ ├── columns: x:4!null │ │ └── project @@ -4126,7 +4126,7 @@ project │ │ └── projections │ │ └── const: 0 [as=x:4] │ └── projections - │ └── udf: exception_block_3 [as=exception_block_3:33] + │ └── udf: nested_block_3 [as=nested_block_3:33] │ ├── args │ │ ├── variable: i:1 │ │ ├── variable: j:2 @@ -4250,9 +4250,9 @@ project ├── params: i:1 └── body └── limit - ├── columns: exception_block_3:30 + ├── columns: nested_block_3:30 ├── project - │ ├── columns: exception_block_3:30 + │ ├── columns: nested_block_3:30 │ ├── barrier │ │ ├── columns: x:2 │ │ └── project @@ -4263,7 +4263,7 @@ project │ │ └── cast: INT8 [as=x:2] │ │ └── null │ └── projections - │ └── udf: exception_block_3 [as=exception_block_3:30] + │ └── udf: nested_block_3 [as=nested_block_3:30] │ ├── args │ │ ├── variable: i:1 │ │ └── variable: x:2 @@ -4443,9 +4443,9 @@ project ├── params: n:1 a:2 └── body └── limit - ├── columns: exception_block_3:44 + ├── columns: nested_block_3:44 ├── project - │ ├── columns: exception_block_3:44 + │ ├── columns: nested_block_3:44 │ ├── barrier │ │ ├── columns: x:3 i:4!null │ │ └── project @@ -4460,7 +4460,7 @@ project │ │ └── projections │ │ └── const: 0 [as=i:4] │ └── projections - │ └── udf: exception_block_3 [as=exception_block_3:44] + │ └── udf: nested_block_3 [as=nested_block_3:44] │ ├── args │ │ ├── variable: n:1 │ │ ├── variable: a:2 @@ -5494,9 +5494,9 @@ project └── udf: f [as=f:19] └── body └── limit - ├── columns: exception_block_5:18 + ├── columns: nested_block_5:18 ├── project - │ ├── columns: exception_block_5:18 + │ ├── columns: nested_block_5:18 │ ├── barrier │ │ ├── columns: curs:1!null │ │ └── project @@ -5506,7 +5506,7 @@ project │ │ └── projections │ │ └── const: 'foo' [as=curs:1] │ └── projections - │ └── udf: exception_block_5 [as=exception_block_5:18] + │ └── udf: nested_block_5 [as=nested_block_5:18] │ ├── args │ │ └── variable: curs:1 │ ├── params: curs:10 @@ -5942,13 +5942,13 @@ project └── udf: f [as=f:6] └── body └── limit - ├── columns: exception_block_7:5 + ├── columns: nested_block_7:5 ├── project - │ ├── columns: exception_block_7:5 + │ ├── columns: nested_block_7:5 │ ├── values │ │ └── tuple │ └── projections - │ └── udf: exception_block_7 [as=exception_block_7:5] + │ └── udf: nested_block_7 [as=nested_block_7:5] │ ├── body │ │ └── project │ │ ├── columns: stmt_return_8:4!null @@ -6531,11 +6531,11 @@ project │ │ ├── const: '' │ │ └── const: '00000' │ └── project - │ ├── columns: nested_block_3:17 + │ ├── columns: post_nested_block_3:17 │ ├── values │ │ └── tuple │ └── projections - │ └── udf: nested_block_3 [as=nested_block_3:17] + │ └── udf: post_nested_block_3 [as=post_nested_block_3:17] │ ├── tail-call │ ├── args │ │ └── variable: outer_quantity:14 @@ -6751,11 +6751,11 @@ project │ │ └── null [type=unknown] │ └── subquery [type=tuple{int, unknown, decimal}] │ └── project - │ ├── columns: nested_block_6:7(tuple{int, unknown, decimal}) + │ ├── columns: post_nested_block_6:7(tuple{int, unknown, decimal}) │ ├── values │ │ └── tuple [type=tuple] │ └── projections - │ └── udf: nested_block_6 [as=nested_block_6:7, type=tuple{int, unknown, decimal}] + │ └── udf: post_nested_block_6 [as=post_nested_block_6:7, type=tuple{int, unknown, decimal}] │ └── body │ └── project │ ├── columns: stmt_if_1:5(tuple{int, unknown, decimal}) From 47c90ee64ce08b4ffcb69c61fea53882f724972e Mon Sep 17 00:00:00 2001 From: Drew Kimball Date: Fri, 12 Apr 2024 16:00:21 -0600 Subject: [PATCH 2/3] plpgsql: keep track of the subroutine that begins a PL/pgSQL block This commit adds logic to keep track of the PL/pgSQL sub-routine that logically transitions into a PL/pgSQL block with an exception handler. This is necessary to ensure that the state shared between sub-routines within the same block is correctly initialized. Previously, the block state was only initialized once, but this is incorrect for loops, which need to re-initialize the state on each iteration. Fixes #122278 Release note: None --- .../testdata/logic_test/plpgsql_block | 43 ++++++++++++++++++- pkg/sql/opt/exec/execbuilder/relational.go | 1 + pkg/sql/opt/exec/execbuilder/scalar.go | 5 +++ pkg/sql/opt/memo/expr.go | 8 +++- pkg/sql/opt/optbuilder/plpgsql.go | 1 + pkg/sql/routine.go | 10 ++--- pkg/sql/sem/tree/routine.go | 7 +++ 7 files changed, 67 insertions(+), 8 deletions(-) diff --git a/pkg/ccl/logictestccl/testdata/logic_test/plpgsql_block b/pkg/ccl/logictestccl/testdata/logic_test/plpgsql_block index 2143df8e9869..5db10d8f53ec 100644 --- a/pkg/ccl/logictestccl/testdata/logic_test/plpgsql_block +++ b/pkg/ccl/logictestccl/testdata/logic_test/plpgsql_block @@ -114,9 +114,50 @@ NOTICE: 2, 1 NOTICE: final j: 2 NOTICE: final i: 3 -subtest nested_block_cursors +# Regression test for #122278 - a nested block with an exception handler inside +# a loop should only rollback mutations from the current iteration. +statement ok +CREATE TABLE t122278(x INT); + statement ok DROP PROCEDURE p; +CREATE PROCEDURE p() AS $$ + DECLARE + i INT := 0; + BEGIN + WHILE i < 5 LOOP + i := i + 1; + BEGIN + INSERT INTO t122278 VALUES (i); + IF i = 3 THEN + SELECT 1 // 0; + END IF; + EXCEPTION WHEN division_by_zero THEN + RAISE NOTICE 'saw exception'; + END; + END LOOP; + END; +$$ LANGUAGE PLpgSQL; + +query T noticetrace +CALL p(); +---- +NOTICE: saw exception + +query I rowsort +SELECT * FROM t122278; +---- +1 +2 +4 +5 + +statement ok +DROP TABLE t122278 CASCADE; + +subtest nested_block_cursors + +statement ok CREATE PROCEDURE p() AS $$ DECLARE curs1 CURSOR FOR SELECT 1 FROM generate_series(1, 10); diff --git a/pkg/sql/opt/exec/execbuilder/relational.go b/pkg/sql/opt/exec/execbuilder/relational.go index 7ce3389c957b..f58618db95ce 100644 --- a/pkg/sql/opt/exec/execbuilder/relational.go +++ b/pkg/sql/opt/exec/execbuilder/relational.go @@ -3392,6 +3392,7 @@ func (b *Builder) buildCall(c *memo.CallExpr) (_ execPlan, outputCols colOrdMap, udf.Def.SetReturning, false, /* tailCall */ true, /* procedure */ + false, /* blockStart */ nil, /* blockState */ nil, /* cursorDeclaration */ ) diff --git a/pkg/sql/opt/exec/execbuilder/scalar.go b/pkg/sql/opt/exec/execbuilder/scalar.go index fb327c8599c2..7b0ec2be500d 100644 --- a/pkg/sql/opt/exec/execbuilder/scalar.go +++ b/pkg/sql/opt/exec/execbuilder/scalar.go @@ -702,6 +702,7 @@ func (b *Builder) buildExistsSubquery( false, /* generator */ false, /* tailCall */ false, /* procedure */ + false, /* blockStart */ nil, /* blockState */ nil, /* cursorDeclaration */ ), @@ -822,6 +823,7 @@ func (b *Builder) buildSubquery( false, /* generator */ false, /* tailCall */ false, /* procedure */ + false, /* blockStart */ nil, /* blockState */ nil, /* cursorDeclaration */ ), nil @@ -881,6 +883,7 @@ func (b *Builder) buildSubquery( false, /* generator */ false, /* tailCall */ false, /* procedure */ + false, /* blockStart */ nil, /* blockState */ nil, /* cursorDeclaration */ ), nil @@ -992,6 +995,7 @@ func (b *Builder) buildUDF(ctx *buildScalarCtx, scalar opt.ScalarExpr) (tree.Typ udf.Def.SetReturning, tailCall, false, /* procedure */ + udf.Def.BlockStart, blockState, udf.Def.CursorDeclaration, ), nil @@ -1047,6 +1051,7 @@ func (b *Builder) initRoutineExceptionHandler( action.SetReturning, false, /* tailCall */ false, /* procedure */ + false, /* blockStart */ nil, /* blockState */ nil, /* cursorDeclaration */ ) diff --git a/pkg/sql/opt/memo/expr.go b/pkg/sql/opt/memo/expr.go index 8d5035c6abb3..79d13cfec5f4 100644 --- a/pkg/sql/opt/memo/expr.go +++ b/pkg/sql/opt/memo/expr.go @@ -709,10 +709,16 @@ type UDFDefinition struct { // data source. MultiColDataSource bool - // IsRecursive indicates whether the UDF recursively calls itself. This + // IsRecursive indicates whether the routine recursively calls itself. This // applies to direct as well as indirect recursive calls (mutual recursion). IsRecursive bool + // BlockStart indicates whether the routine marks the start of a PL/pgSQL + // block with an exception handler. This is used to determine when to + // initialize the common state held between sub-routines within the same + // block. + BlockStart bool + // RoutineType indicates whether this routine is a UDF, stored procedure, or // builtin function. RoutineType tree.RoutineType diff --git a/pkg/sql/opt/optbuilder/plpgsql.go b/pkg/sql/opt/optbuilder/plpgsql.go index 3299366d70f3..49991f1cbaf6 100644 --- a/pkg/sql/opt/optbuilder/plpgsql.go +++ b/pkg/sql/opt/optbuilder/plpgsql.go @@ -385,6 +385,7 @@ func (b *plpgsqlBuilder) buildBlock(astBlock *ast.Block, s *scope) *scope { blockCon := b.makeContinuation("nested_block") blockCon.def.ExceptionBlock = exceptionBlock blockCon.def.Volatility = volatility.Volatile + blockCon.def.BlockStart = true b.appendPlpgSQLStmts(&blockCon, astBlock.Body) return b.callContinuation(&blockCon, s) } diff --git a/pkg/sql/routine.go b/pkg/sql/routine.go index 573da708a57c..e665dbfcfa56 100644 --- a/pkg/sql/routine.go +++ b/pkg/sql/routine.go @@ -439,19 +439,17 @@ func (g *routineGenerator) closeCursors(blockState *tree.BlockState) error { return err } -// maybeInitBlockState creates a savepoint if all the following are true: -// 1. The current routine is within a PLpgSQL exception block. -// 2. The current block has an exception handler -// 3. The savepoint hasn't already been created for this block. +// maybeInitBlockState creates a savepoint for a routine that marks a transition +// into a PL/pgSQL block with an exception handler. // // Note that it is not necessary to explicitly release the savepoint at any // point, because it does not add any overhead. func (g *routineGenerator) maybeInitBlockState(ctx context.Context) error { blockState := g.expr.BlockState - if blockState == nil { + if blockState == nil || !g.expr.BlockStart { return nil } - if blockState.ExceptionHandler != nil && blockState.SavepointTok == nil { + if blockState.ExceptionHandler != nil { // Drop down a savepoint for the current scope. var err error if blockState.SavepointTok, err = g.p.Txn().CreateSavepoint(ctx); err != nil { diff --git a/pkg/sql/sem/tree/routine.go b/pkg/sql/sem/tree/routine.go index 2fef6325d766..5e3c3cf5400a 100644 --- a/pkg/sql/sem/tree/routine.go +++ b/pkg/sql/sem/tree/routine.go @@ -128,6 +128,11 @@ type RoutineExpr struct { // Procedure is true if the routine is a procedure being invoked by CALL. Procedure bool + // BlockStart is true if this routine marks the start of a PL/pgSQL block with + // an exception handler. It determines when to initialize the state shared + // between sub-routines for the block. + BlockStart bool + // BlockState holds the information needed to coordinate error-handling // between the sub-routines that make up a PLpgSQL exception block. BlockState *BlockState @@ -149,6 +154,7 @@ func NewTypedRoutineExpr( generator bool, tailCall bool, procedure bool, + blockStart bool, blockState *BlockState, cursorDeclaration *RoutineOpenCursor, ) *RoutineExpr { @@ -163,6 +169,7 @@ func NewTypedRoutineExpr( Generator: generator, TailCall: tailCall, Procedure: procedure, + BlockStart: blockStart, BlockState: blockState, CursorDeclaration: cursorDeclaration, } From e7292421705d7ed89fd2e0e6d6bc41a3b5fb4acc Mon Sep 17 00:00:00 2001 From: Drew Kimball Date: Fri, 12 Apr 2024 16:03:13 -0600 Subject: [PATCH 3/3] plpgsql: roll back all cursors within a block that handles an exception This commit fixes handling for cursors opened within the scope of a block with an exception handler that has nested blocks or nested routine calls. Previously, if a PL/pgSQL block caught an exception, only the cursors opened directly by that block would be rolled back. Any cursors opened by a nested block or routine call would remain open. Now, the block state tracks the timestamp when the block's execution began. Once an exception is caught, all cursors with a timestamp later than the block's start are rolled back. Fixes #121078 Release note: None --- .../testdata/logic_test/plpgsql_block | 79 ++++++++++++++++++- pkg/sql/routine.go | 38 ++++----- pkg/sql/sem/tree/routine.go | 11 ++- 3 files changed, 102 insertions(+), 26 deletions(-) diff --git a/pkg/ccl/logictestccl/testdata/logic_test/plpgsql_block b/pkg/ccl/logictestccl/testdata/logic_test/plpgsql_block index 5db10d8f53ec..df8a722d1d8d 100644 --- a/pkg/ccl/logictestccl/testdata/logic_test/plpgsql_block +++ b/pkg/ccl/logictestccl/testdata/logic_test/plpgsql_block @@ -236,6 +236,83 @@ NOTICE: a2 NOTICE: a3 NOTICE: a4 +# Regression test for #122278 - all cursors within the scope of a block +# (including those in nested blocks or routines) should be closed when the block +# catches an exception. +statement ok +CREATE PROCEDURE p_nested(curs REFCURSOR) AS $$ + BEGIN + OPEN curs FOR SELECT -100; + END; +$$ LANGUAGE PLpgSQL; + +statement ok +DROP FUNCTION f; +CREATE FUNCTION f(n INT) RETURNS INT AS $$ + DECLARE + x REFCURSOR; + y REFCURSOR; + BEGIN + OPEN x FOR SELECT 100; + BEGIN + OPEN y FOR SELECT 200; + IF n = 0 THEN + RETURN 1 // 0; + END IF; + CALL p_nested('foo'); + IF n = 1 THEN + RETURN 1 // 0; + END IF; + EXCEPTION + WHEN division_by_zero THEN + RETURN (SELECT count(*) FROM pg_cursors); + END; + CALL p_nested('bar'); + IF n = 2 THEN + RETURN 1 // 0; + END IF; + RETURN (SELECT count(*) FROM pg_cursors); + EXCEPTION + WHEN division_by_zero THEN + RETURN (SELECT count(*) FROM pg_cursors); + END +$$ LANGUAGE PLpgSQL; + +statement ok +CLOSE ALL; + +query I +SELECT f(0); +---- +1 + +statement ok +CLOSE ALL; + +query I +SELECT f(1); +---- +1 + +statement ok +CLOSE ALL; + +query I +SELECT f(2); +---- +0 + +statement ok +CLOSE ALL; + +query I +SELECT f(3); +---- +4 + +statement ok +CLOSE ALL; + subtest nested_block_exceptions # Don't catch an exception thrown from the variable declarations. @@ -382,7 +459,7 @@ CALL p(1); ---- NOTICE: 1 -statement error pgcode 34000 pq: cursor \"\" does not exist +statement error pgcode 34000 pq: cursor \"\" does not exist CALL p(2); query T noticetrace diff --git a/pkg/sql/routine.go b/pkg/sql/routine.go index e665dbfcfa56..ee78bbe7ebcc 100644 --- a/pkg/sql/routine.go +++ b/pkg/sql/routine.go @@ -322,7 +322,7 @@ func (g *routineGenerator) startInternal(ctx context.Context, txn *kv.Txn) (err return err } if openCursor { - return cursorHelper.createCursor(g.p, g.expr.BlockState) + return cursorHelper.createCursor(g.p) } return nil }) @@ -392,12 +392,12 @@ func (g *routineGenerator) handleException(ctx context.Context, err error) error cursErr := g.closeCursors(blockState) if cursErr != nil { // This error is unexpected, so return immediately. - return errors.CombineErrors(err, cursErr) + return errors.CombineErrors(err, errors.WithAssertionFailure(cursErr)) } spErr := g.p.Txn().RollbackToSavepoint(ctx, blockState.SavepointTok.(kv.SavepointToken)) if spErr != nil { // This error is unexpected, so return immediately. - return errors.CombineErrors(err, spErr) + return errors.CombineErrors(err, errors.WithAssertionFailure(spErr)) } // Truncate the arguments using the number of variables in scope for the // current block. This is necessary because the error may originate from @@ -422,25 +422,26 @@ func (g *routineGenerator) handleException(ctx context.Context, err error) error // closeCursors closes any cursors that were opened within the scope of the // current block. It is used for PLpgSQL exception handling. func (g *routineGenerator) closeCursors(blockState *tree.BlockState) error { - if blockState == nil { + if blockState == nil || blockState.CursorTimestamp == nil { return nil } + blockStart := *blockState.CursorTimestamp + blockState.CursorTimestamp = nil var err error - for _, name := range blockState.Cursors { - if g.p.sqlCursors.getCursor(name) == nil { - // This cursor has already been closed. - continue - } - if curErr := g.p.sqlCursors.closeCursor(name); curErr != nil { - // Attempt to close all cursors in the block, even if one throws an error. - err = errors.CombineErrors(err, curErr) + for name, cursor := range g.p.sqlCursors.list() { + if cursor.created.After(blockStart) { + if curErr := g.p.sqlCursors.closeCursor(name); curErr != nil { + // Try to close all cursors in the scope, even if one throws an error. + err = errors.CombineErrors(err, curErr) + } } } return err } // maybeInitBlockState creates a savepoint for a routine that marks a transition -// into a PL/pgSQL block with an exception handler. +// into a PL/pgSQL block with an exception handler. It also tracks the current +// timestamp in order to correctly roll back cursors opened within the block. // // Note that it is not necessary to explicitly release the savepoint at any // point, because it does not add any overhead. @@ -455,6 +456,10 @@ func (g *routineGenerator) maybeInitBlockState(ctx context.Context) error { if blockState.SavepointTok, err = g.p.Txn().CreateSavepoint(ctx); err != nil { return err } + // Save the current timestamp, so that cursors opened from now on can be + // rolled back by the exception handler. + curTime := timeutil.Now() + blockState.CursorTimestamp = &curTime } return nil } @@ -599,7 +604,7 @@ type plpgsqlCursorHelper struct { rowsAffected int } -func (h *plpgsqlCursorHelper) createCursor(p *planner, blockState *tree.BlockState) error { +func (h *plpgsqlCursorHelper) createCursor(p *planner) error { h.iter = newRowContainerIterator(h.ctx, h.container) cursor := &sqlCursor{ Rows: h, @@ -616,11 +621,6 @@ func (h *plpgsqlCursorHelper) createCursor(p *planner, blockState *tree.BlockSta if err := p.sqlCursors.addCursor(h.cursorName, cursor); err != nil { return err } - if blockState != nil { - // Add the cursor name to the block's state. This allows the exception handler - // to close it, if necessary. - blockState.Cursors = append(blockState.Cursors, h.cursorName) - } h.addedCursor = true return nil } diff --git a/pkg/sql/sem/tree/routine.go b/pkg/sql/sem/tree/routine.go index 5e3c3cf5400a..0f7a27de21d5 100644 --- a/pkg/sql/sem/tree/routine.go +++ b/pkg/sql/sem/tree/routine.go @@ -12,6 +12,7 @@ package tree import ( "context" + "time" "github.com/cockroachdb/cockroach/pkg/sql/pgwire/pgcode" "github.com/cockroachdb/cockroach/pkg/sql/types" @@ -270,12 +271,10 @@ type BlockState struct { // kv.SavepointToken to avoid import cycles. SavepointTok interface{} - // Cursors is a list of the names of cursors that have been opened within the - // current block. If the exception handler catches an exception, these cursors - // must be closed before the handler can proceed. - // TODO(111139): Once we support nested routine calls, we may have to track - // newly opened cursors differently. - Cursors []Name + // CursorTimestamp is the timestamp at which control transitioned into this + // PL/pgSQL block. It is used to close (only) cursors which were opened within + // the scope of the block when an exception is caught. + CursorTimestamp *time.Time } // StoredProcTxnOp indicates whether a stored procedure has requested that the