diff --git a/pkg/sql/conn_executor.go b/pkg/sql/conn_executor.go index ae2d55648178..896455d1e5e3 100644 --- a/pkg/sql/conn_executor.go +++ b/pkg/sql/conn_executor.go @@ -2195,6 +2195,12 @@ func (ex *connExecutor) execCmd() (retErr error) { Values: portal.Qargs, } + // If this is the first-time execution of a portal without a limit set, + // it means all rows will be exhausted, so no need to pause this portal. + if tcmd.Limit == 0 && portal.pauseInfo != nil && portal.pauseInfo.curRes == nil { + portal.pauseInfo = nil + } + stmtRes := ex.clientComm.CreateStatementResult( portal.Stmt.AST, // The client is using the extended protocol, so no row description is diff --git a/pkg/sql/conn_executor_exec.go b/pkg/sql/conn_executor_exec.go index d865774a2624..c9d32d6275b1 100644 --- a/pkg/sql/conn_executor_exec.go +++ b/pkg/sql/conn_executor_exec.go @@ -1565,9 +1565,32 @@ func (ex *connExecutor) dispatchToExecutionEngine( } ex.sessionTracing.TracePlanCheckStart(ctx) + + distSQLMode := ex.sessionData().DistSQLMode + if planner.pausablePortal != nil { + if len(planner.curPlan.subqueryPlans) == 0 && + len(planner.curPlan.cascades) == 0 && + len(planner.curPlan.checkPlans) == 0 { + // We only allow non-distributed plan for pausable portals. + distSQLMode = sessiondatapb.DistSQLOff + } else { + telemetry.Inc(sqltelemetry.SubOrPostQueryStmtsTriedWithPausablePortals) + // We don't allow sub / post queries for pausable portal. Set it back to an + // un-pausable (normal) portal. + // With pauseInfo is nil, no cleanup function will be added to the stack + // and all clean-up steps will be performed as for normal portals. + planner.pausablePortal.pauseInfo = nil + // We need this so that the result consumption for this portal cannot be + // paused either. + if err := res.RevokePortalPausability(); err != nil { + res.SetError(err) + return nil + } + } + } distributePlan := getPlanDistribution( ctx, planner.Descriptors().HasUncommittedTypes(), - ex.sessionData().DistSQLMode, planner.curPlan.main, + distSQLMode, planner.curPlan.main, ) ex.sessionTracing.TracePlanCheckEnd(ctx, nil, distributePlan.WillDistribute()) diff --git a/pkg/sql/conn_io.go b/pkg/sql/conn_io.go index 7656884005e5..ab76542887e5 100644 --- a/pkg/sql/conn_io.go +++ b/pkg/sql/conn_io.go @@ -813,6 +813,11 @@ type RestrictedCommandResult interface { // released yet. It should be used only in clean-up stages of a pausable // portal. ErrAllowReleased() error + + // RevokePortalPausability is to make a portal un-pausable. It is called when + // we find the underlying query is not supported for a pausable portal. + // This method is implemented only by pgwire.limitedCommandResult. + RevokePortalPausability() error } // DescribeResult represents the result of a Describe command (for either @@ -986,6 +991,11 @@ func (r *streamingCommandResult) ErrAllowReleased() error { return r.err } +// RevokePortalPausability is part of the sql.RestrictedCommandResult interface. +func (r *streamingCommandResult) RevokePortalPausability() error { + return errors.AssertionFailedf("forPausablePortal is for limitedCommandResult only") +} + // SetColumns is part of the RestrictedCommandResult interface. func (r *streamingCommandResult) SetColumns(ctx context.Context, cols colinfo.ResultColumns) { // The interface allows for cols to be nil, yet the iterator result expects diff --git a/pkg/sql/pgwire/command_result.go b/pkg/sql/pgwire/command_result.go index 459fd27bbfaa..254e589ea2d0 100644 --- a/pkg/sql/pgwire/command_result.go +++ b/pkg/sql/pgwire/command_result.go @@ -118,6 +118,11 @@ type paramStatusUpdate struct { var _ sql.CommandResult = &commandResult{} +// RevokePortalPausability is part of the sql.RestrictedCommandResult interface. +func (r *commandResult) RevokePortalPausability() error { + return errors.AssertionFailedf("RevokePortalPausability is only implemented by limitedCommandResult only") +} + // Close is part of the sql.RestrictedCommandResult interface. func (r *commandResult) Close(ctx context.Context, t sql.TransactionStatusIndicator) { r.assertNotReleased() @@ -478,6 +483,8 @@ type limitedCommandResult struct { portalPausablity sql.PortalPausablity } +var _ sql.RestrictedCommandResult = &limitedCommandResult{} + // AddRow is part of the sql.RestrictedCommandResult interface. func (r *limitedCommandResult) AddRow(ctx context.Context, row tree.Datums) error { if err := r.commandResult.AddRow(ctx, row); err != nil { @@ -507,6 +514,12 @@ func (r *limitedCommandResult) AddRow(ctx context.Context, row tree.Datums) erro return nil } +// RevokePortalPausability is part of the sql.RestrictedCommandResult interface. +func (r *limitedCommandResult) RevokePortalPausability() error { + r.portalPausablity = sql.NotPausablePortalForUnsupportedStmt + return nil +} + // SupportsAddBatch is part of the sql.RestrictedCommandResult interface. // TODO(yuzefovich): implement limiting behavior for AddBatch. func (r *limitedCommandResult) SupportsAddBatch() bool { diff --git a/pkg/sql/prepared_stmt.go b/pkg/sql/prepared_stmt.go index 6f8cea3469c2..16ca739a4f55 100644 --- a/pkg/sql/prepared_stmt.go +++ b/pkg/sql/prepared_stmt.go @@ -185,11 +185,18 @@ func (ex *connExecutor) makePreparedPortal( OutFormats: outFormats, } - if ex.sessionData().MultipleActivePortalsEnabled { + if ex.sessionData().MultipleActivePortalsEnabled && ex.executorType != executorTypeInternal { telemetry.Inc(sqltelemetry.StmtsTriedWithPausablePortals) - portal.pauseInfo = &portalPauseInfo{} - portal.pauseInfo.dispatchToExecutionEngine.queryStats = &topLevelQueryStats{} - portal.portalPausablity = PausablePortal + if tree.IsAllowedToPause(stmt.AST) { + portal.pauseInfo = &portalPauseInfo{} + portal.pauseInfo.dispatchToExecutionEngine.queryStats = &topLevelQueryStats{} + portal.portalPausablity = PausablePortal + } else { + telemetry.Inc(sqltelemetry.NotReadOnlyStmtsTriedWithPausablePortals) + // We have set the session variable multiple_active_portals_enabled to + // true, but we don't support the underlying query for a pausable portal. + portal.portalPausablity = NotPausablePortalForUnsupportedStmt + } } return portal, portal.accountForCopy(ctx, &ex.extraTxnState.prepStmtsNamespaceMemAcc, name) } diff --git a/pkg/sql/sem/tree/stmt.go b/pkg/sql/sem/tree/stmt.go index 723fb01f5b3d..686ccd1d1b0c 100644 --- a/pkg/sql/sem/tree/stmt.go +++ b/pkg/sql/sem/tree/stmt.go @@ -121,6 +121,30 @@ type canModifySchema interface { modifiesSchema() bool } +// IsAllowedToPause returns true if the stmt cannot either modify the schema or +// write data. +// This function is to gate the queries allowed for pausable portals. +// TODO(janexing): We should be more accurate about the stmt selection here. +// Now we only allow SELECT, but is it too strict? And how to filter out +// SELECT with data writes / schema changes? +func IsAllowedToPause(stmt Statement) bool { + if stmt != nil && !CanModifySchema(stmt) && !CanWriteData(stmt) { + switch t := stmt.(type) { + case *Select: + if t.With != nil { + ctes := t.With.CTEList + for _, cte := range ctes { + if !IsAllowedToPause(cte.Stmt) { + return false + } + } + } + return true + } + } + return false +} + // CanModifySchema returns true if the statement can modify // the database schema. func CanModifySchema(stmt Statement) bool { diff --git a/pkg/sql/sqltelemetry/pgwire.go b/pkg/sql/sqltelemetry/pgwire.go index 3cacbfad4cdd..fd766c88bbba 100644 --- a/pkg/sql/sqltelemetry/pgwire.go +++ b/pkg/sql/sqltelemetry/pgwire.go @@ -74,3 +74,15 @@ var FlushRequestCounter = telemetry.GetCounterOnce("pgwire.command.flush") // multiple_active_portals_enabled has been set to true. // The statement might not satisfy the restriction for a pausable portal. var StmtsTriedWithPausablePortals = telemetry.GetCounterOnce("pgwire.pausable_portal_stmts") + +// NotReadOnlyStmtsTriedWithPausablePortals is to be incremented every time +// there's a not-internal not-read-only statement executed with a pgwire portal +// and the session variable multiple_active_portals_enabled has been set to true. +// In this case the execution cannot be paused. +var NotReadOnlyStmtsTriedWithPausablePortals = telemetry.GetCounterOnce("pgwire.pausable_portal_not_read_only_stmts") + +// SubOrPostQueryStmtsTriedWithPausablePortals is to be incremented every time +// there's a not-internal statement with post or sub queries executed with a +// pgwire portal and the session variable multiple_active_portals_enabled has +// been set to true. In this case the execution cannot be paused. +var SubOrPostQueryStmtsTriedWithPausablePortals = telemetry.GetCounterOnce("pgwire.pausable_portal_stmts_with_sub_or_post_queries")