diff --git a/pkg/ccl/streamingccl/streamclient/cockroach_sinkless_replication_client.go b/pkg/ccl/streamingccl/streamclient/cockroach_sinkless_replication_client.go index 6f512cd09bf0..23bfb5c36c47 100644 --- a/pkg/ccl/streamingccl/streamclient/cockroach_sinkless_replication_client.go +++ b/pkg/ccl/streamingccl/streamclient/cockroach_sinkless_replication_client.go @@ -11,6 +11,7 @@ package streamclient import ( "context" gosql "database/sql" + "database/sql/driver" "fmt" "strconv" @@ -124,7 +125,15 @@ func (m *sinklessReplicationClient) ConsumePartition( } } if err := rows.Err(); err != nil { - errCh <- err + if errors.Is(err, driver.ErrBadConn) { + select { + case eventCh <- streamingccl.MakeGenerationEvent(): + case <-ctx.Done(): + errCh <- ctx.Err() + } + } else { + errCh <- err + } return } }() diff --git a/pkg/ccl/streamingccl/streamclient/cockroach_sinkless_replication_client_test.go b/pkg/ccl/streamingccl/streamclient/cockroach_sinkless_replication_client_test.go index 296128f4d84e..f46a90b28709 100644 --- a/pkg/ccl/streamingccl/streamclient/cockroach_sinkless_replication_client_test.go +++ b/pkg/ccl/streamingccl/streamclient/cockroach_sinkless_replication_client_test.go @@ -110,4 +110,17 @@ INSERT INTO d.t2 VALUES (2); feed.ObserveResolved(secondObserved.Value.Timestamp) cancelIngestion() }) + + t.Run("stream-address-disconnects", func(t *testing.T) { + clientCtx, cancelIngestion := context.WithCancel(ctx) + eventCh, errCh, err := client.ConsumePartition(clientCtx, pa, startTime) + require.NoError(t, err) + feedSource := &channelFeedSource{eventCh: eventCh, errCh: errCh} + feed := streamingtest.MakeReplicationFeed(t, feedSource) + + h.SysServer.Stopper().Stop(clientCtx) + + require.True(t, feed.ObserveGeneration()) + cancelIngestion() + }) } diff --git a/pkg/ccl/streamingccl/streamingest/stream_ingestion_processor.go b/pkg/ccl/streamingccl/streamingest/stream_ingestion_processor.go index b29fc2a4601e..419e46851415 100644 --- a/pkg/ccl/streamingccl/streamingest/stream_ingestion_processor.go +++ b/pkg/ccl/streamingccl/streamingest/stream_ingestion_processor.go @@ -31,6 +31,7 @@ import ( "github.com/cockroachdb/cockroach/pkg/util/hlc" "github.com/cockroachdb/cockroach/pkg/util/log" "github.com/cockroachdb/cockroach/pkg/util/protoutil" + "github.com/cockroachdb/cockroach/pkg/util/syncutil" "github.com/cockroachdb/cockroach/pkg/util/timeutil" "github.com/cockroachdb/errors" "golang.org/x/sync/errgroup" @@ -95,14 +96,6 @@ type streamIngestionProcessor struct { // and have attempted to flush them with `internalDrained`. internalDrained bool - // ingestionErr stores any error that is returned from the worker goroutine so - // that it can be forwarded through the DistSQL flow. - ingestionErr error - - // pollingErr stores any error that is returned from the poller checking for a - // cutover signal so that it can be forwarded through the DistSQL flow. - pollingErr error - // pollingWaitGroup registers the polling goroutine and waits for it to return // when the processor is being drained. pollingWaitGroup sync.WaitGroup @@ -117,6 +110,20 @@ type streamIngestionProcessor struct { // closePoller is used to shutdown the poller that checks the job for a // cutover signal. closePoller chan struct{} + + // mu is used to provide thread-safe read-write operations to ingestionErr + // and pollingErr. + mu struct { + syncutil.Mutex + + // ingestionErr stores any error that is returned from the worker goroutine so + // that it can be forwarded through the DistSQL flow. + ingestionErr error + + // pollingErr stores any error that is returned from the poller checking for a + // cutover signal so that it can be forwarded through the DistSQL flow. + pollingErr error + } } // partitionEvent augments a normal event with the partition it came from. @@ -190,7 +197,9 @@ func (sip *streamIngestionProcessor) Start(ctx context.Context) { defer sip.pollingWaitGroup.Done() err := sip.checkForCutoverSignal(ctx, sip.closePoller) if err != nil { - sip.pollingErr = errors.Wrap(err, "error while polling job for cutover signal") + sip.mu.Lock() + sip.mu.pollingErr = errors.Wrap(err, "error while polling job for cutover signal") + sip.mu.Unlock() } }() @@ -220,8 +229,11 @@ func (sip *streamIngestionProcessor) Next() (rowenc.EncDatumRow, *execinfrapb.Pr return nil, sip.DrainHelper() } - if sip.pollingErr != nil { - sip.MoveToDraining(sip.pollingErr) + sip.mu.Lock() + err := sip.mu.pollingErr + sip.mu.Unlock() + if err != nil { + sip.MoveToDraining(err) return nil, sip.DrainHelper() } @@ -243,8 +255,11 @@ func (sip *streamIngestionProcessor) Next() (rowenc.EncDatumRow, *execinfrapb.Pr return row, nil } - if sip.ingestionErr != nil { - sip.MoveToDraining(sip.ingestionErr) + sip.mu.Lock() + err = sip.mu.ingestionErr + sip.mu.Unlock() + if err != nil { + sip.MoveToDraining(err) return nil, sip.DrainHelper() } @@ -372,7 +387,10 @@ func (sip *streamIngestionProcessor) merge( }) } go func() { - sip.ingestionErr = g.Wait() + err := g.Wait() + sip.mu.Lock() + defer sip.mu.Unlock() + sip.mu.ingestionErr = err close(merged) }() @@ -426,6 +444,15 @@ func (sip *streamIngestionProcessor) consumeEvents() (*jobspb.ResolvedSpans, err } return sip.flush() + case streamingccl.GenerationEvent: + log.Info(sip.Ctx, "GenerationEvent received") + select { + case <-sip.cutoverCh: + sip.internalDrained = true + return nil, nil + case <-sip.Ctx.Done(): + return nil, sip.Ctx.Err() + } default: return nil, errors.Newf("unknown streaming event type %v", event.Type()) } diff --git a/pkg/ccl/streamingccl/streamingest/stream_ingestion_processor_test.go b/pkg/ccl/streamingccl/streamingest/stream_ingestion_processor_test.go index ea886a64b8c7..6137ecd70572 100644 --- a/pkg/ccl/streamingccl/streamingest/stream_ingestion_processor_test.go +++ b/pkg/ccl/streamingccl/streamingest/stream_ingestion_processor_test.go @@ -12,6 +12,7 @@ import ( "context" "fmt" "strconv" + "sync" "testing" "time" @@ -48,9 +49,20 @@ import ( // partition addresses. type mockStreamClient struct { partitionEvents map[streamingccl.PartitionAddress][]streamingccl.Event + + // mu is used to provide a threadsafe interface to interceptors. + mu struct { + syncutil.Mutex + + // interceptors can be registered to peek at every event generated by this + // client. + interceptors []func(streamingccl.Event, streamingccl.PartitionAddress) + tableID int + } } var _ streamclient.Client = &mockStreamClient{} +var _ streamclient.InterceptableStreamClient = &mockStreamClient{} // GetTopology implements the Client interface. func (m *mockStreamClient) GetTopology( @@ -61,7 +73,7 @@ func (m *mockStreamClient) GetTopology( // ConsumePartition implements the Client interface. func (m *mockStreamClient) ConsumePartition( - _ context.Context, address streamingccl.PartitionAddress, _ hlc.Timestamp, + ctx context.Context, address streamingccl.PartitionAddress, _ hlc.Timestamp, ) (chan streamingccl.Event, chan error, error) { var events []streamingccl.Event var ok bool @@ -69,16 +81,45 @@ func (m *mockStreamClient) ConsumePartition( return nil, nil, errors.Newf("no events found for paritition %s", address) } - eventCh := make(chan streamingccl.Event, len(events)) + eventCh := make(chan streamingccl.Event) + errCh := make(chan error) - for _, event := range events { - eventCh <- event - } - close(eventCh) + go func() { + defer close(eventCh) + defer close(errCh) + + for _, event := range events { + select { + case eventCh <- event: + case <-ctx.Done(): + errCh <- ctx.Err() + } + + func() { + m.mu.Lock() + defer m.mu.Unlock() + + if len(m.mu.interceptors) > 0 { + for _, interceptor := range m.mu.interceptors { + if interceptor != nil { + interceptor(event, address) + } + } + } + }() + } + }() return eventCh, nil, nil } +// RegisterInterception implements the InterceptableStreamClient interface. +func (m *mockStreamClient) RegisterInterception(fn streamclient.InterceptFn) { + m.mu.Lock() + defer m.mu.Unlock() + m.mu.interceptors = append(m.mu.interceptors, fn) +} + // errorStreamClient always returns an error when consuming a partition. type errorStreamClient struct{} @@ -171,6 +212,59 @@ func TestStreamIngestionProcessor(t *testing.T) { require.Nil(t, row) testutils.IsError(meta.Err, "this client always returns an error") }) + + t.Run("stream ingestion processor shuts down gracefully on losing client connection", func(t *testing.T) { + events := []streamingccl.Event{streamingccl.MakeGenerationEvent()} + pa := streamingccl.PartitionAddress("partition") + mockClient := &mockStreamClient{ + partitionEvents: map[streamingccl.PartitionAddress][]streamingccl.Event{pa: events}, + } + + startTime := hlc.Timestamp{WallTime: timeutil.Now().UnixNano()} + partitionAddresses := []streamingccl.PartitionAddress{"partition"} + + interceptCh := make(chan struct{}) + defer close(interceptCh) + sendToInterceptCh := func() { + interceptCh <- struct{}{} + } + interceptGeneration := markGenerationEventReceived(sendToInterceptCh) + sip, out, err := getStreamIngestionProcessor(ctx, t, registry, kvDB, "randomgen://test/", + partitionAddresses, startTime, []streamclient.InterceptFn{interceptGeneration}, mockClient) + require.NoError(t, err) + + var wg sync.WaitGroup + wg.Add(1) + go func() { + defer wg.Done() + sip.Run(ctx) + }() + + // The channel will block on read if the event has not been intercepted yet. + // Once it unblocks, we are guaranteed that the mockClient has sent the + // GenerationEvent and the processor has read it. + <-interceptCh + + // The sip processor has received a GenerationEvent and is thus + // waiting for a cutover signal, so let's send one! + sip.cutoverCh <- struct{}{} + + wg.Wait() + // Ensure that all the outputs are properly closed. + if !out.ProducerClosed() { + t.Fatalf("output RowReceiver not closed") + } + + for { + // No metadata should have been produced since the processor + // should have been moved to draining state with a nil error. + row := out.NextNoMeta(t) + if row == nil { + break + } + t.Fatalf("more output rows than expected") + } + }) } func getPartitionSpanToTableID( @@ -379,6 +473,30 @@ func runStreamIngestionProcessor( interceptEvents []streamclient.InterceptFn, mockClient streamclient.Client, ) (*distsqlutils.RowBuffer, error) { + sip, out, err := getStreamIngestionProcessor(ctx, t, registry, kvDB, streamAddr, + partitionAddresses, startTime, interceptEvents, mockClient) + require.NoError(t, err) + + sip.Run(ctx) + + // Ensure that all the outputs are properly closed. + if !out.ProducerClosed() { + t.Fatalf("output RowReceiver not closed") + } + return out, err +} + +func getStreamIngestionProcessor( + ctx context.Context, + t *testing.T, + registry *jobs.Registry, + kvDB *kv.DB, + streamAddr string, + partitionAddresses []streamingccl.PartitionAddress, + startTime hlc.Timestamp, + interceptEvents []streamclient.InterceptFn, + mockClient streamclient.Client, +) (*streamIngestionProcessor, *distsqlutils.RowBuffer, error) { st := cluster.MakeTestingClusterSettings() evalCtx := tree.MakeTestingEvalContext(st) @@ -423,14 +541,7 @@ func runStreamIngestionProcessor( interceptable.RegisterInterception(interceptor) } } - - sip.Run(ctx) - - // Ensure that all the outputs are properly closed. - if !out.ProducerClosed() { - t.Fatalf("output RowReceiver not closed") - } - return out, err + return sip, out, err } func registerValidatorWithClient( @@ -476,3 +587,15 @@ func makeCheckpointEventCounter( } } } + +// markGenerationEventReceived runs f after seeing a GenerationEvent. +func markGenerationEventReceived( + f func(), +) func(event streamingccl.Event, pa streamingccl.PartitionAddress) { + return func(event streamingccl.Event, pa streamingccl.PartitionAddress) { + switch event.Type() { + case streamingccl.GenerationEvent: + f() + } + } +} diff --git a/pkg/ccl/streamingccl/streamingtest/replication_helpers.go b/pkg/ccl/streamingccl/streamingtest/replication_helpers.go index 6bf98a25118a..4386f7496657 100644 --- a/pkg/ccl/streamingccl/streamingtest/replication_helpers.go +++ b/pkg/ccl/streamingccl/streamingtest/replication_helpers.go @@ -53,6 +53,14 @@ func ResolvedAtLeast(lo hlc.Timestamp) FeedPredicate { } } +// ReceivedNewGeneration makes a FeedPredicate that matches when a GenerationEvent has +// been received. +func ReceivedNewGeneration() FeedPredicate { + return func(msg streamingccl.Event) bool { + return msg.Type() == streamingccl.GenerationEvent + } +} + // FeedSource is a source of events for a ReplicationFeed. type FeedSource interface { // Next returns the next event, and a flag indicating if there are more events @@ -92,6 +100,12 @@ func (rf *ReplicationFeed) ObserveResolved(lo hlc.Timestamp) hlc.Timestamp { return *rf.msg.GetResolved() } +// ObserveGeneration consumes the feed until we received a GenerationEvent. Returns true. +func (rf *ReplicationFeed) ObserveGeneration() bool { + require.NoError(rf.t, rf.consumeUntil(ReceivedNewGeneration())) + return true +} + // Close cleans up any resources. func (rf *ReplicationFeed) Close() { rf.f.Close()