Skip to content

Commit

Permalink
streamingccl: hang processors on losing connection with sinkless stre…
Browse files Browse the repository at this point in the history
…am client

Previously, if the sinkless client loses connection, the processor would receive an error and move to draining. With the concept of generation, the sinkless client should send over a `GenerationEvent` once it has lost connection. On receiving a `GenerationEvent`, the processor should wait for a cutover signal to be sent (the mechanism for issuing cutover signals on new generation will be implemented in a following PR).

Release note: None
  • Loading branch information
azhu-crl committed Aug 10, 2021
1 parent ff40d73 commit f5244f4
Show file tree
Hide file tree
Showing 5 changed files with 215 additions and 29 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ package streamclient
import (
"context"
gosql "database/sql"
"database/sql/driver"
"fmt"
"strconv"

Expand Down Expand Up @@ -124,7 +125,15 @@ func (m *sinklessReplicationClient) ConsumePartition(
}
}
if err := rows.Err(); err != nil {
errCh <- err
if errors.Is(err, driver.ErrBadConn) {
select {
case eventCh <- streamingccl.MakeGenerationEvent():
case <-ctx.Done():
errCh <- ctx.Err()
}
} else {
errCh <- err
}
return
}
}()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -110,4 +110,17 @@ INSERT INTO d.t2 VALUES (2);
feed.ObserveResolved(secondObserved.Value.Timestamp)
cancelIngestion()
})

