diff --git a/pkg/kv/kvserver/raft_transport.go b/pkg/kv/kvserver/raft_transport.go index 3090305ce134..d9ad130a8a65 100644 --- a/pkg/kv/kvserver/raft_transport.go +++ b/pkg/kv/kvserver/raft_transport.go @@ -446,18 +446,6 @@ func (t *RaftTransport) RaftMessageBatch(stream MultiRaft_RaftMessageBatchServer taskCtx = t.AnnotateCtx(taskCtx) defer cancel() - var storeIDs []roachpb.StoreID - defer func() { - ctx := t.AnnotateCtx(context.Background()) - t.kvflowControl.mu.Lock() - t.kvflowControl.mu.connectionTracker.markStoresDisconnected(storeIDs) - t.kvflowControl.mu.Unlock() - t.kvflowControl.disconnectListener.OnRaftTransportDisconnected(ctx, storeIDs...) - if fn := t.knobs.OnServerStreamDisconnected; fn != nil { - fn() - } - }() - if err := t.stopper.RunAsyncTaskEx( taskCtx, stop.TaskOpts{ @@ -465,6 +453,18 @@ func (t *RaftTransport) RaftMessageBatch(stream MultiRaft_RaftMessageBatchServer SpanOpt: stop.ChildSpan, }, func(ctx context.Context) { errCh <- func() error { + var storeIDs []roachpb.StoreID + defer func() { + ctx := t.AnnotateCtx(context.Background()) + t.kvflowControl.mu.Lock() + t.kvflowControl.mu.connectionTracker.markStoresDisconnected(storeIDs) + t.kvflowControl.mu.Unlock() + t.kvflowControl.disconnectListener.OnRaftTransportDisconnected(ctx, storeIDs...) + if fn := t.knobs.OnServerStreamDisconnected; fn != nil { + fn() + } + }() + stream := &lockedRaftMessageResponseStream{wrapped: stream} for { batch, err := stream.Recv()