From 598c6f71133c631349c1caae29681ca9e79da0b6 Mon Sep 17 00:00:00 2001
From: Andrew Werner <awerner32@gmail.com>
Date: Tue, 24 Aug 2021 07:26:16 -0400
Subject: [PATCH] kvclient/rangefeed: emit checkpoint events

Grafted from #69269. This seems like a useful primitive for users of
this library. We intend to use it in #69661 and #69614.

Release note: None

Co-authored-by: irfan sharif <irfanmahmoudsharif@gmail.com>
---
 pkg/kv/kvclient/rangefeed/config.go           | 13 +++
 pkg/kv/kvclient/rangefeed/rangefeed.go        |  3 +
 .../rangefeed/rangefeed_external_test.go      | 99 +++++++++++++++++--
 3 files changed, 105 insertions(+), 10 deletions(-)

diff --git a/pkg/kv/kvclient/rangefeed/config.go b/pkg/kv/kvclient/rangefeed/config.go
index 73ea4d522ff3..7a391aef076f 100644
--- a/pkg/kv/kvclient/rangefeed/config.go
+++ b/pkg/kv/kvclient/rangefeed/config.go
@@ -13,6 +13,7 @@ package rangefeed
 import (
 	"context"
 
+	"github.com/cockroachdb/cockroach/pkg/roachpb"
 	"github.com/cockroachdb/cockroach/pkg/util/retry"
 )
 
@@ -27,6 +28,7 @@ type config struct {
 	withInitialScan    bool
 	withDiff           bool
 	onInitialScanError OnInitialScanError
+	onCheckpoint       OnCheckpoint
 }
 
 type optionFunc func(*config)
@@ -80,6 +82,17 @@ func WithRetry(options retry.Options) Option {
 	})
 }
 
