From 4197e9b6f6fb19e68d06131760ac774e4199909d Mon Sep 17 00:00:00 2001 From: Jane Xing Date: Sun, 12 Mar 2023 22:41:57 -0400 Subject: [PATCH] sql: add restrictions for pausable portals This commits add the following restrictions for pausable portals: 1. Not an internal queries 2. Read-only queries 3. No sub-quereis or post-queries 4. Local plan only This is because the current changes to the consumer-receiver model only consider the local push-based case. Release note: None --- pkg/sql/conn_executor_exec.go | 18 +++++++++++++++++- pkg/sql/conn_executor_prepare.go | 9 +++++---- pkg/sql/conn_io.go | 11 +++++++++++ pkg/sql/internal.go | 2 +- pkg/sql/pgwire/command_result.go | 10 ++++++++++ pkg/sql/prepared_stmt.go | 3 ++- 6 files changed, 46 insertions(+), 7 deletions(-) diff --git a/pkg/sql/conn_executor_exec.go b/pkg/sql/conn_executor_exec.go index 5a8672744653..61f131d47c4a 100644 --- a/pkg/sql/conn_executor_exec.go +++ b/pkg/sql/conn_executor_exec.go @@ -1431,9 +1431,15 @@ func (ex *connExecutor) dispatchToExecutionEngine( } ex.sessionTracing.TracePlanCheckStart(ctx) + + distSQLMode := ex.sessionData().DistSQLMode + // We only allow non-distributed plan for pausable portals. + if planner.portal != nil { + distSQLMode = sessiondatapb.DistSQLOff + } distributePlan := getPlanDistribution( ctx, planner.Descriptors().HasUncommittedTypes(), - ex.sessionData().DistSQLMode, planner.curPlan.main, + distSQLMode, planner.curPlan.main, ) ex.sessionTracing.TracePlanCheckEnd(ctx, nil, distributePlan.WillDistribute()) @@ -1918,6 +1924,16 @@ func (ex *connExecutor) execWithDistSQLEngine( factoryEvalCtx.SessionID = planner.ExtendedEvalContext().SessionID return factoryEvalCtx } + // We don't sub / post queries for pausable portal. Set it back to an + // un-pausable (normal) portal. + if planCtx.getPortalPauseInfo() != nil { + // With pauseInfo is nil, no cleanup function will be added to the stack + // and all clean-up steps will be performed as for normal portals. + planCtx.planner.portal.pauseInfo = nil + // We need this so that the result consumption for this portal cannot be + // paused either. + res.UnsetForPausablePortal() + } } err = ex.server.cfg.DistSQLPlanner.PlanAndRunAll(ctx, evalCtx, planCtx, planner, recv, evalCtxFactory) } diff --git a/pkg/sql/conn_executor_prepare.go b/pkg/sql/conn_executor_prepare.go index cb3603adec46..0fa51958a1c0 100644 --- a/pkg/sql/conn_executor_prepare.go +++ b/pkg/sql/conn_executor_prepare.go @@ -472,7 +472,7 @@ func (ex *connExecutor) execBind( } // Create the new PreparedPortal. - if err := ex.addPortal(ctx, portalName, ps, qargs, columnFormatCodes); err != nil { + if err := ex.addPortal(ctx, portalName, ps, qargs, bindCmd.isInternal, columnFormatCodes); err != nil { return retErr(err) } @@ -493,16 +493,17 @@ func (ex *connExecutor) addPortal( portalName string, stmt *PreparedStatement, qargs tree.QueryArguments, + isInternal bool, outFormats []pgwirebase.FormatCode, ) error { if _, ok := ex.extraTxnState.prepStmtsNamespace.portals[portalName]; ok { - panic(errors.AssertionFailedf("portal already exists: %q", portalName)) + return nil } if cursor := ex.getCursorAccessor().getCursor(tree.Name(portalName)); cursor != nil { - panic(errors.AssertionFailedf("portal already exists as cursor: %q", portalName)) + return nil } - portal, err := ex.makePreparedPortal(ctx, portalName, stmt, qargs, outFormats) + portal, err := ex.makePreparedPortal(ctx, portalName, stmt, qargs, isInternal, outFormats) if err != nil { return err } diff --git a/pkg/sql/conn_io.go b/pkg/sql/conn_io.go index 9875334d861f..a54e8153b03e 100644 --- a/pkg/sql/conn_io.go +++ b/pkg/sql/conn_io.go @@ -269,6 +269,9 @@ type BindStmt struct { // inferred types should reflect that). // If internalArgs is specified, Args and ArgFormatCodes are ignored. internalArgs []tree.Datum + + // isInternal is set ture when the bind stmt is from an internal executor. + isInternal bool } // command implements the Command interface. @@ -811,6 +814,10 @@ type RestrictedCommandResult interface { // GetBulkJobId returns the id of the job for the query, if the query is // IMPORT, BACKUP or RESTORE. GetBulkJobId() uint64 + + // UnsetForPausablePortal is to set the forPausablePortal field to false for + // pgwire.limitedCommandResult, so that the portal becomes un-pausable. + UnsetForPausablePortal() } // DescribeResult represents the result of a Describe command (for either @@ -965,6 +972,10 @@ type streamingCommandResult struct { var _ RestrictedCommandResult = &streamingCommandResult{} var _ CommandResultClose = &streamingCommandResult{} +// UnsetForPausablePortal is part of the sql.RestrictedCommandResult interface. +func (r *streamingCommandResult) UnsetForPausablePortal() { +} + // SetColumns is part of the RestrictedCommandResult interface. func (r *streamingCommandResult) SetColumns(ctx context.Context, cols colinfo.ResultColumns) { // The interface allows for cols to be nil, yet the iterator result expects diff --git a/pkg/sql/internal.go b/pkg/sql/internal.go index 7387a84548ec..399061e44687 100644 --- a/pkg/sql/internal.go +++ b/pkg/sql/internal.go @@ -937,7 +937,7 @@ func (ie *InternalExecutor) execInternal( return nil, err } - if err := stmtBuf.Push(ctx, BindStmt{internalArgs: datums}); err != nil { + if err := stmtBuf.Push(ctx, BindStmt{internalArgs: datums, isInternal: true}); err != nil { return nil, err } diff --git a/pkg/sql/pgwire/command_result.go b/pkg/sql/pgwire/command_result.go index 0658c53ace19..98e53f0579b3 100644 --- a/pkg/sql/pgwire/command_result.go +++ b/pkg/sql/pgwire/command_result.go @@ -118,6 +118,9 @@ type paramStatusUpdate struct { var _ sql.CommandResult = &commandResult{} +// UnsetForPausablePortal is part of the sql.RestrictedCommandResult interface. +func (r *commandResult) UnsetForPausablePortal() {} + // Close is part of the sql.RestrictedCommandResult interface. func (r *commandResult) Close(ctx context.Context, t sql.TransactionStatusIndicator) { r.assertNotReleased() @@ -452,6 +455,8 @@ type limitedCommandResult struct { forPausablePortal bool } +var _ sql.RestrictedCommandResult = &limitedCommandResult{} + // AddRow is part of the sql.RestrictedCommandResult interface. func (r *limitedCommandResult) AddRow(ctx context.Context, row tree.Datums) error { if err := r.commandResult.AddRow(ctx, row); err != nil { @@ -481,6 +486,11 @@ func (r *limitedCommandResult) AddRow(ctx context.Context, row tree.Datums) erro return nil } +// UnsetForPausablePortal is part of the sql.RestrictedCommandResult interface. +func (r *limitedCommandResult) UnsetForPausablePortal() { + r.forPausablePortal = false +} + // SupportsAddBatch is part of the sql.RestrictedCommandResult interface. // TODO(yuzefovich): implement limiting behavior for AddBatch. func (r *limitedCommandResult) SupportsAddBatch() bool { diff --git a/pkg/sql/prepared_stmt.go b/pkg/sql/prepared_stmt.go index 23f8890087a6..8192c45841c4 100644 --- a/pkg/sql/prepared_stmt.go +++ b/pkg/sql/prepared_stmt.go @@ -150,6 +150,7 @@ func (ex *connExecutor) makePreparedPortal( name string, stmt *PreparedStatement, qargs tree.QueryArguments, + isInternal bool, outFormats []pgwirebase.FormatCode, ) (PreparedPortal, error) { portal := PreparedPortal{ @@ -162,7 +163,7 @@ func (ex *connExecutor) makePreparedPortal( // TODO(janexing): maybe we should also add telemetry for the stmt that the // portal hooks on. telemetry.Inc(sqltelemetry.MultipleActivePortalCounter) - if tree.IsReadOnly(stmt.AST) { + if tree.IsReadOnly(stmt.AST) && !isInternal { portal.pauseInfo = &portalPauseInfo{} } }