From d25b624165381de6547167a9eeb5f805f6644d08 Mon Sep 17 00:00:00 2001 From: Jane Xing Date: Sun, 12 Mar 2023 21:49:03 -0400 Subject: [PATCH] sql: set the clean-up steps for pausable portal This is part of the implementation of multiple active portals. To enable executing portals interleavingly, we need to persist certain resources for it, and delay their clean-up till we close the portal. Also, these resources don't need to be re-setup when we re-executing a portal. Thus we stores these cleanup steps in the `__Cleanup` function stacks in `portalPauseInfo`, and they are called when 1. sql txn is commited; 2. sql txn is rollbacked; 3. conn executor is closed. The cleanup functions should be called according to the original order of a normal portal. Since a portal's execution is via the `execPortal() -> execStmtInOpenState () -> dispatchToExecutionEngine() -> flow.Run()` function flow, we categorized the cleanup functions accordingly into 4 "layers": `exhaustPortal`, `execStmtCleanup` `dispatchToExecEngCleanup` and `flowCleanup`. The cleanup is always LIFO, i.e. following the `flowCleanup -> dispatchToExecEngCleanup -> execStmtCleanup -> exhaustPortal` order. Also, when there's error happens in each layer, cleanup the current and proceeding layers. e.g. if we encounter an error in `execStmtInOpenState()`, do `flowCleanup` and `dispatchToExecEngCleanup` (proceeding) and then `execStmtCleanup` (current), and return the error to `execPortal()`, where `exhaustPortal` will eventually be called. We also pass as reference the PreparedPortal to the planner in `execStmtInOpenState()`, so that the portal's flow can be set and reused. Release note: None --- pkg/sql/conn_executor.go | 51 +++-- pkg/sql/conn_executor_exec.go | 378 ++++++++++++++++++++++++------- pkg/sql/conn_io.go | 9 + pkg/sql/distsql_running.go | 48 +++- pkg/sql/instrumentation.go | 43 ++++ pkg/sql/pgwire/command_result.go | 5 + pkg/sql/prepared_stmt.go | 4 + 7 files changed, 440 insertions(+), 98 deletions(-) diff --git a/pkg/sql/conn_executor.go b/pkg/sql/conn_executor.go index 47599cff4343..74f8c916e9f7 100644 --- a/pkg/sql/conn_executor.go +++ b/pkg/sql/conn_executor.go @@ -1143,6 +1143,14 @@ func (ex *connExecutor) close(ctx context.Context, closeType closeType) { txnEvType = txnRollback } + // Close all portals, otherwise there will be leftover bytes. + ex.extraTxnState.prepStmtsNamespace.closeAllPortals( + ctx, &ex.extraTxnState.prepStmtsNamespaceMemAcc, + ) + ex.extraTxnState.prepStmtsNamespaceAtTxnRewindPos.closeAllPortals( + ctx, &ex.extraTxnState.prepStmtsNamespaceMemAcc, + ) + if closeType == normalClose { // We'll cleanup the SQL txn by creating a non-retriable (commit:true) event. // This event is guaranteed to be accepted in every state. @@ -1670,6 +1678,15 @@ func (ns prepStmtNamespace) HasPortal(s string) bool { return ok } +func (ns prepStmtNamespace) closeAllPortals( + ctx context.Context, prepStmtsNamespaceMemAcc *mon.BoundAccount, +) { + for name, p := range ns.portals { + p.close(ctx, prepStmtsNamespaceMemAcc, name) + delete(ns.portals, name) + } +} + // MigratablePreparedStatements returns a mapping of all prepared statements. func (ns prepStmtNamespace) MigratablePreparedStatements() []sessiondatapb.MigratableSession_PreparedStatement { ret := make([]sessiondatapb.MigratableSession_PreparedStatement, 0, len(ns.prepStmts)) @@ -1721,10 +1738,8 @@ func (ns *prepStmtNamespace) resetTo( p.decRef(ctx) delete(ns.prepStmts, name) } - for name, p := range ns.portals { - p.close(ctx, prepStmtsNamespaceMemAcc, name) - delete(ns.portals, name) - } + + ns.closeAllPortals(ctx, prepStmtsNamespaceMemAcc) for name, ps := range to.prepStmts { ps.incRef(ctx) @@ -1761,10 +1776,9 @@ func (ex *connExecutor) resetExtraTxnState(ctx context.Context, ev txnEvent) { } // Close all portals. - for name, p := range ex.extraTxnState.prepStmtsNamespace.portals { - p.close(ctx, &ex.extraTxnState.prepStmtsNamespaceMemAcc, name) - delete(ex.extraTxnState.prepStmtsNamespace.portals, name) - } + ex.extraTxnState.prepStmtsNamespace.closeAllPortals( + ctx, &ex.extraTxnState.prepStmtsNamespaceMemAcc, + ) // Close all cursors. if err := ex.extraTxnState.sqlCursors.closeAll(false /* errorOnWithHold */); err != nil { @@ -1775,10 +1789,9 @@ func (ex *connExecutor) resetExtraTxnState(ctx context.Context, ev txnEvent) { switch ev.eventType { case txnCommit, txnRollback: - for name, p := range ex.extraTxnState.prepStmtsNamespaceAtTxnRewindPos.portals { - p.close(ctx, &ex.extraTxnState.prepStmtsNamespaceMemAcc, name) - delete(ex.extraTxnState.prepStmtsNamespaceAtTxnRewindPos.portals, name) - } + ex.extraTxnState.prepStmtsNamespaceAtTxnRewindPos.closeAllPortals( + ctx, &ex.extraTxnState.prepStmtsNamespaceMemAcc, + ) ex.extraTxnState.savepoints.clear() ex.onTxnFinish(ctx, ev) case txnRestart: @@ -1925,7 +1938,6 @@ func (ex *connExecutor) run( return err } } - } // errDrainingComplete is returned by execCmd when the connExecutor previously got @@ -1997,7 +2009,7 @@ func (ex *connExecutor) execCmd() (retErr error) { (tcmd.LastInBatchBeforeShowCommitTimestamp || tcmd.LastInBatch || !implicitTxnForBatch) ev, payload, err = ex.execStmt( - ctx, tcmd.Statement, nil /* prepared */, nil /* pinfo */, stmtRes, canAutoCommit, + ctx, tcmd.Statement, nil /* portal */, nil /* pinfo */, stmtRes, canAutoCommit, ) return err @@ -2067,6 +2079,9 @@ func (ex *connExecutor) execCmd() (retErr error) { ex.implicitTxn(), portal.portalPausablity, ) + if portal.pauseInfo != nil { + portal.pauseInfo.curRes = stmtRes + } res = stmtRes // In the extended protocol, autocommit is not always allowed. The postgres @@ -2085,6 +2100,7 @@ func (ex *connExecutor) execCmd() (retErr error) { // - ex.statsCollector merely contains a copy of the times, that // was created when the statement started executing (via the // reset() method). + // TODO(sql-sessions): fix the phase time for pausable portals. ex.statsCollector.PhaseTimes().SetSessionPhaseTime(sessionphase.SessionQueryServiced, timeutil.Now()) if err != nil { return err @@ -3486,8 +3502,13 @@ func (ex *connExecutor) txnStateTransitionsApplyWrapper( } fallthrough - case txnRestart, txnRollback: + case txnRestart: + ex.resetExtraTxnState(ex.Ctx(), advInfo.txnEvent) + case txnRollback: ex.resetExtraTxnState(ex.Ctx(), advInfo.txnEvent) + ex.extraTxnState.prepStmtsNamespaceAtTxnRewindPos.closeAllPortals( + ex.Ctx(), &ex.extraTxnState.prepStmtsNamespaceMemAcc, + ) default: return advanceInfo{}, errors.AssertionFailedf( "unexpected event: %v", errors.Safe(advInfo.txnEvent)) diff --git a/pkg/sql/conn_executor_exec.go b/pkg/sql/conn_executor_exec.go index 09f4094e9cd5..5ef8ef64926d 100644 --- a/pkg/sql/conn_executor_exec.go +++ b/pkg/sql/conn_executor_exec.go @@ -95,7 +95,7 @@ const numTxnRetryErrors = 3 func (ex *connExecutor) execStmt( ctx context.Context, parserStmt parser.Statement, - prepared *PreparedStatement, + portal *PreparedPortal, pinfo *tree.PlaceholderInfo, res RestrictedCommandResult, canAutoCommit bool, @@ -133,8 +133,12 @@ func (ex *connExecutor) execStmt( ev, payload = ex.execStmtInNoTxnState(ctx, ast, res) case stateOpen: - err = ex.execWithProfiling(ctx, ast, prepared, func(ctx context.Context) error { - ev, payload, err = ex.execStmtInOpenState(ctx, parserStmt, prepared, pinfo, res, canAutoCommit) + var preparedStmt *PreparedStatement + if portal != nil { + preparedStmt = portal.Stmt + } + err = ex.execWithProfiling(ctx, ast, preparedStmt, func(ctx context.Context) error { + ev, payload, err = ex.execStmtInOpenState(ctx, parserStmt, portal, pinfo, res, canAutoCommit) return err }) switch ev.(type) { @@ -202,6 +206,22 @@ func (ex *connExecutor) execPortal( pinfo *tree.PlaceholderInfo, canAutoCommit bool, ) (ev fsm.Event, payload fsm.EventPayload, err error) { + defer func() { + if portal.isPausable() { + if !portal.pauseInfo.exhaustPortal.isComplete { + portal.pauseInfo.exhaustPortal.appendFunc(namedFunc{fName: "exhaust portal", f: func() { + ex.exhaustPortal(portalName) + }}) + portal.pauseInfo.exhaustPortal.isComplete = true + } + // If we encountered an error when executing a pausable portal, clean up + // the retained resources. + if err != nil { + portal.pauseInfo.cleanupAll() + } + } + }() + switch ex.machine.CurState().(type) { case stateOpen: // We're about to execute the statement in an open state which @@ -223,23 +243,19 @@ func (ex *connExecutor) execPortal( if portal.exhausted { return nil, nil, nil } - ev, payload, err = ex.execStmt(ctx, portal.Stmt.Statement, portal.Stmt, pinfo, stmtRes, canAutoCommit) - // Portal suspension is supported via a "side" state machine - // (see pgwire.limitedCommandResult for details), so when - // execStmt returns, we know for sure that the portal has been - // executed to completion, thus, it is exhausted. - // Note that the portal is considered exhausted regardless of - // the fact whether an error occurred or not - if it did, we - // still don't want to re-execute the portal from scratch. + ev, payload, err = ex.execStmt(ctx, portal.Stmt.Statement, &portal, pinfo, stmtRes, canAutoCommit) + // For a non-pausable portal, it is considered exhausted regardless of the + // fact whether an error occurred or not - if it did, we still don't want + // to re-execute the portal from scratch. // The current statement may have just closed and deleted the portal, // so only exhaust it if it still exists. - if _, ok := ex.extraTxnState.prepStmtsNamespace.portals[portalName]; ok { - ex.exhaustPortal(portalName) + if _, ok := ex.extraTxnState.prepStmtsNamespace.portals[portalName]; ok && !portal.isPausable() { + defer ex.exhaustPortal(portalName) } return ev, payload, err default: - return ex.execStmt(ctx, portal.Stmt.Statement, portal.Stmt, pinfo, stmtRes, canAutoCommit) + return ex.execStmt(ctx, portal.Stmt.Statement, &portal, pinfo, stmtRes, canAutoCommit) } } @@ -259,15 +275,59 @@ func (ex *connExecutor) execPortal( func (ex *connExecutor) execStmtInOpenState( ctx context.Context, parserStmt parser.Statement, - prepared *PreparedStatement, + portal *PreparedPortal, pinfo *tree.PlaceholderInfo, res RestrictedCommandResult, canAutoCommit bool, ) (retEv fsm.Event, retPayload fsm.EventPayload, retErr error) { - ctx, sp := tracing.EnsureChildSpan(ctx, ex.server.cfg.AmbientCtx.Tracer, "sql query") - // TODO(andrei): Consider adding the placeholders as tags too. - sp.SetTag("statement", attribute.StringValue(parserStmt.SQL)) - defer sp.Finish() + isPausablePortal := portal != nil && portal.isPausable() + // For pausable portals, we delay the clean-up until closing the portal by + // adding the function to the execStmtInOpenStateCleanup. + // Otherwise, perform the clean-up step within every execution. + processCleanupFunc := func(fName string, f func()) { + if !isPausablePortal { + f() + } else if !portal.pauseInfo.execStmtInOpenStateCleanup.isComplete { + portal.pauseInfo.execStmtInOpenStateCleanup.appendFunc(namedFunc{ + fName: fName, + f: f, + }) + } + } + defer func() { + // This is the first defer, so it will always be called after any cleanup + // func being added to the stack from the defers below. + if portal != nil && portal.isPausable() && !portal.pauseInfo.execStmtInOpenStateCleanup.isComplete { + portal.pauseInfo.execStmtInOpenStateCleanup.isComplete = true + } + // If there's any error, do the cleanup right here. + if (retErr != nil || payloadHasError(retPayload)) && isPausablePortal { + portal.pauseInfo.flowCleanup.run() + portal.pauseInfo.dispatchToExecEngCleanup.run() + portal.pauseInfo.execStmtInOpenStateCleanup.run() + } + }() + + var sp *tracing.Span + if !isPausablePortal || !portal.pauseInfo.execStmtInOpenStateCleanup.isComplete { + ctx, sp = tracing.EnsureChildSpan(ctx, ex.server.cfg.AmbientCtx.Tracer, "sql query") + // TODO(andrei): Consider adding the placeholders as tags too. + sp.SetTag("statement", attribute.StringValue(parserStmt.SQL)) + if isPausablePortal { + portal.pauseInfo.sp = sp + } + defer func() { + spToFinish := sp + // For pausable portals, we need to persist the span as it shares the ctx + // with the underlying flow. If it gets cleaned up before we close the + // flow, we will hit `span used after finished` whenever we log an event + // when cleaning up the flow. + if isPausablePortal { + spToFinish = portal.pauseInfo.sp + } + processCleanupFunc("cleanup span", spToFinish.Finish) + }() + } ast := parserStmt.AST ctx = withStatement(ctx, ast) @@ -277,7 +337,17 @@ func (ex *connExecutor) execStmtInOpenState( } var stmt Statement - queryID := ex.generateID() + var queryID clusterunique.ID + + if isPausablePortal { + if !portal.pauseInfo.isQueryIDSet() { + portal.pauseInfo.queryID = ex.generateID() + } + queryID = portal.pauseInfo.queryID + } else { + queryID = ex.generateID() + } + // Update the deadline on the transaction based on the collections. err := ex.extraTxnState.descCollection.MaybeUpdateDeadline(ctx, ex.state.mu.txn) if err != nil { @@ -285,39 +355,53 @@ func (ex *connExecutor) execStmtInOpenState( } os := ex.machine.CurState().(stateOpen) - isExtendedProtocol := prepared != nil + isExtendedProtocol := portal != nil && portal.Stmt != nil if isExtendedProtocol { - stmt = makeStatementFromPrepared(prepared, queryID) + stmt = makeStatementFromPrepared(portal.Stmt, queryID) } else { stmt = makeStatement(parserStmt, queryID) } - ex.incrementStartedStmtCounter(ast) - defer func() { - if retErr == nil && !payloadHasError(retPayload) { - ex.incrementExecutedStmtCounter(ast) - } - }() - - func(st *txnState) { - st.mu.Lock() - defer st.mu.Unlock() - st.mu.stmtCount++ - }(&ex.state) - var queryTimeoutTicker *time.Timer var txnTimeoutTicker *time.Timer queryTimedOut := false txnTimedOut := false - // queryDoneAfterFunc and txnDoneAfterFunc will be allocated only when // queryTimeoutTicker or txnTimeoutTicker is non-nil. var queryDoneAfterFunc chan struct{} var txnDoneAfterFunc chan struct{} - var cancelQuery context.CancelFunc - ctx, cancelQuery = contextutil.WithCancel(ctx) - ex.addActiveQuery(parserStmt, pinfo, queryID, cancelQuery) + ctx, cancelQuery := contextutil.WithCancel(ctx) + + addActiveQuery := func() { + ex.incrementStartedStmtCounter(ast) + func(st *txnState) { + st.mu.Lock() + defer st.mu.Unlock() + st.mu.stmtCount++ + }(&ex.state) + ex.addActiveQuery(parserStmt, pinfo, queryID, cancelQuery) + } + + // For pausable portal, the active query needs to be set up only when + // the portal is executed for the first time. + if !isPausablePortal || !portal.pauseInfo.execStmtInOpenStateCleanup.isComplete { + addActiveQuery() + if isPausablePortal { + portal.pauseInfo.cancelQueryFunc = cancelQuery + portal.pauseInfo.cancelQueryCtx = ctx + } + defer func() { + processCleanupFunc( + "increment executed stmt cnt", + func() { + if retErr == nil && !payloadHasError(retPayload) { + ex.incrementExecutedStmtCounter(ast) + } + }, + ) + }() + } // Make sure that we always unregister the query. It also deals with // overwriting res.Error to a more user-friendly message in case of query @@ -338,25 +422,46 @@ func (ex *connExecutor) execStmtInOpenState( } } - // Detect context cancelation and overwrite whatever error might have been - // set on the result before. The idea is that once the query's context is - // canceled, all sorts of actors can detect the cancelation and set all - // sorts of errors on the result. Rather than trying to impose discipline - // in that jungle, we just overwrite them all here with an error that's - // nicer to look at for the client. - if res != nil && ctx.Err() != nil && res.Err() != nil { - // Even in the cases where the error is a retryable error, we want to - // intercept the event and payload returned here to ensure that the query - // is not retried. - retEv = eventNonRetriableErr{ - IsCommit: fsm.FromBool(isCommit(ast)), + processCleanupFunc("cancel query", func() { + cancelQueryCtx := ctx + cancelQueryFunc := cancelQuery + if isPausablePortal { + cancelQueryCtx = portal.pauseInfo.cancelQueryCtx + cancelQueryFunc = portal.pauseInfo.cancelQueryFunc } - res.SetError(cancelchecker.QueryCanceledError) - retPayload = eventNonRetriableErrPayload{err: cancelchecker.QueryCanceledError} - } + // Detect context cancelation and overwrite whatever error might have been + // set on the result before. The idea is that once the query's context is + // canceled, all sorts of actors can detect the cancelation and set all + // sorts of errors on the result. Rather than trying to impose discipline + // in that jungle, we just overwrite them all here with an error that's + // nicer to look at for the client. + var resToPushErr RestrictedCommandResult + resToPushErr = res + // For pausable portals, we retain the query but update the result for + // each execution. When the query context is cancelled and we're in the + // middle of an portal execution, push the error to the current result. + if isPausablePortal { + resToPushErr = portal.pauseInfo.curRes + } + // Explaining why we need to check if the result has been released: + // If we're checking for a portal, it can happen after all executions have + // finished, the result has been released, and we're simply closing the + // connExecutor. We should allow this case, so don't want to return an + // assertion error in resToPushErr.Err() when the result is released. + if resToPushErr != nil && cancelQueryCtx.Err() != nil && !resToPushErr.IsReleased() && resToPushErr.Err() != nil { + // Even in the cases where the error is a retryable error, we want to + // intercept the event and payload returned here to ensure that the query + // is not retried. + retEv = eventNonRetriableErr{ + IsCommit: fsm.FromBool(isCommit(ast)), + } + resToPushErr.SetError(cancelchecker.QueryCanceledError) + retPayload = eventNonRetriableErrPayload{err: cancelchecker.QueryCanceledError} + } + ex.removeActiveQuery(queryID, ast) + cancelQueryFunc() + }) - ex.removeActiveQuery(queryID, ast) - cancelQuery() if ex.executorType != executorTypeInternal { ex.metrics.EngineMetrics.SQLActiveStatements.Dec(1) } @@ -480,25 +585,55 @@ func (ex *connExecutor) execStmtInOpenState( } var needFinish bool - ctx, needFinish = ih.Setup( - ctx, ex.server.cfg, ex.statsCollector, p, ex.stmtDiagnosticsRecorder, - stmt.StmtNoConstants, os.ImplicitTxn.Get(), ex.extraTxnState.shouldCollectTxnExecutionStats, - ) + // For pausable portal, the instrumentation helper needs to be set up only when + // the portal is executed for the first time. + if !isPausablePortal || portal.pauseInfo.ihWrapper == nil { + ctx, needFinish = ih.Setup( + ctx, ex.server.cfg, ex.statsCollector, p, ex.stmtDiagnosticsRecorder, + stmt.StmtNoConstants, os.ImplicitTxn.Get(), ex.extraTxnState.shouldCollectTxnExecutionStats, + ) + } + // For pausable portals, we need to persist the instrumentationHelper as it + // shares the ctx with the underlying flow. If it got cleaned up before we + // clean up the flow, we will hit `span used after finished` whenever we log + // an event when cleaning up the flow. + // We need this seemingly weird wrapper here because we set the planner's ih + // with its pointer. However, for pausable portal, we'd like to persist the + // ih and reuse it for all re-executions. So the planner's ih and the portal's + // ih should never have the same address, otherwise changing the former will + // change the latter, and we will never be able to persist it. + if isPausablePortal { + if portal.pauseInfo.ihWrapper == nil { + portal.pauseInfo.ihWrapper = &instrumentationHelperWrapper{ + *ih, + } + } else { + p.instrumentation = portal.pauseInfo.ihWrapper.ih + } + } if needFinish { sql := stmt.SQL defer func() { - retErr = ih.Finish( - ex.server.cfg, - ex.statsCollector, - &ex.extraTxnState.accumulatedStats, - ih.collectExecStats, - p, - ast, - sql, - res, - retPayload, - retErr, - ) + processCleanupFunc("finish instrumentation helper", func() { + // We need this weird thing because we need to make sure we're closing + // the correct instrumentation helper for the paused portal. + ihToFinish := *ih + if isPausablePortal && portal.pauseInfo.ihWrapper != nil { + ihToFinish = portal.pauseInfo.ihWrapper.ih + } + retErr = ihToFinish.Finish( + ex.server.cfg, + ex.statsCollector, + &ex.extraTxnState.accumulatedStats, + ihToFinish.collectExecStats, + p, + ast, + sql, + res, + retPayload, + retErr, + ) + }) }() } @@ -564,6 +699,7 @@ func (ex *connExecutor) execStmtInOpenState( if retEv != nil || retErr != nil { return } + // As portals are from extended protocol, we don't auto commit for them. if canAutoCommit && !isExtendedProtocol { retEv, retPayload = ex.handleAutoCommit(ctx, ast) } @@ -662,8 +798,13 @@ func (ex *connExecutor) execStmtInOpenState( // For regular statements (the ones that get to this point), we // don't return any event unless an error happens. - if err := ex.handleAOST(ctx, ast); err != nil { - return makeErrEvent(err) + // For a portal (prepared stmt), since handleAOST() is called in when + // preparing the statement, and this function is idempotent, we don't need to + // call it again during execution. + if portal == nil { + if err := ex.handleAOST(ctx, ast); err != nil { + return makeErrEvent(err) + } } // The first order of business is to ensure proper sequencing @@ -711,6 +852,9 @@ func (ex *connExecutor) execStmtInOpenState( p.extendedEvalCtx.Placeholders = &p.semaCtx.Placeholders p.extendedEvalCtx.Annotations = &p.semaCtx.Annotations p.stmt = stmt + if isPausablePortal { + p.pausablePortal = portal + } p.cancelChecker.Reset(ctx) // Auto-commit is disallowed during statement execution if we previously @@ -1105,6 +1249,8 @@ func (ex *connExecutor) commitSQLTransactionInternal(ctx context.Context) error return err } + ex.extraTxnState.prepStmtsNamespace.closeAllPortals(ctx, &ex.extraTxnState.prepStmtsNamespaceMemAcc) + // We need to step the transaction before committing if it has stepping // enabled. If it doesn't have stepping enabled, then we just set the // stepping mode back to what it was. @@ -1193,6 +1339,9 @@ func (ex *connExecutor) rollbackSQLTransaction( if err := ex.extraTxnState.sqlCursors.closeAll(false /* errorOnWithHold */); err != nil { return ex.makeErrEvent(err, stmt) } + + ex.extraTxnState.prepStmtsNamespace.closeAllPortals(ctx, &ex.extraTxnState.prepStmtsNamespaceMemAcc) + if err := ex.state.mu.txn.Rollback(ctx); err != nil { log.Warningf(ctx, "txn rollback failed: %s", err) } @@ -1215,9 +1364,28 @@ func (ex *connExecutor) rollbackSQLTransaction( // producing an appropriate state machine event. func (ex *connExecutor) dispatchToExecutionEngine( ctx context.Context, planner *planner, res RestrictedCommandResult, -) error { +) (retErr error) { + getPausablePortalInfo := func() *portalPauseInfo { + if planner != nil && planner.pausablePortal != nil { + return planner.pausablePortal.pauseInfo + } + return nil + } + defer func() { + if ppInfo := getPausablePortalInfo(); ppInfo != nil { + if !ppInfo.dispatchToExecEngCleanup.isComplete { + ppInfo.dispatchToExecEngCleanup.isComplete = true + } + if retErr != nil || res.Err() != nil { + ppInfo.flowCleanup.run() + ppInfo.dispatchToExecEngCleanup.run() + } + } + }() + stmt := planner.stmt ex.sessionTracing.TracePlanStart(ctx, stmt.AST.StatementTag()) + // TODO(sql-sessions): fix the phase time for pausable portals. ex.statsCollector.PhaseTimes().SetSessionPhaseTime(sessionphase.PlannerStartLogicalPlan, timeutil.Now()) if multitenant.TenantRUEstimateEnabled.Get(ex.server.cfg.SV()) { @@ -1243,10 +1411,25 @@ func (ex *connExecutor) dispatchToExecutionEngine( ex.extraTxnState.hasAdminRoleCache.IsSet = true } } - // Prepare the plan. Note, the error is processed below. Everything - // between here and there needs to happen even if there's an error. - err := ex.makeExecPlan(ctx, planner) - defer planner.curPlan.close(ctx) + + var err error + if ppInfo := getPausablePortalInfo(); ppInfo != nil { + if !ppInfo.dispatchToExecEngCleanup.isComplete { + err = ex.makeExecPlan(ctx, planner) + ppInfo.planTop = planner.curPlan + ppInfo.dispatchToExecEngCleanup.appendFunc(namedFunc{ + fName: "close planTop", + f: func() { ppInfo.planTop.close(ctx) }, + }) + } else { + planner.curPlan = ppInfo.planTop + } + } else { + // Prepare the plan. Note, the error is processed below. Everything + // between here and there needs to happen even if there's an error. + err = ex.makeExecPlan(ctx, planner) + defer planner.curPlan.close(ctx) + } // include gist in error reports ctx = withPlanGist(ctx, planner.instrumentation.planGist.String()) @@ -1272,9 +1455,23 @@ func (ex *connExecutor) dispatchToExecutionEngine( case *tree.Import, *tree.Restore, *tree.Backup: bulkJobId = res.GetBulkJobId() } - planner.maybeLogStatement(ctx, ex.executorType, false, int(ex.state.mu.autoRetryCounter), ex.extraTxnState.txnCounter, nonBulkJobNumRows, bulkJobId, res.Err(), ex.statsCollector.PhaseTimes().GetSessionPhaseTime(sessionphase.SessionQueryReceived), &ex.extraTxnState.hasAdminRoleCache, ex.server.TelemetryLoggingMetrics, stmtFingerprintID, &stats) + if ppInfo := getPausablePortalInfo(); ppInfo != nil && !ppInfo.dispatchToExecEngCleanup.isComplete { + ppInfo.dispatchToExecEngCleanup.appendFunc(namedFunc{ + fName: "log statement", + f: func() { + var resErr error + if !ppInfo.curRes.IsReleased() { + resErr = ppInfo.curRes.Err() + } + planner.maybeLogStatement(ctx, ex.executorType, false, int(ex.state.mu.autoRetryCounter), ex.extraTxnState.txnCounter, nonBulkJobNumRows, bulkJobId, resErr, ex.statsCollector.PhaseTimes().GetSessionPhaseTime(sessionphase.SessionQueryReceived), &ex.extraTxnState.hasAdminRoleCache, ex.server.TelemetryLoggingMetrics, stmtFingerprintID, ppInfo.queryStats) + }, + }) + } else { + planner.maybeLogStatement(ctx, ex.executorType, false, int(ex.state.mu.autoRetryCounter), ex.extraTxnState.txnCounter, nonBulkJobNumRows, bulkJobId, res.Err(), ex.statsCollector.PhaseTimes().GetSessionPhaseTime(sessionphase.SessionQueryReceived), &ex.extraTxnState.hasAdminRoleCache, ex.server.TelemetryLoggingMetrics, stmtFingerprintID, &stats) + } }() + // TODO(sql-sessions): fix the phase time for pausable portals. ex.statsCollector.PhaseTimes().SetSessionPhaseTime(sessionphase.PlannerEndLogicalPlan, timeutil.Now()) ex.sessionTracing.TracePlanEnd(ctx, err) @@ -1304,6 +1501,7 @@ func (ex *connExecutor) dispatchToExecutionEngine( ex.server.cfg.TestingKnobs.BeforeExecute(ctx, stmt.String(), planner.Descriptors()) } + // TODO(sql-sessions): fix the phase time for pausable portals. ex.statsCollector.PhaseTimes().SetSessionPhaseTime(sessionphase.PlannerStartExecStmt, timeutil.Now()) progAtomic, err := func() (*uint64, error) { @@ -1352,6 +1550,12 @@ func (ex *connExecutor) dispatchToExecutionEngine( stats, err = ex.execWithDistSQLEngine( ctx, planner, stmt.AST.StatementReturnType(), res, distribute, progAtomic, ) + if ppInfo := getPausablePortalInfo(); ppInfo != nil { + // For pausable portals, we log the stats when closing the portal, so we need + // to aggregate the stats for all executions. + ppInfo.queryStats.add(&stats) + } + if res.Err() == nil { isSetOrShow := stmt.AST.StatementTag() == "SET" || stmt.AST.StatementTag() == "SHOW" if ex.sessionData().InjectRetryErrorsEnabled && !isSetOrShow && @@ -1366,13 +1570,25 @@ func (ex *connExecutor) dispatchToExecutionEngine( } } ex.sessionTracing.TraceExecEnd(ctx, res.Err(), res.RowsAffected()) + // TODO(sql-sessions): fix the phase time for pausable portals. ex.statsCollector.PhaseTimes().SetSessionPhaseTime(sessionphase.PlannerEndExecStmt, timeutil.Now()) ex.extraTxnState.rowsRead += stats.rowsRead ex.extraTxnState.bytesRead += stats.bytesRead ex.extraTxnState.rowsWritten += stats.rowsWritten - populateQueryLevelStatsAndRegions(ctx, planner, ex.server.cfg, &stats, &ex.cpuStatsCollector) + if ppInfo := getPausablePortalInfo(); ppInfo != nil && !ppInfo.dispatchToExecEngCleanup.isComplete { + ppInfo.dispatchToExecEngCleanup.appendFunc(namedFunc{ + fName: "populate query level stats and regions", + f: func() { + // TODO(janexing): properly set the instrumentation helper. + planner.curPlan = ppInfo.planTop + populateQueryLevelStatsAndRegions(ctx, planner, ex.server.cfg, ppInfo.queryStats, &ex.cpuStatsCollector) + }, + }) + } else { + populateQueryLevelStatsAndRegions(ctx, planner, ex.server.cfg, &stats, &ex.cpuStatsCollector) + } // The transaction (from planner.txn) may already have been committed at this point, // due to one-phase commit optimization or an error. Since we use that transaction diff --git a/pkg/sql/conn_io.go b/pkg/sql/conn_io.go index 6c202eaf1f1e..add515717bdd 100644 --- a/pkg/sql/conn_io.go +++ b/pkg/sql/conn_io.go @@ -808,6 +808,10 @@ type RestrictedCommandResult interface { // This is currently used for sinkless changefeeds. DisableBuffering() + // IsReleased returns true if the result has been released. This method is + // only implemented by pgwire.commandResult. + IsReleased() bool + // GetBulkJobId returns the id of the job for the query, if the query is // IMPORT, BACKUP or RESTORE. GetBulkJobId() uint64 @@ -1035,6 +1039,11 @@ func (r *streamingCommandResult) Err() error { return r.err } +// IsReleased is part of the sql.RestrictedCommandResult interface. +func (r *streamingCommandResult) IsReleased() bool { + return false +} + // IncrementRowsAffected is part of the RestrictedCommandResult interface. func (r *streamingCommandResult) IncrementRowsAffected(ctx context.Context, n int) { r.rowsAffected += n diff --git a/pkg/sql/distsql_running.go b/pkg/sql/distsql_running.go index 9e4c8ab7d580..e3df7013b675 100644 --- a/pkg/sql/distsql_running.go +++ b/pkg/sql/distsql_running.go @@ -46,6 +46,7 @@ import ( "github.com/cockroachdb/cockroach/pkg/sql/sessiondatapb" "github.com/cockroachdb/cockroach/pkg/sql/sqltelemetry" "github.com/cockroachdb/cockroach/pkg/sql/types" + "github.com/cockroachdb/cockroach/pkg/util/buildutil" "github.com/cockroachdb/cockroach/pkg/util/contextutil" "github.com/cockroachdb/cockroach/pkg/util/errorutil/unimplemented" "github.com/cockroachdb/cockroach/pkg/util/hlc" @@ -843,6 +844,7 @@ func (dsp *DistSQLPlanner) Run( ctx, evalCtx, planCtx, leafInputState, flows, recv, localState, statementSQL, ) if m != nil { + // TODO(yuzefovich): add a check that this flow runs in a single goroutine. m.flow = flow m.outputTypes = plan.GetResultTypes() } @@ -1587,7 +1589,15 @@ func (dsp *DistSQLPlanner) PlanAndRunAll( planner *planner, recv *DistSQLReceiver, evalCtxFactory func(usedConcurrently bool) *extendedEvalContext, -) error { +) (retErr error) { + defer func() { + if ppInfo := planCtx.getPortalPauseInfo(); ppInfo != nil && !ppInfo.flowCleanup.isComplete { + ppInfo.flowCleanup.isComplete = true + } + if retErr != nil && planCtx.getPortalPauseInfo() != nil { + planCtx.getPortalPauseInfo().flowCleanup.run() + } + }() if len(planner.curPlan.subqueryPlans) != 0 { // Create a separate memory account for the results of the subqueries. // Note that we intentionally defer the closure of the account until we @@ -1612,11 +1622,45 @@ func (dsp *DistSQLPlanner) PlanAndRunAll( recv.discardRows = planner.instrumentation.ShouldDiscardRows() func() { finishedSetupFn, cleanup := getFinishedSetupFn(planner) - defer cleanup() + defer func() { + if ppInfo := planCtx.getPortalPauseInfo(); ppInfo != nil { + if !ppInfo.flowCleanup.isComplete { + ppInfo.flowCleanup.appendFunc(namedFunc{ + fName: "cleanup inner plan", + f: cleanup, + }) + } + } else { + cleanup() + } + }() dsp.PlanAndRun( ctx, evalCtx, planCtx, planner.txn, planner.curPlan.main, recv, finishedSetupFn, ) }() + + if p := planCtx.getPortalPauseInfo(); p != nil { + if buildutil.CrdbTestBuild && planCtx.getPortalPauseInfo().flow == nil { + checkErr := errors.AssertionFailedf("flow for portal %s cannot be found", planner.pausablePortal.Name) + if recv.commErr != nil { + recv.commErr = errors.CombineErrors(recv.commErr, checkErr) + } else { + return checkErr + } + } + if recv.commErr != nil { + p.flow.Cleanup(ctx) + } else if !p.flowCleanup.isComplete { + flow := p.flow + p.flowCleanup.appendFunc(namedFunc{ + fName: "cleanup flow", f: func() { + flow.Cleanup(ctx) + }, + }) + p.flowCleanup.isComplete = true + } + } + if recv.commErr != nil || recv.getError() != nil { return recv.commErr } diff --git a/pkg/sql/instrumentation.go b/pkg/sql/instrumentation.go index 3e09e80914ba..fc571ba8e9ef 100644 --- a/pkg/sql/instrumentation.go +++ b/pkg/sql/instrumentation.go @@ -818,3 +818,46 @@ func (ih *instrumentationHelper) SetIndexRecommendations( reset, ) } + +// CopyTo is to make a copy of the original instrumentation helper. +func (ih *instrumentationHelper) CopyTo(dst *instrumentationHelper) { + dst.outputMode = ih.outputMode + dst.explainFlags = ih.explainFlags + dst.fingerprint = ih.fingerprint + dst.implicitTxn = ih.implicitTxn + dst.codec = ih.codec + dst.collectBundle = ih.collectBundle + dst.collectExecStats = ih.collectExecStats + dst.isTenant = ih.isTenant + dst.discardRows = ih.discardRows + dst.diagRequestID = ih.diagRequestID + dst.diagRequest = ih.diagRequest + dst.stmtDiagnosticsRecorder = ih.stmtDiagnosticsRecorder + dst.withStatementTrace = ih.withStatementTrace + dst.sp = ih.sp + dst.shouldFinishSpan = ih.shouldFinishSpan + dst.origCtx = ih.origCtx + dst.evalCtx = ih.evalCtx + dst.queryLevelStatsWithErr = ih.queryLevelStatsWithErr + dst.savePlanForStats = ih.savePlanForStats + dst.explainPlan = ih.explainPlan + dst.distribution = ih.distribution + dst.vectorized = ih.vectorized + dst.containsMutation = ih.containsMutation + dst.traceMetadata = ih.traceMetadata + dst.regions = ih.regions + dst.planGist = ih.planGist + dst.costEstimate = ih.costEstimate + dst.indexRecs = ih.indexRecs + dst.maxFullScanRows = ih.maxFullScanRows + dst.totalScanRows = ih.totalScanRows + dst.totalScanRowsWithoutForecasts = ih.totalScanRowsWithoutForecasts + dst.outputRows = ih.outputRows + dst.statsAvailable = ih.statsAvailable + dst.nanosSinceStatsCollected = ih.nanosSinceStatsCollected + dst.nanosSinceStatsForecasted = ih.nanosSinceStatsForecasted + dst.joinTypeCounts = ih.joinTypeCounts + dst.joinAlgorithmCounts = ih.joinAlgorithmCounts + dst.scanCounts = ih.scanCounts + dst.indexesUsed = ih.indexesUsed +} diff --git a/pkg/sql/pgwire/command_result.go b/pkg/sql/pgwire/command_result.go index 31da52ca745c..001cf2f6170a 100644 --- a/pkg/sql/pgwire/command_result.go +++ b/pkg/sql/pgwire/command_result.go @@ -350,6 +350,11 @@ func (r *commandResult) release() { *r = commandResult{released: true} } +// IsReleased returns true if the current commandResult has been released. +func (r *commandResult) IsReleased() bool { + return r.released +} + // assertNotReleased asserts that the commandResult is not being used after // being freed by one of the methods in the CommandResultClose interface. The // assertion can have false negatives, where it fails to detect a use-after-free diff --git a/pkg/sql/prepared_stmt.go b/pkg/sql/prepared_stmt.go index a538d19dd1fd..ae41982ae9d9 100644 --- a/pkg/sql/prepared_stmt.go +++ b/pkg/sql/prepared_stmt.go @@ -207,6 +207,10 @@ func (p *PreparedPortal) close( ) { prepStmtsNamespaceMemAcc.Shrink(ctx, p.size(portalName)) p.Stmt.decRef(ctx) + if p.pauseInfo != nil { + p.pauseInfo.cleanupAll() + p.pauseInfo = nil + } } func (p *PreparedPortal) size(portalName string) int64 {