diff --git a/pkg/sql/internal.go b/pkg/sql/internal.go index 76703b60adaf..03959108d732 100644 --- a/pkg/sql/internal.go +++ b/pkg/sql/internal.go @@ -159,13 +159,14 @@ func (ie *InternalExecutor) runWithEx( ctx context.Context, txn *kv.Txn, w ieResultWriter, + mode ieExecutionMode, sd *sessiondata.SessionData, stmtBuf *StmtBuf, wg *sync.WaitGroup, syncCallback func([]*streamingCommandResult), errCallback func(error), ) error { - ex, err := ie.initConnEx(ctx, txn, w, sd, stmtBuf, syncCallback) + ex, err := ie.initConnEx(ctx, txn, w, mode, sd, stmtBuf, syncCallback) if err != nil { return err } @@ -201,13 +202,19 @@ func (ie *InternalExecutor) initConnEx( ctx context.Context, txn *kv.Txn, w ieResultWriter, + mode ieExecutionMode, sd *sessiondata.SessionData, stmtBuf *StmtBuf, syncCallback func([]*streamingCommandResult), ) (*connExecutor, error) { clientComm := &internalClientComm{ w: w, + mode: mode, sync: syncCallback, + resetRowsAffected: func() { + var zero int + _ = w.addResult(ctx, ieIteratorResult{rowsAffected: &zero}) + }, } applicationStats := ie.s.sqlStats.GetApplicationStats(sd.ApplicationName, true /* internal */) @@ -363,6 +370,8 @@ type rowsIterator struct { rowsAffected int resultCols colinfo.ResultColumns + mode ieExecutionMode + // first, if non-nil, is the first object read from r. We block the return // of the created rowsIterator in execInternal() until the producer writes // something into the corresponding ieResultWriter because this indicates @@ -432,6 +441,12 @@ func (r *rowsIterator) Next(ctx context.Context) (_ bool, retErr error) { return r.Next(ctx) } if data.cols != nil { + if r.mode == rowsAffectedIEExecutionMode { + // In "rows affected" execution mode we simply ignore the column + // schema since we always return the number of rows affected + // (i.e. a single integer column). + return r.Next(ctx) + } // At this point we don't expect to see the columns - we should only // return the rowsIterator to the caller of execInternal after the // columns have been determined. @@ -559,7 +574,7 @@ func (ie *InternalExecutor) queryInternalBuffered( // We will run the query to completion, so we can use an async result // channel. rw := newAsyncIEResultChannel() - it, err := ie.execInternal(ctx, opName, rw, txn, sessionDataOverride, stmt, qargs...) + it, err := ie.execInternal(ctx, opName, rw, defaultIEExecutionMode, txn, sessionDataOverride, stmt, qargs...) if err != nil { return nil, nil, err } @@ -661,7 +676,11 @@ func (ie *InternalExecutor) ExecEx( // We will run the query to completion, so we can use an async result // channel. rw := newAsyncIEResultChannel() - it, err := ie.execInternal(ctx, opName, rw, txn, session, stmt, qargs...) + // Since we only return the number of rows affected as given by the + // rowsIterator, we execute this stmt in "rows affected" mode allowing the + // internal executor to transparently retry. + const mode = rowsAffectedIEExecutionMode + it, err := ie.execInternal(ctx, opName, rw, mode, txn, session, stmt, qargs...) if err != nil { return 0, err } @@ -700,7 +719,7 @@ func (ie *InternalExecutor) QueryIteratorEx( qargs ...interface{}, ) (isql.Rows, error) { return ie.execInternal( - ctx, opName, newSyncIEResultChannel(), txn, session, stmt, qargs..., + ctx, opName, newSyncIEResultChannel(), defaultIEExecutionMode, txn, session, stmt, qargs..., ) } @@ -834,6 +853,12 @@ var rowsAffectedResultColumns = colinfo.ResultColumns{ // - If the retry error occurs after some rows have been sent from the // streamingCommandResult to the rowsIterator, we have no choice but to return // the retry error to the caller. +// - The only exception to this is when the stmt of "Rows" type was issued via +// ExecEx call. In such a scenario, we only need to report the number of +// "rows affected" that we obtain by counting all rows seen by the +// rowsIterator. With such a setup, we can transparently retry the execution +// of the corresponding command by simply resetting the counter when +// discarding the result of Sync command after the retry error occurs. // - If the retry error occurs after the "rows affected" metadata was sent for // stmts of "RowsAffected" type, then we will always retry transparently. This // is achieved by overriding the "rows affected" number, stored in the @@ -868,6 +893,7 @@ func (ie *InternalExecutor) execInternal( ctx context.Context, opName string, rw *ieResultChannel, + mode ieExecutionMode, txn *kv.Txn, sessionDataOverride sessiondata.InternalExecutorOverride, stmt string, @@ -986,7 +1012,7 @@ func (ie *InternalExecutor) execInternal( errCallback := func(err error) { _ = rw.addResult(ctx, ieIteratorResult{err: err}) } - err = ie.runWithEx(ctx, txn, rw, sd, stmtBuf, &wg, syncCallback, errCallback) + err = ie.runWithEx(ctx, txn, rw, mode, sd, stmtBuf, &wg, syncCallback, errCallback) if err != nil { return nil, err } @@ -1047,6 +1073,7 @@ func (ie *InternalExecutor) execInternal( } r = &rowsIterator{ r: rw, + mode: mode, stmtBuf: stmtBuf, wg: &wg, } @@ -1112,7 +1139,7 @@ func (ie *InternalExecutor) commitTxn(ctx context.Context) error { rw := newAsyncIEResultChannel() stmtBuf := NewStmtBuf() - ex, err := ie.initConnEx(ctx, ie.extraTxnState.txn, rw, sd, stmtBuf, nil /* syncCallback */) + ex, err := ie.initConnEx(ctx, ie.extraTxnState.txn, rw, defaultIEExecutionMode, sd, stmtBuf, nil /* syncCallback */) if err != nil { return errors.Wrap(err, "cannot create conn executor to commit txn") } @@ -1165,6 +1192,26 @@ func (ie *InternalExecutor) checkIfTxnIsConsistent(txn *kv.Txn) error { return nil } +// ieExecutionMode determines how the internal executor consumes the results of +// the statement evaluation. +type ieExecutionMode int + +const ( + // defaultIEExecutionMode is the execution mode in which the results of the + // statement evaluation are consumed according to the statement's type. + defaultIEExecutionMode ieExecutionMode = iota + // rowsAffectedIEExecutionMode is the execution mode in which the internal + // executor is only interested in the number of rows affected, regardless of + // the statement's type. + // + // With this mode, if a stmt encounters a retry error, the internal executor + // will proceed to transparently reset the number of rows affected (if any + // have been seen by the rowsIterator) and retry the corresponding command. + // Such behavior makes sense given that in production code at most one + // command in the StmtBuf results in "rows affected". + rowsAffectedIEExecutionMode +) + // internalClientComm is an implementation of ClientComm used by the // InternalExecutor. Result rows are streamed on the channel to the // ieResultWriter. @@ -1183,6 +1230,15 @@ type internalClientComm struct { // The results of the query execution will be written into w. w ieResultWriter + // mode determines how the results of the query execution are consumed. + mode ieExecutionMode + + // resetRowsAffected is a callback that sends a single ieIteratorResult + // object to w in order to set the number of rows affected to zero. Only + // used in rowsAffectedIEExecutionMode when discarding a result (indicating + // that a command will be retried). + resetRowsAffected func() + // sync, if set, is called whenever a Sync is executed with all accumulated // results since the last Sync. sync func([]*streamingCommandResult) @@ -1220,6 +1276,9 @@ func (icc *internalClientComm) createRes(pos CmdPos) *streamingCommandResult { // results slice at the moment and all previous results have been // "finalized"). icc.results = icc.results[:len(icc.results)-1] + if icc.mode == rowsAffectedIEExecutionMode { + icc.resetRowsAffected() + } }, } icc.results = append(icc.results, res) @@ -1315,6 +1374,14 @@ func (icc *internalClientComm) Close() {} // ClientPos is part of the ClientLock interface. func (icc *internalClientComm) ClientPos() CmdPos { + if icc.mode == rowsAffectedIEExecutionMode { + // With the "rows affected" mode, any command can be rewound since we + // assume that only a single command results in actual "rows affected", + // and in Discard we will reset the number to zero (if we were in + // process of evaluation that command when we encountered the retry + // error). + return -1 + } // Find the latest result that cannot be rewound. lastDelivered := CmdPos(-1) for _, r := range icc.results { diff --git a/pkg/sql/internal_test.go b/pkg/sql/internal_test.go index 315edc66c656..acc8d22f129d 100644 --- a/pkg/sql/internal_test.go +++ b/pkg/sql/internal_test.go @@ -695,7 +695,7 @@ func TestInternalExecutorEncountersRetry(t *testing.T) { ctx := context.Background() params, _ := tests.CreateTestServerParams() - s, db, _ := serverutils.StartServer(t, params) + s, db, kvDB := serverutils.StartServer(t, params) defer s.Stopper().Stop(ctx) if _, err := db.Exec("CREATE DATABASE test; CREATE TABLE test.t (c) AS SELECT 1"); err != nil { @@ -728,12 +728,35 @@ func TestInternalExecutorEncountersRetry(t *testing.T) { } }) + const rowsStmt = `SELECT * FROM test.t` + // This test case verifies that if the retry error occurs after some rows // have been communicated to the client, then the stmt results in the retry // error too - the IE cannot transparently retry it. t.Run("Rows stmt", func(t *testing.T) { - const stmt = `SELECT * FROM test.t` - _, err := ie.QueryBufferedEx(ctx, "read rows", nil /* txn */, ieo, stmt) + _, err := ie.QueryBufferedEx(ctx, "read rows", nil /* txn */, ieo, rowsStmt) + if !testutils.IsError(err, "inject_retry_errors_enabled") { + t.Fatalf("expected to see injected retry error, got %v", err) + } + }) + + // This test case verifies that ExecEx of a stmt of Rows type correctly and + // transparently to us retries the stmt. + t.Run("ExecEx retries in implicit txn", func(t *testing.T) { + numRows, err := ie.ExecEx(ctx, "read rows", nil /* txn */, ieo, rowsStmt) + if err != nil { + t.Fatal(err) + } + if numRows != 1 { + t.Fatalf("expected 1 rowsAffected, got %d", numRows) + } + }) + + // This test case verifies that ExecEx doesn't retry when it's provided with + // an explicit txn. + t.Run("ExecEx doesn't retry in explicit txn", func(t *testing.T) { + txn := kvDB.NewTxn(ctx, "explicit") + _, err := ie.ExecEx(ctx, "read rows", txn, ieo, rowsStmt) if !testutils.IsError(err, "inject_retry_errors_enabled") { t.Fatalf("expected to see injected retry error, got %v", err) }