t.Run("stream-address-disconnects", func(t *testing.T) {
clientCtx, cancelIngestion := context.WithCancel(ctx)
eventCh, errCh, err := client.ConsumePartition(clientCtx, pa, startTime)
require.NoError(t, err)
feedSource := &channelFeedSource{eventCh: eventCh, errCh: errCh}
feed := streamingtest.MakeReplicationFeed(t, feedSource)

h.SysServer.Stopper().Stop(clientCtx)

require.True(t, feed.ObserveGeneration())
cancelIngestion()
})
}
55 changes: 41 additions & 14 deletions pkg/ccl/streamingccl/streamingest/stream_ingestion_processor.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ import (
"github.com/cockroachdb/cockroach/pkg/util/hlc"
"github.com/cockroachdb/cockroach/pkg/util/log"
"github.com/cockroachdb/cockroach/pkg/util/protoutil"
"github.com/cockroachdb/cockroach/pkg/util/syncutil"
"github.com/cockroachdb/cockroach/pkg/util/timeutil"
"github.com/cockroachdb/errors"
"golang.org/x/sync/errgroup"
Expand Down Expand Up @@ -95,14 +96,6 @@ type streamIngestionProcessor struct {
// and have attempted to flush them with `internalDrained`.
internalDrained bool

// ingestionErr stores any error that is returned from the worker goroutine so
// that it can be forwarded through the DistSQL flow.
ingestionErr error

// pollingErr stores any error that is returned from the poller checking for a
// cutover signal so that it can be forwarded through the DistSQL flow.
pollingErr error

// pollingWaitGroup registers the polling goroutine and waits for it to return
// when the processor is being drained.
pollingWaitGroup sync.WaitGroup
Expand All @@ -117,6 +110,20 @@ type streamIngestionProcessor struct {
// closePoller is used to shutdown the poller that checks the job for a
// cutover signal.
closePoller chan struct{}

// mu is used to provide thread-safe read-write operations to ingestionErr
// and pollingErr.
mu struct {
syncutil.Mutex

// ingestionErr stores any error that is returned from the worker goroutine so
// that it can be forwarded through the DistSQL flow.
ingestionErr error

// pollingErr stores any error that is returned from the poller checking for a
// cutover signal so that it can be forwarded through the DistSQL flow.
pollingErr error
}
}

// partitionEvent augments a normal event with the partition it came from.
Expand Down Expand Up @@ -190,7 +197,9 @@ func (sip *streamIngestionProcessor) Start(ctx context.Context) {
defer sip.pollingWaitGroup.Done()
err := sip.checkForCutoverSignal(ctx, sip.closePoller)
if err != nil {
sip.pollingErr = errors.Wrap(err, "error while polling job for cutover signal")
sip.mu.Lock()
sip.mu.pollingErr = errors.Wrap(err, "error while polling job for cutover signal")
sip.mu.Unlock()
}
}()

Expand Down Expand Up @@ -220,8 +229,11 @@ func (sip *streamIngestionProcessor) Next() (rowenc.EncDatumRow, *execinfrapb.Pr
return nil, sip.DrainHelper()
}

if sip.pollingErr != nil {
sip.MoveToDraining(sip.pollingErr)
sip.mu.Lock()
err := sip.mu.pollingErr
sip.mu.Unlock()
if err != nil {
sip.MoveToDraining(err)
return nil, sip.DrainHelper()
}

Expand All @@ -243,8 +255,11 @@ func (sip *streamIngestionProcessor) Next() (rowenc.EncDatumRow, *execinfrapb.Pr
return row, nil
}

if sip.ingestionErr != nil {
sip.MoveToDraining(sip.ingestionErr)
sip.mu.Lock()
err = sip.mu.ingestionErr
sip.mu.Unlock()
if err != nil {
sip.MoveToDraining(err)
return nil, sip.DrainHelper()
}

Expand Down Expand Up @@ -372,7 +387,10 @@ func (sip *streamIngestionProcessor) merge(
})
}
go func() {
sip.ingestionErr = g.Wait()
err := g.Wait()
sip.mu.Lock()
defer sip.mu.Unlock()
sip.mu.ingestionErr = err
close(merged)
}()

Expand Down Expand Up @@ -426,6 +444,15 @@ func (sip *streamIngestionProcessor) consumeEvents() (*jobspb.ResolvedSpans, err
}

return sip.flush()
case streamingccl.GenerationEvent:
log.Info(sip.Ctx, "GenerationEvent received")
select {
case <-sip.cutoverCh:
sip.internalDrained = true
return nil, nil
case <-sip.Ctx.Done():
return nil, sip.Ctx.Err()
}
default:
return nil, errors.Newf("unknown streaming event type %v", event.Type())
}
Expand Down
151 changes: 137 additions & 14 deletions pkg/ccl/streamingccl/streamingest/stream_ingestion_processor_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ import (
"context"
"fmt"
"strconv"
"sync"
"testing"
"time"

Expand Down Expand Up @@ -48,9 +49,20 @@ import (
// partition addresses.
type mockStreamClient struct {
partitionEvents map[streamingccl.PartitionAddress][]streamingccl.Event

// mu is used to provide a threadsafe interface to interceptors.
mu struct {
syncutil.Mutex

// interceptors can be registered to peek at every event generated by this
// client.
interceptors []func(streamingccl.Event, streamingccl.PartitionAddress)
tableID int
}
}

var _ streamclient.Client = &mockStreamClient{}
var _ streamclient.InterceptableStreamClient = &mockStreamClient{}

// GetTopology implements the Client interface.
func (m *mockStreamClient) GetTopology(
Expand All @@ -61,24 +73,53 @@ func (m *mockStreamClient) GetTopology(

// ConsumePartition implements the Client interface.
func (m *mockStreamClient) ConsumePartition(
_ context.Context, address streamingccl.PartitionAddress, _ hlc.Timestamp,
ctx context.Context, address streamingccl.PartitionAddress, _ hlc.Timestamp,
) (chan streamingccl.Event, chan error, error) {
var events []streamingccl.Event
var ok bool
if events, ok = m.partitionEvents[address]; !ok {
return nil, nil, errors.Newf("no events found for paritition %s", address)
}

eventCh := make(chan streamingccl.Event, len(events))
eventCh := make(chan streamingccl.Event)
errCh := make(chan error)

for _, event := range events {
eventCh <- event
}
close(eventCh)
go func() {
defer close(eventCh)
defer close(errCh)

for _, event := range events {
select {
case eventCh <- event:
case <-ctx.Done():
errCh <- ctx.Err()
}

func() {
m.mu.Lock()
defer m.mu.Unlock()

if len(m.mu.interceptors) > 0 {
for _, interceptor := range m.mu.interceptors {
if interceptor != nil {
interceptor(event, address)
}
}
}
}()
}
}()

return eventCh, nil, nil
}

// RegisterInterception implements the InterceptableStreamClient interface.
func (m *mockStreamClient) RegisterInterception(fn streamclient.InterceptFn) {
m.mu.Lock()
defer m.mu.Unlock()
m.mu.interceptors = append(m.mu.interceptors, fn)
}

// errorStreamClient always returns an error when consuming a partition.
type errorStreamClient struct{}

Expand Down Expand Up @@ -171,6 +212,59 @@ func TestStreamIngestionProcessor(t *testing.T) {
require.Nil(t, row)
testutils.IsError(meta.Err, "this client always returns an error")
})

t.Run("stream ingestion processor shuts down gracefully on losing client connection", func(t *testing.T) {
events := []streamingccl.Event{streamingccl.MakeGenerationEvent()}
pa := streamingccl.PartitionAddress("partition")
mockClient := &mockStreamClient{
partitionEvents: map[streamingccl.PartitionAddress][]streamingccl.Event{pa: events},
}

startTime := hlc.Timestamp{WallTime: timeutil.Now().UnixNano()}
partitionAddresses := []streamingccl.PartitionAddress{"partition"}

interceptCh := make(chan struct{})
defer close(interceptCh)
sendToInterceptCh := func() {
interceptCh <- struct{}{}
}
interceptGeneration := markGenerationEventReceived(sendToInterceptCh)
sip, out, err := getStreamIngestionProcessor(ctx, t, registry, kvDB, "randomgen://test/",
partitionAddresses, startTime, []streamclient.InterceptFn{interceptGeneration}, mockClient)
require.NoError(t, err)

var wg sync.WaitGroup
wg.Add(1)
go func() {
defer wg.Done()
sip.Run(ctx)
}()

// The channel will block on read if the event has not been intercepted yet.
// Once it unblocks, we are guaranteed that the mockClient has sent the
// GenerationEvent and the processor has read it.
<-interceptCh

// The sip processor has received a GenerationEvent and is thus
// waiting for a cutover signal, so let's send one!
sip.cutoverCh <- struct{}{}

wg.Wait()
// Ensure that all the outputs are properly closed.
if !out.ProducerClosed() {
t.Fatalf("output RowReceiver not closed")
}

for {
// No metadata should have been produced since the processor
// should have been moved to draining state with a nil error.
row := out.NextNoMeta(t)
if row == nil {
break
}
t.Fatalf("more output rows than expected")
}
})
}

func getPartitionSpanToTableID(
Expand Down Expand Up @@ -379,6 +473,30 @@ func runStreamIngestionProcessor(
interceptEvents []streamclient.InterceptFn,
mockClient streamclient.Client,
) (*distsqlutils.RowBuffer, error) {
sip, out, err := getStreamIngestionProcessor(ctx, t, registry, kvDB, streamAddr,
partitionAddresses, startTime, interceptEvents, mockClient)
require.NoError(t, err)

sip.Run(ctx)

// Ensure that all the outputs are properly closed.
if !out.ProducerClosed() {
t.Fatalf("output RowReceiver not closed")
}
return out, err
}

func getStreamIngestionProcessor(
ctx context.Context,
t *testing.T,
registry *jobs.Registry,
kvDB *kv.DB,
streamAddr string,
partitionAddresses []streamingccl.PartitionAddress,
startTime hlc.Timestamp,
interceptEvents []streamclient.InterceptFn,
mockClient streamclient.Client,
) (*streamIngestionProcessor, *distsqlutils.RowBuffer, error) {
st := cluster.MakeTestingClusterSettings()
evalCtx := tree.MakeTestingEvalContext(st)

Expand Down Expand Up @@ -423,14 +541,7 @@ func runStreamIngestionProcessor(
interceptable.RegisterInterception(interceptor)
}
}

sip.Run(ctx)

// Ensure that all the outputs are properly closed.
if !out.ProducerClosed() {
t.Fatalf("output RowReceiver not closed")
}
return out, err
return sip, out, err
}

func registerValidatorWithClient(
Expand Down Expand Up @@ -476,3 +587,15 @@ func makeCheckpointEventCounter(
}
}
}

// markGenerationEventReceived runs f after seeing a GenerationEvent.
func markGenerationEventReceived(
f func(),
) func(event streamingccl.Event, pa streamingccl.PartitionAddress) {
return func(event streamingccl.Event, pa streamingccl.PartitionAddress) {
switch event.Type() {
case streamingccl.GenerationEvent:
f()
}
}
}
14 changes: 14 additions & 0 deletions pkg/ccl/streamingccl/streamingtest/replication_helpers.go
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,14 @@ func ResolvedAtLeast(lo hlc.Timestamp) FeedPredicate {
}
}

// ReceivedNewGeneration makes a FeedPredicate that matches when a GenerationEvent has
// been received.
func ReceivedNewGeneration() FeedPredicate {
return func(msg streamingccl.Event) bool {
return msg.Type() == streamingccl.GenerationEvent
}
}

// FeedSource is a source of events for a ReplicationFeed.
type FeedSource interface {
// Next returns the next event, and a flag indicating if there are more events
Expand Down Expand Up @@ -92,6 +100,12 @@ func (rf *ReplicationFeed) ObserveResolved(lo hlc.Timestamp) hlc.Timestamp {
return *rf.msg.GetResolved()
}

// ObserveGeneration consumes the feed until we received a GenerationEvent. Returns true.
func (rf *ReplicationFeed) ObserveGeneration() bool {
require.NoError(rf.t, rf.consumeUntil(ReceivedNewGeneration()))
return true
}

// Close cleans up any resources.
func (rf *ReplicationFeed) Close() {
rf.f.Close()
Expand Down

0 comments on commit f5244f4

Please sign in to comment.