diff --git a/pkg/kv/kvserver/rangefeed/BUILD.bazel b/pkg/kv/kvserver/rangefeed/BUILD.bazel index c02cd9ec3161..16e0e1fcb991 100644 --- a/pkg/kv/kvserver/rangefeed/BUILD.bazel +++ b/pkg/kv/kvserver/rangefeed/BUILD.bazel @@ -14,6 +14,7 @@ go_library( "scheduled_processor.go", "scheduler.go", "stream_muxer.go", + "stream_muxer_test_helper.go", "task.go", "testutil.go", ], @@ -65,6 +66,7 @@ go_test( "registry_test.go", "resolved_timestamp_test.go", "scheduler_test.go", + "stream_muxer_test.go", "task_test.go", ], embed = [":rangefeed"], diff --git a/pkg/kv/kvserver/rangefeed/stream_muxer.go b/pkg/kv/kvserver/rangefeed/stream_muxer.go index 8fa54c03fa40..07b519b25ac9 100644 --- a/pkg/kv/kvserver/rangefeed/stream_muxer.go +++ b/pkg/kv/kvserver/rangefeed/stream_muxer.go @@ -15,6 +15,7 @@ import ( "sync" "github.com/cockroachdb/cockroach/pkg/kv/kvpb" + "github.com/cockroachdb/cockroach/pkg/roachpb" "github.com/cockroachdb/cockroach/pkg/util/log" "github.com/cockroachdb/cockroach/pkg/util/stop" "github.com/cockroachdb/cockroach/pkg/util/syncutil" @@ -105,6 +106,9 @@ type StreamMuxer struct { // thread safety. sender ServerStreamSender + // streamID -> context.CancelFunc for active rangefeeds + activeStreams sync.Map + // notifyMuxError is a buffered channel of size 1 used to signal the presence // of muxErrors. Additional signals are dropped if the channel is already full // so that it's non-blocking. @@ -128,20 +132,71 @@ func NewStreamMuxer(sender ServerStreamSender) *StreamMuxer { } } -// AppendMuxError appends a mux rangefeed completion error to be sent back to +// AddStream registers a server rangefeed stream with the StreamMuxer. It +// remains active until DisconnectStreamWithError is called with the same +// streamID. Caller must ensure no duplicate stream IDs are added without +// disconnecting the old one first. +func (sm *StreamMuxer) AddStream(streamID int64, cancel context.CancelFunc) { + if _, loaded := sm.activeStreams.LoadOrStore(streamID, cancel); loaded { + log.Fatalf(context.Background(), "stream %d already exists", streamID) + } + +} + +// transformRangefeedErrToClientError converts a rangefeed error to a client +// error to be sent back to client. This also handles nil values, preventing nil +// pointer dereference. +func transformRangefeedErrToClientError(err *kvpb.Error) *kvpb.Error { + if err == nil { + // When processor is stopped when it no longer has any registrations, it + // would attempt to close all feeds again with a nil error. Theoretically, + // this should never happen as processor would always stop with a reason if + // feeds are active. + return kvpb.NewError(kvpb.NewRangeFeedRetryError(kvpb.RangeFeedRetryError_REASON_RANGEFEED_CLOSED)) + } + return err +} + +// appendMuxError appends a mux rangefeed completion error to be sent back to // the client. Note that this method cannot block on IO. If the underlying // stream is broken, the error will be dropped. -func (sm *StreamMuxer) AppendMuxError(e *kvpb.MuxRangeFeedEvent) { +func (sm *StreamMuxer) appendMuxError(e *kvpb.MuxRangeFeedEvent) { sm.mu.Lock() defer sm.mu.Unlock() sm.mu.muxErrors = append(sm.mu.muxErrors, e) - // Note that notify is non-blocking. + // Note that notifyMuxError is non-blocking. select { case sm.notifyMuxError <- struct{}{}: default: } } +// DisconnectStreamWithError disconnects a stream with an error. Safe to call +// repeatedly for the same stream, but subsequent errors are ignored. It ensures +// 1. the stream context is cancelled 2. exactly one error is sent back to the +// client on behalf of the stream. +// +// Note that this function can be called by the processor worker while holding +// raftMu, so it is important that this function doesn't block IO. It does so by +// delegating the responsibility of sending mux error to StreamMuxer.run. +func (sm *StreamMuxer) DisconnectStreamWithError( + streamID int64, rangeID roachpb.RangeID, err *kvpb.Error, +) { + if cancelFunc, ok := sm.activeStreams.LoadAndDelete(streamID); ok { + f := cancelFunc.(context.CancelFunc) + f() + clientErrorEvent := transformRangefeedErrToClientError(err) + ev := &kvpb.MuxRangeFeedEvent{ + StreamID: streamID, + RangeID: rangeID, + } + ev.MustSetValue(&kvpb.RangeFeedError{ + Error: *clientErrorEvent, + }) + sm.appendMuxError(ev) + } +} + // detachMuxErrors returns muxErrors and clears the slice. Caller must ensure // the returned errors are sent back to the client. func (sm *StreamMuxer) detachMuxErrors() []*kvpb.MuxRangeFeedEvent { @@ -160,7 +215,8 @@ func (sm *StreamMuxer) run(ctx context.Context, stopper *stop.Stopper) error { for { select { case <-sm.notifyMuxError: - for _, clientErr := range sm.detachMuxErrors() { + toSend := sm.detachMuxErrors() + for _, clientErr := range toSend { if err := sm.sender.Send(clientErr); err != nil { log.Errorf(ctx, "failed to send rangefeed completion error back to client due to broken stream: %v", err) diff --git a/pkg/kv/kvserver/rangefeed/stream_muxer_test.go b/pkg/kv/kvserver/rangefeed/stream_muxer_test.go new file mode 100644 index 000000000000..8355a00bcad5 --- /dev/null +++ b/pkg/kv/kvserver/rangefeed/stream_muxer_test.go @@ -0,0 +1,168 @@ +// Copyright 2024 The Cockroach Authors. +// +// Use of this software is governed by the Business Source License +// included in the file licenses/BSL.txt. +// +// As of the Change Date specified in that file, in accordance with +// the Business Source License, use of this software will be governed +// by the Apache License, Version 2.0, included in the file +// licenses/APL.txt. + +package rangefeed + +import ( + "context" + "sync" + "testing" + "time" + + "github.com/cockroachdb/cockroach/pkg/kv/kvpb" + "github.com/cockroachdb/cockroach/pkg/roachpb" + "github.com/cockroachdb/cockroach/pkg/testutils" + "github.com/cockroachdb/cockroach/pkg/util/hlc" + "github.com/cockroachdb/cockroach/pkg/util/leaktest" + "github.com/cockroachdb/cockroach/pkg/util/log" + "github.com/cockroachdb/cockroach/pkg/util/stop" + "github.com/cockroachdb/errors" + "github.com/stretchr/testify/require" +) + +// TestStreamMuxer tests that correctly forwards rangefeed completion errors to +// the server stream. +func TestStreamMuxer(t *testing.T) { + defer leaktest.AfterTest(t)() + defer log.Scope(t).Close(t) + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + stopper := stop.NewStopper() + defer stopper.Stop(ctx) + + testServerStream := newTestServerStream() + muxer := NewStreamMuxer(testServerStream) + require.NoError(t, muxer.Start(ctx, stopper)) + defer muxer.Stop() + + t.Run("nil handling", func(t *testing.T) { + const streamID = 0 + const rangeID = 1 + streamCtx, cancel := context.WithCancel(context.Background()) + muxer.AddStream(0, cancel) + // Note that kvpb.NewError(nil) == nil. + muxer.DisconnectStreamWithError(streamID, rangeID, kvpb.NewError(nil)) + require.Equal(t, context.Canceled, streamCtx.Err()) + expectedErrEvent := &kvpb.MuxRangeFeedEvent{ + StreamID: streamID, + RangeID: rangeID, + } + expectedErrEvent.MustSetValue(&kvpb.RangeFeedError{ + Error: *kvpb.NewError(kvpb.NewRangeFeedRetryError(kvpb.RangeFeedRetryError_REASON_RANGEFEED_CLOSED)), + }) + time.Sleep(10 * time.Millisecond) + require.Equal(t, 1, testServerStream.totalEventsSent()) + require.True(t, testServerStream.hasEvent(expectedErrEvent)) + + // Repeat closing the stream does nothing. + muxer.DisconnectStreamWithError(streamID, rangeID, + kvpb.NewError(kvpb.NewRangeFeedRetryError(kvpb.RangeFeedRetryError_REASON_RANGEFEED_CLOSED))) + time.Sleep(10 * time.Millisecond) + require.Equal(t, 1, testServerStream.totalEventsSent()) + }) + + t.Run("send rangefeed completion error", func(t *testing.T) { + testRangefeedCompletionErrors := []struct { + streamID int64 + rangeID roachpb.RangeID + Error error + }{ + {0, 1, kvpb.NewRangeFeedRetryError(kvpb.RangeFeedRetryError_REASON_RANGEFEED_CLOSED)}, + {1, 1, context.Canceled}, + {2, 2, &kvpb.NodeUnavailableError{}}, + } + + for _, muxError := range testRangefeedCompletionErrors { + muxer.AddStream(muxError.streamID, func() {}) + } + + var wg sync.WaitGroup + for _, muxError := range testRangefeedCompletionErrors { + wg.Add(1) + go func(streamID int64, rangeID roachpb.RangeID, err error) { + defer wg.Done() + muxer.DisconnectStreamWithError(streamID, rangeID, kvpb.NewError(err)) + }(muxError.streamID, muxError.rangeID, muxError.Error) + } + wg.Wait() + + for _, muxError := range testRangefeedCompletionErrors { + testutils.SucceedsSoon(t, func() error { + ev := &kvpb.MuxRangeFeedEvent{ + StreamID: muxError.streamID, + RangeID: muxError.rangeID, + } + ev.MustSetValue(&kvpb.RangeFeedError{ + Error: *kvpb.NewError(muxError.Error), + }) + if testServerStream.hasEvent(ev) { + return nil + } + return errors.Newf("expected error %v not found", muxError) + }) + } + }) +} + +// TestStreamMuxerOnBlockingIO tests that the +// StreamMuxer.DisconnectStreamWithError doesn't block on IO. +func TestStreamMuxerOnBlockingIO(t *testing.T) { + defer leaktest.AfterTest(t)() + defer log.Scope(t).Close(t) + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + stopper := stop.NewStopper() + defer stopper.Stop(ctx) + + testServerStream := newTestServerStream() + muxer := NewStreamMuxer(testServerStream) + require.NoError(t, muxer.Start(ctx, stopper)) + defer muxer.Stop() + + const streamID = 0 + const rangeID = 1 + streamCtx, streamCancel := context.WithCancel(context.Background()) + muxer.AddStream(0, streamCancel) + + ev := &kvpb.MuxRangeFeedEvent{ + StreamID: streamID, + RangeID: rangeID, + } + ev.MustSetValue(&kvpb.RangeFeedCheckpoint{ + Span: roachpb.Span{Key: roachpb.Key("a"), EndKey: roachpb.Key("m")}, + ResolvedTS: hlc.Timestamp{WallTime: 1}, + }) + require.NoError(t, muxer.sender.Send(ev)) + require.Truef(t, testServerStream.hasEvent(ev), + "expected event %v not found in %v", ev, testServerStream) + + // Block the stream. + unblock := testServerStream.BlockSend() + + // Although stream is blocked, we should be able to disconnect the stream + // without blocking. + muxer.DisconnectStreamWithError(streamID, rangeID, + kvpb.NewError(kvpb.NewRangeFeedRetryError(kvpb.RangeFeedRetryError_REASON_NO_LEASEHOLDER))) + require.Equal(t, streamCtx.Err(), context.Canceled) + unblock() + time.Sleep(100 * time.Millisecond) + expectedErrEvent := &kvpb.MuxRangeFeedEvent{ + StreamID: streamID, + RangeID: rangeID, + } + expectedErrEvent.MustSetValue(&kvpb.RangeFeedError{ + Error: *kvpb.NewError(kvpb.NewRangeFeedRetryError(kvpb.RangeFeedRetryError_REASON_NO_LEASEHOLDER)), + }) + // Receive the event after non-blocking. + require.Truef(t, testServerStream.hasEvent(expectedErrEvent), + "expected event %v not found in %v", ev, testServerStream) +} diff --git a/pkg/kv/kvserver/rangefeed/stream_muxer_test_helper.go b/pkg/kv/kvserver/rangefeed/stream_muxer_test_helper.go new file mode 100644 index 000000000000..80564f6f56ee --- /dev/null +++ b/pkg/kv/kvserver/rangefeed/stream_muxer_test_helper.go @@ -0,0 +1,89 @@ +// Copyright 2024 The Cockroach Authors. +// +// Use of this software is governed by the Business Source License +// included in the file licenses/BSL.txt. +// +// As of the Change Date specified in that file, in accordance with +// the Business Source License, use of this software will be governed +// by the Apache License, Version 2.0, included in the file +// licenses/APL.txt. + +package rangefeed + +import ( + "fmt" + "reflect" + "strings" + "sync" + + "github.com/cockroachdb/cockroach/pkg/kv/kvpb" + "github.com/cockroachdb/cockroach/pkg/util/syncutil" +) + +// testServerStream mocks grpc server stream for testing. +type testServerStream struct { + syncutil.Mutex + // eventsSent is the total number of events sent. + eventsSent int + // streamEvents is a map of streamID to a list of events sent to that stream. + streamEvents map[int64][]*kvpb.MuxRangeFeedEvent +} + +func newTestServerStream() *testServerStream { + return &testServerStream{ + streamEvents: make(map[int64][]*kvpb.MuxRangeFeedEvent), + } +} + +func (s *testServerStream) totalEventsSent() int { + s.Lock() + defer s.Unlock() + return s.eventsSent +} + +// hasEvent returns true if the event is found in the streamEvents map. Note +// that it does a deep equal comparison. +func (s *testServerStream) hasEvent(e *kvpb.MuxRangeFeedEvent) bool { + if e == nil { + return false + } + s.Lock() + defer s.Unlock() + for _, streamEvent := range s.streamEvents[e.StreamID] { + if reflect.DeepEqual(e, streamEvent) { + return true + } + } + return false +} + +// String returns a string representation of the events sent in the stream. +func (s *testServerStream) String() string { + var str strings.Builder + for streamID, eventList := range s.streamEvents { + fmt.Fprintf(&str, "StreamID:%d, Len:%d\n", streamID, len(eventList)) + } + return str.String() +} + +func (s *testServerStream) SendIsThreadSafe() {} + +// Send mocks grpc.ServerStream Send method. It only counts events and stores +// events by streamID in streamEvents. +func (s *testServerStream) Send(e *kvpb.MuxRangeFeedEvent) error { + s.Lock() + defer s.Unlock() + s.eventsSent++ + s.streamEvents[e.StreamID] = append(s.streamEvents[e.StreamID], e) + return nil +} + +// BlockSend blocks any subsequent Send methods until the unblock callback is +// called. +func (s *testServerStream) BlockSend() (unblock func()) { + s.Lock() + var once sync.Once + return func() { + once.Do(s.Unlock) //nolint:deferunlockcheck + } +} diff --git a/pkg/server/node.go b/pkg/server/node.go index a71cae3384fe..b1b05a9a7bc6 100644 --- a/pkg/server/node.go +++ b/pkg/server/node.go @@ -17,7 +17,6 @@ import ( "net" "sort" "strings" - "sync" "sync/atomic" "time" @@ -1882,8 +1881,6 @@ func (n *Node) MuxRangeFeed(stream kvpb.Internal_MuxRangeFeedServer) error { n.metrics.ActiveMuxRangeFeed.Inc(1) defer n.metrics.ActiveMuxRangeFeed.Inc(-1) - var activeStreams sync.Map - for { select { case err := <-streamMuxer.Error(): @@ -1899,17 +1896,10 @@ func (n *Node) MuxRangeFeed(stream kvpb.Internal_MuxRangeFeedServer) error { } if req.CloseStream { - // Client issued a request to close previously established stream. - if v, loaded := activeStreams.LoadAndDelete(req.StreamID); loaded { - s := v.(*setRangeIDEventSink) - s.cancel() - } else { - // This is a bit strange, but it could happen if this stream completes - // just before we receive close request. So, just print out a warning. - if log.V(1) { - log.Infof(ctx, "closing unknown rangefeed stream ID %d", req.StreamID) - } - } + // Note that we will call disconnect again when future.Error returns, + // but DisconnectStreamWithError will ignore subsequent errors. + streamMuxer.DisconnectStreamWithError(req.StreamID, req.RangeID, + kvpb.NewError(kvpb.NewRangeFeedRetryError(kvpb.RangeFeedRetryError_REASON_RANGEFEED_CLOSED))) continue } @@ -1925,48 +1915,14 @@ func (n *Node) MuxRangeFeed(stream kvpb.Internal_MuxRangeFeedServer) error { streamID: req.StreamID, wrapped: muxStream, } - activeStreams.Store(req.StreamID, streamSink) + streamMuxer.AddStream(req.StreamID, cancel) n.metrics.NumMuxRangeFeed.Inc(1) n.metrics.ActiveMuxRangeFeed.Inc(1) f := n.stores.RangeFeed(req, streamSink) f.WhenReady(func(err error) { n.metrics.ActiveMuxRangeFeed.Inc(-1) - - _, loaded := activeStreams.LoadAndDelete(req.StreamID) - streamClosedByClient := !loaded - streamSink.cancel() - - if streamClosedByClient && streamSink.ctx.Err() != nil { - // If the stream was explicitly closed by the client, we expect to see - // context.Canceled error. In this case, return - // kvpb.RangeFeedRetryError_REASON_RANGEFEED_CLOSED to the client. - err = kvpb.NewRangeFeedRetryError(kvpb.RangeFeedRetryError_REASON_RANGEFEED_CLOSED) - } - - if err == nil { - cause := kvpb.RangeFeedRetryError_REASON_RANGEFEED_CLOSED - err = kvpb.NewRangeFeedRetryError(cause) - } - - e := &kvpb.MuxRangeFeedEvent{ - RangeID: req.RangeID, - StreamID: req.StreamID, - } - - e.SetValue(&kvpb.RangeFeedError{ - Error: *kvpb.NewError(err), - }) - - // When rangefeed completes, we must notify the client about that. - // - // NB: even though calling sink.Send() to send notification might seem - // correct, it is also unsafe. This future may be completed at any point, - // including during critical section when some important lock (such as - // raftMu in processor) may be held. Issuing potentially blocking IO - // during that time is not a good idea. Thus, we shunt the notification to - // a dedicated goroutine. - streamMuxer.AppendMuxError(e) + streamMuxer.DisconnectStreamWithError(req.StreamID, req.RangeID, kvpb.NewError(err)) }) } }