Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

sql: fix internal executor when it encounters a retry error #101477

Merged
merged 5 commits into from
Apr 19, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions pkg/ccl/changefeedccl/changefeed_dist.go
Original file line number Diff line number Diff line change
Expand Up @@ -492,8 +492,8 @@ func (w *changefeedResultWriter) AddRow(ctx context.Context, row tree.Datums) er
return nil
}
}
func (w *changefeedResultWriter) IncrementRowsAffected(ctx context.Context, n int) {
w.rowsAffected += n
func (w *changefeedResultWriter) SetRowsAffected(ctx context.Context, n int) {
w.rowsAffected = n
}
func (w *changefeedResultWriter) SetError(err error) {
w.err = err
Expand Down
17 changes: 12 additions & 5 deletions pkg/sql/conn_executor_internal_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -263,7 +263,14 @@ func mustParseOne(s string) statements.Statement[tree.Statement] {
// need to read from it.
func startConnExecutor(
ctx context.Context,
) (*StmtBuf, <-chan []resWithPos, <-chan error, *stop.Stopper, ieResultReader, error) {
) (
*StmtBuf,
<-chan []*streamingCommandResult,
<-chan error,
*stop.Stopper,
ieResultReader,
error,
) {
// A lot of boilerplate for creating a connExecutor.
stopper := stop.NewStopper()
clock := hlc.NewClockForTesting(nil)
Expand Down Expand Up @@ -340,10 +347,10 @@ func startConnExecutor(

s := NewServer(cfg, pool)
buf := NewStmtBuf()
syncResults := make(chan []resWithPos, 1)
syncResults := make(chan []*streamingCommandResult, 1)
resultChannel := newAsyncIEResultChannel()
var cc ClientComm = &internalClientComm{
sync: func(res []resWithPos) {
sync: func(res []*streamingCommandResult) {
syncResults <- res
},
w: resultChannel,
Expand Down Expand Up @@ -380,9 +387,9 @@ func TestSessionCloseWithPendingTempTableInTxn(t *testing.T) {

srv := s.SQLServer().(*Server)
stmtBuf := NewStmtBuf()
flushed := make(chan []resWithPos)
flushed := make(chan []*streamingCommandResult)
clientComm := &internalClientComm{
sync: func(res []resWithPos) {
sync: func(res []*streamingCommandResult) {
flushed <- res
},
}
Expand Down
71 changes: 39 additions & 32 deletions pkg/sql/conn_io.go
Original file line number Diff line number Diff line change
Expand Up @@ -634,7 +634,7 @@ const (
//
// ClientComm is implemented by the pgwire connection.
type ClientComm interface {
// createStatementResult creates a StatementResult for stmt.
// CreateStatementResult creates a StatementResult for stmt.
//
// descOpt specifies if result needs to inform the client about row schema. If
// it doesn't, a SetColumns call becomes a no-op.
Expand Down Expand Up @@ -683,7 +683,7 @@ type ClientComm interface {
// CreateDrainResult creates a result for a Drain command.
CreateDrainResult(pos CmdPos) DrainResult

// lockCommunication ensures that no further results are delivered to the
// LockCommunication ensures that no further results are delivered to the
// client. The returned ClientLock can be queried to see what results have
// been already delivered to the client and to discard results that haven't
// been delivered.
Expand Down Expand Up @@ -791,12 +791,13 @@ type RestrictedCommandResult interface {
// AddBatch is undefined.
SupportsAddBatch() bool

// IncrementRowsAffected increments a counter by n. This is used for all
// SetRowsAffected sets RowsAffected counter to n. This is used for all
// result types other than tree.Rows.
IncrementRowsAffected(ctx context.Context, n int)
SetRowsAffected(ctx context.Context, n int)

// RowsAffected returns either the number of times AddRow was called, or the
// sum of all n passed into IncrementRowsAffected.
// RowsAffected returns either the number of times AddRow was called, total
// number of rows pushed via AddBatch, or the last value of n passed into
// SetRowsAffected.
RowsAffected() int

// DisableBuffering can be called during execution to ensure that
Expand Down Expand Up @@ -927,10 +928,11 @@ type ClientLock interface {
// connection.
ClientPos() CmdPos

// RTrim iterates backwards through the results and drops all results with
// position >= pos.
// It is illegal to call rtrim with a position <= clientPos(). In other words,
// results can
// RTrim drops all results with position >= pos.
//
// It is illegal to call RTrim with a position <= ClientPos(). In other
// words, results can only be trimmed if they haven't been sent to the
// client.
RTrim(ctx context.Context, pos CmdPos)
}

Expand Down Expand Up @@ -963,24 +965,27 @@ func (rc *rewindCapability) close() {
rc.cl.Close()
}

type resCloseType bool

const closed resCloseType = true
const discarded resCloseType = false

// streamingCommandResult is a CommandResult that streams rows on the channel
// and can call a provided callback when closed.
type streamingCommandResult struct {
pos CmdPos

// All the data (the rows and the metadata) are written into w. The
// goroutine writing into this streamingCommandResult might block depending
// on the synchronization strategy.
w ieResultWriter

// cannotRewind indicates whether this result has communicated some data
// (rows or metadata) such that the corresponding command cannot be rewound.
cannotRewind bool

err error
rowsAffected int

// closeCallback, if set, is called when Close()/Discard() is called.
closeCallback func(*streamingCommandResult, resCloseType)
// closeCallback, if set, is called when Close() is called.
closeCallback func()
// discardCallback, if set, is called when Discard() is called.
discardCallback func()
}

var _ RestrictedCommandResult = &streamingCommandResult{}
Expand All @@ -993,7 +998,7 @@ func (r *streamingCommandResult) ErrAllowReleased() error {

// RevokePortalPausability is part of the sql.RestrictedCommandResult interface.
func (r *streamingCommandResult) RevokePortalPausability() error {
return errors.AssertionFailedf("forPausablePortal is for limitedCommandResult only")
return errors.AssertionFailedf("RevokePortalPausability is for limitedCommandResult only")
}

// SetColumns is part of the RestrictedCommandResult interface.
Expand All @@ -1003,6 +1008,8 @@ func (r *streamingCommandResult) SetColumns(ctx context.Context, cols colinfo.Re
if cols == nil {
cols = colinfo.ResultColumns{}
}
// NB: we do not set r.cannotRewind here because the correct columns will be
// set in rowsIterator.Next.
_ = r.w.addResult(ctx, ieIteratorResult{cols: cols})
}

Expand All @@ -1023,12 +1030,15 @@ func (r *streamingCommandResult) ResetStmtType(stmt tree.Statement) {

// AddRow is part of the RestrictedCommandResult interface.
func (r *streamingCommandResult) AddRow(ctx context.Context, row tree.Datums) error {
// AddRow() and IncrementRowsAffected() are never called on the same command
// AddRow() and SetRowsAffected() are never called on the same command
// result, so we will not double count the affected rows by an increment
// here.
r.rowsAffected++
rowCopy := make(tree.Datums, len(row))
copy(rowCopy, row)
// Once we add this row to the writer, it can be immediately consumed by the
// reader, so this result can no longer be rewound.
r.cannotRewind = true
return r.w.addResult(ctx, ieIteratorResult{row: rowCopy})
}

Expand Down Expand Up @@ -1056,7 +1066,7 @@ func (r *streamingCommandResult) SetError(err error) {
// in execStmtInOpenState().
}

// GetEntryFromExtraInfo is part of the sql.RestrictedCommandResult interface.
// GetBulkJobId is part of the sql.RestrictedCommandResult interface.
func (r *streamingCommandResult) GetBulkJobId() uint64 {
return 0
}
Expand All @@ -1066,13 +1076,15 @@ func (r *streamingCommandResult) Err() error {
return r.err
}

// IncrementRowsAffected is part of the RestrictedCommandResult interface.
func (r *streamingCommandResult) IncrementRowsAffected(ctx context.Context, n int) {
r.rowsAffected += n
// SetRowsAffected is part of the RestrictedCommandResult interface.
func (r *streamingCommandResult) SetRowsAffected(ctx context.Context, n int) {
r.rowsAffected = n
// streamingCommandResult might be used outside of the internal executor
// (i.e. not by rowsIterator) in which case the channel is not set.
if r.w != nil {
_ = r.w.addResult(ctx, ieIteratorResult{rowsAffectedIncrement: &n})
// NB: we do not set r.cannotRewind here because rowsAffected value will
// be overwritten in rowsIterator.Next correctly if necessary.
_ = r.w.addResult(ctx, ieIteratorResult{rowsAffected: &n})
}
}

Expand All @@ -1084,14 +1096,14 @@ func (r *streamingCommandResult) RowsAffected() int {
// Close is part of the CommandResultClose interface.
func (r *streamingCommandResult) Close(context.Context, TransactionStatusIndicator) {
if r.closeCallback != nil {
r.closeCallback(r, closed)
r.closeCallback()
}
}

// Discard is part of the CommandResult interface.
func (r *streamingCommandResult) Discard() {
if r.closeCallback != nil {
r.closeCallback(r, discarded)
if r.discardCallback != nil {
r.discardCallback()
}
}

Expand All @@ -1110,11 +1122,6 @@ func (r *streamingCommandResult) SetPortalOutput(
) {
}

// SetRowsAffected is part of the sql.CopyInResult interface.
func (r *streamingCommandResult) SetRowsAffected(ctx context.Context, rows int) {
r.rowsAffected = rows
}

// SendCopyOut is part of the sql.CopyOutResult interface.
func (r *streamingCommandResult) SendCopyOut(
ctx context.Context, cols colinfo.ResultColumns, format pgwirebase.FormatCode,
Expand Down
22 changes: 11 additions & 11 deletions pkg/sql/distsql_running.go
Original file line number Diff line number Diff line change
Expand Up @@ -997,7 +997,7 @@ type rowResultWriter interface {
// AddRow writes a result row.
// Note that the caller owns the row slice and might reuse it.
AddRow(ctx context.Context, row tree.Datums) error
IncrementRowsAffected(ctx context.Context, n int)
SetRowsAffected(ctx context.Context, n int)
SetError(error)
Err() error
}
Expand Down Expand Up @@ -1088,8 +1088,8 @@ func (w *errOnlyResultWriter) AddBatch(ctx context.Context, batch coldata.Batch)
panic("AddBatch not supported by errOnlyResultWriter")
}

func (w *errOnlyResultWriter) IncrementRowsAffected(ctx context.Context, n int) {
panic("IncrementRowsAffected not supported by errOnlyResultWriter")
func (w *errOnlyResultWriter) SetRowsAffected(ctx context.Context, n int) {
panic("SetRowsAffected not supported by errOnlyResultWriter")
}

// RowResultWriter is a thin wrapper around a RowContainer.
Expand All @@ -1106,9 +1106,9 @@ func NewRowResultWriter(rowContainer *rowContainerHelper) *RowResultWriter {
return &RowResultWriter{rowContainer: rowContainer}
}

// IncrementRowsAffected implements the rowResultWriter interface.
func (b *RowResultWriter) IncrementRowsAffected(ctx context.Context, n int) {
b.rowsAffected += n
// SetRowsAffected implements the rowResultWriter interface.
func (b *RowResultWriter) SetRowsAffected(ctx context.Context, n int) {
b.rowsAffected = n
}

// AddRow implements the rowResultWriter interface.
Expand Down Expand Up @@ -1146,9 +1146,9 @@ func NewCallbackResultWriter(
return &CallbackResultWriter{fn: fn}
}

// IncrementRowsAffected is part of the rowResultWriter interface.
func (c *CallbackResultWriter) IncrementRowsAffected(ctx context.Context, n int) {
c.rowsAffected += n
// SetRowsAffected is part of the rowResultWriter interface.
func (c *CallbackResultWriter) SetRowsAffected(ctx context.Context, n int) {
c.rowsAffected = n
}

// AddRow is part of the rowResultWriter interface.
Expand Down Expand Up @@ -1432,7 +1432,7 @@ func (r *DistSQLReceiver) Push(
// We only need the row count. planNodeToRowSource is set up to handle
// ensuring that the last stage in the pipeline will return a single-column
// row with the row count in it, so just grab that and exit.
r.resultWriterMu.row.IncrementRowsAffected(r.ctx, n)
r.resultWriterMu.row.SetRowsAffected(r.ctx, n)
return r.status
}

Expand Down Expand Up @@ -1516,7 +1516,7 @@ func (r *DistSQLReceiver) PushBatch(
// We only need the row count. planNodeToRowSource is set up to handle
// ensuring that the last stage in the pipeline will return a single-column
// row with the row count in it, so just grab that and exit.
r.resultWriterMu.row.IncrementRowsAffected(r.ctx, int(batch.ColVec(0).Int64()[0]))
r.resultWriterMu.row.SetRowsAffected(r.ctx, int(batch.ColVec(0).Int64()[0]))
return r.status
}

Expand Down
Loading