From d12fe78d5f62922a3663eecd745f3d39b40adaed Mon Sep 17 00:00:00 2001 From: Drew Kimball Date: Tue, 26 Mar 2024 03:53:25 -0600 Subject: [PATCH] sql: defer tail-call identification until execbuilding This commit changes the way routine tail-calls are handled. Before, only PL/pgSQL sub-routines were considered as tail-calls, and this was determined by a `TailCall` property that was set during optbuilding. This approach was fragile, did not work for explicit tail-calls, and did not work well with nested routine calls in general. Now, tail-calls are determined after optimization, during execbuilding. This will allow explicit (user-specified) tail calls to be optimized. It also prevents inlining rules from causing correctness bugs, since the old `TailCall` property only applied to the original calling routine. See `ExtractTailCalls` for further details. The next commit will add additional testing. Informs #120916 Release note: None --- pkg/sql/opt/exec/execbuilder/builder.go | 5 ++ pkg/sql/opt/exec/execbuilder/relational.go | 8 +-- pkg/sql/opt/exec/execbuilder/scalar.go | 26 +++++++- pkg/sql/opt/memo/expr.go | 3 + pkg/sql/opt/memo/extract.go | 75 ++++++++++++++++++++++ pkg/sql/opt/norm/decorrelate_funcs.go | 10 ++- pkg/sql/opt/ops/scalar.opt | 5 -- pkg/sql/opt/optbuilder/plpgsql.go | 4 +- pkg/sql/opt/optbuilder/routine.go | 1 + 9 files changed, 121 insertions(+), 16 deletions(-) diff --git a/pkg/sql/opt/exec/execbuilder/builder.go b/pkg/sql/opt/exec/execbuilder/builder.go index 7f5c6c7c5fea..612619899955 100644 --- a/pkg/sql/opt/exec/execbuilder/builder.go +++ b/pkg/sql/opt/exec/execbuilder/builder.go @@ -111,6 +111,11 @@ type Builder struct { // subqueries for statements inside a UDF. planLazySubqueries bool + // tailCalls is used when building the last body statement of a routine. It + // identifies nested routines that are in tail-call position. This information + // is used to determine whether tail-call optimization is applicable. + tailCalls map[*memo.UDFCallExpr]struct{} + // -- output -- // flags tracks various properties of the plan accumulated while building. diff --git a/pkg/sql/opt/exec/execbuilder/relational.go b/pkg/sql/opt/exec/execbuilder/relational.go index 189236be45fd..043f85e18e71 100644 --- a/pkg/sql/opt/exec/execbuilder/relational.go +++ b/pkg/sql/opt/exec/execbuilder/relational.go @@ -3390,10 +3390,10 @@ func (b *Builder) buildCall(c *memo.CallExpr) (_ execPlan, outputCols colOrdMap, udf.Def.CalledOnNullInput, udf.Def.MultiColDataSource, udf.Def.SetReturning, - udf.TailCall, - true, /* procedure */ - nil, /* blockState */ - nil, /* cursorDeclaration */ + false, /* tailCall */ + true, /* procedure */ + nil, /* blockState */ + nil, /* cursorDeclaration */ ) var ep execPlan diff --git a/pkg/sql/opt/exec/execbuilder/scalar.go b/pkg/sql/opt/exec/execbuilder/scalar.go index aeeced100cce..1c7fcc2f635a 100644 --- a/pkg/sql/opt/exec/execbuilder/scalar.go +++ b/pkg/sql/opt/exec/execbuilder/scalar.go @@ -953,6 +953,14 @@ func (b *Builder) buildUDF(ctx *buildScalarCtx, scalar opt.ScalarExpr) (tree.Typ b.initRoutineExceptionHandler(blockState, udf.Def.ExceptionBlock) } + // Execution expects there to be more than one body statement if a cursor is + // opened. + if udf.Def.CursorDeclaration != nil && len(udf.Def.Body) <= 1 { + panic(errors.AssertionFailedf( + "expected more than one body statement for a routine that opens a cursor", + )) + } + // Create a tree.RoutinePlanFn that can plan the statements in the UDF body. // TODO(mgartner): Add support for WITH expressions inside UDF bodies. planGen := b.buildRoutinePlanGenerator( @@ -969,6 +977,10 @@ func (b *Builder) buildUDF(ctx *buildScalarCtx, scalar opt.ScalarExpr) (tree.Typ // statements. enableStepping := udf.Def.Volatility == volatility.Volatile + // The calling routine, if any, will have already determined whether this + // routine is in tail-call position. + _, tailCall := b.tailCalls[udf] + return tree.NewTypedRoutineExpr( udf.Def.Name, args, @@ -978,7 +990,7 @@ func (b *Builder) buildUDF(ctx *buildScalarCtx, scalar opt.ScalarExpr) (tree.Typ udf.Def.CalledOnNullInput, udf.Def.MultiColDataSource, udf.Def.SetReturning, - udf.TailCall, + tailCall, false, /* procedure */ blockState, udf.Def.CursorDeclaration, @@ -1173,12 +1185,23 @@ func (b *Builder) buildRoutinePlanGenerator( return err } + // Identify nested routines that are in tail-call position, and cache them + // in the Builder. When a nested routine is evaluated, this information + // may be used to enable tail-call optimization. + isFinalPlan := i == len(stmts)-1 + var tailCalls map[*memo.UDFCallExpr]struct{} + if isFinalPlan { + tailCalls = make(map[*memo.UDFCallExpr]struct{}) + memo.ExtractTailCalls(optimizedExpr, tailCalls) + } + // Build the memo into a plan. ef := ref.(exec.Factory) eb := New(ctx, ef, &o, f.Memo(), b.catalog, optimizedExpr, b.semaCtx, b.evalCtx, false /* allowAutoCommit */, b.IsANSIDML) eb.withExprs = withExprs eb.disableTelemetry = true eb.planLazySubqueries = true + eb.tailCalls = tailCalls plan, err := eb.Build() if err != nil { if errors.IsAssertionFailure(err) { @@ -1194,7 +1217,6 @@ func (b *Builder) buildRoutinePlanGenerator( if len(eb.subqueries) > 0 { return expectedLazyRoutineError("subquery") } - isFinalPlan := i == len(stmts)-1 var stmtForDistSQLDiagram string if i < len(stmtStr) { stmtForDistSQLDiagram = stmtStr[i] diff --git a/pkg/sql/opt/memo/expr.go b/pkg/sql/opt/memo/expr.go index 82e74d5f8b91..c085c5febb65 100644 --- a/pkg/sql/opt/memo/expr.go +++ b/pkg/sql/opt/memo/expr.go @@ -717,6 +717,9 @@ type UDFDefinition struct { // builtin function. RoutineType tree.RoutineType + // RoutineLang indicates the language of the routine (SQL or PL/pgSQL). + RoutineLang tree.RoutineLanguage + // Params is the list of columns representing parameters of the function. The // i-th column in the list corresponds to the i-th parameter of the function. // During execution of the UDF, these columns are replaced with the arguments diff --git a/pkg/sql/opt/memo/extract.go b/pkg/sql/opt/memo/extract.go index b23410562838..04feb5a66a53 100644 --- a/pkg/sql/opt/memo/extract.go +++ b/pkg/sql/opt/memo/extract.go @@ -450,3 +450,78 @@ func ExtractValueForConstColumn( } return nil } + +// ExtractTailCalls traverses the given expression tree, searching for routines +// that are in tail-call position relative to the (assumed) calling routine. +// ExtractTailCalls assumes that the given expression is the last body statement +// of the calling routine, and that the map is already initialized. +// +// In order for a nested routine to qualify as a tail-call, the following +// condition must be true: If the nested routine is evaluated, then the calling +// routine must return the result of the nested routine without further +// modification. This means even simple expressions like CAST are not allowed. +// +// ExtractTailCalls is best-effort, but is sufficient to identify the tail-calls +// produced among PL/pgSQL sub-routines. +// +// NOTE: ExtractTailCalls does not take into account whether the calling routine +// has an exception handler. The execution engine must take this into account +// before applying tail-call optimization. +func ExtractTailCalls(expr opt.Expr, tailCalls map[*UDFCallExpr]struct{}) { + switch t := expr.(type) { + case *ProjectExpr: + // * The cardinality cannot be greater than one: Otherwise, a nested routine + // will be evaluated more than once, and all evaluations other than the last + // are not tail-calls. + // + // * There must be a single projection: the execution does not provide + // guarantees about order of evaluation for projections (though it may in + // the future). + // + // * The passthrough set must be empty: Otherwise, the result of the nested + // routine cannot directly be used as the result of the calling routine. + // + // * No routine in the input of the project can be a tail-call, since the + // Project will perform work after the nested routine evaluates. + // Note: this condition is enforced by simply not calling ExtractTailCalls + // on the input of the Project. + if t.Relational().Cardinality.IsZeroOrOne() && + len(t.Projections) == 1 && t.Passthrough.Empty() { + ExtractTailCalls(t.Projections[0].Element, tailCalls) + } + + case *ValuesExpr: + // Allow only the case where the Values expression contains only a single + // expression. Note: it may be possible to make an explicit guarantee that + // expressions in a row are evaluated in order, in which case it would be + // sufficient to ensure that the nested routine is in the last column. + if len(t.Rows) == 1 && len(t.Rows[0].(*TupleExpr).Elems) == 1 { + ExtractTailCalls(t.Rows[0].(*TupleExpr).Elems[0], tailCalls) + } + + case *SubqueryExpr: + // A subquery within a routine is lazily evaluated and passes through a + // single input row. Similar to Project, we require that the input have only + // one row and one column, since otherwise work may happen after the nested + // routine evaluates. + if t.Input.Relational().Cardinality.IsZeroOrOne() && + t.Input.Relational().OutputCols.Len() == 1 { + ExtractTailCalls(t.Input, tailCalls) + } + + case *CaseExpr: + // Case expressions guarantee that exactly one branch is evaluated, and pass + // through the result of the chosen branch. Therefore, a routine within a + // CASE branch can be a tail-call. + for i := range t.Whens { + ExtractTailCalls(t.Whens[i].(*WhenExpr).Value, tailCalls) + } + ExtractTailCalls(t.OrElse, tailCalls) + + case *UDFCallExpr: + // If we reached a scalar UDFCall expression, it is a tail call. + if !t.Def.SetReturning { + tailCalls[t] = struct{}{} + } + } +} diff --git a/pkg/sql/opt/norm/decorrelate_funcs.go b/pkg/sql/opt/norm/decorrelate_funcs.go index 8359e1b640c4..588747a83b74 100644 --- a/pkg/sql/opt/norm/decorrelate_funcs.go +++ b/pkg/sql/opt/norm/decorrelate_funcs.go @@ -14,6 +14,7 @@ import ( "github.com/cockroachdb/cockroach/pkg/sql/opt" "github.com/cockroachdb/cockroach/pkg/sql/opt/memo" "github.com/cockroachdb/cockroach/pkg/sql/opt/props" + "github.com/cockroachdb/cockroach/pkg/sql/sem/tree" "github.com/cockroachdb/cockroach/pkg/sql/types" "github.com/cockroachdb/errors" "github.com/cockroachdb/redact" @@ -65,9 +66,12 @@ func (c *CustomFuncs) deriveHasUnhoistableExpr(expr opt.Expr) bool { // cannot be reordered with other expressions. return true case *memo.UDFCallExpr: - if t.TailCall { - // A routine with the "tail-call" property cannot be reordered with other - // expressions, since it may then no longer be in tail-call position. + if t.Def.RoutineLang == tree.RoutineLangPLpgSQL { + // Hoisting a PL/pgSQL sub-routine could move it out of tail-call + // position, forcing inefficient nested execution. + // + // TODO(#119956): consider relaxing this for routines which aren't already + // in tail-call position. return true } } diff --git a/pkg/sql/opt/ops/scalar.opt b/pkg/sql/opt/ops/scalar.opt index 4d62fbfd16db..88c715ba7999 100644 --- a/pkg/sql/opt/ops/scalar.opt +++ b/pkg/sql/opt/ops/scalar.opt @@ -1272,11 +1272,6 @@ define UDFCall { define UDFCallPrivate { # Def points to the UDF SQL body. Def UDFDefinition - - # TailCall indicates whether the UDF is in tail-call position, meaning that - # it is nested in a parent routine which will not perform any additional - # processing once this call is evaluated. - TailCall bool } # TxnControl allows PL/pgSQL stored procedures to pause their execution, commit diff --git a/pkg/sql/opt/optbuilder/plpgsql.go b/pkg/sql/opt/optbuilder/plpgsql.go index ff14ad16938b..513860ab0079 100644 --- a/pkg/sql/opt/optbuilder/plpgsql.go +++ b/pkg/sql/opt/optbuilder/plpgsql.go @@ -1620,6 +1620,7 @@ func (b *plpgsqlBuilder) makeContinuation(conName string) continuation { CalledOnNullInput: true, BlockState: b.block().state, RoutineType: tree.UDFRoutine, + RoutineLang: tree.RoutineLangPLpgSQL, }, typ: continuationDefault, s: s, @@ -1671,9 +1672,8 @@ func (b *plpgsqlBuilder) callContinuation(con *continuation, s *scope) *scope { if con == nil { return b.handleEndOfFunction(s) } - // PLpgSQL continuation routines are always in tail-call position. args := b.makeContinuationArgs(con, s) - call := b.ob.factory.ConstructUDFCall(args, &memo.UDFCallPrivate{Def: con.def, TailCall: true}) + call := b.ob.factory.ConstructUDFCall(args, &memo.UDFCallPrivate{Def: con.def}) b.addBarrierIfVolatile(s, call) returnColName := scopeColName("").WithMetadataName(con.def.Name) diff --git a/pkg/sql/opt/optbuilder/routine.go b/pkg/sql/opt/optbuilder/routine.go index a9bab6e48169..129bdb1b8187 100644 --- a/pkg/sql/opt/optbuilder/routine.go +++ b/pkg/sql/opt/optbuilder/routine.go @@ -393,6 +393,7 @@ func (b *Builder) buildRoutine( CalledOnNullInput: o.CalledOnNullInput, MultiColDataSource: isMultiColDataSource, RoutineType: o.Type, + RoutineLang: o.Language, Body: body, BodyProps: bodyProps, BodyStmts: bodyStmts,