diff --git a/pkg/sql/conn_executor.go b/pkg/sql/conn_executor.go index d1ee58b9c162..28eec80d988e 100644 --- a/pkg/sql/conn_executor.go +++ b/pkg/sql/conn_executor.go @@ -64,6 +64,7 @@ import ( "github.com/cockroachdb/cockroach/pkg/sql/stmtdiagnostics" "github.com/cockroachdb/cockroach/pkg/util" "github.com/cockroachdb/cockroach/pkg/util/buildutil" + "github.com/cockroachdb/cockroach/pkg/util/cancelchecker" "github.com/cockroachdb/cockroach/pkg/util/contextutil" "github.com/cockroachdb/cockroach/pkg/util/envutil" "github.com/cockroachdb/cockroach/pkg/util/errorutil" @@ -2134,9 +2135,11 @@ func (ex *connExecutor) execCmd() error { ex.phaseTimes.SetSessionPhaseTime(sessionphase.SessionQueryReceived, tcmd.TimeReceived) ex.phaseTimes.SetSessionPhaseTime(sessionphase.SessionStartParse, tcmd.ParseStart) ex.phaseTimes.SetSessionPhaseTime(sessionphase.SessionEndParse, tcmd.ParseEnd) + res = ex.clientComm.CreateCopyInResult(pos) - var err error - ev, payload, err = ex.execCopyIn(ctx, tcmd) + stmtCtx := withStatement(ctx, tcmd.Stmt) + ev, payload = ex.execCopyIn(stmtCtx, tcmd) + // Note: we write to ex.statsCollector.phaseTimes, instead of ex.phaseTimes, // because: // - stats use ex.statsCollector, not ex.phasetimes. @@ -2144,26 +2147,14 @@ func (ex *connExecutor) execCmd() error { // was created when the statement started executing (via the // reset() method). ex.statsCollector.PhaseTimes().SetSessionPhaseTime(sessionphase.SessionQueryServiced, timeutil.Now()) - if err != nil { - return err - } case CopyOut: ex.phaseTimes.SetSessionPhaseTime(sessionphase.SessionQueryReceived, tcmd.TimeReceived) ex.phaseTimes.SetSessionPhaseTime(sessionphase.SessionStartParse, tcmd.ParseStart) ex.phaseTimes.SetSessionPhaseTime(sessionphase.SessionEndParse, tcmd.ParseEnd) res = ex.clientComm.CreateCopyOutResult(pos) + stmtCtx := withStatement(ctx, tcmd.Stmt) + ev, payload = ex.execCopyOut(stmtCtx, tcmd) - // Handle conn executor state transitions. - switch ex.machine.CurState().(type) { - case stateNoTxn: - ev, payload = ex.execStmtInNoTxnState(ctx, tcmd.Stmt, res.(CopyOutResult)) - case stateOpen: - ev, payload = ex.execCopyOut(ctx, tcmd) - case stateAborted: - ev, payload = ex.execStmtInAbortedState(ctx, tcmd.Stmt, res.(CopyOutResult)) - case stateCommitWait: - ev, payload = ex.execStmtInCommitWaitState(ctx, tcmd.Stmt, res.(CopyOutResult)) - } // Note: we write to ex.statsCollector.phaseTimes, instead of ex.phaseTimes, // because: // - stats use ex.statsCollector, not ex.phasetimes. @@ -2466,76 +2457,208 @@ func isCopyToExternalStorage(cmd CopyIn) bool { func (ex *connExecutor) execCopyOut( ctx context.Context, cmd CopyOut, -) (fsm.Event, fsm.EventPayload) { - err := func() error { - ex.incrementStartedStmtCounter(cmd.Stmt) +) (retEv fsm.Event, retPayload fsm.EventPayload) { + // First handle connExecutor state transitions. + if _, isNoTxn := ex.machine.CurState().(stateNoTxn); isNoTxn { + return ex.beginImplicitTxn(ctx, cmd.ParsedStmt.AST) + } else if _, isAbortedTxn := ex.machine.CurState().(stateAborted); isAbortedTxn { + return ex.makeErrEvent(sqlerrors.NewTransactionAbortedError("" /* customMsg */), cmd.ParsedStmt.AST) + } + + ex.incrementStartedStmtCounter(cmd.Stmt) + var numOutputRows int + var cancelQuery context.CancelFunc + ctx, cancelQuery = contextutil.WithCancel(ctx) + queryID := ex.generateID() + ex.addActiveQuery(cmd.ParsedStmt, nil /* placeholders */, queryID, cancelQuery) + ex.metrics.EngineMetrics.SQLActiveStatements.Inc(1) + + defer func() { + ex.removeActiveQuery(queryID, cmd.Stmt) + cancelQuery() + ex.metrics.EngineMetrics.SQLActiveStatements.Dec(1) + if !payloadHasError(retPayload) { + ex.incrementExecutedStmtCounter(cmd.Stmt) + } var copyErr error - var numOutputRows int - var cancelQuery context.CancelFunc - ctx, cancelQuery = contextutil.WithCancel(ctx) - queryID := ex.generateID() + if p, ok := retPayload.(payloadWithError); ok { + copyErr = p.errorCause() + log.SqlExec.Errorf(ctx, "error executing %s: %+v", cmd, copyErr) + } // Log the query for sampling. - defer func() { - ex.removeActiveQuery(queryID, cmd.Stmt) - cancelQuery() - // These fields are not available in COPY, so use the empty value. - f := tree.NewFmtCtx(tree.FmtHideConstants) - f.FormatNode(cmd.Stmt) - stmtFingerprintID := appstatspb.ConstructStatementFingerprintID( - f.CloseAndGetString(), - copyErr != nil, - ex.implicitTxn(), - ex.planner.CurrentDatabase(), - ) - var stats topLevelQueryStats - ex.planner.maybeLogStatement( - ctx, - ex.executorType, - true, /* isCopy */ - int(ex.state.mu.autoRetryCounter), - ex.extraTxnState.txnCounter, - numOutputRows, - 0, /* bulkJobId */ - copyErr, - ex.statsCollector.PhaseTimes().GetSessionPhaseTime(sessionphase.SessionQueryReceived), - &ex.extraTxnState.hasAdminRoleCache, - ex.server.TelemetryLoggingMetrics, - stmtFingerprintID, - &stats, - ) - }() + // These fields are not available in COPY, so use the empty value. + f := tree.NewFmtCtx(tree.FmtHideConstants) + f.FormatNode(cmd.Stmt) + stmtFingerprintID := appstatspb.ConstructStatementFingerprintID( + f.CloseAndGetString(), + copyErr != nil, + ex.implicitTxn(), + ex.planner.CurrentDatabase(), + ) + var stats topLevelQueryStats + ex.planner.maybeLogStatement( + ctx, + ex.executorType, + true, /* isCopy */ + int(ex.state.mu.autoRetryCounter), + ex.extraTxnState.txnCounter, + numOutputRows, + 0, /* bulkJobId */ + copyErr, + ex.statsCollector.PhaseTimes().GetSessionPhaseTime(sessionphase.SessionQueryReceived), + &ex.extraTxnState.hasAdminRoleCache, + ex.server.TelemetryLoggingMetrics, + stmtFingerprintID, + &stats, + ) + }() - stmtTS := ex.server.cfg.Clock.PhysicalTime() - ex.resetPlanner(ctx, &ex.planner, ex.state.mu.txn, stmtTS) - ex.setCopyLoggingFields(cmd.ParsedStmt) + stmtTS := ex.server.cfg.Clock.PhysicalTime() + ex.statsCollector.Reset(ex.applicationStats, ex.phaseTimes) + ex.resetPlanner(ctx, &ex.planner, ex.state.mu.txn, stmtTS) + ex.setCopyLoggingFields(cmd.ParsedStmt) - return ex.execWithProfiling(ctx, cmd.Stmt, nil, func(ctx context.Context) error { - // We'll always have a txn on the planner since we called resetPlanner - // above. - txn := ex.planner.Txn() - ex.addActiveQuery(cmd.ParsedStmt, nil /* placeholders */, queryID, cancelQuery) + var queryTimeoutTicker *time.Timer + var txnTimeoutTicker *time.Timer + queryTimedOut := false + txnTimedOut := false + + // queryDoneAfterFunc and txnDoneAfterFunc will be allocated only when + // queryTimeoutTicker or txnTimeoutTicker is non-nil. + var queryDoneAfterFunc chan struct{} + var txnDoneAfterFunc chan struct{} + + defer func(ctx context.Context) { + if queryTimeoutTicker != nil { + if !queryTimeoutTicker.Stop() { + // Wait for the timer callback to complete to avoid a data race on + // queryTimedOut. + <-queryDoneAfterFunc + } + } + if txnTimeoutTicker != nil { + if !txnTimeoutTicker.Stop() { + // Wait for the timer callback to complete to avoid a data race on + // txnTimedOut. + <-txnDoneAfterFunc + } + } - var err error - if numOutputRows, err = runCopyTo(ctx, &ex.planner, txn, cmd); err != nil { - return err + // Detect context cancelation and overwrite whatever error might have been + // set on the result before. The idea is that once the query's context is + // canceled, all sorts of actors can detect the cancelation and set all + // sorts of errors on the result. Rather than trying to impose discipline + // in that jungle, we just overwrite them all here with an error that's + // nicer to look at for the client. + if ctx.Err() != nil { + // Even in the cases where the error is a retryable error, we want to + // intercept the event and payload returned here to ensure that the query + // is not retried. + retEv = eventNonRetriableErr{ + IsCommit: fsm.FromBool(false), } + retPayload = eventNonRetriableErrPayload{err: cancelchecker.QueryCanceledError} + } + + // If the query timed out, we intercept the error, payload, and event here + // for the same reasons we intercept them for canceled queries above. + // Overriding queries with a QueryTimedOut error needs to happen after + // we've checked for canceled queries as some queries may be canceled + // because of a timeout, in which case the appropriate error to return to + // the client is one that indicates the timeout, rather than the more general + // query canceled error. It's important to note that a timed out query may + // not have been canceled (eg. We never even start executing a query + // because the timeout has already expired), and therefore this check needs + // to happen outside the canceled query check above. + if queryTimedOut { + // A timed out query should never produce retryable errors/events/payloads + // so we intercept and overwrite them all here. + retEv = eventNonRetriableErr{ + IsCommit: fsm.FromBool(false), + } + retPayload = eventNonRetriableErrPayload{err: sqlerrors.QueryTimeoutError} + } else if txnTimedOut { + retEv = eventNonRetriableErr{ + IsCommit: fsm.FromBool(false), + } + retPayload = eventNonRetriableErrPayload{err: sqlerrors.TxnTimeoutError} + } + }(ctx) + + if ex.sessionData().StmtTimeout > 0 { + timerDuration := + ex.sessionData().StmtTimeout - timeutil.Since(ex.phaseTimes.GetSessionPhaseTime(sessionphase.SessionQueryReceived)) + // There's no need to proceed with execution if the timer has already expired. + if timerDuration < 0 { + queryTimedOut = true + return ex.makeErrEvent(sqlerrors.QueryTimeoutError, cmd.Stmt) + } + queryDoneAfterFunc = make(chan struct{}, 1) + queryTimeoutTicker = time.AfterFunc( + timerDuration, + func() { + cancelQuery() + queryTimedOut = true + queryDoneAfterFunc <- struct{}{} + }) + } + if ex.sessionData().TransactionTimeout > 0 && !ex.implicitTxn() { + timerDuration := + ex.sessionData().TransactionTimeout - timeutil.Since(ex.phaseTimes.GetSessionPhaseTime(sessionphase.SessionTransactionStarted)) + + // If the timer already expired, but the transaction is not yet aborted, + // we should error immediately without executing. If the timer + // expired but the transaction already is aborted, then we should still + // proceed with executing the statement in order to get a + // TransactionAbortedError. + _, txnAborted := ex.machine.CurState().(stateAborted) + + if timerDuration < 0 && !txnAborted { + txnTimedOut = true + return ex.makeErrEvent(sqlerrors.TxnTimeoutError, cmd.Stmt) + } + + if timerDuration > 0 { + txnDoneAfterFunc = make(chan struct{}, 1) + txnTimeoutTicker = time.AfterFunc( + timerDuration, + func() { + cancelQuery() + txnTimedOut = true + txnDoneAfterFunc <- struct{}{} + }) + } + } + + if copyErr := ex.execWithProfiling(ctx, cmd.Stmt, nil, func(ctx context.Context) error { + ex.mu.Lock() + queryMeta, ok := ex.mu.ActiveQueries[queryID] + if !ok { + return errors.AssertionFailedf("query %d not in registry", queryID) + } + queryMeta.phase = executing + ex.mu.Unlock() - // Finalize execution by sending the statement tag and number of rows read. - dummy := tree.CopyTo{} - tag := []byte(dummy.StatementTag()) - tag = append(tag, ' ') - tag = strconv.AppendInt(tag, int64(numOutputRows), 10 /* base */) - return cmd.Conn.SendCommandComplete(tag) - }) - }() - if err != nil { - log.SqlExec.Errorf(ctx, "error executing %s: %+v", cmd, err) - return eventNonRetriableErr{IsCommit: fsm.False}, eventNonRetriableErrPayload{ - err: err, + // We'll always have a txn on the planner since we called resetPlanner + // above. + txn := ex.planner.Txn() + var err error + if numOutputRows, err = runCopyTo(ctx, &ex.planner, txn, cmd); err != nil { + return err } + + // Finalize execution by sending the statement tag and number of rows read. + dummy := tree.CopyTo{} + tag := []byte(dummy.StatementTag()) + tag = append(tag, ' ') + tag = strconv.AppendInt(tag, int64(numOutputRows), 10 /* base */) + return cmd.Conn.SendCommandComplete(tag) + }); copyErr != nil { + ev := eventNonRetriableErr{IsCommit: fsm.False} + payload := eventNonRetriableErrPayload{err: copyErr} + return ev, payload } - ex.incrementExecutedStmtCounter(cmd.Stmt) return nil, nil } @@ -2557,81 +2680,66 @@ func (ex *connExecutor) setCopyLoggingFields(stmt parser.Statement) { // and writing up to the CommandComplete message. func (ex *connExecutor) execCopyIn( ctx context.Context, cmd CopyIn, -) (_ fsm.Event, retPayload fsm.EventPayload, retErr error) { +) (retEv fsm.Event, retPayload fsm.EventPayload) { + // First handle connExecutor state transitions. + if _, isNoTxn := ex.machine.CurState().(stateNoTxn); isNoTxn { + return ex.beginImplicitTxn(ctx, cmd.ParsedStmt.AST) + } else if _, isAbortedTxn := ex.machine.CurState().(stateAborted); isAbortedTxn { + return ex.makeErrEvent(sqlerrors.NewTransactionAbortedError("" /* customMsg */), cmd.ParsedStmt.AST) + } + ex.incrementStartedStmtCounter(cmd.Stmt) + var cancelQuery context.CancelFunc + ctx, cancelQuery = contextutil.WithCancel(ctx) + queryID := ex.generateID() + ex.addActiveQuery(cmd.ParsedStmt, nil /* placeholders */, queryID, cancelQuery) + ex.metrics.EngineMetrics.SQLActiveStatements.Inc(1) + defer func() { - if retErr == nil && !payloadHasError(retPayload) { + ex.removeActiveQuery(queryID, cmd.Stmt) + cancelQuery() + ex.metrics.EngineMetrics.SQLActiveStatements.Dec(1) + if !payloadHasError(retPayload) { ex.incrementExecutedStmtCounter(cmd.Stmt) } if p, ok := retPayload.(payloadWithError); ok { log.SqlExec.Errorf(ctx, "error executing %s: %+v", cmd, p.errorCause()) } - if retErr != nil { - log.SqlExec.Errorf(ctx, "error executing %s: %+v", cmd, retErr) - } }() // When we're done, unblock the network connection. defer cmd.CopyDone.Done() - state := ex.machine.CurState() - _, isNoTxn := state.(stateNoTxn) - _, isOpen := state.(stateOpen) - if !isNoTxn && !isOpen { - ev := eventNonRetriableErr{IsCommit: fsm.False} - payload := eventNonRetriableErrPayload{ - err: sqlerrors.NewTransactionAbortedError("" /* customMsg */)} - return ev, payload, nil - } - - // If we're in an explicit txn, then the copying will be done within that - // txn. Otherwise, we tell the copyMachine to manage its own transactions - // and give it a closure to reset the accumulated extraTxnState. - var txnOpt copyTxnOpt - if isOpen { - txnOpt = copyTxnOpt{ - txn: ex.state.mu.txn, - txnTimestamp: ex.state.sqlTimestamp, - stmtTimestamp: ex.server.cfg.Clock.PhysicalTime(), - } - } else { - txnOpt = copyTxnOpt{ - resetExtraTxnState: func(ctx context.Context) { - ex.resetExtraTxnState(ctx, txnEvent{eventType: noEvent}) - }, - } + // The connExecutor state machine has already set us up with a txn at this + // point. + txnOpt := copyTxnOpt{ + txn: ex.state.mu.txn, + txnTimestamp: ex.state.sqlTimestamp, + stmtTimestamp: ex.server.cfg.Clock.PhysicalTime(), + initPlanner: func(ctx context.Context, p *planner) { + ex.initPlanner(ctx, p) + }, + resetPlanner: func(ctx context.Context, p *planner, txn *kv.Txn, txnTS time.Time, stmtTS time.Time) { + ex.statsCollector.Reset(ex.applicationStats, ex.phaseTimes) + ex.resetPlanner(ctx, p, txn, stmtTS) + }, } - - var monToStop *mon.BytesMonitor - defer func() { - if monToStop != nil { - monToStop.Stop(ctx) + // If COPY is not atomic, then each batch must manage the txn state. + if ex.implicitTxn() && !ex.sessionData().CopyFromAtomicEnabled { + txnOpt.resetExtraTxnState = func(ctx context.Context) { + ex.resetExtraTxnState(ctx, txnEvent{eventType: noEvent}) } - }() - if isNoTxn { - // HACK: We're reaching inside ex.state and starting the monitor. Normally - // that's driven by the state machine, but we're bypassing the state machine - // here. - ex.state.mon.StartNoReserved(ctx, ex.sessionMon) - monToStop = ex.state.mon - } - txnOpt.resetPlanner = func(ctx context.Context, p *planner, txn *kv.Txn, txnTS time.Time, stmtTS time.Time) { - // HACK: We're reaching inside ex.state and changing sqlTimestamp by hand. - // It is used by resetPlanner. Normally sqlTimestamp is updated by the - // state machine, but the copyMachine manages its own transactions without - // going through the state machine. - ex.state.sqlTimestamp = txnTS - ex.statsCollector.Reset(ex.applicationStats, ex.phaseTimes) - ex.initPlanner(ctx, p) - ex.resetPlanner(ctx, p, txn, stmtTS) } ex.setCopyLoggingFields(cmd.ParsedStmt) var cm copyMachineInterface - var copyErr error // Log the query for sampling. defer func() { + var copyErr error + if p, ok := retPayload.(payloadWithError); ok { + copyErr = p.errorCause() + } var numInsertedRows int if cm != nil { numInsertedRows = cm.numInsertedRows() @@ -2649,13 +2757,14 @@ func (ex *connExecutor) execCopyIn( ex.planner.maybeLogStatement(ctx, ex.executorType, true, int(ex.state.mu.autoRetryCounter), ex.extraTxnState.txnCounter, numInsertedRows, 0 /* bulkJobId */, copyErr, ex.statsCollector.PhaseTimes().GetSessionPhaseTime(sessionphase.SessionQueryReceived), &ex.extraTxnState.hasAdminRoleCache, ex.server.TelemetryLoggingMetrics, stmtFingerprintID, &stats) }() + var copyErr error if isCopyToExternalStorage(cmd) { - cm, copyErr = newFileUploadMachine(ctx, cmd.Conn, cmd.Stmt, txnOpt, ex.server.cfg, ex.state.mon) + cm, copyErr = newFileUploadMachine(ctx, cmd.Conn, cmd.Stmt, txnOpt, &ex.planner, ex.state.mon) } else { // The planner will be prepared before use. - p := planner{execCfg: ex.server.cfg} + p := ex.planner cm, copyErr = newCopyMachine( - ctx, cmd.Conn, cmd.Stmt, &p, txnOpt, ex.state.mon, + ctx, cmd.Conn, cmd.Stmt, &p, txnOpt, ex.state.mon, ex.implicitTxn(), // execInsertPlan func(ctx context.Context, p *planner, res RestrictedCommandResult) error { _, err := ex.execWithDistSQLEngine(ctx, p, tree.RowsAffected, res, DistributionTypeNone, nil /* progressAtomic */) @@ -2666,11 +2775,131 @@ func (ex *connExecutor) execCopyIn( if copyErr != nil { ev := eventNonRetriableErr{IsCommit: fsm.False} payload := eventNonRetriableErrPayload{err: copyErr} - return ev, payload, nil + return ev, payload + } + + var queryTimeoutTicker *time.Timer + var txnTimeoutTicker *time.Timer + queryTimedOut := false + txnTimedOut := false + + // queryDoneAfterFunc and txnDoneAfterFunc will be allocated only when + // queryTimeoutTicker or txnTimeoutTicker is non-nil. + var queryDoneAfterFunc chan struct{} + var txnDoneAfterFunc chan struct{} + + defer func(ctx context.Context) { + if queryTimeoutTicker != nil { + if !queryTimeoutTicker.Stop() { + // Wait for the timer callback to complete to avoid a data race on + // queryTimedOut. + <-queryDoneAfterFunc + } + } + if txnTimeoutTicker != nil { + if !txnTimeoutTicker.Stop() { + // Wait for the timer callback to complete to avoid a data race on + // txnTimedOut. + <-txnDoneAfterFunc + } + } + + // Detect context cancelation and overwrite whatever error might have been + // set on the result before. The idea is that once the query's context is + // canceled, all sorts of actors can detect the cancelation and set all + // sorts of errors on the result. Rather than trying to impose discipline + // in that jungle, we just overwrite them all here with an error that's + // nicer to look at for the client. + if ctx.Err() != nil { + // Even in the cases where the error is a retryable error, we want to + // intercept the event and payload returned here to ensure that the query + // is not retried. + retEv = eventNonRetriableErr{ + IsCommit: fsm.FromBool(false), + } + retPayload = eventNonRetriableErrPayload{err: cancelchecker.QueryCanceledError} + } + + cm.Close(ctx) + + // If the query timed out, we intercept the error, payload, and event here + // for the same reasons we intercept them for canceled queries above. + // Overriding queries with a QueryTimedOut error needs to happen after + // we've checked for canceled queries as some queries may be canceled + // because of a timeout, in which case the appropriate error to return to + // the client is one that indicates the timeout, rather than the more general + // query canceled error. It's important to note that a timed out query may + // not have been canceled (eg. We never even start executing a query + // because the timeout has already expired), and therefore this check needs + // to happen outside the canceled query check above. + if queryTimedOut { + // A timed out query should never produce retryable errors/events/payloads + // so we intercept and overwrite them all here. + retEv = eventNonRetriableErr{ + IsCommit: fsm.FromBool(false), + } + retPayload = eventNonRetriableErrPayload{err: sqlerrors.QueryTimeoutError} + } else if txnTimedOut { + retEv = eventNonRetriableErr{ + IsCommit: fsm.FromBool(false), + } + retPayload = eventNonRetriableErrPayload{err: sqlerrors.TxnTimeoutError} + } + }(ctx) + + if ex.sessionData().StmtTimeout > 0 { + timerDuration := + ex.sessionData().StmtTimeout - timeutil.Since(ex.phaseTimes.GetSessionPhaseTime(sessionphase.SessionQueryReceived)) + // There's no need to proceed with execution if the timer has already expired. + if timerDuration < 0 { + queryTimedOut = true + return ex.makeErrEvent(sqlerrors.QueryTimeoutError, cmd.Stmt) + } + queryDoneAfterFunc = make(chan struct{}, 1) + queryTimeoutTicker = time.AfterFunc( + timerDuration, + func() { + cancelQuery() + queryTimedOut = true + queryDoneAfterFunc <- struct{}{} + }) + } + if ex.sessionData().TransactionTimeout > 0 && !ex.implicitTxn() { + timerDuration := + ex.sessionData().TransactionTimeout - timeutil.Since(ex.phaseTimes.GetSessionPhaseTime(sessionphase.SessionTransactionStarted)) + + // If the timer already expired, but the transaction is not yet aborted, + // we should error immediately without executing. If the timer + // expired but the transaction already is aborted, then we should still + // proceed with executing the statement in order to get a + // TransactionAbortedError. + _, txnAborted := ex.machine.CurState().(stateAborted) + + if timerDuration < 0 && !txnAborted { + txnTimedOut = true + return ex.makeErrEvent(sqlerrors.TxnTimeoutError, cmd.Stmt) + } + + if timerDuration > 0 { + txnDoneAfterFunc = make(chan struct{}, 1) + txnTimeoutTicker = time.AfterFunc( + timerDuration, + func() { + cancelQuery() + txnTimedOut = true + txnDoneAfterFunc <- struct{}{} + }) + } } - defer cm.Close(ctx) if copyErr = ex.execWithProfiling(ctx, cmd.Stmt, nil, func(ctx context.Context) error { + ex.mu.Lock() + queryMeta, ok := ex.mu.ActiveQueries[queryID] + if !ok { + return errors.AssertionFailedf("query %d not in registry", queryID) + } + queryMeta.phase = executing + ex.mu.Unlock() return cm.run(ctx) }); copyErr != nil { // TODO(andrei): We don't have a full retriable error story for the copy machine. @@ -2683,9 +2912,9 @@ func (ex *connExecutor) execCopyIn( // errors as query errors. ev := eventNonRetriableErr{IsCommit: fsm.False} payload := eventNonRetriableErrPayload{err: copyErr} - return ev, payload, nil + return ev, payload } - return nil, nil, nil + return nil, nil } // stmtHasNoData returns true if describing a result of the input statement diff --git a/pkg/sql/conn_executor_exec.go b/pkg/sql/conn_executor_exec.go index 8d6cb5494d20..82d569a340b1 100644 --- a/pkg/sql/conn_executor_exec.go +++ b/pkg/sql/conn_executor_exec.go @@ -500,7 +500,7 @@ func (ex *connExecutor) execStmtInOpenState( }() } - if ex.sessionData().TransactionTimeout > 0 && !ex.implicitTxn() { + if ex.sessionData().TransactionTimeout > 0 && !ex.implicitTxn() && ex.executorType != executorTypeInternal { timerDuration := ex.sessionData().TransactionTimeout - timeutil.Since(ex.phaseTimes.GetSessionPhaseTime(sessionphase.SessionTransactionStarted)) diff --git a/pkg/sql/conn_io.go b/pkg/sql/conn_io.go index f0e0bcbada7e..cd4e426d8ed6 100644 --- a/pkg/sql/conn_io.go +++ b/pkg/sql/conn_io.go @@ -885,7 +885,6 @@ type CopyInResult interface { // produces no output for the client. type CopyOutResult interface { ResultBase - RestrictedCommandResult } // ClientLock is an interface returned by ClientComm.lockCommunication(). It diff --git a/pkg/sql/copy/BUILD.bazel b/pkg/sql/copy/BUILD.bazel index 47c7ba851093..0a7ed2015780 100644 --- a/pkg/sql/copy/BUILD.bazel +++ b/pkg/sql/copy/BUILD.bazel @@ -31,6 +31,7 @@ go_test( "//pkg/testutils/serverutils", "//pkg/testutils/sqlutils", "//pkg/testutils/testcluster", + "//pkg/util/ctxgroup", "//pkg/util/encoding/csv", "//pkg/util/leaktest", "//pkg/util/log", @@ -40,6 +41,7 @@ go_test( "//pkg/util/timetz", "@com_github_cockroachdb_apd_v3//:apd", "@com_github_cockroachdb_datadriven//:datadriven", + "@com_github_cockroachdb_errors//:errors", "@com_github_jackc_pgconn//:pgconn", "@com_github_jackc_pgtype//:pgtype", "@com_github_jackc_pgx_v4//:pgx", diff --git a/pkg/sql/copy/copy_test.go b/pkg/sql/copy/copy_test.go index 3f0d9e13da57..6df9e447977c 100644 --- a/pkg/sql/copy/copy_test.go +++ b/pkg/sql/copy/copy_test.go @@ -14,7 +14,6 @@ import ( "bytes" "context" "database/sql/driver" - "errors" "fmt" "io" "net/url" @@ -22,6 +21,7 @@ import ( "runtime/pprof" "strings" "testing" + "time" "github.com/cockroachdb/apd/v3" "github.com/cockroachdb/cockroach/pkg/base" @@ -29,14 +29,18 @@ import ( "github.com/cockroachdb/cockroach/pkg/security/username" "github.com/cockroachdb/cockroach/pkg/settings/cluster" "github.com/cockroachdb/cockroach/pkg/sql" + "github.com/cockroachdb/cockroach/pkg/testutils" "github.com/cockroachdb/cockroach/pkg/testutils/datapathutils" "github.com/cockroachdb/cockroach/pkg/testutils/serverutils" "github.com/cockroachdb/cockroach/pkg/testutils/sqlutils" + "github.com/cockroachdb/cockroach/pkg/util/ctxgroup" "github.com/cockroachdb/cockroach/pkg/util/leaktest" "github.com/cockroachdb/cockroach/pkg/util/log" "github.com/cockroachdb/datadriven" + "github.com/cockroachdb/errors" "github.com/jackc/pgconn" "github.com/jackc/pgtype" + "github.com/jackc/pgx/v4" "github.com/stretchr/testify/require" ) @@ -288,6 +292,182 @@ func TestCopyFromTransaction(t *testing.T) { } } +// slowCopySource is a pgx.CopyFromSource that copies a fixed number of rows +// and sleeps for 500 ms in between each one. +type slowCopySource struct { + count int + total int +} + +func (s *slowCopySource) Next() bool { + s.count++ + return s.count < s.total +} + +func (s *slowCopySource) Values() ([]interface{}, error) { + time.Sleep(500 * time.Millisecond) + return []interface{}{s.count}, nil +} + +func (s *slowCopySource) Err() error { + return nil +} + +var _ pgx.CopyFromSource = &slowCopySource{} + +// TestCopyFromTimeout checks that COPY FROM respects the statement_timeout +// and transaction_timeout settings. +func TestCopyFromTimeout(t *testing.T) { + defer leaktest.AfterTest(t)() + defer log.Scope(t).Close(t) + ctx := context.Background() + + s, _, _ := serverutils.StartServer(t, base.TestServerArgs{}) + defer s.Stopper().Stop(ctx) + + pgURL, cleanup := sqlutils.PGUrl( + t, + s.ServingSQLAddr(), + "TestCopyFromTimeout", + url.User(username.RootUser), + ) + defer cleanup() + + t.Run("copy from", func(t *testing.T) { + conn, err := pgx.Connect(ctx, pgURL.String()) + require.NoError(t, err) + + _, err = conn.Exec(ctx, "CREATE TABLE t (a INT PRIMARY KEY)") + require.NoError(t, err) + + _, err = conn.Exec(ctx, "SET transaction_timeout = '100ms'") + require.NoError(t, err) + + tx, err := conn.Begin(ctx) + require.NoError(t, err) + + _, err = tx.CopyFrom(ctx, pgx.Identifier{"t"}, []string{"a"}, &slowCopySource{total: 2}) + require.ErrorContains(t, err, "query execution canceled due to transaction timeout") + + err = tx.Rollback(ctx) + require.NoError(t, err) + + _, err = conn.Exec(ctx, "SET statement_timeout = '200ms'") + require.NoError(t, err) + + _, err = conn.CopyFrom(ctx, pgx.Identifier{"t"}, []string{"a"}, &slowCopySource{total: 2}) + require.ErrorContains(t, err, "query execution canceled due to statement timeout") + }) + + t.Run("copy to", func(t *testing.T) { + conn, err := pgx.Connect(ctx, pgURL.String()) + require.NoError(t, err) + + _, err = conn.Exec(ctx, "SET transaction_timeout = '100ms'") + require.NoError(t, err) + + tx, err := conn.Begin(ctx) + require.NoError(t, err) + + _, err = tx.Exec(ctx, "COPY (SELECT pg_sleep(1) FROM ROWS FROM (generate_series(1, 60)) AS i) TO STDOUT") + require.ErrorContains(t, err, "query execution canceled due to transaction timeout") + + err = tx.Rollback(ctx) + require.NoError(t, err) + + _, err = conn.Exec(ctx, "SET statement_timeout = '200ms'") + require.NoError(t, err) + + _, err = conn.Exec(ctx, "COPY (SELECT pg_sleep(1) FROM ROWS FROM (generate_series(1, 60)) AS i) TO STDOUT") + require.ErrorContains(t, err, "query execution canceled due to statement timeout") + }) +} + +func TestShowQueriesIncludesCopy(t *testing.T) { + defer leaktest.AfterTest(t)() + defer log.Scope(t).Close(t) + ctx := context.Background() + + s, _, _ := serverutils.StartServer(t, base.TestServerArgs{}) + defer s.Stopper().Stop(ctx) + + pgURL, cleanup := sqlutils.PGUrl( + t, + s.ServingSQLAddr(), + "TestShowQueriesIncludesCopy", + url.User(username.RootUser), + ) + defer cleanup() + + showConn, err := pgx.Connect(ctx, pgURL.String()) + require.NoError(t, err) + q := pgURL.Query() + q.Add("application_name", "app_name") + pgURL.RawQuery = q.Encode() + copyConn, err := pgx.Connect(ctx, pgURL.String()) + require.NoError(t, err) + _, err = copyConn.Exec(ctx, "CREATE TABLE t (a INT PRIMARY KEY)") + require.NoError(t, err) + + t.Run("copy to", func(t *testing.T) { + g := ctxgroup.WithContext(ctx) + g.GoCtx(func(ctx context.Context) error { + _, err = copyConn.Exec(ctx, "COPY (SELECT pg_sleep(1) FROM ROWS FROM (generate_series(1, 60)) AS i) TO STDOUT") + return err + }) + + // The COPY query should use the specified app name. SucceedsSoon is used + // since COPY is being executed concurrently. + var appName string + testutils.SucceedsSoon(t, func() error { + err = showConn.QueryRow(ctx, "SELECT application_name FROM [SHOW QUERIES] WHERE query LIKE 'COPY (SELECT pg_sleep(1) %'").Scan(&appName) + if err != nil { + return err + } + if appName != "app_name" { + return errors.New("expected COPY to appear in SHOW QUERIES") + } + return nil + }) + + err = copyConn.PgConn().CancelRequest(ctx) + require.NoError(t, err) + + // An error is expected, since the query was canceled. + err = g.Wait() + require.ErrorContains(t, err, "query execution canceled") + }) + + t.Run("copy from", func(t *testing.T) { + g := ctxgroup.WithContext(ctx) + g.GoCtx(func(ctx context.Context) error { + _, err := copyConn.CopyFrom(ctx, pgx.Identifier{"t"}, []string{"a"}, &slowCopySource{total: 5}) + return err + }) + + // The COPY query should use the specified app name. SucceedsSoon is used + // since COPY is being executed concurrently. + var appName string + testutils.SucceedsSoon(t, func() error { + err = showConn.QueryRow(ctx, "SELECT application_name FROM [SHOW QUERIES] WHERE query ILIKE 'COPY%t%a%FROM%'").Scan(&appName) + if err != nil { + return err + } + if appName != "app_name" { + return errors.New("expected COPY to appear in SHOW QUERIES") + } + return nil + }) + + err = copyConn.PgConn().CancelRequest(ctx) + require.NoError(t, err) + + // An error is expected, since the query was canceled. + err = g.Wait() + require.ErrorContains(t, err, "query execution canceled") + }) +} + // BenchmarkCopyFrom measures copy performance against a TestServer. func BenchmarkCopyFrom(b *testing.B) { defer leaktest.AfterTest(b)() diff --git a/pkg/sql/copy_file_upload.go b/pkg/sql/copy_file_upload.go index 5f890f25fd54..d31ebeddc68c 100644 --- a/pkg/sql/copy_file_upload.go +++ b/pkg/sql/copy_file_upload.go @@ -81,16 +81,17 @@ func newFileUploadMachine( conn pgwirebase.Conn, n *tree.CopyFrom, txnOpt copyTxnOpt, - execCfg *ExecutorConfig, + p *planner, parentMon *mon.BytesMonitor, ) (f *fileUploadMachine, retErr error) { if len(n.Columns) != 0 { return nil, errors.New("expected 0 columns specified for file uploads") } c := ©Machine{ - conn: conn, + conn: conn, + txnOpt: txnOpt, // The planner will be prepared before use. - p: &planner{execCfg: execCfg}, + p: p, } f = &fileUploadMachine{ c: c, @@ -98,7 +99,7 @@ func newFileUploadMachine( // We need a planner to do the initial planning, even if a planner // is not required after that. - cleanup := c.p.preparePlannerForCopy(ctx, &txnOpt, false /* finalBatch */, c.implicitTxn) + cleanup := c.p.preparePlannerForCopy(ctx, &c.txnOpt, false /* finalBatch */, c.implicitTxn) defer func() { retErr = cleanup(ctx, retErr) }() diff --git a/pkg/sql/copy_from.go b/pkg/sql/copy_from.go index 33004cfc0508..97f67e9f977c 100644 --- a/pkg/sql/copy_from.go +++ b/pkg/sql/copy_from.go @@ -31,7 +31,6 @@ import ( "github.com/cockroachdb/cockroach/pkg/sql/rowcontainer" "github.com/cockroachdb/cockroach/pkg/sql/sem/eval" "github.com/cockroachdb/cockroach/pkg/sql/sem/tree" - "github.com/cockroachdb/cockroach/pkg/sql/sessiondatapb" "github.com/cockroachdb/cockroach/pkg/sql/types" "github.com/cockroachdb/cockroach/pkg/util" "github.com/cockroachdb/cockroach/pkg/util/buildutil" @@ -254,6 +253,7 @@ func newCopyMachine( p *planner, txnOpt copyTxnOpt, parentMon *mon.BytesMonitor, + implicitTxn bool, execInsertPlan func(ctx context.Context, p *planner, res RestrictedCommandResult) error, ) (_ *copyMachine, retErr error) { cOpts, err := processCopyOptions(ctx, p, n.Options) @@ -271,7 +271,7 @@ func newCopyMachine( txnOpt: txnOpt, p: p, execInsertPlan: execInsertPlan, - implicitTxn: txnOpt.txn == nil, + implicitTxn: implicitTxn, } // We need a planner to do the initial planning, in addition // to those used for the main execution of the COPY afterwards. @@ -356,6 +356,7 @@ type copyTxnOpt struct { txn *kv.Txn txnTimestamp time.Time stmtTimestamp time.Time + initPlanner func(ctx context.Context, p *planner) resetPlanner func(ctx context.Context, p *planner, txn *kv.Txn, txnTS time.Time, stmtTS time.Time) // resetExtraTxnState should be called upon completing a batch from the copy @@ -792,37 +793,23 @@ func (c *copyMachine) readBinarySignature() ([]byte, error) { // Depending on how the requesting COPY machine was configured, a new // transaction might be created. // -// It returns a cleanup function that needs to be called when we're -// done with the planner (before preparePlannerForCopy is called -// again). The cleanup function commits the txn (if it hasn't already -// been committed) or rolls it back depending on whether it is passed -// an error. If an error is passed in to the cleanup function, the -// same error is returned. +// It returns a cleanup function that needs to be called when we're done with +// the planner (before preparePlannerForCopy is called again). If +// CopyFromAtomicEnabled is false, the cleanup function commits the txn (if it +// hasn't already been committed) or rolls it back depending on whether it is +// passed an error. If an error is passed in to the cleanup function, the same +// error is returned. func (p *planner) preparePlannerForCopy( ctx context.Context, txnOpt *copyTxnOpt, finalBatch bool, implicitTxn bool, ) func(context.Context, error) error { - txn := txnOpt.txn - txnTs := txnOpt.txnTimestamp - stmtTs := txnOpt.stmtTimestamp - autoCommit := finalBatch && implicitTxn - if txn == nil { - nodeID, _ := p.execCfg.NodeInfo.NodeID.OptionalNodeID() - // The session data stack in the planner is not set up at this point, so use - // the default Normal QoSLevel. - txn = kv.NewTxnWithSteppingEnabled(ctx, p.execCfg.DB, nodeID, sessiondatapb.Normal) - txnTs = p.execCfg.Clock.PhysicalTime() - stmtTs = txnTs - } - txnOpt.resetPlanner(ctx, p, txn, txnTs, stmtTs) + autoCommit := false + txnOpt.resetPlanner(ctx, p, txnOpt.txn, txnOpt.txnTimestamp, txnOpt.stmtTimestamp) if implicitTxn { - // For atomic implicit COPY remember txn for next time so we don't start a new one. if p.SessionData().CopyFromAtomicEnabled { - txnOpt.txn = txn - txnOpt.txnTimestamp = txnTs - txnOpt.stmtTimestamp = txnTs + // If the COPY should be atomic, only the final batch can commit. autoCommit = finalBatch } else { - // We're doing original behavior of committing each batch. + // Otherwise we do the original behavior of committing each batch. autoCommit = true } } @@ -831,26 +818,33 @@ func (p *planner) preparePlannerForCopy( return func(ctx context.Context, prevErr error) (err error) { // Ensure that we clean up any accumulated extraTxnState state if we've // been handed a mechanism to do so. - if txnOpt.resetExtraTxnState != nil { + // If this is the finalBatch, then the connExecutor state machine takes + // care of this cleanup. + if implicitTxn && !p.SessionData().CopyFromAtomicEnabled && !finalBatch { defer txnOpt.resetExtraTxnState(ctx) - } - if prevErr == nil { - // Ensure that the txn is committed if the copyMachine is in charge of - // committing its transactions and the execution didn't already commit it - // (through the planner.autoCommit optimization). - if autoCommit && !txn.IsCommitted() { - err = txn.Commit(ctx) - if err != nil { - if rollbackErr := txn.Rollback(ctx); rollbackErr != nil { - log.Eventf(ctx, "rollback failed: %s", rollbackErr) + + if prevErr == nil { + // Ensure that the txn is committed if the copyMachine is in charge of + // committing its transactions and the execution didn't already commit it + // (through the planner.autoCommit optimization). + if !txnOpt.txn.IsCommitted() { + err = txnOpt.txn.Commit(ctx) + if err != nil { + if rollbackErr := txnOpt.txn.Rollback(ctx); rollbackErr != nil { + log.Eventf(ctx, "rollback failed: %s", rollbackErr) + } + return err } } - return err + } else if rollbackErr := txnOpt.txn.Rollback(ctx); rollbackErr != nil { + log.Eventf(ctx, "rollback failed: %s", rollbackErr) } - return nil - } - if rollbackErr := txn.Rollback(ctx); rollbackErr != nil { - log.Eventf(ctx, "rollback failed: %s", rollbackErr) + + // Start the implicit txn for the next batch. + nodeID, _ := p.execCfg.NodeInfo.NodeID.OptionalNodeID() + txnOpt.txn = kv.NewTxnWithSteppingEnabled(ctx, p.execCfg.DB, nodeID, p.SessionData().DefaultTxnQualityOfService) + txnOpt.txnTimestamp = p.execCfg.Clock.PhysicalTime() + txnOpt.stmtTimestamp = txnOpt.txnTimestamp } return prevErr } @@ -931,6 +925,7 @@ func (c *copyMachine) insertRowsInternal(ctx context.Context, finalBatch bool) ( }, Returning: tree.AbsentReturningClause, } + c.txnOpt.initPlanner(ctx, c.p) if err := c.p.makeOptimizerPlan(ctx); err != nil { return err } diff --git a/pkg/sql/pgwire/testdata/pgtest/copy b/pkg/sql/pgwire/testdata/pgtest/copy index 5dd6c94e7931..b53cc110f4dd 100644 --- a/pkg/sql/pgwire/testdata/pgtest/copy +++ b/pkg/sql/pgwire/testdata/pgtest/copy @@ -994,3 +994,49 @@ ReadyForQuery {"Type":"DataRow","Values":[{"text":"t"}]} {"Type":"CommandComplete","CommandTag":"SELECT 1"} {"Type":"ReadyForQuery","TxStatus":"I"} + +# Verify that a CopyFail message does not break the conn_executor state machine. +send +Query {"String": "DROP TABLE IF EXISTS copy_fail_test"} +Query {"String": "CREATE TABLE copy_fail_test(a INT PRIMARY KEY)"} +---- + +until ignore=NoticeResponse +ReadyForQuery +ReadyForQuery +---- +{"Type":"CommandComplete","CommandTag":"DROP TABLE"} +{"Type":"ReadyForQuery","TxStatus":"I"} +{"Type":"CommandComplete","CommandTag":"CREATE TABLE"} +{"Type":"ReadyForQuery","TxStatus":"I"} + +send +Query {"String": "COPY copy_fail_test(a) FROM STDIN"} +CopyFail +---- + +until +ErrorResponse +ReadyForQuery +---- +{"Type":"CopyInResponse","ColumnFormatCodes":[0]} +{"Type":"ErrorResponse","Code":"57014"} +{"Type":"ReadyForQuery","TxStatus":"I"} + +# Previously there was a bug that caused the Parse to fail with "query execution +# canceled" since the cancellation from the CopyFail message was not cleaned up. +send +Parse {"Query": "SELECT count(*) FROM copy_fail_test"} +Bind +Execute +Sync +---- + +until +ReadyForQuery +---- +{"Type":"ParseComplete"} +{"Type":"BindComplete"} +{"Type":"DataRow","Values":[{"text":"0"}]} +{"Type":"CommandComplete","CommandTag":"SELECT 1"} +{"Type":"ReadyForQuery","TxStatus":"I"}