From f5244f499144c7dbfb34849e18fc8a69fb3783dc Mon Sep 17 00:00:00 2001 From: Anne Zhu Date: Wed, 28 Jul 2021 14:53:51 -0400 Subject: [PATCH] streamingccl: hang processors on losing connection with sinkless stream client Previously, if the sinkless client loses connection, the processor would receive an error and move to draining. With the concept of generation, the sinkless client should send over a `GenerationEvent` once it has lost connection. On receiving a `GenerationEvent`, the processor should wait for a cutover signal to be sent (the mechanism for issuing cutover signals on new generation will be implemented in a following PR). Release note: None --- .../cockroach_sinkless_replication_client.go | 11 +- ...kroach_sinkless_replication_client_test.go | 13 ++ .../stream_ingestion_processor.go | 55 +++++-- .../stream_ingestion_processor_test.go | 151 ++++++++++++++++-- .../streamingtest/replication_helpers.go | 14 ++ 5 files changed, 215 insertions(+), 29 deletions(-) diff --git a/pkg/ccl/streamingccl/streamclient/cockroach_sinkless_replication_client.go b/pkg/ccl/streamingccl/streamclient/cockroach_sinkless_replication_client.go index 6f512cd09bf0..23bfb5c36c47 100644 --- a/pkg/ccl/streamingccl/streamclient/cockroach_sinkless_replication_client.go +++ b/pkg/ccl/streamingccl/streamclient/cockroach_sinkless_replication_client.go @@ -11,6 +11,7 @@ package streamclient import ( "context" gosql "database/sql" + "database/sql/driver" "fmt" "strconv" @@ -124,7 +125,15 @@ func (m *sinklessReplicationClient) ConsumePartition( } } if err := rows.Err(); err != nil { - errCh <- err + if errors.Is(err, driver.ErrBadConn) { + select { + case eventCh <- streamingccl.MakeGenerationEvent(): + case <-ctx.Done(): + errCh <- ctx.Err() + } + } else { + errCh <- err + } return } }() diff --git a/pkg/ccl/streamingccl/streamclient/cockroach_sinkless_replication_client_test.go b/pkg/ccl/streamingccl/streamclient/cockroach_sinkless_replication_client_test.go index 296128f4d84e..f46a90b28709 100644 --- a/pkg/ccl/streamingccl/streamclient/cockroach_sinkless_replication_client_test.go +++ b/pkg/ccl/streamingccl/streamclient/cockroach_sinkless_replication_client_test.go @@ -110,4 +110,17 @@ INSERT INTO d.t2 VALUES (2); feed.ObserveResolved(secondObserved.Value.Timestamp) cancelIngestion() }) + + t.Run("stream-address-disconnects", func(t *testing.T) { + clientCtx, cancelIngestion := context.WithCancel(ctx) + eventCh, errCh, err := client.ConsumePartition(clientCtx, pa, startTime) + require.NoError(t, err) + feedSource := &channelFeedSource{eventCh: eventCh, errCh: errCh} + feed := streamingtest.MakeReplicationFeed(t, feedSource) + + h.SysServer.Stopper().Stop(clientCtx) + + require.True(t, feed.ObserveGeneration()) + cancelIngestion() + }) } diff --git a/pkg/ccl/streamingccl/streamingest/stream_ingestion_processor.go b/pkg/ccl/streamingccl/streamingest/stream_ingestion_processor.go index b29fc2a4601e..419e46851415 100644 --- a/pkg/ccl/streamingccl/streamingest/stream_ingestion_processor.go +++ b/pkg/ccl/streamingccl/streamingest/stream_ingestion_processor.go @@ -31,6 +31,7 @@ import ( "github.com/cockroachdb/cockroach/pkg/util/hlc" "github.com/cockroachdb/cockroach/pkg/util/log" "github.com/cockroachdb/cockroach/pkg/util/protoutil" + "github.com/cockroachdb/cockroach/pkg/util/syncutil" "github.com/cockroachdb/cockroach/pkg/util/timeutil" "github.com/cockroachdb/errors" "golang.org/x/sync/errgroup" @@ -95,14 +96,6 @@ type streamIngestionProcessor struct { // and have attempted to flush them with `internalDrained`. internalDrained bool - // ingestionErr stores any error that is returned from the worker goroutine so - // that it can be forwarded through the DistSQL flow. - ingestionErr error - - // pollingErr stores any error that is returned from the poller checking for a - // cutover signal so that it can be forwarded through the DistSQL flow. - pollingErr error - // pollingWaitGroup registers the polling goroutine and waits for it to return // when the processor is being drained. pollingWaitGroup sync.WaitGroup @@ -117,6 +110,20 @@ type streamIngestionProcessor struct { // closePoller is used to shutdown the poller that checks the job for a // cutover signal. closePoller chan struct{} + + // mu is used to provide thread-safe read-write operations to ingestionErr + // and pollingErr. + mu struct { + syncutil.Mutex + + // ingestionErr stores any error that is returned from the worker goroutine so + // that it can be forwarded through the DistSQL flow. + ingestionErr error + + // pollingErr stores any error that is returned from the poller checking for a + // cutover signal so that it can be forwarded through the DistSQL flow. + pollingErr error + } } // partitionEvent augments a normal event with the partition it came from. @@ -190,7 +197,9 @@ func (sip *streamIngestionProcessor) Start(ctx context.Context) { defer sip.pollingWaitGroup.Done() err := sip.checkForCutoverSignal(ctx, sip.closePoller) if err != nil { - sip.pollingErr = errors.Wrap(err, "error while polling job for cutover signal") + sip.mu.Lock() + sip.mu.pollingErr = errors.Wrap(err, "error while polling job for cutover signal") + sip.mu.Unlock() } }() @@ -220,8 +229,11 @@ func (sip *streamIngestionProcessor) Next() (rowenc.EncDatumRow, *execinfrapb.Pr return nil, sip.DrainHelper() } - if sip.pollingErr != nil { - sip.MoveToDraining(sip.pollingErr) + sip.mu.Lock() + err := sip.mu.pollingErr + sip.mu.Unlock() + if err != nil { + sip.MoveToDraining(err) return nil, sip.DrainHelper() } @@ -243,8 +255,11 @@ func (sip *streamIngestionProcessor) Next() (rowenc.EncDatumRow, *execinfrapb.Pr return row, nil } - if sip.ingestionErr != nil { - sip.MoveToDraining(sip.ingestionErr) + sip.mu.Lock() + err = sip.mu.ingestionErr + sip.mu.Unlock() + if err != nil { + sip.MoveToDraining(err) return nil, sip.DrainHelper() } @@ -372,7 +387,10 @@ func (sip *streamIngestionProcessor) merge( }) } go func() { - sip.ingestionErr = g.Wait() + err := g.Wait() + sip.mu.Lock() + defer sip.mu.Unlock() + sip.mu.ingestionErr = err close(merged) }() @@ -426,6 +444,15 @@ func (sip *streamIngestionProcessor) consumeEvents() (*jobspb.ResolvedSpans, err } return sip.flush() + case streamingccl.GenerationEvent: + log.Info(sip.Ctx, "GenerationEvent received") + select { + case <-sip.cutoverCh: + sip.internalDrained = true + return nil, nil + case <-sip.Ctx.Done(): + return nil, sip.Ctx.Err() + } default: return nil, errors.Newf("unknown streaming event type %v", event.Type()) } diff --git a/pkg/ccl/streamingccl/streamingest/stream_ingestion_processor_test.go b/pkg/ccl/streamingccl/streamingest/stream_ingestion_processor_test.go index ea886a64b8c7..6137ecd70572 100644 --- a/pkg/ccl/streamingccl/streamingest/stream_ingestion_processor_test.go +++ b/pkg/ccl/streamingccl/streamingest/stream_ingestion_processor_test.go @@ -12,6 +12,7 @@ import ( "context" "fmt" "strconv" + "sync" "testing" "time" @@ -48,9 +49,20 @@ import ( // partition addresses. type mockStreamClient struct { partitionEvents map[streamingccl.PartitionAddress][]streamingccl.Event + + // mu is used to provide a threadsafe interface to interceptors. + mu struct { + syncutil.Mutex + + // interceptors can be registered to peek at every event generated by this + // client. + interceptors []func(streamingccl.Event, streamingccl.PartitionAddress) + tableID int + } } var _ streamclient.Client = &mockStreamClient{} +var _ streamclient.InterceptableStreamClient = &mockStreamClient{} // GetTopology implements the Client interface. func (m *mockStreamClient) GetTopology( @@ -61,7 +73,7 @@ func (m *mockStreamClient) GetTopology( // ConsumePartition implements the Client interface. func (m *mockStreamClient) ConsumePartition( - _ context.Context, address streamingccl.PartitionAddress, _ hlc.Timestamp, + ctx context.Context, address streamingccl.PartitionAddress, _ hlc.Timestamp, ) (chan streamingccl.Event, chan error, error) { var events []streamingccl.Event var ok bool @@ -69,16 +81,45 @@ func (m *mockStreamClient) ConsumePartition( return nil, nil, errors.Newf("no events found for paritition %s", address) } - eventCh := make(chan streamingccl.Event, len(events)) + eventCh := make(chan streamingccl.Event) + errCh := make(chan error) - for _, event := range events { - eventCh <- event - } - close(eventCh) + go func() { + defer close(eventCh) + defer close(errCh) + + for _, event := range events { + select { + case eventCh <- event: + case <-ctx.Done(): + errCh <- ctx.Err() + } + + func() { + m.mu.Lock() + defer m.mu.Unlock() + + if len(m.mu.interceptors) > 0 { + for _, interceptor := range m.mu.interceptors { + if interceptor != nil { + interceptor(event, address) + } + } + } + }() + } + }() return eventCh, nil, nil } +// RegisterInterception implements the InterceptableStreamClient interface. +func (m *mockStreamClient) RegisterInterception(fn streamclient.InterceptFn) { + m.mu.Lock() + defer m.mu.Unlock() + m.mu.interceptors = append(m.mu.interceptors, fn) +} + // errorStreamClient always returns an error when consuming a partition. type errorStreamClient struct{} @@ -171,6 +212,59 @@ func TestStreamIngestionProcessor(t *testing.T) { require.Nil(t, row) testutils.IsError(meta.Err, "this client always returns an error") }) + + t.Run("stream ingestion processor shuts down gracefully on losing client connection", func(t *testing.T) { + events := []streamingccl.Event{streamingccl.MakeGenerationEvent()} + pa := streamingccl.PartitionAddress("partition") + mockClient := &mockStreamClient{ + partitionEvents: map[streamingccl.PartitionAddress][]streamingccl.Event{pa: events}, + } + + startTime := hlc.Timestamp{WallTime: timeutil.Now().UnixNano()} + partitionAddresses := []streamingccl.PartitionAddress{"partition"} + + interceptCh := make(chan struct{}) + defer close(interceptCh) + sendToInterceptCh := func() { + interceptCh <- struct{}{} + } + interceptGeneration := markGenerationEventReceived(sendToInterceptCh) + sip, out, err := getStreamIngestionProcessor(ctx, t, registry, kvDB, "randomgen://test/", + partitionAddresses, startTime, []streamclient.InterceptFn{interceptGeneration}, mockClient) + require.NoError(t, err) + + var wg sync.WaitGroup + wg.Add(1) + go func() { + defer wg.Done() + sip.Run(ctx) + }() + + // The channel will block on read if the event has not been intercepted yet. + // Once it unblocks, we are guaranteed that the mockClient has sent the + // GenerationEvent and the processor has read it. + <-interceptCh + + // The sip processor has received a GenerationEvent and is thus + // waiting for a cutover signal, so let's send one! + sip.cutoverCh <- struct{}{} + + wg.Wait() + // Ensure that all the outputs are properly closed. + if !out.ProducerClosed() { + t.Fatalf("output RowReceiver not closed") + } + + for { + // No metadata should have been produced since the processor + // should have been moved to draining state with a nil error. + row := out.NextNoMeta(t) + if row == nil { + break + } + t.Fatalf("more output rows than expected") + } + }) } func getPartitionSpanToTableID( @@ -379,6 +473,30 @@ func runStreamIngestionProcessor( interceptEvents []streamclient.InterceptFn, mockClient streamclient.Client, ) (*distsqlutils.RowBuffer, error) { + sip, out, err := getStreamIngestionProcessor(ctx, t, registry, kvDB, streamAddr, + partitionAddresses, startTime, interceptEvents, mockClient) + require.NoError(t, err) + + sip.Run(ctx) + + // Ensure that all the outputs are properly closed. + if !out.ProducerClosed() { + t.Fatalf("output RowReceiver not closed") + } + return out, err +} + +func getStreamIngestionProcessor( + ctx context.Context, + t *testing.T, + registry *jobs.Registry, + kvDB *kv.DB, + streamAddr string, + partitionAddresses []streamingccl.PartitionAddress, + startTime hlc.Timestamp, + interceptEvents []streamclient.InterceptFn, + mockClient streamclient.Client, +) (*streamIngestionProcessor, *distsqlutils.RowBuffer, error) { st := cluster.MakeTestingClusterSettings() evalCtx := tree.MakeTestingEvalContext(st) @@ -423,14 +541,7 @@ func runStreamIngestionProcessor( interceptable.RegisterInterception(interceptor) } } - - sip.Run(ctx) - - // Ensure that all the outputs are properly closed. - if !out.ProducerClosed() { - t.Fatalf("output RowReceiver not closed") - } - return out, err + return sip, out, err } func registerValidatorWithClient( @@ -476,3 +587,15 @@ func makeCheckpointEventCounter( } } } + +// markGenerationEventReceived runs f after seeing a GenerationEvent. +func markGenerationEventReceived( + f func(), +) func(event streamingccl.Event, pa streamingccl.PartitionAddress) { + return func(event streamingccl.Event, pa streamingccl.PartitionAddress) { + switch event.Type() { + case streamingccl.GenerationEvent: + f() + } + } +} diff --git a/pkg/ccl/streamingccl/streamingtest/replication_helpers.go b/pkg/ccl/streamingccl/streamingtest/replication_helpers.go index 6bf98a25118a..4386f7496657 100644 --- a/pkg/ccl/streamingccl/streamingtest/replication_helpers.go +++ b/pkg/ccl/streamingccl/streamingtest/replication_helpers.go @@ -53,6 +53,14 @@ func ResolvedAtLeast(lo hlc.Timestamp) FeedPredicate { } } +// ReceivedNewGeneration makes a FeedPredicate that matches when a GenerationEvent has +// been received. +func ReceivedNewGeneration() FeedPredicate { + return func(msg streamingccl.Event) bool { + return msg.Type() == streamingccl.GenerationEvent + } +} + // FeedSource is a source of events for a ReplicationFeed. type FeedSource interface { // Next returns the next event, and a flag indicating if there are more events @@ -92,6 +100,12 @@ func (rf *ReplicationFeed) ObserveResolved(lo hlc.Timestamp) hlc.Timestamp { return *rf.msg.GetResolved() } +// ObserveGeneration consumes the feed until we received a GenerationEvent. Returns true. +func (rf *ReplicationFeed) ObserveGeneration() bool { + require.NoError(rf.t, rf.consumeUntil(ReceivedNewGeneration())) + return true +} + // Close cleans up any resources. func (rf *ReplicationFeed) Close() { rf.f.Close()