+// OnCheckpoint is called when a rangefeed checkpoint occurs.
+type OnCheckpoint func(ctx context.Context, checkpoint *roachpb.RangeFeedCheckpoint)
+
+// WithOnCheckpoint sets up a callback that's invoked whenever a check point
+// event is emitted.
+func WithOnCheckpoint(f OnCheckpoint) Option {
+	return optionFunc(func(c *config) {
+		c.onCheckpoint = f
+	})
+}
+
 func initConfig(c *config, options []Option) {
 	*c = config{} // the default config is its zero value
 	for _, o := range options {
diff --git a/pkg/kv/kvclient/rangefeed/rangefeed.go b/pkg/kv/kvclient/rangefeed/rangefeed.go
index aff22fe5135e..c98fd6c17bd5 100644
--- a/pkg/kv/kvclient/rangefeed/rangefeed.go
+++ b/pkg/kv/kvclient/rangefeed/rangefeed.go
@@ -321,6 +321,9 @@ func (f *RangeFeed) processEvents(
 				if _, err := frontier.Forward(ev.Checkpoint.Span, ev.Checkpoint.ResolvedTS); err != nil {
 					return err
 				}
+				if f.onCheckpoint != nil {
+					f.onCheckpoint(ctx, ev.Checkpoint)
+				}
 			case ev.Error != nil:
 				// Intentionally do nothing, we'll get an error returned from the
 				// call to RangeFeed.
diff --git a/pkg/kv/kvclient/rangefeed/rangefeed_external_test.go b/pkg/kv/kvclient/rangefeed/rangefeed_external_test.go
index 9b130b8374d8..bb4ee7feeaec 100644
--- a/pkg/kv/kvclient/rangefeed/rangefeed_external_test.go
+++ b/pkg/kv/kvclient/rangefeed/rangefeed_external_test.go
@@ -12,11 +12,13 @@ package rangefeed_test
 
 import (
 	"context"
+	"errors"
 	"testing"
 
 	"github.com/cockroachdb/cockroach/pkg/base"
 	"github.com/cockroachdb/cockroach/pkg/kv/kvclient/rangefeed"
 	"github.com/cockroachdb/cockroach/pkg/roachpb"
+	"github.com/cockroachdb/cockroach/pkg/testutils"
 	"github.com/cockroachdb/cockroach/pkg/testutils/testcluster"
 	"github.com/cockroachdb/cockroach/pkg/util/encoding"
 	"github.com/cockroachdb/cockroach/pkg/util/leaktest"
@@ -64,16 +66,18 @@ func TestRangeFeedIntegration(t *testing.T) {
 	require.NoError(t, err)
 	rows := make(chan *roachpb.RangeFeedValue)
 	initialScanDone := make(chan struct{})
-	r, err := f.RangeFeed(ctx, "test", sp, afterB, func(
-		ctx context.Context, value *roachpb.RangeFeedValue,
-	) {
-		select {
-		case rows <- value:
-		case <-ctx.Done():
-		}
-	}, rangefeed.WithDiff(), rangefeed.WithInitialScan(func(ctx context.Context) {
-		close(initialScanDone)
-	}))
+	r, err := f.RangeFeed(ctx, "test", sp, afterB,
+		func(ctx context.Context, value *roachpb.RangeFeedValue) {
+			select {
+			case rows <- value:
+			case <-ctx.Done():
+			}
+		},
+		rangefeed.WithDiff(),
+		rangefeed.WithInitialScan(func(ctx context.Context) {
+			close(initialScanDone)
+		}),
+	)
 	require.NoError(t, err)
 	defer r.Close()
 	{
@@ -106,3 +110,78 @@ func TestRangeFeedIntegration(t *testing.T) {
 		require.Equal(t, int64(4), updated)
 	}
 }
+
+// TestWithOnCheckpoint verifies that we correctly emit rangefeed checkpoint
+// events.
+func TestWithOnCheckpoint(t *testing.T) {
+	defer leaktest.AfterTest(t)()
+
+	ctx := context.Background()
+	tc := testcluster.StartTestCluster(t, 1, base.TestClusterArgs{})
+	defer tc.Stopper().Stop(ctx)
+
+	db := tc.Server(0).DB()
+	scratchKey := tc.ScratchRange(t)
+	scratchKey = scratchKey[:len(scratchKey):len(scratchKey)]
+	mkKey := func(k string) roachpb.Key {
+		return encoding.EncodeStringAscending(scratchKey, k)
+	}
+
+	sp := roachpb.Span{
+		Key:    scratchKey,
+		EndKey: scratchKey.PrefixEnd(),
+	}
+	{
+		// Enable rangefeeds, otherwise the thing will retry until they are enabled.
+		_, err := tc.ServerConn(0).Exec("SET CLUSTER SETTING kv.rangefeed.enabled = true")
+		require.NoError(t, err)
+	}
+	{
+		// Lower the closed timestamp target duration to speed up the test.
+		_, err := tc.ServerConn(0).Exec("SET CLUSTER SETTING kv.closed_timestamp.target_duration = '100ms'")
+		require.NoError(t, err)
+	}
+
+	f, err := rangefeed.NewFactory(tc.Stopper(), db, nil)
+	require.NoError(t, err)
+	rows := make(chan *roachpb.RangeFeedValue)
+	checkpoints := make(chan *roachpb.RangeFeedCheckpoint)
+	r, err := f.RangeFeed(ctx, "test", sp, db.Clock().Now(),
+		func(ctx context.Context, value *roachpb.RangeFeedValue) {
+			select {
+			case rows <- value:
+			case <-ctx.Done():
+			}
+		},
+		rangefeed.WithOnCheckpoint(func(ctx context.Context, checkpoint *roachpb.RangeFeedCheckpoint) {
+			select {
+			case checkpoints <- checkpoint:
+			case <-ctx.Done():
+			}
+		}),
+	)
+	require.NoError(t, err)
+	defer r.Close()
+
+	require.NoError(t, db.Put(ctx, mkKey("a"), 1))
+	afterA := db.Clock().Now()
+	{
+		v := <-rows
+		require.Equal(t, mkKey("a"), v.Key)
+	}
+
+	// We should expect a checkpoint event covering the key we just wrote, at a
+	// timestamp greater than when we wrote it.
+	testutils.SucceedsSoon(t, func() error {
+		for {
+			select {
+			case c := <-checkpoints:
+				if afterA.LessEq(c.ResolvedTS) && c.Span.ContainsKey(mkKey("a")) {
+					return nil
+				}
+			default:
+				return errors.New("no valid checkpoints found")
+			}
+		}
+	})
+}