Skip to content

Commit

Permalink
sql: add restrictions to pausable portals
Browse files Browse the repository at this point in the history
This commit adds several restrictions to pausable portals to ensure that they
work properly with the current changes to the consumer-receiver model.
Specifically, pausable portals must meet the following criteria:

1. Not be internal queries
2. Be read-only queries
3. Not contain sub-queries or post-queries
4. Only use local plans

These restrictions are necessary because the current changes to the
consumer-receiver model only consider the local push-based case.

Release note: None
  • Loading branch information
ZhouXing19 committed Apr 7, 2023
1 parent dcbcca9 commit aa6cd36
Show file tree
Hide file tree
Showing 7 changed files with 100 additions and 5 deletions.
6 changes: 6 additions & 0 deletions pkg/sql/conn_executor.go
Original file line number Diff line number Diff line change
Expand Up @@ -2195,6 +2195,12 @@ func (ex *connExecutor) execCmd() (retErr error) {
Values: portal.Qargs,
}

// If this is the first-time execution of a portal without a limit set,
// it means all rows will be exhausted, so no need to pause this portal.
if tcmd.Limit == 0 && portal.pauseInfo != nil && portal.pauseInfo.curRes == nil {
portal.pauseInfo = nil
}

stmtRes := ex.clientComm.CreateStatementResult(
portal.Stmt.AST,
// The client is using the extended protocol, so no row description is
Expand Down
25 changes: 24 additions & 1 deletion pkg/sql/conn_executor_exec.go
Original file line number Diff line number Diff line change
Expand Up @@ -1565,9 +1565,32 @@ func (ex *connExecutor) dispatchToExecutionEngine(
}

ex.sessionTracing.TracePlanCheckStart(ctx)

distSQLMode := ex.sessionData().DistSQLMode
if planner.pausablePortal != nil {
if len(planner.curPlan.subqueryPlans) == 0 &&
len(planner.curPlan.cascades) == 0 &&
len(planner.curPlan.checkPlans) == 0 {
// We only allow non-distributed plan for pausable portals.
distSQLMode = sessiondatapb.DistSQLOff
} else {
telemetry.Inc(sqltelemetry.SubOrPostQueryStmtsTriedWithPausablePortals)
// We don't allow sub / post queries for pausable portal. Set it back to an
// un-pausable (normal) portal.
// With pauseInfo is nil, no cleanup function will be added to the stack
// and all clean-up steps will be performed as for normal portals.
planner.pausablePortal.pauseInfo = nil
// We need this so that the result consumption for this portal cannot be
// paused either.
if err := res.RevokePortalPausability(); err != nil {
res.SetError(err)
return nil
}
}
}
distributePlan := getPlanDistribution(
ctx, planner.Descriptors().HasUncommittedTypes(),
ex.sessionData().DistSQLMode, planner.curPlan.main,
distSQLMode, planner.curPlan.main,
)
ex.sessionTracing.TracePlanCheckEnd(ctx, nil, distributePlan.WillDistribute())

Expand Down
10 changes: 10 additions & 0 deletions pkg/sql/conn_io.go
Original file line number Diff line number Diff line change
Expand Up @@ -813,6 +813,11 @@ type RestrictedCommandResult interface {
// released yet. It should be used only in clean-up stages of a pausable
// portal.
ErrAllowReleased() error

// RevokePortalPausability is to make a portal un-pausable. It is called when
// we find the underlying query is not supported for a pausable portal.
// This method is implemented only by pgwire.limitedCommandResult.
RevokePortalPausability() error
}

// DescribeResult represents the result of a Describe command (for either
Expand Down Expand Up @@ -986,6 +991,11 @@ func (r *streamingCommandResult) ErrAllowReleased() error {
return r.err
}

// RevokePortalPausability is part of the sql.RestrictedCommandResult interface.
func (r *streamingCommandResult) RevokePortalPausability() error {
return errors.AssertionFailedf("forPausablePortal is for limitedCommandResult only")
}

// SetColumns is part of the RestrictedCommandResult interface.
func (r *streamingCommandResult) SetColumns(ctx context.Context, cols colinfo.ResultColumns) {
// The interface allows for cols to be nil, yet the iterator result expects
Expand Down
13 changes: 13 additions & 0 deletions pkg/sql/pgwire/command_result.go
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,11 @@ type paramStatusUpdate struct {

var _ sql.CommandResult = &commandResult{}

// RevokePortalPausability is part of the sql.RestrictedCommandResult interface.
func (r *commandResult) RevokePortalPausability() error {
return errors.AssertionFailedf("RevokePortalPausability is only implemented by limitedCommandResult only")
}

// Close is part of the sql.RestrictedCommandResult interface.
func (r *commandResult) Close(ctx context.Context, t sql.TransactionStatusIndicator) {
r.assertNotReleased()
Expand Down Expand Up @@ -478,6 +483,8 @@ type limitedCommandResult struct {
portalPausablity sql.PortalPausablity
}

var _ sql.RestrictedCommandResult = &limitedCommandResult{}

// AddRow is part of the sql.RestrictedCommandResult interface.
func (r *limitedCommandResult) AddRow(ctx context.Context, row tree.Datums) error {
if err := r.commandResult.AddRow(ctx, row); err != nil {
Expand Down Expand Up @@ -507,6 +514,12 @@ func (r *limitedCommandResult) AddRow(ctx context.Context, row tree.Datums) erro
return nil
}

// RevokePortalPausability is part of the sql.RestrictedCommandResult interface.
func (r *limitedCommandResult) RevokePortalPausability() error {
r.portalPausablity = sql.NotPausablePortalForUnsupportedStmt
return nil
}

// SupportsAddBatch is part of the sql.RestrictedCommandResult interface.
// TODO(yuzefovich): implement limiting behavior for AddBatch.
func (r *limitedCommandResult) SupportsAddBatch() bool {
Expand Down
15 changes: 11 additions & 4 deletions pkg/sql/prepared_stmt.go
Original file line number Diff line number Diff line change
Expand Up @@ -185,11 +185,18 @@ func (ex *connExecutor) makePreparedPortal(
OutFormats: outFormats,
}

if ex.sessionData().MultipleActivePortalsEnabled {
if ex.sessionData().MultipleActivePortalsEnabled && ex.executorType != executorTypeInternal {
telemetry.Inc(sqltelemetry.StmtsTriedWithPausablePortals)
portal.pauseInfo = &portalPauseInfo{}
portal.pauseInfo.dispatchToExecutionEngine.queryStats = &topLevelQueryStats{}
portal.portalPausablity = PausablePortal
if tree.IsAllowedToPause(stmt.AST) {
portal.pauseInfo = &portalPauseInfo{}
portal.pauseInfo.dispatchToExecutionEngine.queryStats = &topLevelQueryStats{}
portal.portalPausablity = PausablePortal
} else {
telemetry.Inc(sqltelemetry.NotReadOnlyStmtsTriedWithPausablePortals)
// We have set the session variable multiple_active_portals_enabled to
// true, but we don't support the underlying query for a pausable portal.
portal.portalPausablity = NotPausablePortalForUnsupportedStmt
}
}
return portal, portal.accountForCopy(ctx, &ex.extraTxnState.prepStmtsNamespaceMemAcc, name)
}
Expand Down
24 changes: 24 additions & 0 deletions pkg/sql/sem/tree/stmt.go
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,30 @@ type canModifySchema interface {
modifiesSchema() bool
}

// IsAllowedToPause returns true if the stmt cannot either modify the schema or
// write data.
// This function is to gate the queries allowed for pausable portals.
// TODO(janexing): We should be more accurate about the stmt selection here.
// Now we only allow SELECT, but is it too strict? And how to filter out
// SELECT with data writes / schema changes?
func IsAllowedToPause(stmt Statement) bool {
if stmt != nil && !CanModifySchema(stmt) && !CanWriteData(stmt) {
switch t := stmt.(type) {
case *Select:
if t.With != nil {
ctes := t.With.CTEList
for _, cte := range ctes {
if !IsAllowedToPause(cte.Stmt) {
return false
}
}
}
return true
}
}
return false
}

// CanModifySchema returns true if the statement can modify
// the database schema.
func CanModifySchema(stmt Statement) bool {
Expand Down
12 changes: 12 additions & 0 deletions pkg/sql/sqltelemetry/pgwire.go
Original file line number Diff line number Diff line change
Expand Up @@ -74,3 +74,15 @@ var FlushRequestCounter = telemetry.GetCounterOnce("pgwire.command.flush")
// multiple_active_portals_enabled has been set to true.
// The statement might not satisfy the restriction for a pausable portal.
var StmtsTriedWithPausablePortals = telemetry.GetCounterOnce("pgwire.pausable_portal_stmts")

// NotReadOnlyStmtsTriedWithPausablePortals is to be incremented every time
// there's a not-internal not-read-only statement executed with a pgwire portal
// and the session variable multiple_active_portals_enabled has been set to true.
// In this case the execution cannot be paused.
var NotReadOnlyStmtsTriedWithPausablePortals = telemetry.GetCounterOnce("pgwire.pausable_portal_not_read_only_stmts")

// SubOrPostQueryStmtsTriedWithPausablePortals is to be incremented every time
// there's a not-internal statement with post or sub queries executed with a
// pgwire portal and the session variable multiple_active_portals_enabled has
// been set to true. In this case the execution cannot be paused.
var SubOrPostQueryStmtsTriedWithPausablePortals = telemetry.GetCounterOnce("pgwire.pausable_portal_stmts_with_sub_or_post_queries")

0 comments on commit aa6cd36

Please sign in to comment.