diff --git a/impl/integration_test.go b/impl/integration_test.go index db02479e..9b2030ed 100644 --- a/impl/integration_test.go +++ b/impl/integration_test.go @@ -885,6 +885,221 @@ func TestAutoRestart(t *testing.T) { } } +// TestAutoRestartAfterBouncingInitiator verifies correct behaviour in the +// following scenario: +// 1. An "initiator" opens a push / pull channel to a "responder" +// 2. The initiator is shut down when the first block is received +// 3. The initiator is brought back up +// 4. The initiator restarts the data transfer with RestartDataTransferChannel +// 5. The connection is broken when the first block is received +// 6. The connection is automatically re-established and the transfer completes +func TestAutoRestartAfterBouncingInitiator(t *testing.T) { + //SetDTLogLevelDebug() + + runTest := func(t *testing.T, isPush bool) { + ctx := context.Background() + ctx, cancel := context.WithTimeout(ctx, 20*time.Second) + defer cancel() + + // The retry config for the network layer: make 5 attempts, backing off by 1s each time + netRetry := network.RetryParameters(time.Second, time.Second, 5, 1) + gsData := testutil.NewGraphsyncTestingData(ctx, t, nil, nil) + gsData.DtNet1 = network.NewFromLibp2pHost(gsData.Host1, netRetry) + initiatorHost := gsData.Host1 // initiator, data sender + responderHost := gsData.Host2 // data recipient + + initiatorGSTspt := gsData.SetupGSTransportHost1() + responderGSTspt := gsData.SetupGSTransportHost2() + + // Set up + restartConf := ChannelRestartConfig(channelmonitor.Config{ + AcceptTimeout: 10 * time.Second, + RestartDebounce: 500 * time.Millisecond, + RestartBackoff: 500 * time.Millisecond, + MaxConsecutiveRestarts: 10, + CompleteTimeout: 100 * time.Millisecond, + }) + initiator, err := NewDataTransfer(gsData.DtDs1, gsData.TempDir1, gsData.DtNet1, initiatorGSTspt, restartConf) + require.NoError(t, err) + testutil.StartAndWaitForReady(ctx, t, initiator) + defer initiator.Stop(ctx) + + responder, err := NewDataTransfer(gsData.DtDs2, gsData.TempDir2, gsData.DtNet2, responderGSTspt) + require.NoError(t, err) + testutil.StartAndWaitForReady(ctx, t, responder) + defer responder.Stop(ctx) + + // Watch for the Completed event on the responder. + // (below we watch for the Completed event on the initiator) + finished := make(chan struct{}, 2) + var completeSubscriber datatransfer.Subscriber = func(event datatransfer.Event, channelState datatransfer.ChannelState) { + if channelState.Status() == datatransfer.Completed { + finished <- struct{}{} + } + } + responder.SubscribeToEvents(completeSubscriber) + + // onDataReceivedChan watches for the first DataReceived event + dataReceiver := initiator + if isPush { + dataReceiver = responder + } + onDataReceivedChan := func(dataRcvr datatransfer.Manager) chan struct{} { + dataReceived := make(chan struct{}, 1) + rcvdCount := 0 + dataRcvdSubscriber := func(event datatransfer.Event, channelState datatransfer.ChannelState) { + //t.Logf("resp: %s / %s\n", datatransfer.Events[event.Code], datatransfer.Statuses[channelState.Status()]) + if event.Code == datatransfer.DataReceived { + rcvdCount++ + if rcvdCount == 1 { + dataReceived <- struct{}{} + } + } + } + dataRcvr.SubscribeToEvents(dataRcvdSubscriber) + return dataReceived + } + dataReceived := onDataReceivedChan(dataReceiver) + + voucher := testutil.FakeDTType{Data: "applesauce"} + sv := testutil.NewStubbedValidator() + + var sourceDagService, destDagService ipldformat.DAGService + if isPush { + sourceDagService = gsData.DagService1 + destDagService = gsData.DagService2 + } else { + sourceDagService = gsData.DagService2 + destDagService = gsData.DagService1 + } + + root, origBytes := testutil.LoadUnixFSFile(ctx, t, sourceDagService, loremFile) + rootCid := root.(cidlink.Link).Cid + + require.NoError(t, initiator.RegisterVoucherType(&testutil.FakeDTType{}, sv)) + require.NoError(t, responder.RegisterVoucherType(&testutil.FakeDTType{}, sv)) + + // Register a revalidator that records calls to OnPullDataSent and OnPushDataReceived + srv := newRestartRevalidator() + require.NoError(t, responder.RegisterRevalidator(testutil.NewFakeDTType(), srv)) + + var chid datatransfer.ChannelID + if isPush { + // Open a push channel + chid, err = initiator.OpenPushDataChannel(ctx, responderHost.ID(), &voucher, rootCid, gsData.AllSelector) + } else { + // Open a pull channel + chid, err = initiator.OpenPullDataChannel(ctx, responderHost.ID(), &voucher, rootCid, gsData.AllSelector) + } + require.NoError(t, err) + + // Wait for the first block to be received + select { + case <-time.After(time.Second): + t.Fatal("Timed out waiting for point at which to break connection") + case <-dataReceived: + } + + // Shut down the initiator of the data transfer + t.Logf("Stopping initiator") + err = initiator.Stop(ctx) + require.NoError(t, err) + + // Break connection + t.Logf("Breaking connection to peer") + require.NoError(t, gsData.Mn.UnlinkPeers(initiatorHost.ID(), responderHost.ID())) + require.NoError(t, gsData.Mn.DisconnectPeers(initiatorHost.ID(), responderHost.ID())) + + // Create a new initiator + initiator2GSTspt := gsData.SetupGSTransportHost1() + initiator2, err := NewDataTransfer(gsData.DtDs1, gsData.TempDir1, gsData.DtNet1, initiator2GSTspt, restartConf) + require.NoError(t, err) + require.NoError(t, initiator2.RegisterVoucherType(&testutil.FakeDTType{}, sv)) + initiator2.SubscribeToEvents(completeSubscriber) + + testutil.StartAndWaitForReady(ctx, t, initiator2) + defer initiator2.Stop(ctx) + + t.Logf("Sleep for a second") + time.Sleep(1 * time.Second) + + // Restore link + t.Logf("Restore link") + require.NoError(t, gsData.Mn.LinkAll()) + time.Sleep(200 * time.Millisecond) + + // Watch for data received event + dataReceiver = initiator2 + if isPush { + dataReceiver = responder + } + dataReceivedAfterRestart := onDataReceivedChan(dataReceiver) + + // Restart the data transfer on the new initiator. + // (this is equivalent to shutting down and restarting a node running + // the initiator) + err = initiator2.RestartDataTransferChannel(ctx, chid) + require.NoError(t, err) + + // Wait for the first block to be received + select { + case <-time.After(time.Second): + t.Fatal("Timed out waiting for point at which to break connection") + case <-dataReceivedAfterRestart: + } + + // Break connection + t.Logf("Breaking connection to peer") + require.NoError(t, gsData.Mn.UnlinkPeers(initiatorHost.ID(), responderHost.ID())) + require.NoError(t, gsData.Mn.DisconnectPeers(initiatorHost.ID(), responderHost.ID())) + + t.Logf("Sleep for a second") + time.Sleep(1 * time.Second) + + // Restore link + t.Logf("Restore link") + require.NoError(t, gsData.Mn.LinkAll()) + time.Sleep(200 * time.Millisecond) + + // Wait for the transfer to complete + t.Logf("Waiting for auto-restart on push channel %s", chid) + + (func() { + finishedCount := 0 + for { + select { + case <-ctx.Done(): + t.Fatal("Did not complete successful data transfer") + return + case <-finished: + finishedCount++ + if finishedCount == 2 { + return + } + } + } + })() + + // Verify that the total amount of data sent / received that was + // reported to the revalidator is correct + if isPush { + require.EqualValues(t, loremFileTransferBytes, srv.pushDataSum(chid)) + } else { + require.EqualValues(t, loremFileTransferBytes, srv.pullDataSum(chid)) + } + + // Verify that the file was transferred to the destination node + testutil.VerifyHasFile(ctx, t, destDagService, root, origBytes) + } + + t.Run("push", func(t *testing.T) { + runTest(t, true) + }) + t.Run("pull", func(t *testing.T) { + runTest(t, false) + }) +} + func TestRoundTripCancelledRequest(t *testing.T) { ctx := context.Background() testCases := map[string]struct { diff --git a/impl/restart.go b/impl/restart.go index 9b6fcce7..697985cd 100644 --- a/impl/restart.go +++ b/impl/restart.go @@ -99,8 +99,16 @@ func (m *manager) openPushRestartChannel(ctx context.Context, channel datatransf } m.dataTransferNetwork.Protect(requestTo, chid.String()) + // Monitor the state of the connection for the channel + monitoredChan := m.channelMonitor.AddPushChannel(chid) log.Infof("sending push restart channel to %s for channel %s", requestTo, chid) if err := m.dataTransferNetwork.SendMessage(ctx, requestTo, req); err != nil { + // If push channel monitoring is enabled, shutdown the monitor as it + // wasn't possible to start the data transfer + if monitoredChan != nil { + monitoredChan.Shutdown() + } + return xerrors.Errorf("Unable to send restart request: %w", err) } @@ -126,8 +134,16 @@ func (m *manager) openPullRestartChannel(ctx context.Context, channel datatransf } m.dataTransferNetwork.Protect(requestTo, chid.String()) + // Monitor the state of the connection for the channel + monitoredChan := m.channelMonitor.AddPullChannel(chid) log.Infof("sending open channel to %s to restart channel %s", requestTo, chid) if err := m.transport.OpenChannel(ctx, requestTo, chid, cidlink.Link{Cid: baseCid}, selector, channel, req); err != nil { + // If pull channel monitoring is enabled, shutdown the monitor as it + // wasn't possible to start the data transfer + if monitoredChan != nil { + monitoredChan.Shutdown() + } + return xerrors.Errorf("Unable to send open channel restart request: %w", err) }