From bd4aa8deef83668085931acd11540a775dc1433a Mon Sep 17 00:00:00 2001
From: Wenyi <wenyi.hu@cockroachlabs.com>
Date: Tue, 30 May 2023 15:42:44 -0400
Subject: [PATCH] kvserver: refactor getSnapshotBytesMetrics

This commit refactors `getSnapshotBytesMetrics` in `replica_learner_test`
to return a `map[string]snapshotBytesMetrics` instead of
`map[SnapShotRequest_Priority]snapshotBytesMetrics`. This allows us to include
and compare different types of snapshot metrics, removing the constraint of
being limited to `SnapShotRequest_Priority`. This commit does not change any
existing functionality, and the main purpose is to make future commits cleaner.

Part of: cockroachdb#104124
Release note: none
---
 pkg/kv/kvserver/replica_learner_test.go | 207 ++++++++++++------------
 1 file changed, 100 insertions(+), 107 deletions(-)

diff --git a/pkg/kv/kvserver/replica_learner_test.go b/pkg/kv/kvserver/replica_learner_test.go
index 80299b6344f6..c5a0570fa040 100644
--- a/pkg/kv/kvserver/replica_learner_test.go
+++ b/pkg/kv/kvserver/replica_learner_test.go
@@ -967,8 +967,8 @@ func testRaftSnapshotsToNonVoters(t *testing.T, drainReceivingNode bool) {
 	g, ctx := errgroup.WithContext(ctx)
 
 	// Record the snapshot metrics before anything has been sent / received.
-	senderTotalBefore, senderMetricsMapBefore := getSnapshotBytesMetrics(t, tc, 0 /* serverIdx */)
-	receiverTotalBefore, receiverMetricsMapBefore := getSnapshotBytesMetrics(t, tc, 1 /* serverIdx */)
+	senderMetricsMapBefore := getSnapshotBytesMetrics(t, tc, 0 /* serverIdx */)
+	receiverMetricsMapBefore := getSnapshotBytesMetrics(t, tc, 1 /* serverIdx */)
 
 	// Add a new voting replica, but don't initialize it. Note that
 	// `tc.AddNonVoters` will not return until the newly added non-voter is
@@ -1044,36 +1044,34 @@ func testRaftSnapshotsToNonVoters(t *testing.T, drainReceivingNode bool) {
 	require.NoError(t, g.Wait())
 
 	// Record the snapshot metrics for the sender after the raft snapshot was sent.
-	senderTotalAfter, senderMetricsMapAfter := getSnapshotBytesMetrics(t, tc, 0)
+	senderMetricsMapAfter := getSnapshotBytesMetrics(t, tc, 0)
 
 	// Asserts that the raft snapshot (aka recovery snapshot) bytes sent have been
 	// recorded and that it was not double counted in a different metric.
-	senderTotalDelta, senderMapDelta := getSnapshotMetricsDiff(senderTotalBefore, senderMetricsMapBefore, senderTotalAfter, senderMetricsMapAfter)
+	senderMapDelta := getSnapshotMetricsDiff(senderMetricsMapBefore, senderMetricsMapAfter)
 
-	senderTotalExpected := snapshotBytesMetrics{sentBytes: snapshotLength, rcvdBytes: 0}
-	senderMapExpected := map[kvserverpb.SnapshotRequest_Priority]snapshotBytesMetrics{
-		kvserverpb.SnapshotRequest_REBALANCE: {sentBytes: 0, rcvdBytes: 0},
-		kvserverpb.SnapshotRequest_RECOVERY:  {sentBytes: snapshotLength, rcvdBytes: 0},
-		kvserverpb.SnapshotRequest_UNKNOWN:   {sentBytes: 0, rcvdBytes: 0},
+	senderMapExpected := map[string]snapshotBytesMetrics{
+		".rebalancing": {sentBytes: 0, rcvdBytes: 0},
+		".recovery":    {sentBytes: snapshotLength, rcvdBytes: 0},
+		".unknown":     {sentBytes: 0, rcvdBytes: 0},
+		"":             {sentBytes: snapshotLength, rcvdBytes: 0},
 	}
-	require.Equal(t, senderTotalExpected, senderTotalDelta)
 	require.Equal(t, senderMapExpected, senderMapDelta)
 
 	// Record the snapshot metrics for the receiver after the raft snapshot was
 	// received.
-	receiverTotalAfter, receiverMetricsMapAfter := getSnapshotBytesMetrics(t, tc, 1)
+	receiverMetricsMapAfter := getSnapshotBytesMetrics(t, tc, 1)
 
 	// Asserts that the raft snapshot (aka recovery snapshot) bytes received have
 	// been recorded and that it was not double counted in a different metric.
-	receiverTotalDelta, receiverMapDelta := getSnapshotMetricsDiff(receiverTotalBefore, receiverMetricsMapBefore, receiverTotalAfter, receiverMetricsMapAfter)
+	receiverMapDelta := getSnapshotMetricsDiff(receiverMetricsMapBefore, receiverMetricsMapAfter)
 
-	receiverTotalExpected := snapshotBytesMetrics{sentBytes: 0, rcvdBytes: snapshotLength}
-	receiverMapExpected := map[kvserverpb.SnapshotRequest_Priority]snapshotBytesMetrics{
-		kvserverpb.SnapshotRequest_REBALANCE: {sentBytes: 0, rcvdBytes: 0},
-		kvserverpb.SnapshotRequest_RECOVERY:  {sentBytes: 0, rcvdBytes: snapshotLength},
-		kvserverpb.SnapshotRequest_UNKNOWN:   {sentBytes: 0, rcvdBytes: 0},
+	receiverMapExpected := map[string]snapshotBytesMetrics{
+		".rebalancing": {sentBytes: 0, rcvdBytes: 0},
+		".recovery":    {sentBytes: 0, rcvdBytes: snapshotLength},
+		".unknown":     {sentBytes: 0, rcvdBytes: 0},
+		"":             {sentBytes: 0, rcvdBytes: snapshotLength},
 	}
-	require.Equal(t, receiverTotalExpected, receiverTotalDelta)
 	require.Equal(t, receiverMapExpected, receiverMapDelta)
 }
 
@@ -2199,28 +2197,22 @@ type snapshotBytesMetrics struct {
 // and granularMetrics is the map mentioned above.
 func getSnapshotBytesMetrics(
 	t *testing.T, tc *testcluster.TestCluster, serverIdx int,
-) (snapshotBytesMetrics, map[kvserverpb.SnapshotRequest_Priority]snapshotBytesMetrics) {
-	granularMetrics := make(map[kvserverpb.SnapshotRequest_Priority]snapshotBytesMetrics)
-
-	granularMetrics[kvserverpb.SnapshotRequest_UNKNOWN] = snapshotBytesMetrics{
-		sentBytes: getFirstStoreMetric(t, tc.Server(serverIdx), "range.snapshots.unknown.sent-bytes"),
-		rcvdBytes: getFirstStoreMetric(t, tc.Server(serverIdx), "range.snapshots.unknown.rcvd-bytes"),
-	}
-	granularMetrics[kvserverpb.SnapshotRequest_RECOVERY] = snapshotBytesMetrics{
-		sentBytes: getFirstStoreMetric(t, tc.Server(serverIdx), "range.snapshots.recovery.sent-bytes"),
-		rcvdBytes: getFirstStoreMetric(t, tc.Server(serverIdx), "range.snapshots.recovery.rcvd-bytes"),
-	}
-	granularMetrics[kvserverpb.SnapshotRequest_REBALANCE] = snapshotBytesMetrics{
-		sentBytes: getFirstStoreMetric(t, tc.Server(serverIdx), "range.snapshots.rebalancing.sent-bytes"),
-		rcvdBytes: getFirstStoreMetric(t, tc.Server(serverIdx), "range.snapshots.rebalancing.rcvd-bytes"),
+) map[string]snapshotBytesMetrics {
+	metrics := make(map[string]snapshotBytesMetrics)
+
+	findSnapshotBytesMetrics := func(metricName string) snapshotBytesMetrics {
+		sentMetricStr := fmt.Sprintf("range.snapshots%v.sent-bytes", metricName)
+		rcvdMetricStr := fmt.Sprintf("range.snapshots%v.rcvd-bytes", metricName)
+		return snapshotBytesMetrics{
+			sentBytes: getFirstStoreMetric(t, tc.Server(serverIdx), sentMetricStr),
+			rcvdBytes: getFirstStoreMetric(t, tc.Server(serverIdx), rcvdMetricStr),
+		}
 	}
-
-	totalBytes := snapshotBytesMetrics{
-		sentBytes: getFirstStoreMetric(t, tc.Server(serverIdx), "range.snapshots.sent-bytes"),
-		rcvdBytes: getFirstStoreMetric(t, tc.Server(serverIdx), "range.snapshots.rcvd-bytes"),
+	types := [4]string{".unknown", ".recovery", ".rebalancing", ""}
+	for _, v := range types {
+		metrics[v] = findSnapshotBytesMetrics(v)
 	}
-
-	return totalBytes, granularMetrics
+	return metrics
 }
 
 // getSnapshotMetricsDiff returns the delta between snapshot byte metrics
@@ -2231,31 +2223,16 @@ func getSnapshotBytesMetrics(
 // sent/received, and granularMetrics is the map of snapshotBytesMetrics structs
 // containing deltas for each type of snapshot.
 func getSnapshotMetricsDiff(
-	beforeTotal snapshotBytesMetrics,
-	beforeMap map[kvserverpb.SnapshotRequest_Priority]snapshotBytesMetrics,
-	afterTotal snapshotBytesMetrics,
-	afterMap map[kvserverpb.SnapshotRequest_Priority]snapshotBytesMetrics,
-) (snapshotBytesMetrics, map[kvserverpb.SnapshotRequest_Priority]snapshotBytesMetrics) {
-	diffTotal := snapshotBytesMetrics{
-		sentBytes: afterTotal.sentBytes - beforeTotal.sentBytes,
-		rcvdBytes: afterTotal.rcvdBytes - beforeTotal.rcvdBytes,
-	}
-	diffMap := map[kvserverpb.SnapshotRequest_Priority]snapshotBytesMetrics{
-		kvserverpb.SnapshotRequest_REBALANCE: {
-			sentBytes: afterMap[kvserverpb.SnapshotRequest_REBALANCE].sentBytes - beforeMap[kvserverpb.SnapshotRequest_REBALANCE].sentBytes,
-			rcvdBytes: afterMap[kvserverpb.SnapshotRequest_REBALANCE].rcvdBytes - beforeMap[kvserverpb.SnapshotRequest_REBALANCE].rcvdBytes,
-		},
-		kvserverpb.SnapshotRequest_RECOVERY: {
-			sentBytes: afterMap[kvserverpb.SnapshotRequest_RECOVERY].sentBytes - beforeMap[kvserverpb.SnapshotRequest_RECOVERY].sentBytes,
-			rcvdBytes: afterMap[kvserverpb.SnapshotRequest_RECOVERY].rcvdBytes - beforeMap[kvserverpb.SnapshotRequest_RECOVERY].rcvdBytes,
-		},
-		kvserverpb.SnapshotRequest_UNKNOWN: {
-			sentBytes: afterMap[kvserverpb.SnapshotRequest_UNKNOWN].sentBytes - beforeMap[kvserverpb.SnapshotRequest_UNKNOWN].sentBytes,
-			rcvdBytes: afterMap[kvserverpb.SnapshotRequest_UNKNOWN].rcvdBytes - beforeMap[kvserverpb.SnapshotRequest_UNKNOWN].rcvdBytes,
-		},
+	beforeMap map[string]snapshotBytesMetrics, afterMap map[string]snapshotBytesMetrics,
+) map[string]snapshotBytesMetrics {
+	diffMap := make(map[string]snapshotBytesMetrics)
+	for metricName, beforeValue := range beforeMap {
+		diffMap[metricName] = snapshotBytesMetrics{
+			afterMap[metricName].sentBytes - beforeValue.sentBytes,
+			afterMap[metricName].rcvdBytes - beforeValue.rcvdBytes,
+		}
 	}
-
-	return diffTotal, diffMap
+	return diffMap
 }
 
 // This function returns the number of bytes sent for a snapshot. It follows the
@@ -2328,11 +2305,14 @@ func TestRebalancingSnapshotMetrics(t *testing.T) {
 	knobs, ltk := makeReplicationTestKnobs()
 	ltk.storeKnobs.DisableRaftSnapshotQueue = true
 
-	// Synchronize on the moment before the snapshot gets sent so we can measure
-	// the state at that time.
 	blockUntilSnapshotSendCh := make(chan struct{})
 	blockSnapshotSendCh := make(chan struct{})
 	ltk.storeKnobs.SendSnapshot = func(request *kvserverpb.DelegateSendSnapshotRequest) {
+		// This testing knob allows accurate calculation of expected snapshot bytes
+		// by unblocking the current goroutine when `HandleDelegatedSnapshot` is
+		// about to send the snapshot. In addition, it also blocks the new
+		// goroutine, which was created to send the snapshot, until the calculation
+		// is complete.
 		close(blockUntilSnapshotSendCh)
 		select {
 		case <-blockSnapshotSendCh:
@@ -2348,62 +2328,75 @@ func TestRebalancingSnapshotMetrics(t *testing.T) {
 	})
 	defer tc.Stopper().Stop(ctx)
 
-	scratchStartKey := tc.ScratchRange(t)
-
-	// Record the snapshot metrics before anything has been sent / received.
-	senderTotalBefore, senderMetricsMapBefore := getSnapshotBytesMetrics(t, tc, 0 /* serverIdx */)
-	receiverTotalBefore, receiverMetricsMapBefore := getSnapshotBytesMetrics(t, tc, 1 /* serverIdx */)
+	// sendSnapshotFromServer is a testing helper function that sends a learner
+	// snapshot from sever[0] to server[serverIndex]. It adds a replica of the
+	// given key range on server[serverIndex] as a voter, resulting in a learner
+	// snapshot being sent. The function returns the expected snapshot length and
+	// the updated range descriptor.
+	sendSnapshotToServer := func(key roachpb.Key, serverIndex int) (roachpb.RangeDescriptor, int64) {
+		blockUntilSnapshotSendCh = make(chan struct{})
+		blockSnapshotSendCh = make(chan struct{})
+		g := ctxgroup.WithContext(ctx)
+		rangeDesc := roachpb.RangeDescriptor{}
+		g.GoCtx(func(ctx context.Context) error {
+			// A new replica at servers[serverIndex] is now added to the cluster,
+			// resulting in a learner snapshot to be sent from servers[0] to
+			// servers[serverIndex]. This function is executed in a new goroutine to
+			// help us capture the expected snapshot bytes count accurately.
+			rangeDesc = tc.AddVotersOrFatal(t, key, tc.Target(serverIndex))
+			return nil
+		})
 
-	g := ctxgroup.WithContext(ctx)
-	g.GoCtx(func(ctx context.Context) error {
-		_, err := tc.AddVoters(scratchStartKey, tc.Target(1))
-		return err
-	})
+		// The current goroutine is blocked until the new goroutine, which has just
+		// been added, is about to send the snapshot (see the testing knob above).
+		// This allows us to calculate the snapshot bytes count accurately,
+		// accounting for any state changes that happen between calling
+		// AddVotersOrFatal and the snapshot being sent.
+		<-blockUntilSnapshotSendCh
+		store, repl := getFirstStoreReplica(t, tc.Server(0), key)
+		snapshotLength, err := getExpectedSnapshotSizeBytes(ctx, store, repl, kvserverpb.SnapshotRequest_INITIAL)
+		require.NoError(t, err)
 
-	// Wait until the snapshot is about to be sent before calculating what the
-	// snapshot size should be. This allows our snapshot measurement to account
-	// for any state changes that happen between calling AddVoters and the
-	// snapshot being sent.
-	<-blockUntilSnapshotSendCh
-	store, repl := getFirstStoreReplica(t, tc.Server(0), scratchStartKey)
-	snapshotLength, err := getExpectedSnapshotSizeBytes(ctx, store, repl, kvserverpb.SnapshotRequest_INITIAL)
-	require.NoError(t, err)
+		close(blockSnapshotSendCh)
+		// Wait the new goroutine (sending the snapshot) to complete before
+		// measuring the after-sending-snapshot metrics.
+		require.NoError(t, g.Wait())
+		return rangeDesc, snapshotLength
+	}
 
-	close(blockSnapshotSendCh)
-	require.NoError(t, g.Wait())
+	// Record the snapshot metrics before anything has been sent / received.
+	senderMetricsMapBefore := getSnapshotBytesMetrics(t, tc, 0 /* serverIdx */)
+	receiverMetricsMapBefore := getSnapshotBytesMetrics(t, tc, 1 /* serverIdx */)
+	scratchStartKey := tc.ScratchRange(t)
 
-	// Record the snapshot metrics for the sender after a voter has been added. A
-	// learner snapshot should have been sent from the sender to the receiver.
-	senderTotalAfter, senderMetricsMapAfter := getSnapshotBytesMetrics(t, tc, 0)
+	// A learner snapshot should have been sent from the sender to the receiver.
+	_, snapshotLength := sendSnapshotToServer(scratchStartKey, 1)
 
+	// Record the snapshot metrics for the sender after a voter has been added.
+	senderMetricsMapAfter := getSnapshotBytesMetrics(t, tc, 0)
 	// Asserts that the learner snapshot (aka rebalancing snapshot) bytes sent
-	// have been recorded and that it was not double counted in a different
+	// have been recorded, and that it was not double counted in a different
 	// metric.
-	senderTotalDelta, senderMapDelta := getSnapshotMetricsDiff(senderTotalBefore, senderMetricsMapBefore, senderTotalAfter, senderMetricsMapAfter)
-
-	senderTotalExpected := snapshotBytesMetrics{sentBytes: snapshotLength, rcvdBytes: 0}
-	senderMapExpected := map[kvserverpb.SnapshotRequest_Priority]snapshotBytesMetrics{
-		kvserverpb.SnapshotRequest_REBALANCE: {sentBytes: snapshotLength, rcvdBytes: 0},
-		kvserverpb.SnapshotRequest_RECOVERY:  {sentBytes: 0, rcvdBytes: 0},
-		kvserverpb.SnapshotRequest_UNKNOWN:   {sentBytes: 0, rcvdBytes: 0},
+	senderMapDelta := getSnapshotMetricsDiff(senderMetricsMapBefore, senderMetricsMapAfter)
+	senderMapExpected := map[string]snapshotBytesMetrics{
+		".rebalancing": {sentBytes: snapshotLength, rcvdBytes: 0},
+		".recovery":    {sentBytes: 0, rcvdBytes: 0},
+		".unknown":     {sentBytes: 0, rcvdBytes: 0},
+		"":             {sentBytes: snapshotLength, rcvdBytes: 0},
 	}
-	require.Equal(t, senderTotalExpected, senderTotalDelta)
 	require.Equal(t, senderMapExpected, senderMapDelta)
 
 	// Record the snapshot metrics for the receiver after a voter has been added.
-	receiverTotalAfter, receiverMetricsMapAfter := getSnapshotBytesMetrics(t, tc, 1)
-
+	receiverMetricsMapAfter := getSnapshotBytesMetrics(t, tc, 1)
 	// Asserts that the learner snapshot (aka rebalancing snapshot) bytes received
-	// have been recorded and that it was not double counted in a different
+	// have been recorded, and that it was not double counted in a different
 	// metric.
-	receiverTotalDelta, receiverMapDelta := getSnapshotMetricsDiff(receiverTotalBefore, receiverMetricsMapBefore, receiverTotalAfter, receiverMetricsMapAfter)
-
-	receiverTotalExpected := snapshotBytesMetrics{sentBytes: 0, rcvdBytes: snapshotLength}
-	receiverMapExpected := map[kvserverpb.SnapshotRequest_Priority]snapshotBytesMetrics{
-		kvserverpb.SnapshotRequest_REBALANCE: {sentBytes: 0, rcvdBytes: snapshotLength},
-		kvserverpb.SnapshotRequest_RECOVERY:  {sentBytes: 0, rcvdBytes: 0},
-		kvserverpb.SnapshotRequest_UNKNOWN:   {sentBytes: 0, rcvdBytes: 0},
+	receiverMapDelta := getSnapshotMetricsDiff(receiverMetricsMapBefore, receiverMetricsMapAfter)
+	receiverMapExpected := map[string]snapshotBytesMetrics{
+		".rebalancing": {sentBytes: 0, rcvdBytes: snapshotLength},
+		".recovery":    {sentBytes: 0, rcvdBytes: 0},
+		".unknown":     {sentBytes: 0, rcvdBytes: 0},
+		"":             {sentBytes: 0, rcvdBytes: snapshotLength},
 	}
-	require.Equal(t, receiverTotalExpected, receiverTotalDelta)
 	require.Equal(t, receiverMapExpected, receiverMapDelta)
 }