diff --git a/pkg/kv/kvclient/kvcoord/dist_sender_mux_rangefeed.go b/pkg/kv/kvclient/kvcoord/dist_sender_mux_rangefeed.go index c3f1823142af..dd4706040681 100644 --- a/pkg/kv/kvclient/kvcoord/dist_sender_mux_rangefeed.go +++ b/pkg/kv/kvclient/kvcoord/dist_sender_mux_rangefeed.go @@ -106,14 +106,14 @@ func muxRangeFeed( // the entire stream must be torn down, and all active range feeds should be // restarted. type muxStream struct { - nodeID roachpb.NodeID - streams syncutil.IntMap // streamID -> *activeMuxRangeFeed. + nodeID roachpb.NodeID // mu must be held when starting rangefeed. mu struct { syncutil.Mutex - sender rangeFeedRequestSender - closed bool + sender rangeFeedRequestSender + streams map[int64]*activeMuxRangeFeed + closed bool } } @@ -311,6 +311,7 @@ func (m *rangefeedMuxer) startNodeMuxRangeFeed( ms := muxStream{nodeID: nodeID} ms.mu.sender = mux + ms.mu.streams = make(map[int64]*activeMuxRangeFeed) if err := future.MustSet(stream, muxStreamOrError{stream: &ms}); err != nil { return err } @@ -330,7 +331,7 @@ func (m *rangefeedMuxer) startNodeMuxRangeFeed( recvErr = nil } - return ms.closeWithRestart(ctx, recvErr, func(_ int64, a *activeMuxRangeFeed) error { + return ms.closeWithRestart(ctx, recvErr, func(a *activeMuxRangeFeed) error { return m.restartActiveRangeFeed(ctx, a, recvErr) }) } @@ -399,7 +400,7 @@ func (m *rangefeedMuxer) receiveEventsFromNode( if active.catchupRes != nil { m.ds.metrics.RangefeedErrorCatchup.Inc(1) } - ms.streams.Delete(event.StreamID) + ms.deleteStream(event.StreamID) if err := m.restartActiveRangeFeed(ctx, active, t.Error.GoError()); err != nil { return err } @@ -421,14 +422,10 @@ func (m *rangefeedMuxer) receiveEventsFromNode( // get stuck in the first place. if timeutil.Now().Before(nextStuckCheck) { if threshold := stuckThreshold(); threshold > 0 { - if _, err := ms.eachStream(func(id int64, a *activeMuxRangeFeed) error { - if !a.startAfter.IsEmpty() && timeutil.Since(a.startAfter.GoTime()) > stuckThreshold() { - ms.streams.Delete(id) - return m.restartActiveRangeFeed(ctx, a, errRestartStuckRange) + for _, a := range ms.purgeStuckStreams(threshold) { + if err := m.restartActiveRangeFeed(ctx, a, errRestartStuckRange); err != nil { + return err } - return nil - }); err != nil { - return err } } nextStuckCheck = timeutil.Now().Add(stuckCheckFreq()) @@ -491,29 +488,46 @@ func (c *muxStream) startRangeFeed( return err } - // mu must be held while marking this stream in flight (streams.Store) to - // synchronize with mux termination. When node mux terminates, it invokes - // c.closeWithRestart(), which marks this mux stream connection closed and - // restarts all active streams. Thus, we must make sure that this streamID - // gets properly recorded even if mux go routine terminates right after the - // above sender.Send() succeeded. - c.streams.Store(streamID, unsafe.Pointer(stream)) + // As soon as we issue Send above, the stream may return an error that + // may be seen by the event consumer (receiveEventsFromNode). + // Therefore, we update streams map under the lock to ensure that the + // receiver will be able to observe this stream. + c.mu.streams[streamID] = stream return nil } -func (c *muxStream) lookupStream(streamID int64) *activeMuxRangeFeed { - v, ok := c.streams.Load(streamID) - if ok { - return (*activeMuxRangeFeed)(v) +func (c *muxStream) lookupStream(streamID int64) (a *activeMuxRangeFeed) { + c.mu.Lock() + a = c.mu.streams[streamID] + c.mu.Unlock() + return a +} + +func (c *muxStream) purgeStuckStreams(threshold time.Duration) (stuck []*activeMuxRangeFeed) { + c.mu.Lock() + for streamID, a := range c.mu.streams { + if !a.startAfter.IsEmpty() && timeutil.Since(a.startAfter.GoTime()) > threshold { + stuck = append(stuck, a) + delete(c.mu.streams, streamID) + } } - return nil + c.mu.Unlock() + return stuck +} + +func (c *muxStream) deleteStream(streamID int64) { + c.mu.Lock() + delete(c.mu.streams, streamID) + c.mu.Unlock() } func (c *muxStream) closeWithRestart( - ctx context.Context, reason error, restartFn func(streamID int64, a *activeMuxRangeFeed) error, + ctx context.Context, reason error, restartFn func(a *activeMuxRangeFeed) error, ) error { c.mu.Lock() c.mu.closed = true + toRestart := c.mu.streams + c.mu.streams = nil c.mu.Unlock() // make sure that the underlying error is not fatal. If it is, there is no @@ -522,22 +536,14 @@ func (c *muxStream) closeWithRestart( return err } - n, err := c.eachStream(restartFn) - if log.V(1) { - log.Infof(ctx, "mux to node %d restarted %d streams: err=%v", c.nodeID, n, err) + for _, a := range toRestart { + if err := restartFn(a); err != nil { + return err + } } - return err -} -// eachStream invokes provided function for each stream. If the function -// returns an error, iteration stops. Returns number of streams processed. -func (c *muxStream) eachStream( - fn func(streamID int64, a *activeMuxRangeFeed) error, -) (n int, err error) { - c.streams.Range(func(key int64, value unsafe.Pointer) bool { - err = fn(key, (*activeMuxRangeFeed)(value)) - n++ - return err == nil - }) - return n, err + if log.V(1) { + log.Infof(ctx, "mux to node %d restarted %d streams", c.nodeID, len(toRestart)) + } + return nil }