diff --git a/pkg/kv/kvserver/client_merge_test.go b/pkg/kv/kvserver/client_merge_test.go index 09113542db42..1de3d259e88a 100644 --- a/pkg/kv/kvserver/client_merge_test.go +++ b/pkg/kv/kvserver/client_merge_test.go @@ -4813,7 +4813,7 @@ func TestMergeQueueWithSlowNonVoterSnaps(t *testing.T) { 1: { Knobs: base.TestingKnobs{ Store: &kvserver.StoreTestingKnobs{ - ReceiveSnapshot: func(header *kvserverpb.SnapshotRequest_Header) error { + ReceiveSnapshot: func(_ context.Context, header *kvserverpb.SnapshotRequest_Header) error { val := delaySnapshotTrap.Load() if val != nil { fn := val.(func() error) diff --git a/pkg/kv/kvserver/client_migration_test.go b/pkg/kv/kvserver/client_migration_test.go index 340c9bbace8f..f1f60ad42d13 100644 --- a/pkg/kv/kvserver/client_migration_test.go +++ b/pkg/kv/kvserver/client_migration_test.go @@ -139,7 +139,7 @@ func TestMigrateWithInflightSnapshot(t *testing.T) { blockSnapshotsCh := make(chan struct{}) knobs, ltk := makeReplicationTestKnobs() ltk.storeKnobs.DisableRaftSnapshotQueue = true // we'll control it ourselves - ltk.storeKnobs.ReceiveSnapshot = func(h *kvserverpb.SnapshotRequest_Header) error { + ltk.storeKnobs.ReceiveSnapshot = func(_ context.Context, h *kvserverpb.SnapshotRequest_Header) error { // We'll want a signal for when the snapshot was received by the sender. once.Do(func() { close(blockUntilSnapshotCh) }) diff --git a/pkg/kv/kvserver/client_raft_test.go b/pkg/kv/kvserver/client_raft_test.go index 23b6b49d4b84..0e99efc8c186 100644 --- a/pkg/kv/kvserver/client_raft_test.go +++ b/pkg/kv/kvserver/client_raft_test.go @@ -17,6 +17,7 @@ import ( "math" "math/rand" "reflect" + "regexp" "strconv" "strings" "sync" @@ -58,10 +59,13 @@ import ( "github.com/cockroachdb/cockroach/pkg/util/humanizeutil" "github.com/cockroachdb/cockroach/pkg/util/leaktest" "github.com/cockroachdb/cockroach/pkg/util/log" + "github.com/cockroachdb/cockroach/pkg/util/log/logpb" "github.com/cockroachdb/cockroach/pkg/util/protoutil" "github.com/cockroachdb/cockroach/pkg/util/randutil" "github.com/cockroachdb/cockroach/pkg/util/stop" "github.com/cockroachdb/cockroach/pkg/util/timeutil" + "github.com/cockroachdb/cockroach/pkg/util/tracing" + "github.com/cockroachdb/cockroach/pkg/util/tracing/tracingpb" "github.com/cockroachdb/cockroach/pkg/util/uuid" "github.com/cockroachdb/errors" "github.com/stretchr/testify/assert" @@ -1460,6 +1464,187 @@ func (c fakeSnapshotStream) Send(request *kvserverpb.SnapshotResponse) error { return nil } +type snapshotTestSignals struct { + // Receiver-side wait channels. + receiveErrCh chan error + batchReceiveReadyCh chan struct{} + + // Sender-side wait channels. + svrContextDone <-chan struct{} + receiveStartedCh chan struct{} + batchReceiveStartedCh chan struct{} + receiverDoneCh chan struct{} +} + +// TestReceiveSnapshotLogging tests that a snapshot receiver properly captures +// the collected tracing spans in the last response, or logs the span if the +// context is cancelled from the client side. +func TestReceiveSnapshotLogging(t *testing.T) { + defer leaktest.AfterTest(t)() + defer log.Scope(t).Close(t) + + const senderNodeIdx = 0 + const receiverNodeIdx = 1 + const dummyEventMsg = "test receive snapshot logging - dummy event" + + setupTest := func(t *testing.T) (context.Context, *testcluster.TestCluster, *roachpb.RangeDescriptor, *snapshotTestSignals) { + ctx := context.Background() + + signals := &snapshotTestSignals{ + receiveErrCh: make(chan error), + batchReceiveReadyCh: make(chan struct{}), + + svrContextDone: nil, + receiveStartedCh: make(chan struct{}), + batchReceiveStartedCh: make(chan struct{}), + receiverDoneCh: make(chan struct{}, 1), + } + + tc := testcluster.StartTestCluster(t, 3, base.TestClusterArgs{ + ServerArgs: base.TestServerArgs{ + Knobs: base.TestingKnobs{ + Store: &kvserver.StoreTestingKnobs{ + DisableRaftSnapshotQueue: true, + }, + }, + }, + ReplicationMode: base.ReplicationManual, + ServerArgsPerNode: map[int]base.TestServerArgs{ + receiverNodeIdx: { + Knobs: base.TestingKnobs{ + Store: &kvserver.StoreTestingKnobs{ + DisableRaftSnapshotQueue: true, + ThrottleEmptySnapshots: true, + ReceiveSnapshot: func(ctx context.Context, _ *kvserverpb.SnapshotRequest_Header) error { + t.Logf("incoming snapshot on n2") + log.Event(ctx, dummyEventMsg) + signals.svrContextDone = ctx.Done() + close(signals.receiveStartedCh) + return <-signals.receiveErrCh + }, + BeforeRecvAcceptedSnapshot: func() { + t.Logf("receiving on n2") + signals.batchReceiveStartedCh <- struct{}{} + <-signals.batchReceiveReadyCh + }, + HandleSnapshotDone: func() { + t.Logf("receiver on n2 completed") + signals.receiverDoneCh <- struct{}{} + }, + }, + }, + }, + }, + }) + + _, scratchRange, err := tc.Servers[0].ScratchRangeEx() + require.NoError(t, err) + + return ctx, tc, &scratchRange, signals + } + + snapshotAndValidateLogs := func(t *testing.T, ctx context.Context, tc *testcluster.TestCluster, rngDesc *roachpb.RangeDescriptor, signals *snapshotTestSignals, expectTraceOnSender bool) error { + t.Helper() + + repl := tc.GetFirstStoreFromServer(t, senderNodeIdx).LookupReplica(rngDesc.StartKey) + chgs := kvpb.MakeReplicationChanges(roachpb.ADD_VOTER, tc.Target(receiverNodeIdx)) + + testStartTs := timeutil.Now() + _, pErr := repl.ChangeReplicas(ctx, rngDesc, kvserverpb.SnapshotRequest_REBALANCE, kvserverpb.ReasonRangeUnderReplicated, "", chgs) + + // When ready, flush logs and check messages from store_raft.go since + // call to repl.ChangeReplicas(..). + <-signals.receiverDoneCh + log.Flush() + entries, err := log.FetchEntriesFromFiles(testStartTs.UnixNano(), + math.MaxInt64, 100, regexp.MustCompile(`store_raft\.go`), log.WithMarkedSensitiveData) + require.NoError(t, err) + + errRegexp, err := regexp.Compile(`incoming snapshot stream failed with error`) + require.NoError(t, err) + foundEntry := false + var entry logpb.Entry + for _, entry = range entries { + if errRegexp.MatchString(entry.Message) { + foundEntry = true + break + } + } + expectTraceOnReceiver := !expectTraceOnSender + require.Equal(t, expectTraceOnReceiver, foundEntry) + if expectTraceOnReceiver { + require.Contains(t, entry.Message, dummyEventMsg) + } + + // Check that receiver traces were imported in sender's context on success. + clientTraces := tracing.SpanFromContext(ctx).GetConfiguredRecording() + _, receiverTraceFound := clientTraces.FindLogMessage(dummyEventMsg) + require.Equal(t, expectTraceOnSender, receiverTraceFound) + + return pErr + } + + t.Run("cancel on header", func(t *testing.T) { + ctx, tc, scratchRange, signals := setupTest(t) + defer tc.Stopper().Stop(ctx) + + ctx, sp := tracing.EnsureChildSpan(ctx, tc.GetFirstStoreFromServer(t, senderNodeIdx).GetStoreConfig().Tracer(), + t.Name(), tracing.WithRecording(tracingpb.RecordingVerbose)) + defer sp.Finish() + + ctx, cancel := context.WithCancel(ctx) + go func() { + <-signals.receiveStartedCh + cancel() + <-signals.svrContextDone + time.Sleep(10 * time.Millisecond) + signals.receiveErrCh <- errors.Errorf("header is bad") + }() + err := snapshotAndValidateLogs(t, ctx, tc, scratchRange, signals, false /* expectTraceOnSender */) + require.Error(t, err) + }) + t.Run("cancel during receive", func(t *testing.T) { + ctx, tc, scratchRange, signals := setupTest(t) + defer tc.Stopper().Stop(ctx) + + ctx, sp := tracing.EnsureChildSpan(ctx, tc.GetFirstStoreFromServer(t, senderNodeIdx).GetStoreConfig().Tracer(), + t.Name(), tracing.WithRecording(tracingpb.RecordingVerbose)) + defer sp.Finish() + + ctx, cancel := context.WithCancel(ctx) + close(signals.receiveErrCh) + go func() { + <-signals.receiveStartedCh + <-signals.batchReceiveStartedCh + cancel() + <-signals.svrContextDone + time.Sleep(10 * time.Millisecond) + close(signals.batchReceiveReadyCh) + }() + err := snapshotAndValidateLogs(t, ctx, tc, scratchRange, signals, false /* expectTraceOnSender */) + require.Error(t, err) + }) + t.Run("successful send", func(t *testing.T) { + ctx, tc, scratchRange, signals := setupTest(t) + defer tc.Stopper().Stop(ctx) + + ctx, sp := tracing.EnsureChildSpan(ctx, tc.GetFirstStoreFromServer(t, senderNodeIdx).GetStoreConfig().Tracer(), + t.Name(), tracing.WithRecording(tracingpb.RecordingVerbose)) + defer sp.Finish() + + ctx, cancel := context.WithCancel(ctx) + defer cancel() + close(signals.receiveErrCh) + close(signals.batchReceiveReadyCh) + go func() { + <-signals.receiveStartedCh + <-signals.batchReceiveStartedCh + }() + err := snapshotAndValidateLogs(t, ctx, tc, scratchRange, signals, true /* expectTraceOnSender */) + require.NoError(t, err) + }) +} + // TestFailedSnapshotFillsReservation tests that failing to finish applying an // incoming snapshot still cleans up the outstanding reservation that was made. func TestFailedSnapshotFillsReservation(t *testing.T) { diff --git a/pkg/kv/kvserver/replica_learner_test.go b/pkg/kv/kvserver/replica_learner_test.go index 401809b191ec..9cccd955400b 100644 --- a/pkg/kv/kvserver/replica_learner_test.go +++ b/pkg/kv/kvserver/replica_learner_test.go @@ -142,7 +142,7 @@ func TestAddReplicaViaLearner(t *testing.T) { var receivedSnap int64 blockSnapshotsCh := make(chan struct{}) knobs, ltk := makeReplicationTestKnobs() - ltk.storeKnobs.ReceiveSnapshot = func(h *kvserverpb.SnapshotRequest_Header) error { + ltk.storeKnobs.ReceiveSnapshot = func(_ context.Context, h *kvserverpb.SnapshotRequest_Header) error { if atomic.CompareAndSwapInt64(&receivedSnap, 0, 1) { close(blockUntilSnapshotCh) } else { @@ -238,7 +238,7 @@ func TestAddReplicaWithReceiverThrottling(t *testing.T) { activateBlocking := int64(1) var count int64 knobs, ltk := makeReplicationTestKnobs() - ltk.storeKnobs.ReceiveSnapshot = func(h *kvserverpb.SnapshotRequest_Header) error { + ltk.storeKnobs.ReceiveSnapshot = func(_ context.Context, h *kvserverpb.SnapshotRequest_Header) error { if atomic.LoadInt64(&activateBlocking) > 0 { // Signal waitForRebalanceToBlockCh to indicate the testing knob was hit. close(waitForRebalanceToBlockCh) @@ -250,7 +250,7 @@ func TestAddReplicaWithReceiverThrottling(t *testing.T) { ltk.storeKnobs.BeforeSendSnapshotThrottle = func() { atomic.AddInt64(&count, 1) } - ltk.storeKnobs.AfterSendSnapshotThrottle = func() { + ltk.storeKnobs.AfterSnapshotThrottle = func() { atomic.AddInt64(&count, -1) } ctx := context.Background() @@ -492,7 +492,7 @@ func TestDelegateSnapshotFails(t *testing.T) { } setupFn := func(t *testing.T, - receiveFunc func(*kvserverpb.SnapshotRequest_Header) error, + receiveFunc func(context.Context, *kvserverpb.SnapshotRequest_Header) error, sendFunc func(*kvserverpb.DelegateSendSnapshotRequest), processRaft func(roachpb.StoreID) bool, ) ( @@ -618,7 +618,7 @@ func TestDelegateSnapshotFails(t *testing.T) { var block atomic.Int32 tc, scratchKey := setupFn( t, - func(h *kvserverpb.SnapshotRequest_Header) error { + func(_ context.Context, h *kvserverpb.SnapshotRequest_Header) error { // TODO(abaptist): Remove this check once #96841 is fixed. if h.SenderQueueName == kvserverpb.SnapshotRequest_RAFT_SNAPSHOT_QUEUE { return nil @@ -862,7 +862,7 @@ func TestLearnerSnapshotFailsRollback(t *testing.T) { runTest := func(t *testing.T, replicaType roachpb.ReplicaType) { var rejectSnapshotErr atomic.Value // error knobs, ltk := makeReplicationTestKnobs() - ltk.storeKnobs.ReceiveSnapshot = func(h *kvserverpb.SnapshotRequest_Header) error { + ltk.storeKnobs.ReceiveSnapshot = func(_ context.Context, h *kvserverpb.SnapshotRequest_Header) error { if err := rejectSnapshotErr.Load().(error); err != nil { return err } @@ -1374,7 +1374,7 @@ func TestRaftSnapshotQueueSeesLearner(t *testing.T) { blockSnapshotsCh := make(chan struct{}) knobs, ltk := makeReplicationTestKnobs() ltk.storeKnobs.DisableRaftSnapshotQueue = true - ltk.storeKnobs.ReceiveSnapshot = func(h *kvserverpb.SnapshotRequest_Header) error { + ltk.storeKnobs.ReceiveSnapshot = func(_ context.Context, h *kvserverpb.SnapshotRequest_Header) error { select { case <-blockSnapshotsCh: case <-time.After(10 * time.Second): @@ -1438,7 +1438,7 @@ func TestLearnerAdminChangeReplicasRace(t *testing.T) { blockUntilSnapshotCh := make(chan struct{}, 2) blockSnapshotsCh := make(chan struct{}) knobs, ltk := makeReplicationTestKnobs() - ltk.storeKnobs.ReceiveSnapshot = func(h *kvserverpb.SnapshotRequest_Header) error { + ltk.storeKnobs.ReceiveSnapshot = func(_ context.Context, h *kvserverpb.SnapshotRequest_Header) error { blockUntilSnapshotCh <- struct{}{} <-blockSnapshotsCh return nil @@ -1991,7 +1991,7 @@ func TestMergeQueueDoesNotInterruptReplicationChange(t *testing.T) { // Disable load-based splitting, so that the absence of sufficient // QPS measurements do not prevent ranges from merging. DisableLoadBasedSplitting: true, - ReceiveSnapshot: func(_ *kvserverpb.SnapshotRequest_Header) error { + ReceiveSnapshot: func(_ context.Context, _ *kvserverpb.SnapshotRequest_Header) error { if atomic.LoadInt64(&activateSnapshotTestingKnob) == 1 { // While the snapshot RPC should only happen once given // that the cluster is running under manual replication, diff --git a/pkg/kv/kvserver/replicate_queue_test.go b/pkg/kv/kvserver/replicate_queue_test.go index 99c26cec4aac..32d6e3aafc87 100644 --- a/pkg/kv/kvserver/replicate_queue_test.go +++ b/pkg/kv/kvserver/replicate_queue_test.go @@ -840,7 +840,7 @@ func TestReplicateQueueTracingOnError(t *testing.T) { t, 4, base.TestClusterArgs{ ReplicationMode: base.ReplicationManual, ServerArgs: base.TestServerArgs{Knobs: base.TestingKnobs{Store: &kvserver.StoreTestingKnobs{ - ReceiveSnapshot: func(_ *kvserverpb.SnapshotRequest_Header) error { + ReceiveSnapshot: func(_ context.Context, _ *kvserverpb.SnapshotRequest_Header) error { if atomic.LoadInt64(&rejectSnapshots) == 1 { return errors.Newf("boom") } @@ -967,7 +967,7 @@ func TestReplicateQueueDecommissionPurgatoryError(t *testing.T) { t, 4, base.TestClusterArgs{ ReplicationMode: base.ReplicationManual, ServerArgs: base.TestServerArgs{Knobs: base.TestingKnobs{Store: &kvserver.StoreTestingKnobs{ - ReceiveSnapshot: func(_ *kvserverpb.SnapshotRequest_Header) error { + ReceiveSnapshot: func(_ context.Context, _ *kvserverpb.SnapshotRequest_Header) error { if atomic.LoadInt64(&rejectSnapshots) == 1 { return errors.Newf("boom") } diff --git a/pkg/kv/kvserver/store_raft.go b/pkg/kv/kvserver/store_raft.go index 72b8a666b458..82d2065730c3 100644 --- a/pkg/kv/kvserver/store_raft.go +++ b/pkg/kv/kvserver/store_raft.go @@ -203,12 +203,22 @@ func (s *Store) HandleDelegatedSnapshot( func (s *Store) HandleSnapshot( ctx context.Context, header *kvserverpb.SnapshotRequest_Header, stream SnapshotResponseStream, ) error { + if fn := s.cfg.TestingKnobs.HandleSnapshotDone; fn != nil { + defer fn() + } ctx = s.AnnotateCtx(ctx) const name = "storage.Store: handle snapshot" return s.stopper.RunTaskWithErr(ctx, name, func(ctx context.Context) error { s.metrics.RaftRcvdMessages[raftpb.MsgSnap].Inc(1) - return s.receiveSnapshot(ctx, header, stream) + err := s.receiveSnapshot(ctx, header, stream) + if err != nil && ctx.Err() != nil { + // Log trace of incoming snapshot on context cancellation (e.g. + // times out or caller goes away). + log.Infof(ctx, "incoming snapshot stream failed with error: %v\ntrace:\n%v", + err, tracing.SpanFromContext(ctx).GetConfiguredRecording()) + } + return err }) } diff --git a/pkg/kv/kvserver/store_snapshot.go b/pkg/kv/kvserver/store_snapshot.go index e028ade2f952..5f486f0124e6 100644 --- a/pkg/kv/kvserver/store_snapshot.go +++ b/pkg/kv/kvserver/store_snapshot.go @@ -93,43 +93,6 @@ type incomingSnapshotStream interface { Recv() (*kvserverpb.SnapshotRequest, error) } -// loggingIncomingSnapshotStream wraps the interface on a GRPC stream used -// to receive a snapshot over the network, with special handling for logging -// the current tracing span on context cancellation. -type loggingIncomingSnapshotStream struct { - stream incomingSnapshotStream -} - -func (l loggingIncomingSnapshotStream) Send( - ctx context.Context, resp *kvserverpb.SnapshotResponse, -) error { - err := l.stream.Send(resp) - if err != nil && ctx.Err() != nil { - // Log trace of incoming snapshot on context cancellation (e.g. - // times out or caller goes away). - if sp := tracing.SpanFromContext(ctx); sp != nil && !sp.IsNoop() { - log.Infof(ctx, "incoming snapshot stream response send failed with error: %s\ntrace:\n%s", - err, sp.GetConfiguredRecording()) - } - } - return err -} - -func (l loggingIncomingSnapshotStream) Recv( - ctx context.Context, -) (*kvserverpb.SnapshotRequest, error) { - req, err := l.stream.Recv() - if err != nil && ctx.Err() != nil { - // Log trace of incoming snapshot on context cancellation (e.g. - // times out or caller goes away). - if sp := tracing.SpanFromContext(ctx); sp != nil && !sp.IsNoop() { - log.Infof(ctx, "incoming snapshot stream request recv failed with error: %s\ntrace:\n%s", - err, sp.GetConfiguredRecording()) - } - } - return req, err -} - // outgoingSnapshotStream is the minimal interface on a GRPC stream required // to send a snapshot over the network. type outgoingSnapshotStream interface { @@ -152,7 +115,7 @@ type snapshotStrategy interface { Receive( context.Context, *Store, - loggingIncomingSnapshotStream, + incomingSnapshotStream, kvserverpb.SnapshotRequest_Header, snapshotRecordMetrics, ) (IncomingSnapshot, error) @@ -415,11 +378,17 @@ func (tag *snapshotTimingTag) Render() []attribute.KeyValue { func (kvSS *kvBatchSnapshotStrategy) Receive( ctx context.Context, s *Store, - loggingStream loggingIncomingSnapshotStream, + stream incomingSnapshotStream, header kvserverpb.SnapshotRequest_Header, recordBytesReceived snapshotRecordMetrics, ) (IncomingSnapshot, error) { assertStrategy(ctx, header, kvserverpb.SnapshotRequest_KV_BATCH) + if fn := s.cfg.TestingKnobs.BeforeRecvAcceptedSnapshot; fn != nil { + fn() + } + snapshotCtx := ctx + ctx, rSp := tracing.EnsureChildSpan(ctx, s.cfg.Tracer(), "receive snapshot data") + defer rSp.Finish() // Ensure that the tracing span is closed, even if Receive errors // These stopwatches allow us to time the various components of Receive(). // - totalTime Stopwatch measures the total time spent within this function. @@ -451,14 +420,14 @@ func (kvSS *kvBatchSnapshotStrategy) Receive( for { timingTag.start("recv") - req, err := loggingStream.Recv(ctx) + req, err := stream.Recv() timingTag.stop("recv") if err != nil { return noSnap, err } if req.Header != nil { err := errors.New("client error: provided a header mid-stream") - return noSnap, sendSnapshotError(ctx, s, loggingStream, err) + return noSnap, sendSnapshotError(snapshotCtx, s, stream, err) } if req.KVBatch != nil { @@ -524,7 +493,7 @@ func (kvSS *kvBatchSnapshotStrategy) Receive( snapUUID, err := uuid.FromBytes(header.RaftMessageRequest.Message.Snapshot.Data) if err != nil { err = errors.Wrap(err, "client error: invalid snapshot") - return noSnap, sendSnapshotError(ctx, s, loggingStream, err) + return noSnap, sendSnapshotError(snapshotCtx, s, stream, err) } inSnap := IncomingSnapshot{ @@ -823,7 +792,7 @@ func (s *Store) throttleSnapshot( select { case permit = <-task.GetWaitChan(): // Got a spot in the snapshotQueue, continue with sending the snapshot. - if fn := s.cfg.TestingKnobs.AfterSendSnapshotThrottle; fn != nil { + if fn := s.cfg.TestingKnobs.AfterSnapshotThrottle; fn != nil { fn() } log.Event(ctx, "acquired spot in the snapshot snapshotQueue") @@ -1032,7 +1001,6 @@ func (s *Store) getLocalityComparison( func (s *Store) receiveSnapshot( ctx context.Context, header *kvserverpb.SnapshotRequest_Header, stream incomingSnapshotStream, ) error { - loggingStream := loggingIncomingSnapshotStream{stream} // Draining nodes will generally not be rebalanced to (see the filtering that // happens in getStoreListFromIDsLocked()), but in case they are, they should // reject the incoming rebalancing snapshots. @@ -1047,7 +1015,7 @@ func (s *Store) receiveSnapshot( // getStoreListFromIDsLocked(). Is that sound? Don't we want to // upreplicate to draining nodes if there are no other candidates? case kvserverpb.SnapshotRequest_REBALANCE: - return sendSnapshotError(ctx, s, loggingStream, errors.New(storeDrainingMsg)) + return sendSnapshotError(ctx, s, stream, errors.New(storeDrainingMsg)) default: // If this a new snapshot type that this cockroach version does not know // about, we let it through. @@ -1055,10 +1023,10 @@ func (s *Store) receiveSnapshot( } if fn := s.cfg.TestingKnobs.ReceiveSnapshot; fn != nil { - if err := fn(header); err != nil { + if err := fn(ctx, header); err != nil { // NB: we intentionally don't mark this error as errMarkSnapshotError so // that we don't end up retrying injected errors in tests. - return sendSnapshotError(ctx, s, loggingStream, err) + return sendSnapshotError(ctx, s, stream, err) } } @@ -1098,7 +1066,7 @@ func (s *Store) receiveSnapshot( return nil }); pErr != nil { log.Infof(ctx, "cannot accept snapshot: %s", pErr) - return sendSnapshotError(ctx, s, loggingStream, pErr.GoError()) + return sendSnapshotError(ctx, s, stream, pErr.GoError()) } defer func() { @@ -1120,7 +1088,7 @@ func (s *Store) receiveSnapshot( snapUUID, err := uuid.FromBytes(header.RaftMessageRequest.Message.Snapshot.Data) if err != nil { err = errors.Wrap(err, "invalid snapshot") - return sendSnapshotError(ctx, s, loggingStream, err) + return sendSnapshotError(ctx, s, stream, err) } ss = &kvBatchSnapshotStrategy{ @@ -1130,13 +1098,13 @@ func (s *Store) receiveSnapshot( } defer ss.Close(ctx) default: - return sendSnapshotError(ctx, s, loggingStream, + return sendSnapshotError(ctx, s, stream, errors.Errorf("%s,r%d: unknown snapshot strategy: %s", s, header.State.Desc.RangeID, header.Strategy), ) } - if err := loggingStream.Send(ctx, &kvserverpb.SnapshotResponse{Status: kvserverpb.SnapshotResponse_ACCEPTED}); err != nil { + if err := stream.Send(&kvserverpb.SnapshotResponse{Status: kvserverpb.SnapshotResponse_ACCEPTED}); err != nil { return err } if log.V(2) { @@ -1161,9 +1129,7 @@ func (s *Store) receiveSnapshot( s.metrics.RangeSnapshotUnknownRcvdBytes.Inc(inc) } } - ctx, rSp := tracing.EnsureChildSpan(ctx, s.cfg.Tracer(), "receive snapshot data") - defer rSp.Finish() // Ensure that the tracing span is closed, even if ss.Receive errors - inSnap, err := ss.Receive(ctx, s, loggingStream, *header, recordBytesReceived) + inSnap, err := ss.Receive(ctx, s, stream, *header, recordBytesReceived) if err != nil { return err } @@ -1180,9 +1146,9 @@ func (s *Store) receiveSnapshot( // sender as this being a retriable error, see isSnapshotError(). err = errors.Mark(err, errMarkSnapshotError) err = errors.Wrap(err, "failed to apply snapshot") - return sendSnapshotError(ctx, s, loggingStream, err) + return sendSnapshotError(ctx, s, stream, err) } - return loggingStream.Send(ctx, &kvserverpb.SnapshotResponse{ + return stream.Send(&kvserverpb.SnapshotResponse{ Status: kvserverpb.SnapshotResponse_APPLIED, CollectedSpans: tracing.SpanFromContext(ctx).GetConfiguredRecording(), }) @@ -1192,13 +1158,13 @@ func (s *Store) receiveSnapshot( // to signify that it can not accept this snapshot. Internally it increments the // statistic tracking how many invalid snapshots it received. func sendSnapshotError( - ctx context.Context, s *Store, stream loggingIncomingSnapshotStream, err error, + ctx context.Context, s *Store, stream incomingSnapshotStream, err error, ) error { s.metrics.RangeSnapshotRecvFailed.Inc(1) resp := snapRespErr(err) resp.CollectedSpans = tracing.SpanFromContext(ctx).GetConfiguredRecording() - return stream.Send(ctx, resp) + return stream.Send(resp) } func snapRespErr(err error) *kvserverpb.SnapshotResponse { diff --git a/pkg/kv/kvserver/testing_knobs.go b/pkg/kv/kvserver/testing_knobs.go index d8462940073b..b1c25d3efcb9 100644 --- a/pkg/kv/kvserver/testing_knobs.go +++ b/pkg/kv/kvserver/testing_knobs.go @@ -298,7 +298,10 @@ type StoreTestingKnobs struct { // ReceiveSnapshot is run after receiving a snapshot header but before // acquiring snapshot quota or doing shouldAcceptSnapshotData checks. If an // error is returned from the hook, it's sent as an ERROR SnapshotResponse. - ReceiveSnapshot func(*kvserverpb.SnapshotRequest_Header) error + ReceiveSnapshot func(context.Context, *kvserverpb.SnapshotRequest_Header) error + // HandleSnapshotDone is run after the entirety of receiving a snapshot, + // regardless of whether it succeeds, gets cancelled, times out, or errors. + HandleSnapshotDone func() // ReplicaAddSkipLearnerRollback causes replica addition to skip the learner // rollback that happens when either the initial snapshot or the promotion of // a learner to a voter fails. @@ -428,9 +431,12 @@ type StoreTestingKnobs struct { // BeforeSendSnapshotThrottle intercepts replicas before entering send // snapshot throttling. BeforeSendSnapshotThrottle func() - // AfterSendSnapshotThrottle intercepts replicas after receiving a spot in the - // send snapshot semaphore. - AfterSendSnapshotThrottle func() + // AfterSnapshotThrottle intercepts replicas after receiving a spot in the + // send/recv snapshot semaphore. + AfterSnapshotThrottle func() + // BeforeRecvAcceptedSnapshot intercepts replicas before receiving the batches + // of a reserved and accepted snapshot. + BeforeRecvAcceptedSnapshot func() // SelectDelegateSnapshotSender returns an ordered list of replica which will // be used as delegates for sending a snapshot. SelectDelegateSnapshotSender func(*roachpb.RangeDescriptor) []roachpb.ReplicaDescriptor