diff --git a/pushchannelmonitor/pushchannelmonitor.go b/channelmonitor/channelmonitor.go similarity index 59% rename from pushchannelmonitor/pushchannelmonitor.go rename to channelmonitor/channelmonitor.go index f14382cf..2493f1e8 100644 --- a/pushchannelmonitor/pushchannelmonitor.go +++ b/channelmonitor/channelmonitor.go @@ -1,4 +1,4 @@ -package pushchannelmonitor +package channelmonitor import ( "context" @@ -13,7 +13,7 @@ import ( "github.com/filecoin-project/go-data-transfer/channels" ) -var log = logging.Logger("dt-pushchanmon") +var log = logging.Logger("dt-chanmon") type monitorAPI interface { SubscribeToEvents(subscriber datatransfer.Subscriber) datatransfer.Unsubscribe @@ -21,8 +21,8 @@ type monitorAPI interface { CloseDataTransferChannelWithError(ctx context.Context, chid datatransfer.ChannelID, cherr error) error } -// Monitor watches the data-rate for push channels, and restarts -// a channel if the data-rate falls too low +// Monitor watches the data-rate for data transfer channels, and restarts +// a channel if the data-rate falls too low or if there are timeouts / errors type Monitor struct { ctx context.Context stop context.CancelFunc @@ -30,16 +30,16 @@ type Monitor struct { cfg *Config lk sync.RWMutex - channels map[*monitoredChannel]struct{} + channels map[monitoredChan]struct{} } type Config struct { - // Max time to wait for other side to accept push before attempting restart + // Max time to wait for other side to accept open channel request before attempting restart AcceptTimeout time.Duration // Interval between checks of transfer rate Interval time.Duration - // Min bytes that must be sent in interval - MinBytesSent uint64 + // Min bytes that must be sent / received in interval + MinBytesTransferred uint64 // Number of times to check transfer rate per interval ChecksPerInterval uint32 // Backoff after restarting @@ -59,7 +59,7 @@ func NewMonitor(mgr monitorAPI, cfg *Config) *Monitor { stop: cancel, mgr: mgr, cfg: cfg, - channels: make(map[*monitoredChannel]struct{}), + channels: make(map[monitoredChan]struct{}), } } @@ -68,7 +68,7 @@ func checkConfig(cfg *Config) { return } - prefix := "data-transfer channel push monitor config " + prefix := "data-transfer channel monitor config " if cfg.AcceptTimeout <= 0 { panic(fmt.Sprintf(prefix+"AcceptTimeout is %s but must be > 0", cfg.AcceptTimeout)) } @@ -78,8 +78,8 @@ func checkConfig(cfg *Config) { if cfg.ChecksPerInterval == 0 { panic(fmt.Sprintf(prefix+"ChecksPerInterval is %d but must be > 0", cfg.ChecksPerInterval)) } - if cfg.MinBytesSent == 0 { - panic(fmt.Sprintf(prefix+"MinBytesSent is %d but must be > 0", cfg.MinBytesSent)) + if cfg.MinBytesTransferred == 0 { + panic(fmt.Sprintf(prefix+"MinBytesTransferred is %d but must be > 0", cfg.MinBytesTransferred)) } if cfg.MaxConsecutiveRestarts == 0 { panic(fmt.Sprintf(prefix+"MaxConsecutiveRestarts is %d but must be > 0", cfg.MaxConsecutiveRestarts)) @@ -89,8 +89,25 @@ func checkConfig(cfg *Config) { } } -// AddChannel adds a channel to the push channel monitor -func (m *Monitor) AddChannel(chid datatransfer.ChannelID) *monitoredChannel { +// This interface just makes it easier to abstract some methods between the +// push and pull monitor implementations +type monitoredChan interface { + Shutdown() + checkDataRate() +} + +// AddPushChannel adds a push channel to the channel monitor +func (m *Monitor) AddPushChannel(chid datatransfer.ChannelID) monitoredChan { + return m.addChannel(chid, true) +} + +// AddPullChannel adds a pull channel to the channel monitor +func (m *Monitor) AddPullChannel(chid datatransfer.ChannelID) monitoredChan { + return m.addChannel(chid, false) +} + +// addChannel adds a channel to the channel monitor +func (m *Monitor) addChannel(chid datatransfer.ChannelID, isPush bool) monitoredChan { if !m.enabled() { return nil } @@ -98,7 +115,12 @@ func (m *Monitor) AddChannel(chid datatransfer.ChannelID) *monitoredChannel { m.lk.Lock() defer m.lk.Unlock() - mpc := newMonitoredChannel(m.mgr, chid, m.cfg, m.onMonitoredChannelShutdown) + var mpc monitoredChan + if isPush { + mpc = newMonitoredPushChannel(m.mgr, chid, m.cfg, m.onMonitoredChannelShutdown) + } else { + mpc = newMonitoredPullChannel(m.mgr, chid, m.cfg, m.onMonitoredChannelShutdown) + } m.channels[mpc] = struct{}{} return mpc } @@ -127,7 +149,7 @@ func (m *Monitor) onMonitoredChannelShutdown(mpc *monitoredChannel) { delete(m.channels, mpc) } -// enabled indicates whether the push channel monitor is running +// enabled indicates whether the channel monitor is running func (m *Monitor) enabled() bool { return m.cfg != nil } @@ -148,9 +170,9 @@ func (m *Monitor) run() { ticker := time.NewTicker(tickInterval) defer ticker.Stop() - log.Infof("Starting push channel monitor with "+ + log.Infof("Starting data-transfer channel monitor with "+ "%d checks per %s interval (check interval %s); min bytes per interval: %d, restart backoff: %s; max consecutive restarts: %d", - m.cfg.ChecksPerInterval, m.cfg.Interval, tickInterval, m.cfg.MinBytesSent, m.cfg.RestartBackoff, m.cfg.MaxConsecutiveRestarts) + m.cfg.ChecksPerInterval, m.cfg.Interval, tickInterval, m.cfg.MinBytesTransferred, m.cfg.RestartBackoff, m.cfg.MaxConsecutiveRestarts) for { select { @@ -172,7 +194,7 @@ func (m *Monitor) checkDataRate() { } } -// monitoredChannel keeps track of the data-rate for a push channel, and +// monitoredChannel keeps track of the data-rate for a channel, and // restarts the channel if the rate falls below the minimum allowed type monitoredChannel struct { ctx context.Context @@ -182,16 +204,12 @@ type monitoredChannel struct { cfg *Config unsub datatransfer.Unsubscribe onShutdown func(*monitoredChannel) + onDTEvent datatransfer.Subscriber shutdownLk sync.Mutex - statsLk sync.RWMutex - queued uint64 - sent uint64 - dataRatePoints chan *dataRatePoint + restartLk sync.RWMutex + restartedAt time.Time consecutiveRestarts int - - restartLk sync.RWMutex - restartedAt time.Time } func newMonitoredChannel( @@ -199,21 +217,26 @@ func newMonitoredChannel( chid datatransfer.ChannelID, cfg *Config, onShutdown func(*monitoredChannel), + onDTEvent datatransfer.Subscriber, ) *monitoredChannel { ctx, cancel := context.WithCancel(context.Background()) mpc := &monitoredChannel{ - ctx: ctx, - cancel: cancel, - mgr: mgr, - chid: chid, - cfg: cfg, - onShutdown: onShutdown, - dataRatePoints: make(chan *dataRatePoint, cfg.ChecksPerInterval), + ctx: ctx, + cancel: cancel, + mgr: mgr, + chid: chid, + cfg: cfg, + onShutdown: onShutdown, + onDTEvent: onDTEvent, } mpc.start() return mpc } +// Overridden by sub-classes +func (mc *monitoredChannel) checkDataRate() { +} + // Cancel the context and unsubscribe from events func (mc *monitoredChannel) Shutdown() { mc.shutdownLk.Lock() @@ -238,7 +261,7 @@ func (mc *monitoredChannel) start() { mc.shutdownLk.Lock() defer mc.shutdownLk.Unlock() - log.Debugf("%s: starting push channel data-rate monitoring", mc.chid) + log.Debugf("%s: starting channel data-rate monitoring", mc.chid) // Watch to make sure the responder accepts the channel in time cancelAcceptTimer := mc.watchForResponderAccept() @@ -249,13 +272,10 @@ func (mc *monitoredChannel) start() { return } - mc.statsLk.Lock() - defer mc.statsLk.Unlock() - // Once the channel completes, shut down the monitor state := channelState.Status() if channels.IsChannelCleaningUp(state) || channels.IsChannelTerminated(state) { - log.Debugf("%s: stopping push channel data-rate monitoring", mc.chid) + log.Debugf("%s: stopping channel data-rate monitoring", mc.chid) go mc.Shutdown() return } @@ -264,23 +284,20 @@ func (mc *monitoredChannel) start() { case datatransfer.Accept: // The Accept event is fired when we receive an Accept message from the responder cancelAcceptTimer() - case datatransfer.Error: - // If there's an error, attempt to restart the channel - log.Debugf("%s: data transfer error, restarting", mc.chid) + case datatransfer.SendDataError: + // If the transport layer reports an error sending data over the wire, + // attempt to restart the channel + log.Warnf("%s: data transfer transport send error, restarting data transfer", mc.chid) go mc.restartChannel() - case datatransfer.DataQueued: - // Keep track of the amount of data queued - mc.queued = channelState.Queued() - case datatransfer.DataSent: - // Keep track of the amount of data sent - mc.sent = channelState.Sent() - // Some data was sent so reset the consecutive restart counter - mc.consecutiveRestarts = 0 case datatransfer.FinishTransfer: // The client has finished sending all data. Watch to make sure // that the responder sends a message to acknowledge that the // transfer is complete go mc.watchForResponderComplete() + default: + // Delegate to the push channel monitor or pull channel monitor to + // handle the event + mc.onDTEvent(event, channelState) } }) } @@ -326,71 +343,38 @@ func (mc *monitoredChannel) watchForResponderComplete() { } } -type dataRatePoint struct { - pending uint64 - sent uint64 -} - -// check if the amount of data sent in the interval was too low, and if so -// restart the channel -func (mc *monitoredChannel) checkDataRate() { - mc.statsLk.Lock() - defer mc.statsLk.Unlock() - - // Before returning, add the current data rate stats to the queue - defer func() { - var pending uint64 - if mc.queued > mc.sent { // should always be true but just in case - pending = mc.queued - mc.sent - } - mc.dataRatePoints <- &dataRatePoint{ - pending: pending, - sent: mc.sent, - } - }() - - // Check that there are enough data points that an interval has elapsed - if len(mc.dataRatePoints) < int(mc.cfg.ChecksPerInterval) { - log.Debugf("%s: not enough data points to check data rate yet (%d / %d)", - mc.chid, len(mc.dataRatePoints), mc.cfg.ChecksPerInterval) - - return - } - - // Pop the data point from one interval ago - atIntervalStart := <-mc.dataRatePoints +// clear the consecutive restart count (we do this when data is sent or +// received) +func (mc *monitoredChannel) resetConsecutiveRestarts() { + mc.restartLk.Lock() + defer mc.restartLk.Unlock() - // If there was enough pending data to cover the minimum required amount, - // and the amount sent was lower than the minimum required, restart the - // channel - sentInInterval := mc.sent - atIntervalStart.sent - log.Debugf("%s: since last check: sent: %d - %d = %d, pending: %d, required %d", - mc.chid, mc.sent, atIntervalStart.sent, sentInInterval, atIntervalStart.pending, mc.cfg.MinBytesSent) - if atIntervalStart.pending > sentInInterval && sentInInterval < mc.cfg.MinBytesSent { - go mc.restartChannel() - } + mc.consecutiveRestarts = 0 } func (mc *monitoredChannel) restartChannel() { - // Check if the channel is already being restarted + var restartCount int + var restartedAt time.Time mc.restartLk.Lock() - restartedAt := mc.restartedAt - if restartedAt.IsZero() { - mc.restartedAt = time.Now() + { + // If the channel is not already being restarted, record the restart + // time and increment the consecutive restart count + restartedAt = mc.restartedAt + if mc.restartedAt.IsZero() { + mc.restartedAt = time.Now() + mc.consecutiveRestarts++ + restartCount = mc.consecutiveRestarts + } } mc.restartLk.Unlock() + // Check if channel is already being restarted if !restartedAt.IsZero() { log.Debugf("%s: restart called but already restarting channel (for %s so far; restart backoff is %s)", - mc.chid, time.Since(mc.restartedAt), mc.cfg.RestartBackoff) + mc.chid, time.Since(restartedAt), mc.cfg.RestartBackoff) return } - mc.statsLk.Lock() - mc.consecutiveRestarts++ - restartCount := mc.consecutiveRestarts - mc.statsLk.Unlock() - if uint32(restartCount) > mc.cfg.MaxConsecutiveRestarts { // If no data has been transferred since the last transfer, and we've // reached the consecutive restart limit, close the channel and @@ -423,6 +407,8 @@ func (mc *monitoredChannel) restartChannel() { mc.chid, mc.cfg.RestartBackoff) } + // Restart complete, so clear the restart time so that another restart + // can begin mc.restartLk.Lock() mc.restartedAt = time.Time{} mc.restartLk.Unlock() @@ -432,8 +418,171 @@ func (mc *monitoredChannel) closeChannelAndShutdown(cherr error) { log.Errorf("closing data-transfer channel: %s", cherr) err := mc.mgr.CloseDataTransferChannelWithError(mc.ctx, mc.chid, cherr) if err != nil { - log.Errorf("error closing data-transfer channel %s: %w", mc.chid, err) + log.Errorf("error closing data-transfer channel %s: %s", mc.chid, err) } mc.Shutdown() } + +// Snapshot of the pending and sent data at a particular point in time. +// The push channel monitor takes regular snapshots and compares them to +// decide if the data rate has fallen too low. +type dataRatePoint struct { + pending uint64 + sent uint64 +} + +// Keeps track of the data rate for a push channel +type monitoredPushChannel struct { + *monitoredChannel + + statsLk sync.RWMutex + queued uint64 + sent uint64 + dataRatePoints chan *dataRatePoint +} + +func newMonitoredPushChannel( + mgr monitorAPI, + chid datatransfer.ChannelID, + cfg *Config, + onShutdown func(*monitoredChannel), +) *monitoredPushChannel { + mpc := &monitoredPushChannel{ + dataRatePoints: make(chan *dataRatePoint, cfg.ChecksPerInterval), + } + mpc.monitoredChannel = newMonitoredChannel(mgr, chid, cfg, onShutdown, mpc.onDTEvent) + return mpc +} + +// check if the amount of data sent in the interval was too low, and if so +// restart the channel +func (mc *monitoredPushChannel) checkDataRate() { + mc.statsLk.Lock() + defer mc.statsLk.Unlock() + + // Before returning, add the current data rate stats to the queue + defer func() { + var pending uint64 + if mc.queued > mc.sent { // should always be true but just in case + pending = mc.queued - mc.sent + } + mc.dataRatePoints <- &dataRatePoint{ + pending: pending, + sent: mc.sent, + } + }() + + // Check that there are enough data points that an interval has elapsed + if len(mc.dataRatePoints) < int(mc.cfg.ChecksPerInterval) { + log.Debugf("%s: not enough data points to check data rate yet (%d / %d)", + mc.chid, len(mc.dataRatePoints), mc.cfg.ChecksPerInterval) + + return + } + + // Pop the data point from one interval ago + atIntervalStart := <-mc.dataRatePoints + + // If there was enough pending data to cover the minimum required amount, + // and the amount sent was lower than the minimum required, restart the + // channel + sentInInterval := mc.sent - atIntervalStart.sent + log.Debugf("%s: since last check: sent: %d - %d = %d, pending: %d, required %d", + mc.chid, mc.sent, atIntervalStart.sent, sentInInterval, atIntervalStart.pending, mc.cfg.MinBytesTransferred) + if atIntervalStart.pending > sentInInterval && sentInInterval < mc.cfg.MinBytesTransferred { + log.Warnf("%s: data-rate too low, restarting channel: since last check %s ago: sent: %d, required %d", + mc.chid, mc.cfg.Interval, mc.sent, mc.cfg.MinBytesTransferred) + go mc.restartChannel() + } +} + +// Update the queued / sent amount each time it changes +func (mc *monitoredPushChannel) onDTEvent(event datatransfer.Event, channelState datatransfer.ChannelState) { + switch event.Code { + case datatransfer.DataQueued: + // Keep track of the amount of data queued + mc.statsLk.Lock() + mc.queued = channelState.Queued() + mc.statsLk.Unlock() + + case datatransfer.DataSent: + // Keep track of the amount of data sent + mc.statsLk.Lock() + mc.sent = channelState.Sent() + mc.statsLk.Unlock() + + // Some data was sent so reset the consecutive restart counter + mc.resetConsecutiveRestarts() + } +} + +// Keeps track of the data rate for a pull channel +type monitoredPullChannel struct { + *monitoredChannel + + statsLk sync.RWMutex + received uint64 + dataRatePoints chan uint64 +} + +func newMonitoredPullChannel( + mgr monitorAPI, + chid datatransfer.ChannelID, + cfg *Config, + onShutdown func(*monitoredChannel), +) *monitoredPullChannel { + mpc := &monitoredPullChannel{ + dataRatePoints: make(chan uint64, cfg.ChecksPerInterval), + } + mpc.monitoredChannel = newMonitoredChannel(mgr, chid, cfg, onShutdown, mpc.onDTEvent) + return mpc +} + +// check if the amount of data received in the interval was too low, and if so +// restart the channel +func (mc *monitoredPullChannel) checkDataRate() { + mc.statsLk.Lock() + defer mc.statsLk.Unlock() + + // Before returning, add the current data rate stats to the queue + defer func() { + mc.dataRatePoints <- mc.received + }() + + // Check that there are enough data points that an interval has elapsed + if len(mc.dataRatePoints) < int(mc.cfg.ChecksPerInterval) { + log.Debugf("%s: not enough data points to check data rate yet (%d / %d)", + mc.chid, len(mc.dataRatePoints), mc.cfg.ChecksPerInterval) + + return + } + + // Pop the data point from one interval ago + atIntervalStart := <-mc.dataRatePoints + + // If the amount received was lower than the minimum required, restart the + // channel + rcvdInInterval := mc.received - atIntervalStart + log.Debugf("%s: since last check: received: %d - %d = %d, required %d", + mc.chid, mc.received, atIntervalStart, rcvdInInterval, mc.cfg.MinBytesTransferred) + if rcvdInInterval < mc.cfg.MinBytesTransferred { + log.Warnf("%s: data-rate too low, restarting channel: since last check %s ago: received: %d, required %d", + mc.chid, mc.cfg.Interval, mc.received, mc.cfg.MinBytesTransferred) + go mc.restartChannel() + } +} + +// Update the received amount each time it changes +func (mc *monitoredPullChannel) onDTEvent(event datatransfer.Event, channelState datatransfer.ChannelState) { + switch event.Code { + case datatransfer.DataReceived: + // Keep track of the amount of data received + mc.statsLk.Lock() + mc.received = channelState.Received() + mc.statsLk.Unlock() + + // Some data was received so reset the consecutive restart counter + mc.resetConsecutiveRestarts() + } +} diff --git a/pushchannelmonitor/pushchannelmonitor_test.go b/channelmonitor/channelmonitor_test.go similarity index 51% rename from pushchannelmonitor/pushchannelmonitor_test.go rename to channelmonitor/channelmonitor_test.go index c21ac8ab..86105637 100644 --- a/pushchannelmonitor/pushchannelmonitor_test.go +++ b/channelmonitor/channelmonitor_test.go @@ -1,4 +1,4 @@ -package pushchannelmonitor +package channelmonitor import ( "context" @@ -15,6 +15,12 @@ import ( datatransfer "github.com/filecoin-project/go-data-transfer" ) +var ch1 = datatransfer.ChannelID{ + Initiator: "initiator", + Responder: "responder", + ID: 1, +} + func TestPushChannelMonitorAutoRestart(t *testing.T) { type testCase struct { name string @@ -47,11 +53,6 @@ func TestPushChannelMonitorAutoRestart(t *testing.T) { errorEvent: true, }} - ch1 := datatransfer.ChannelID{ - Initiator: "initiator", - Responder: "responder", - ID: 1, - } for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { ch := &mockChannelState{chid: ch1} @@ -61,29 +62,34 @@ func TestPushChannelMonitorAutoRestart(t *testing.T) { AcceptTimeout: time.Hour, Interval: 10 * time.Millisecond, ChecksPerInterval: 10, - MinBytesSent: 1, + MinBytesTransferred: 1, MaxConsecutiveRestarts: 3, CompleteTimeout: time.Hour, }) m.Start() - m.AddChannel(ch1) - mch := getFirstMonitoredChannel(m) + mch := m.AddPushChannel(ch1).(*monitoredPushChannel) + // Simulate the responder sending Accept mockAPI.accept() + + // Simulate data being queued and sent + // If sent - queued > MinBytesTransferred it should cause a restart mockAPI.dataQueued(tc.dataQueued) mockAPI.dataSent(tc.dataSent) + if tc.errorEvent { - mockAPI.errorEvent() + // Fire an error event, should cause a restart + mockAPI.sendDataErrorEvent() } if tc.errOnRestart { - // If there is no recovery from restart, wait for the push - // channel to be closed + // If there is an error attempting to restart, just wait for + // the push channel to be closed <-mockAPI.closed return } - // Verify that channel was restarted + // Verify that channel is restarted within interval select { case <-time.After(100 * time.Millisecond): require.Fail(t, "failed to restart channel") @@ -100,7 +106,87 @@ func TestPushChannelMonitorAutoRestart(t *testing.T) { mockAPI.completed() // Verify that channel has been shutdown - verifyChannelShutdown(t, mch) + verifyChannelShutdown(t, mch.ctx) + }) + } +} + +func TestPullChannelMonitorAutoRestart(t *testing.T) { + type testCase struct { + name string + errOnRestart bool + dataRcvd uint64 + errorEvent bool + } + testCases := []testCase{{ + name: "attempt restart", + errOnRestart: false, + dataRcvd: 10, + }, { + name: "fail attempt restart", + errOnRestart: true, + dataRcvd: 10, + }, { + name: "error event", + errOnRestart: false, + dataRcvd: 10, + errorEvent: true, + }, { + name: "error event then fail attempt restart", + errOnRestart: true, + dataRcvd: 10, + errorEvent: true, + }} + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + ch := &mockChannelState{chid: ch1} + mockAPI := newMockMonitorAPI(ch, tc.errOnRestart) + + m := NewMonitor(mockAPI, &Config{ + AcceptTimeout: time.Hour, + Interval: 10 * time.Millisecond, + ChecksPerInterval: 10, + MinBytesTransferred: 1, + MaxConsecutiveRestarts: 3, + CompleteTimeout: time.Hour, + }) + m.Start() + mch := m.AddPullChannel(ch1).(*monitoredPullChannel) + + // Simulate the responder sending Accept + mockAPI.accept() + + // Simulate receiving some data + mockAPI.dataReceived(tc.dataRcvd) + + if tc.errorEvent { + // Fire an error event, should cause a restart + mockAPI.sendDataErrorEvent() + } + + if tc.errOnRestart { + // If there is an error attempting to restart, just wait for + // the pull channel to be closed + <-mockAPI.closed + return + } + + // Verify that channel is restarted within interval + select { + case <-time.After(100 * time.Millisecond): + require.Fail(t, "failed to restart channel") + case <-mockAPI.restarts: + } + + // Simulate sending more data + mockAPI.dataSent(tc.dataRcvd) + + // Simulate the complete event + mockAPI.completed() + + // Verify that channel has been shutdown + verifyChannelShutdown(t, mch.ctx) }) } } @@ -117,7 +203,7 @@ func TestPushChannelMonitorDataRate(t *testing.T) { expectRestart bool } testCases := []testCase{{ - name: "restart when min sent (1) < pending (10)", + name: "restart when sent (10) < pending (20)", minBytesSent: 1, dataPoints: []dataPoint{{ queued: 20, @@ -125,7 +211,7 @@ func TestPushChannelMonitorDataRate(t *testing.T) { }}, expectRestart: true, }, { - name: "dont restart when min sent (20) >= pending (10)", + name: "dont restart when sent (20) >= pending (10)", minBytesSent: 1, dataPoints: []dataPoint{{ queued: 20, @@ -136,7 +222,7 @@ func TestPushChannelMonitorDataRate(t *testing.T) { }}, expectRestart: false, }, { - name: "restart when min sent (5) < pending (10)", + name: "restart when sent (5) < pending (10)", minBytesSent: 10, dataPoints: []dataPoint{{ queued: 20, @@ -213,11 +299,6 @@ func TestPushChannelMonitorDataRate(t *testing.T) { expectRestart: false, }} - ch1 := datatransfer.ChannelID{ - Initiator: "initiator", - Responder: "responder", - ID: 1, - } for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { ch := &mockChannelState{chid: ch1} @@ -228,14 +309,14 @@ func TestPushChannelMonitorDataRate(t *testing.T) { AcceptTimeout: time.Hour, Interval: time.Hour, ChecksPerInterval: checksPerInterval, - MinBytesSent: tc.minBytesSent, + MinBytesTransferred: tc.minBytesSent, MaxConsecutiveRestarts: 3, CompleteTimeout: time.Hour, }) // Note: Don't start monitor, we'll call checkDataRate() manually - m.AddChannel(ch1) + m.AddPushChannel(ch1) totalChecks := checksPerInterval + uint32(len(tc.dataPoints)) for i := uint32(0); i < totalChecks; i++ { @@ -262,185 +343,265 @@ func TestPushChannelMonitorDataRate(t *testing.T) { } } -func TestPushChannelMonitorMaxConsecutiveRestarts(t *testing.T) { - ch1 := datatransfer.ChannelID{ - Initiator: "initiator", - Responder: "responder", - ID: 1, - } - ch := &mockChannelState{chid: ch1} - mockAPI := newMockMonitorAPI(ch, false) - - maxConsecutiveRestarts := 3 - m := NewMonitor(mockAPI, &Config{ - AcceptTimeout: time.Hour, - Interval: time.Hour, - ChecksPerInterval: 1, - MinBytesSent: 2, - MaxConsecutiveRestarts: uint32(maxConsecutiveRestarts), - CompleteTimeout: time.Hour, - }) - - // Note: Don't start monitor, we'll call checkDataRate() manually - - m.AddChannel(ch1) - mch := getFirstMonitoredChannel(m) - - mockAPI.dataQueued(10) - mockAPI.dataSent(5) - - // Check once to add a data point to the queue. - // Subsequent checks will compare against the previous data point. - m.checkDataRate() - - // Each check should trigger a restart up to the maximum number of restarts - triggerMaxRestarts := func() { - for i := 0; i < maxConsecutiveRestarts; i++ { - m.checkDataRate() - - err := mockAPI.awaitRestart() - require.NoError(t, err) - } - } - triggerMaxRestarts() - - // When data is sent it should reset the consecutive restarts back to zero - mockAPI.dataSent(6) - - // Trigger restarts up to max again - triggerMaxRestarts() - - // Reached max restarts, so now there should not be another restart - // attempt. - // Instead the channel should be closed and the monitor shut down. - m.checkDataRate() - err := mockAPI.awaitRestart() - require.Error(t, err) // require error because expecting no restart - verifyChannelShutdown(t, mch) -} - -func TestPushChannelMonitorTimeouts(t *testing.T) { +func TestPullChannelMonitorDataRate(t *testing.T) { type testCase struct { - name string - expectAccept bool - expectComplete bool + name string + minBytesTransferred uint64 + dataPoints []uint64 + expectRestart bool } testCases := []testCase{{ - name: "accept in time", - expectAccept: true, + name: "restart when received (5) < min required (10)", + minBytesTransferred: 10, + dataPoints: []uint64{10, 15}, + expectRestart: true, }, { - name: "accept too late", - expectAccept: false, - }, { - name: "complete in time", - expectAccept: true, - expectComplete: true, + name: "dont restart when received (5) > min required (1)", + minBytesTransferred: 1, + dataPoints: []uint64{10, 15}, + expectRestart: false, }, { - name: "complete too late", - expectAccept: true, - expectComplete: true, + name: "dont restart with typical progression", + minBytesTransferred: 1, + dataPoints: []uint64{10, 20, 30, 40}, + expectRestart: false, }} - ch1 := datatransfer.ChannelID{ - Initiator: "initiator", - Responder: "responder", - ID: 1, - } for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { ch := &mockChannelState{chid: ch1} mockAPI := newMockMonitorAPI(ch, false) - verifyClosedAndShutdown := func(mch *monitoredChannel, timeout time.Duration) { - // Verify channel has been closed - select { - case <-time.After(timeout): - require.Fail(t, "failed to close channel") - case <-mockAPI.closed: - } + checksPerInterval := uint32(1) + m := NewMonitor(mockAPI, &Config{ + AcceptTimeout: time.Hour, + Interval: time.Hour, + ChecksPerInterval: checksPerInterval, + MinBytesTransferred: tc.minBytesTransferred, + MaxConsecutiveRestarts: 3, + CompleteTimeout: time.Hour, + }) + + // Note: Don't start monitor, we'll call checkDataRate() manually + + m.AddPullChannel(ch1) - // Verify that channel has been shutdown - verifyChannelShutdown(t, mch) + totalChecks := uint32(len(tc.dataPoints)) + for i := uint32(0); i < totalChecks; i++ { + if i < uint32(len(tc.dataPoints)) { + rcvd := tc.dataPoints[i] + mockAPI.dataReceived(rcvd) + } + m.checkDataRate() } - verifyNotClosed := func(timeout time.Duration) { - // Verify channel has not been closed - select { - case <-time.After(timeout): - case <-mockAPI.closed: - require.Fail(t, "expected channel not to have been closed") + // Check if channel was restarted + select { + case <-time.After(5 * time.Millisecond): + if tc.expectRestart { + require.Fail(t, "failed to restart channel") + } + case <-mockAPI.restarts: + if !tc.expectRestart { + require.Fail(t, "expected no channel restart") } } + }) + } +} + +func TestChannelMonitorMaxConsecutiveRestarts(t *testing.T) { + runTest := func(name string, isPush bool) { + t.Run(name, func(t *testing.T) { + ch := &mockChannelState{chid: ch1} + mockAPI := newMockMonitorAPI(ch, false) - acceptTimeout := 10 * time.Millisecond - completeTimeout := 10 * time.Millisecond + maxConsecutiveRestarts := 3 m := NewMonitor(mockAPI, &Config{ - AcceptTimeout: acceptTimeout, + AcceptTimeout: time.Hour, Interval: time.Hour, ChecksPerInterval: 1, - MinBytesSent: 1, - MaxConsecutiveRestarts: 1, - CompleteTimeout: completeTimeout, + MinBytesTransferred: 2, + MaxConsecutiveRestarts: uint32(maxConsecutiveRestarts), + CompleteTimeout: time.Hour, }) - m.Start() - m.AddChannel(ch1) - mch := getFirstMonitoredChannel(m) - if tc.expectAccept { - // Fire the Accept event - mockAPI.accept() - } + // Note: Don't start monitor, we'll call checkDataRate() manually - if !tc.expectAccept { - // If we are expecting the test to have a timeout waiting for - // the accept event verify that channel was closed (because a - // timeout error occurred) - verifyClosedAndShutdown(mch, 5*acceptTimeout) - return + var chanCtx context.Context + if isPush { + mch := m.AddPushChannel(ch1).(*monitoredPushChannel) + chanCtx = mch.ctx + + mockAPI.dataQueued(10) + mockAPI.dataSent(5) + } else { + mch := m.AddPullChannel(ch1).(*monitoredPullChannel) + chanCtx = mch.ctx + + mockAPI.dataReceived(5) } - // If we're not expecting the test to have a timeout waiting for - // the accept event, verify that channel was not closed - verifyNotClosed(2 * acceptTimeout) + // Check once to add a data point to the queue. + // Subsequent checks will compare against the previous data point. + m.checkDataRate() + + // Each check should trigger a restart up to the maximum number of restarts + triggerMaxRestarts := func() { + for i := 0; i < maxConsecutiveRestarts; i++ { + m.checkDataRate() - // Fire the FinishTransfer event - mockAPI.finishTransfer() - if tc.expectComplete { - // Fire the Complete event - mockAPI.completed() + err := mockAPI.awaitRestart() + require.NoError(t, err) + } } + triggerMaxRestarts() - if !tc.expectComplete { - // If we are expecting the test to have a timeout waiting for - // the complete event verify that channel was closed (because a - // timeout error occurred) - verifyClosedAndShutdown(mch, 5*completeTimeout) - return + // When data is transferred it should reset the consecutive restarts back to zero + if isPush { + mockAPI.dataSent(6) + } else { + mockAPI.dataReceived(5) } - // If we're not expecting the test to have a timeout waiting for - // the accept event, verify that channel was not closed - verifyNotClosed(2 * completeTimeout) + // Trigger restarts up to max again + triggerMaxRestarts() + + // Reached max restarts, so now there should not be another restart + // attempt. + // Instead the channel should be closed and the monitor shut down. + m.checkDataRate() + err := mockAPI.awaitRestart() + require.Error(t, err) // require error because expecting no restart + verifyChannelShutdown(t, chanCtx) }) } + + // test push channel + runTest("push", true) + // test pull channel + runTest("pull", false) } -func getFirstMonitoredChannel(m *Monitor) *monitoredChannel { - m.lk.Lock() - defer m.lk.Unlock() +func TestChannelMonitorTimeouts(t *testing.T) { + type testCase struct { + name string + expectAccept bool + expectComplete bool + } + testCases := []testCase{{ + name: "accept in time", + expectAccept: true, + expectComplete: true, + }, { + name: "accept too late", + expectAccept: false, + }, { + name: "complete in time", + expectAccept: true, + expectComplete: true, + }, { + name: "complete too late", + expectAccept: true, + expectComplete: false, + }} + + runTest := func(name string, isPush bool) { + for _, tc := range testCases { + t.Run(name+": "+tc.name, func(t *testing.T) { + ch := &mockChannelState{chid: ch1} + mockAPI := newMockMonitorAPI(ch, false) + + verifyClosedAndShutdown := func(chCtx context.Context, timeout time.Duration) { + // Verify channel has been closed + select { + case <-time.After(timeout): + require.Fail(t, "failed to close channel within "+timeout.String()) + case <-mockAPI.closed: + } + + // Verify that channel has been shutdown + verifyChannelShutdown(t, chCtx) + } - var mch *monitoredChannel - for mch = range m.channels { - return mch + verifyNotClosed := func(timeout time.Duration) { + // Verify channel has not been closed + select { + case <-time.After(timeout): + case <-mockAPI.closed: + require.Fail(t, "expected channel not to have been closed") + } + } + + acceptTimeout := 10 * time.Millisecond + completeTimeout := 10 * time.Millisecond + m := NewMonitor(mockAPI, &Config{ + AcceptTimeout: acceptTimeout, + Interval: time.Hour, + ChecksPerInterval: 1, + MinBytesTransferred: 1, + MaxConsecutiveRestarts: 1, + CompleteTimeout: completeTimeout, + }) + m.Start() + + var chCtx context.Context + if isPush { + mch := m.AddPushChannel(ch1).(*monitoredPushChannel) + chCtx = mch.ctx + } else { + mch := m.AddPushChannel(ch1).(*monitoredPushChannel) + chCtx = mch.ctx + } + + if tc.expectAccept { + // Fire the Accept event + mockAPI.accept() + } else { + // If we are expecting the test to have a timeout waiting for + // the accept event, verify that channel was closed (because a + // timeout error occurred) + verifyClosedAndShutdown(chCtx, 5*acceptTimeout) + return + } + + // If we're not expecting the test to have a timeout waiting for + // the accept event, verify that channel was not closed + verifyNotClosed(2 * acceptTimeout) + + // Fire the FinishTransfer event + mockAPI.finishTransfer() + if tc.expectComplete { + // Fire the Complete event + mockAPI.completed() + } + + if !tc.expectComplete { + // If we are expecting the test to have a timeout waiting for + // the complete event verify that channel was closed (because a + // timeout error occurred) + verifyClosedAndShutdown(chCtx, 5*completeTimeout) + return + } + + // If we're not expecting the test to have a timeout waiting for + // the accept event, verify that channel was not closed + verifyNotClosed(2 * completeTimeout) + }) + } } - panic("no channels") + + // test push channel + runTest("push", true) + // test pull channel + runTest("pull", false) } -func verifyChannelShutdown(t *testing.T, mch *monitoredChannel) { +func verifyChannelShutdown(t *testing.T, shutdownCtx context.Context) { select { case <-time.After(10 * time.Millisecond): require.Fail(t, "failed to shutdown channel") - case <-mch.ctx.Done(): + case <-shutdownCtx.Done(): } } @@ -528,6 +689,11 @@ func (m *mockMonitorAPI) dataSent(n uint64) { m.callSubscriber(datatransfer.Event{Code: datatransfer.DataSent}, m.ch) } +func (m *mockMonitorAPI) dataReceived(n uint64) { + m.ch.received = n + m.callSubscriber(datatransfer.Event{Code: datatransfer.DataReceived}, m.ch) +} + func (m *mockMonitorAPI) finishTransfer() { m.callSubscriber(datatransfer.Event{Code: datatransfer.FinishTransfer}, m.ch) } @@ -537,14 +703,15 @@ func (m *mockMonitorAPI) completed() { m.callSubscriber(datatransfer.Event{Code: datatransfer.Complete}, m.ch) } -func (m *mockMonitorAPI) errorEvent() { - m.callSubscriber(datatransfer.Event{Code: datatransfer.Error}, m.ch) +func (m *mockMonitorAPI) sendDataErrorEvent() { + m.callSubscriber(datatransfer.Event{Code: datatransfer.SendDataError}, m.ch) } type mockChannelState struct { chid datatransfer.ChannelID queued uint64 sent uint64 + received uint64 complete bool } @@ -556,6 +723,10 @@ func (m *mockChannelState) Sent() uint64 { return m.sent } +func (m *mockChannelState) Received() uint64 { + return m.received +} + func (m *mockChannelState) ChannelID() datatransfer.ChannelID { return m.chid } @@ -607,10 +778,6 @@ func (m *mockChannelState) SelfPeer() peer.ID { panic("implement me") } -func (m *mockChannelState) Received() uint64 { - panic("implement me") -} - func (m *mockChannelState) Message() string { panic("implement me") } diff --git a/channels/channels.go b/channels/channels.go index d16a6524..6e4116f2 100644 --- a/channels/channels.go +++ b/channels/channels.go @@ -315,8 +315,22 @@ func (c *Channels) Error(chid datatransfer.ChannelID, err error) error { return c.send(chid, datatransfer.Error, err) } -func (c *Channels) Disconnected(chid datatransfer.ChannelID) error { - return c.send(chid, datatransfer.Disconnected) +// Disconnected indicates that the connection went down and it was not possible +// to restart it +func (c *Channels) Disconnected(chid datatransfer.ChannelID, err error) error { + return c.send(chid, datatransfer.Disconnected, err) +} + +// RequestTimedOut indicates that the transport layer had a timeout trying to +// make a request +func (c *Channels) RequestTimedOut(chid datatransfer.ChannelID, err error) error { + return c.send(chid, datatransfer.RequestTimedOut, err) +} + +// SendDataError indicates that the transport layer had an error trying +// to send data to the remote peer +func (c *Channels) SendDataError(chid datatransfer.ChannelID, err error) error { + return c.send(chid, datatransfer.SendDataError, err) } // HasChannel returns true if the given channel id is being tracked diff --git a/channels/channels_fsm.go b/channels/channels_fsm.go index e16aaecc..731d6d45 100644 --- a/channels/channels_fsm.go +++ b/channels/channels_fsm.go @@ -24,8 +24,12 @@ var transferringStates = []fsm.StateKey{ // ChannelEvents describe the events taht can var ChannelEvents = fsm.Events{ + // Open a channel fsm.Event(datatransfer.Open).FromAny().To(datatransfer.Requested), + + // Remote peer has accepted the Open channel request fsm.Event(datatransfer.Accept).From(datatransfer.Requested).To(datatransfer.Ongoing), + fsm.Event(datatransfer.Restart).FromAny().ToNoChange().Action(func(chst *internal.ChannelState) error { chst.Message = "" return nil @@ -52,8 +56,19 @@ var ChannelEvents = fsm.Events{ chst.Queued += delta return nil }), - fsm.Event(datatransfer.Disconnected).FromAny().ToNoChange().Action(func(chst *internal.ChannelState) error { - chst.Message = datatransfer.ErrDisconnected.Error() + + fsm.Event(datatransfer.Disconnected).FromAny().ToNoChange().Action(func(chst *internal.ChannelState, err error) error { + chst.Message = err.Error() + return nil + }), + + fsm.Event(datatransfer.SendDataError).FromAny().ToNoChange().Action(func(chst *internal.ChannelState, err error) error { + chst.Message = err.Error() + return nil + }), + + fsm.Event(datatransfer.RequestTimedOut).FromAny().ToNoChange().Action(func(chst *internal.ChannelState, err error) error { + chst.Message = err.Error() return nil }), @@ -61,6 +76,7 @@ var ChannelEvents = fsm.Events{ chst.Message = err.Error() return nil }), + fsm.Event(datatransfer.NewVoucher).FromAny().ToNoChange(). Action(func(chst *internal.ChannelState, vtype datatransfer.TypeIdentifier, voucherBytes []byte) error { chst.Vouchers = append(chst.Vouchers, internal.EncodedVoucher{Type: vtype, Voucher: &cbg.Deferred{Raw: voucherBytes}}) @@ -72,32 +88,41 @@ var ChannelEvents = fsm.Events{ internal.EncodedVoucherResult{Type: vtype, VoucherResult: &cbg.Deferred{Raw: voucherResultBytes}}) return nil }), + fsm.Event(datatransfer.PauseInitiator). FromMany(datatransfer.Requested, datatransfer.Ongoing).To(datatransfer.InitiatorPaused). From(datatransfer.ResponderPaused).To(datatransfer.BothPaused). FromAny().ToJustRecord(), + fsm.Event(datatransfer.PauseResponder). FromMany(datatransfer.Requested, datatransfer.Ongoing).To(datatransfer.ResponderPaused). From(datatransfer.InitiatorPaused).To(datatransfer.BothPaused). FromAny().ToJustRecord(), + fsm.Event(datatransfer.ResumeInitiator). From(datatransfer.InitiatorPaused).To(datatransfer.Ongoing). From(datatransfer.BothPaused).To(datatransfer.ResponderPaused). FromAny().ToJustRecord(), + fsm.Event(datatransfer.ResumeResponder). From(datatransfer.ResponderPaused).To(datatransfer.Ongoing). From(datatransfer.BothPaused).To(datatransfer.InitiatorPaused). From(datatransfer.Finalizing).To(datatransfer.Completing). FromAny().ToJustRecord(), + + // The transfer has finished on the local node - all data was sent / received fsm.Event(datatransfer.FinishTransfer). FromAny().To(datatransfer.TransferFinished). FromMany(datatransfer.Failing, datatransfer.Cancelling).ToJustRecord(). From(datatransfer.ResponderCompleted).To(datatransfer.Completing). From(datatransfer.ResponderFinalizing).To(datatransfer.ResponderFinalizingTransferFinished), + fsm.Event(datatransfer.ResponderBeginsFinalization). FromAny().To(datatransfer.ResponderFinalizing). FromMany(datatransfer.Failing, datatransfer.Cancelling).ToJustRecord(). From(datatransfer.TransferFinished).To(datatransfer.ResponderFinalizingTransferFinished), + + // The remote peer sent a Complete message, meaning it has sent / received all data fsm.Event(datatransfer.ResponderCompletes). FromAny().To(datatransfer.ResponderCompleted). FromMany(datatransfer.Failing, datatransfer.Cancelling).ToJustRecord(). @@ -105,8 +130,12 @@ var ChannelEvents = fsm.Events{ From(datatransfer.TransferFinished).To(datatransfer.Completing). From(datatransfer.ResponderFinalizing).To(datatransfer.ResponderCompleted). From(datatransfer.ResponderFinalizingTransferFinished).To(datatransfer.Completing), + fsm.Event(datatransfer.BeginFinalizing).FromAny().To(datatransfer.Finalizing), + + // Both the local node and the remote peer have completed the transfer fsm.Event(datatransfer.Complete).FromAny().To(datatransfer.Completing), + fsm.Event(datatransfer.CleanupComplete). From(datatransfer.Cancelling).To(datatransfer.Cancelled). From(datatransfer.Failing).To(datatransfer.Failed). diff --git a/channels/channels_test.go b/channels/channels_test.go index a34cd7a2..133d745b 100644 --- a/channels/channels_test.go +++ b/channels/channels_test.go @@ -11,6 +11,7 @@ import ( "github.com/ipfs/go-cid" "github.com/ipfs/go-datastore" + dss "github.com/ipfs/go-datastore/sync" "github.com/ipld/go-ipld-prime/codec/dagcbor" basicnode "github.com/ipld/go-ipld-prime/node/basic" "github.com/ipld/go-ipld-prime/traversal/selector/builder" @@ -37,7 +38,7 @@ func TestChannels(t *testing.T) { ctx := context.Background() ctx, cancel := context.WithTimeout(ctx, 2*time.Second) defer cancel() - ds := datastore.NewMapDatastore() + ds := dss.MutexWrap(datastore.NewMapDatastore()) received := make(chan event) notifier := func(evt datatransfer.Event, chst datatransfer.ChannelState) { received <- event{evt, chst} @@ -127,7 +128,7 @@ func TestChannels(t *testing.T) { }) t.Run("updating send/receive values", func(t *testing.T) { - ds := datastore.NewMapDatastore() + ds := dss.MutexWrap(datastore.NewMapDatastore()) dir := os.TempDir() cidLists, err := cidlists.NewCIDLists(dir) require.NoError(t, err) @@ -302,7 +303,7 @@ func TestChannels(t *testing.T) { }) t.Run("test disconnected", func(t *testing.T) { - ds := datastore.NewMapDatastore() + ds := dss.MutexWrap(datastore.NewMapDatastore()) received := make(chan event) notifier := func(evt datatransfer.Event, chst datatransfer.ChannelState) { received <- event{evt, chst} @@ -320,10 +321,11 @@ func TestChannels(t *testing.T) { state := checkEvent(ctx, t, received, datatransfer.Open) require.Equal(t, datatransfer.Requested, state.Status()) - err = channelList.Disconnected(chid) + disconnectErr := xerrors.Errorf("disconnected") + err = channelList.Disconnected(chid, disconnectErr) require.NoError(t, err) state = checkEvent(ctx, t, received, datatransfer.Disconnected) - require.Equal(t, datatransfer.ErrDisconnected.Error(), state.Message()) + require.Equal(t, disconnectErr.Error(), state.Message()) }) t.Run("test self peer and other peer", func(t *testing.T) { @@ -364,7 +366,7 @@ func TestMigrationsV0(t *testing.T) { ctx, cancel := context.WithTimeout(ctx, 2*time.Second) defer cancel() - ds := datastore.NewMapDatastore() + ds := dss.MutexWrap(datastore.NewMapDatastore()) received := make(chan event) notifier := func(evt datatransfer.Event, chst datatransfer.ChannelState) { received <- event{evt, chst} @@ -484,7 +486,7 @@ func TestMigrationsV1(t *testing.T) { ctx, cancel := context.WithTimeout(ctx, 2*time.Second) defer cancel() - ds := datastore.NewMapDatastore() + ds := dss.MutexWrap(datastore.NewMapDatastore()) received := make(chan event) notifier := func(evt datatransfer.Event, chst datatransfer.ChannelState) { received <- event{evt, chst} diff --git a/errors.go b/errors.go index f50bcbfc..0e9903f6 100644 --- a/errors.go +++ b/errors.go @@ -30,9 +30,3 @@ const ErrRejected = errorType("response rejected") // ErrUnsupported indicates an operation is not supported by the transport protocol const ErrUnsupported = errorType("unsupported") - -// ErrDisconnected indicates the other peer may have hung up and you should try restarting the channel. -const ErrDisconnected = errorType("other peer appears to have hung up. restart Channel") - -// ErrRemoved indicates the channel was inactive long enough that it was put in a permaneant error state -const ErrRemoved = errorType("channel removed due to inactivity") diff --git a/events.go b/events.go index c4216b34..a71bf5a6 100644 --- a/events.go +++ b/events.go @@ -90,6 +90,14 @@ const ( // the remote peer. It is used to measure progress of how much of the total // data has been received. DataReceivedProgress + + // RequestTimedOut indicates that the transport layer had a timeout trying to + // make a request + RequestTimedOut + + // SendDataError indicates that the transport layer had an error trying + // to send data to the remote peer + SendDataError ) // Events are human readable names for data transfer events diff --git a/impl/environment.go b/impl/environment.go index f1a71155..81671717 100644 --- a/impl/environment.go +++ b/impl/environment.go @@ -23,8 +23,5 @@ func (ce *channelEnvironment) ID() peer.ID { } func (ce *channelEnvironment) CleanupChannel(chid datatransfer.ChannelID) { - ce.m.reconnectsLk.Lock() - delete(ce.m.reconnects, chid) - ce.m.reconnectsLk.Unlock() ce.m.transport.CleanupChannel(chid) } diff --git a/impl/events.go b/impl/events.go index 1c207bf6..fe49b57d 100644 --- a/impl/events.go +++ b/impl/events.go @@ -3,7 +3,6 @@ package impl import ( "context" "errors" - "time" "github.com/ipfs/go-cid" "github.com/ipld/go-ipld-prime" @@ -12,7 +11,6 @@ import ( "golang.org/x/xerrors" datatransfer "github.com/filecoin-project/go-data-transfer" - "github.com/filecoin-project/go-data-transfer/channels" "github.com/filecoin-project/go-data-transfer/encoding" "github.com/filecoin-project/go-data-transfer/registry" ) @@ -36,19 +34,6 @@ func (m *manager) OnDataReceived(chid datatransfer.ChannelID, link ipld.Link, si return err } - m.reconnectsLk.RLock() - reconnect, ok := m.reconnects[chid] - var alreadyReconnected bool - select { - case <-reconnect: - alreadyReconnected = true - default: - } - if ok && !alreadyReconnected { - close(reconnect) - } - m.reconnectsLk.RUnlock() - if chid.Initiator != m.peerID { var result datatransfer.VoucherResult var err error @@ -99,18 +84,6 @@ func (m *manager) OnDataQueued(chid datatransfer.ChannelID, link ipld.Link, size } func (m *manager) OnDataSent(chid datatransfer.ChannelID, link ipld.Link, size uint64) error { - m.reconnectsLk.RLock() - reconnect, ok := m.reconnects[chid] - var alreadyReconnected bool - select { - case <-reconnect: - alreadyReconnected = true - default: - } - if ok && !alreadyReconnected { - close(reconnect) - } - m.reconnectsLk.RUnlock() return m.channels.DataSent(chid, link.(cidlink.Link).Cid, size) } @@ -201,92 +174,26 @@ func (m *manager) OnResponseReceived(chid datatransfer.ChannelID, response datat return m.resumeOther(chid) } -func (m *manager) OnRequestTimedOut(ctx context.Context, chid datatransfer.ChannelID) error { - log.Warnf("channel %+v has timed out", chid) - - m.reconnectsLk.Lock() - reconnect, ok := m.reconnects[chid] - var alreadyReconnected bool - select { - case <-reconnect: - alreadyReconnected = true - default: - } - if !ok || alreadyReconnected { - reconnect = make(chan struct{}) - m.reconnects[chid] = reconnect - } - m.reconnectsLk.Unlock() - timer := time.NewTimer(m.channelRemoveTimeout) - - go func() { - select { - case <-ctx.Done(): - case <-reconnect: - case <-timer.C: - channel, err := m.channels.GetByID(ctx, chid) - if err == nil { - if !(channels.IsChannelTerminated(channel.Status()) || - channels.IsChannelCleaningUp(channel.Status())) { - if err := m.channels.Error(chid, datatransfer.ErrRemoved); err != nil { - log.Errorf("failed to cancel timed-out channel: %v", err) - return - } - log.Warnf("channel %+v has ben cancelled because of timeout", chid) - } - } - } - }() - - return nil +func (m *manager) OnRequestTimedOut(chid datatransfer.ChannelID, err error) error { + log.Warnf("channel %+v has timed out: %s", chid, err) + return m.channels.RequestTimedOut(chid, err) } -func (m *manager) OnRequestDisconnected(ctx context.Context, chid datatransfer.ChannelID) error { - log.Warnf("channel %+v has stalled or disconnected", chid) - - // mark peer disconnected for informational purposes - err := m.channels.Disconnected(chid) - if err != nil { - return err - } - - m.reconnectsLk.Lock() - reconnect, ok := m.reconnects[chid] - var alreadyReconnected bool - select { - case <-reconnect: - alreadyReconnected = true - default: - } - if !ok || alreadyReconnected { - reconnect = make(chan struct{}) - m.reconnects[chid] = reconnect - } - m.reconnectsLk.Unlock() - timer := time.NewTimer(m.channelRemoveTimeout) - go func() { - select { - case <-ctx.Done(): - case <-reconnect: - case <-timer.C: - channel, err := m.channels.GetByID(ctx, chid) - if err == nil { - if !(channels.IsChannelTerminated(channel.Status()) || - channels.IsChannelCleaningUp(channel.Status())) { - if err := m.channels.Error(chid, datatransfer.ErrRemoved); err != nil { - log.Errorf("failed to cancel timed-out channel: %v", err) - return - } - log.Warnf("channel %+v has ben cancelled because of timeout", chid) - } - } - } - }() +func (m *manager) OnRequestDisconnected(chid datatransfer.ChannelID, err error) error { + log.Warnf("channel %+v has stalled or disconnected: %s", chid, err) + return m.channels.Disconnected(chid, err) +} - return nil +func (m *manager) OnSendDataError(chid datatransfer.ChannelID, err error) error { + log.Warnf("channel %+v had transport send error: %s", chid, err) + return m.channels.SendDataError(chid, err) } +// OnChannelCompleted is called +// - by the requester when all data for a transfer has been received +// - by the responder when all data for a transfer has been sent func (m *manager) OnChannelCompleted(chid datatransfer.ChannelID, completeErr error) error { + // If the channel completed successfully if completeErr == nil { // If the channel was initiated by the other peer if chid.Initiator != m.peerID { @@ -298,8 +205,9 @@ func (m *manager) OnChannelCompleted(chid datatransfer.ChannelID, completeErr er // Send the other peer a message that the transfer has completed log.Infof("channel %s: sending completion message to initiator", chid) if err := m.dataTransferNetwork.SendMessage(context.Background(), chid.Initiator, msg); err != nil { - log.Warnf("channel %s: failed to send completion message to initiator: %s", chid, err) - return m.OnRequestDisconnected(context.TODO(), chid) + err := xerrors.Errorf("channel %s: failed to send completion message to initiator: %w", chid, err) + log.Warn(err) + return m.OnRequestDisconnected(chid, err) } } if msg.Accepted() { @@ -315,6 +223,8 @@ func (m *manager) OnChannelCompleted(chid datatransfer.ChannelID, completeErr er log.Infof("channel %s: transfer initiated by local node is complete", chid) return m.channels.FinishTransfer(chid) } + + // There was an error so fire an Error event chst, err := m.channels.GetByID(context.TODO(), chid) if err != nil { return err diff --git a/impl/impl.go b/impl/impl.go index de1f7b63..c1b6788f 100644 --- a/impl/impl.go +++ b/impl/impl.go @@ -4,8 +4,6 @@ import ( "context" "errors" "fmt" - "sync" - "time" "github.com/hannahhoward/go-pubsub" "github.com/ipfs/go-cid" @@ -19,35 +17,32 @@ import ( "github.com/filecoin-project/go-storedcounter" datatransfer "github.com/filecoin-project/go-data-transfer" + "github.com/filecoin-project/go-data-transfer/channelmonitor" "github.com/filecoin-project/go-data-transfer/channels" "github.com/filecoin-project/go-data-transfer/cidlists" "github.com/filecoin-project/go-data-transfer/encoding" "github.com/filecoin-project/go-data-transfer/message" "github.com/filecoin-project/go-data-transfer/network" - "github.com/filecoin-project/go-data-transfer/pushchannelmonitor" "github.com/filecoin-project/go-data-transfer/registry" ) var log = logging.Logger("dt-impl") type manager struct { - dataTransferNetwork network.DataTransferNetwork - validatedTypes *registry.Registry - resultTypes *registry.Registry - revalidators *registry.Registry - transportConfigurers *registry.Registry - pubSub *pubsub.PubSub - readySub *pubsub.PubSub - channels *channels.Channels - peerID peer.ID - transport datatransfer.Transport - storedCounter *storedcounter.StoredCounter - channelRemoveTimeout time.Duration - reconnectsLk sync.RWMutex - reconnects map[datatransfer.ChannelID]chan struct{} - cidLists cidlists.CIDLists - pushChannelMonitor *pushchannelmonitor.Monitor - pushChannelMonitorCfg *pushchannelmonitor.Config + dataTransferNetwork network.DataTransferNetwork + validatedTypes *registry.Registry + resultTypes *registry.Registry + revalidators *registry.Registry + transportConfigurers *registry.Registry + pubSub *pubsub.PubSub + readySub *pubsub.PubSub + channels *channels.Channels + peerID peer.ID + transport datatransfer.Transport + storedCounter *storedcounter.StoredCounter + cidLists cidlists.CIDLists + channelMonitor *channelmonitor.Monitor + channelMonitorCfg *channelmonitor.Config } type internalEvent struct { @@ -84,23 +79,14 @@ func readyDispatcher(evt pubsub.Event, fn pubsub.SubscriberFn) error { // DataTransferOption configures the data transfer manager type DataTransferOption func(*manager) -// ChannelRemoveTimeout sets the timeout after which channels are removed from the manager -func ChannelRemoveTimeout(timeout time.Duration) DataTransferOption { +// ChannelRestartConfig sets the configuration options for automatically +// restarting push and pull channels +func ChannelRestartConfig(cfg channelmonitor.Config) DataTransferOption { return func(m *manager) { - m.channelRemoveTimeout = timeout + m.channelMonitorCfg = &cfg } } -// PushChannelRestartConfig sets the configuration options for automatically -// restarting push channels -func PushChannelRestartConfig(cfg pushchannelmonitor.Config) DataTransferOption { - return func(m *manager) { - m.pushChannelMonitorCfg = &cfg - } -} - -const defaultChannelRemoveTimeout = 1 * time.Hour - // NewDataTransfer initializes a new instance of a data transfer manager func NewDataTransfer(ds datastore.Batching, cidListsDir string, dataTransferNetwork network.DataTransferNetwork, transport datatransfer.Transport, storedCounter *storedcounter.StoredCounter, options ...DataTransferOption) (datatransfer.Manager, error) { m := &manager{ @@ -114,8 +100,6 @@ func NewDataTransfer(ds datastore.Batching, cidListsDir string, dataTransferNetw peerID: dataTransferNetwork.ID(), transport: transport, storedCounter: storedCounter, - channelRemoveTimeout: defaultChannelRemoveTimeout, - reconnects: make(map[datatransfer.ChannelID]chan struct{}), } cidLists, err := cidlists.NewCIDLists(cidListsDir) @@ -134,10 +118,10 @@ func NewDataTransfer(ds datastore.Batching, cidListsDir string, dataTransferNetw option(m) } - // Start push channel monitor after applying config options as the config + // Start push / pull channel monitor after applying config options as the config // options may apply to the monitor - m.pushChannelMonitor = pushchannelmonitor.NewMonitor(m, m.pushChannelMonitorCfg) - m.pushChannelMonitor.Start() + m.channelMonitor = channelmonitor.NewMonitor(m, m.channelMonitorCfg) + m.channelMonitor.Start() return m, nil } @@ -185,7 +169,7 @@ func (m *manager) OnReady(ready datatransfer.ReadyFunc) { // Stop terminates all data transfers and ends processing func (m *manager) Stop(ctx context.Context) error { log.Info("stop data-transfer module") - m.pushChannelMonitor.Shutdown() + m.channelMonitor.Shutdown() return m.transport.Shutdown(ctx) } @@ -223,7 +207,7 @@ func (m *manager) OpenPushDataChannel(ctx context.Context, requestTo peer.ID, vo transportConfigurer(chid, voucher, m.transport) } m.dataTransferNetwork.Protect(requestTo, chid.String()) - monitoredChan := m.pushChannelMonitor.AddChannel(chid) + monitoredChan := m.channelMonitor.AddPushChannel(chid) if err := m.dataTransferNetwork.SendMessage(ctx, requestTo, req); err != nil { err = fmt.Errorf("Unable to send request: %w", err) _ = m.channels.Error(chid, err) @@ -261,9 +245,17 @@ func (m *manager) OpenPullDataChannel(ctx context.Context, requestTo peer.ID, vo transportConfigurer(chid, voucher, m.transport) } m.dataTransferNetwork.Protect(requestTo, chid.String()) + monitoredChan := m.channelMonitor.AddPullChannel(chid) if err := m.transport.OpenChannel(ctx, requestTo, chid, cidlink.Link{Cid: baseCid}, selector, nil, req); err != nil { err = fmt.Errorf("Unable to send request: %w", err) _ = m.channels.Error(chid, err) + + // If pull channel monitoring is enabled, shutdown the monitor as it + // wasn't possible to start the data transfer + if monitoredChan != nil { + monitoredChan.Shutdown() + } + return chid, err } return chid, nil @@ -284,7 +276,7 @@ func (m *manager) SendVoucher(ctx context.Context, channelID datatransfer.Channe } if err := m.dataTransferNetwork.SendMessage(ctx, chst.OtherPeer(), updateRequest); err != nil { err = fmt.Errorf("Unable to send request: %w", err) - _ = m.OnRequestDisconnected(ctx, channelID) + _ = m.OnRequestDisconnected(channelID, err) return err } return m.channels.NewVoucher(channelID, voucher) @@ -311,7 +303,7 @@ func (m *manager) CloseDataTransferChannel(ctx context.Context, chid datatransfe if err != nil { err = fmt.Errorf("unable to send cancel message for channel %s to peer %s: %w", chid, m.peerID, err) - _ = m.OnRequestDisconnected(ctx, chid) + _ = m.OnRequestDisconnected(chid, err) log.Warn(err) } @@ -384,7 +376,7 @@ func (m *manager) PauseDataTransferChannel(ctx context.Context, chid datatransfe if err := m.dataTransferNetwork.SendMessage(ctx, chid.OtherParty(m.peerID), m.pauseMessage(chid)); err != nil { err = fmt.Errorf("Unable to send pause message: %w", err) - _ = m.OnRequestDisconnected(ctx, chid) + _ = m.OnRequestDisconnected(chid, err) return err } diff --git a/impl/initiating_test.go b/impl/initiating_test.go index ae0dea37..4b13a009 100644 --- a/impl/initiating_test.go +++ b/impl/initiating_test.go @@ -85,75 +85,6 @@ func TestDataTransferInitiating(t *testing.T) { testutil.AssertFakeDTVoucher(t, receivedRequest, h.voucher) }, }, - "Remove Timed-out request": { - expectedEvents: []datatransfer.EventCode{datatransfer.Open, datatransfer.Error, datatransfer.CleanupComplete}, - options: []DataTransferOption{ChannelRemoveTimeout(10 * time.Millisecond)}, - verify: func(t *testing.T, h *harness) { - channelID, err := h.dt.OpenPullDataChannel(h.ctx, h.peers[1], h.voucher, h.baseCid, h.stor) - require.NoError(t, err) - require.NoError(t, h.transport.EventHandler.OnRequestTimedOut(ctx, channelID)) - // need time for the events to take place - time.Sleep(1 * time.Second) - }, - }, - "Remove disconnected request": { - expectedEvents: []datatransfer.EventCode{datatransfer.Open, datatransfer.Disconnected, datatransfer.Error, datatransfer.CleanupComplete}, - options: []DataTransferOption{ChannelRemoveTimeout(10 * time.Millisecond)}, - verify: func(t *testing.T, h *harness) { - channelID, err := h.dt.OpenPullDataChannel(h.ctx, h.peers[1], h.voucher, h.baseCid, h.stor) - require.NoError(t, err) - require.NoError(t, h.transport.EventHandler.OnRequestDisconnected(ctx, channelID)) - // need time for the events to take place - time.Sleep(1 * time.Second) - }, - }, - "Remove disconnected push request": { - expectedEvents: []datatransfer.EventCode{datatransfer.Open, datatransfer.Accept, datatransfer.ResumeResponder, datatransfer.Disconnected, datatransfer.Error, datatransfer.CleanupComplete}, - options: []DataTransferOption{ChannelRemoveTimeout(10 * time.Millisecond)}, - verify: func(t *testing.T, h *harness) { - channelID, err := h.dt.OpenPushDataChannel(h.ctx, h.peers[1], h.voucher, h.baseCid, h.stor) - require.NoError(t, err) - response, err := message.NewResponse(channelID.ID, true, false, datatransfer.EmptyTypeIdentifier, nil) - require.NoError(t, err) - err = h.transport.EventHandler.OnResponseReceived(channelID, response) - require.NoError(t, err) - require.NoError(t, h.transport.EventHandler.OnRequestDisconnected(ctx, channelID)) - // need time for the events to take place - time.Sleep(1 * time.Second) - }, - }, - "Disconnected request resumes": { - expectedEvents: []datatransfer.EventCode{datatransfer.Open, datatransfer.Disconnected, datatransfer.DataReceivedProgress, datatransfer.DataReceived}, - options: []DataTransferOption{ChannelRemoveTimeout(10 * time.Millisecond)}, - verify: func(t *testing.T, h *harness) { - channelID, err := h.dt.OpenPullDataChannel(h.ctx, h.peers[1], h.voucher, h.baseCid, h.stor) - require.NoError(t, err) - require.NoError(t, h.transport.EventHandler.OnRequestDisconnected(ctx, channelID)) - testCids := testutil.GenerateCids(1) - require.NoError(t, h.transport.EventHandler.OnDataReceived(channelID, cidlink.Link{Cid: testCids[0]}, uint64(12345))) - - // need time for the events to take place - time.Sleep(1 * time.Second) - }, - }, - "Disconnected request resumes, push": { - expectedEvents: []datatransfer.EventCode{datatransfer.Open, datatransfer.Accept, datatransfer.ResumeResponder, datatransfer.Disconnected, datatransfer.DataSentProgress, datatransfer.DataSent}, - options: []DataTransferOption{ChannelRemoveTimeout(10 * time.Millisecond)}, - verify: func(t *testing.T, h *harness) { - channelID, err := h.dt.OpenPullDataChannel(h.ctx, h.peers[1], h.voucher, h.baseCid, h.stor) - require.NoError(t, err) - response, err := message.NewResponse(channelID.ID, true, false, datatransfer.EmptyTypeIdentifier, nil) - require.NoError(t, err) - err = h.transport.EventHandler.OnResponseReceived(channelID, response) - require.NoError(t, err) - require.NoError(t, h.transport.EventHandler.OnRequestDisconnected(ctx, channelID)) - testCids := testutil.GenerateCids(1) - require.NoError(t, h.transport.EventHandler.OnDataSent(channelID, cidlink.Link{Cid: testCids[0]}, uint64(12345))) - - // need time for the events to take place - time.Sleep(1 * time.Second) - }, - }, "SendVoucher with no channel open": { verify: func(t *testing.T, h *harness) { err := h.dt.SendVoucher(h.ctx, datatransfer.ChannelID{Initiator: h.peers[1], Responder: h.peers[0], ID: 999999}, h.voucher) diff --git a/impl/integration_test.go b/impl/integration_test.go index 594fe5ce..2e744df2 100644 --- a/impl/integration_test.go +++ b/impl/integration_test.go @@ -32,11 +32,11 @@ import ( "github.com/filecoin-project/go-storedcounter" datatransfer "github.com/filecoin-project/go-data-transfer" + "github.com/filecoin-project/go-data-transfer/channelmonitor" "github.com/filecoin-project/go-data-transfer/encoding" . "github.com/filecoin-project/go-data-transfer/impl" "github.com/filecoin-project/go-data-transfer/message" "github.com/filecoin-project/go-data-transfer/network" - "github.com/filecoin-project/go-data-transfer/pushchannelmonitor" "github.com/filecoin-project/go-data-transfer/testutil" tp "github.com/filecoin-project/go-data-transfer/transport/graphsync" "github.com/filecoin-project/go-data-transfer/transport/graphsync/extension" @@ -524,22 +524,24 @@ func (dc *disconnectCoordinator) onDisconnect() { close(dc.disconnected) } -// TestPushRequestAutoRestart tests that if the connection for a push request +// TestAutoRestart tests that if the connection for a push or pull request // goes down, it will automatically restart (given the right config options) -func TestPushRequestAutoRestart(t *testing.T) { - //logging.SetLogLevel("dt-pushchanmon", "debug") +func TestAutoRestart(t *testing.T) { + //SetDTLogLevelDebug() testCases := []struct { name string + isPush bool expectInitiatorDTFail bool disconnectOnRequestComplete bool registerResponder func(responder datatransfer.Manager, dc *disconnectCoordinator) }{{ - // Verify that the client fires an error event when the disconnect + // Push: Verify that the client fires an error event when the disconnect // occurs right when the responder receives the open channel request // (ie the responder doesn't get a chance to respond to the open // channel request) - name: "when responder receives incoming request", + name: "push: when responder receives incoming request", + isPush: true, expectInitiatorDTFail: true, registerResponder: func(responder datatransfer.Manager, dc *disconnectCoordinator) { subscriber := func(event datatransfer.Event, channelState datatransfer.ChannelState) { @@ -550,10 +552,27 @@ func TestPushRequestAutoRestart(t *testing.T) { responder.SubscribeToEvents(subscriber) }, }, { - // Verify that if a disconnect happens right after the responder + // Pull: Verify that the client fires an error event when the disconnect + // occurs right when the responder receives the open channel request + // (ie the responder doesn't get a chance to respond to the open + // channel request) + name: "pull: when responder receives incoming request", + isPush: false, + expectInitiatorDTFail: true, + registerResponder: func(responder datatransfer.Manager, dc *disconnectCoordinator) { + subscriber := func(event datatransfer.Event, channelState datatransfer.ChannelState) { + if event.Code == datatransfer.Open { + dc.signalReadyForDisconnect(true) + } + } + responder.SubscribeToEvents(subscriber) + }, + }, { + // Push: Verify that if a disconnect happens right after the responder // receives the first block, the transfer will complete automatically // when the link comes back up - name: "when responder receives first block", + name: "push: when responder receives first block", + isPush: true, registerResponder: func(responder datatransfer.Manager, dc *disconnectCoordinator) { rcvdCount := 0 subscriber := func(event datatransfer.Event, channelState datatransfer.ChannelState) { @@ -568,11 +587,39 @@ func TestPushRequestAutoRestart(t *testing.T) { responder.SubscribeToEvents(subscriber) }, }, { - // Verify that the client fires an error event when disconnect occurs - // right before the responder sends the complete message (ie all blocks - // have been received but the responder doesn't get a chance to tell + // Pull: Verify that if a disconnect happens right after the responder + // enqueues the first block, the transfer will complete automatically + // when the link comes back up + name: "pull: when responder sends first block", + isPush: false, + registerResponder: func(responder datatransfer.Manager, dc *disconnectCoordinator) { + sentCount := 0 + subscriber := func(event datatransfer.Event, channelState datatransfer.ChannelState) { + if event.Code == datatransfer.DataSent { + sentCount++ + if sentCount == 1 { + dc.signalReadyForDisconnect(false) + } + } + } + responder.SubscribeToEvents(subscriber) + }, + }, { + // Push: Verify that the client fires an error event when disconnect occurs + // right before the responder sends the complete message (ie the responder + // has received all blocks but the responder doesn't get a chance to tell // the initiator before the disconnect) - name: "before requester sends complete message", + name: "push: before requester sends complete message", + isPush: true, + expectInitiatorDTFail: true, + disconnectOnRequestComplete: true, + }, { + // Pull: Verify that the client fires an error event when disconnect occurs + // right before the responder sends the complete message (ie responder sent + // all blocks, but the responder doesn't get a chance to tell the initiator + // before the disconnect) + name: "push: before requester sends complete message", + isPush: true, expectInitiatorDTFail: true, disconnectOnRequestComplete: true, }} @@ -589,37 +636,50 @@ func TestPushRequestAutoRestart(t *testing.T) { // the Complete message, add a hook to do so var responderTransportOpts []tp.Option if tc.disconnectOnRequestComplete { - responderTransportOpts = []tp.Option{ - tp.RegisterCompletedRequestListener(func(chid datatransfer.ChannelID) { - dc.signalReadyForDisconnect(true) - }), + if tc.isPush { + responderTransportOpts = []tp.Option{ + tp.RegisterCompletedRequestListener(func(chid datatransfer.ChannelID) { + dc.signalReadyForDisconnect(true) + }), + } + } else { + responderTransportOpts = []tp.Option{ + tp.RegisterCompletedResponseListener(func(chid datatransfer.ChannelID) { + dc.signalReadyForDisconnect(true) + }), + } } } - gsData := testutil.NewGraphsyncTestingData(ctx, t, nil, nil) + // The retry config for the network layer: make 5 attempts, backing off by 1s each time netRetry := network.RetryParameters(time.Second, time.Second, 5, 1) + gsData := testutil.NewGraphsyncTestingData(ctx, t, nil, nil) gsData.DtNet1 = network.NewFromLibp2pHost(gsData.Host1, netRetry) - host1 := gsData.Host1 // initiator, data sender - host2 := gsData.Host2 // data recipient + initiatorHost := gsData.Host1 // initiator, data sender + responderHost := gsData.Host2 // data recipient initiatorGSTspt := gsData.SetupGSTransportHost1() responderGSTspt := gsData.SetupGSTransportHost2(responderTransportOpts...) - restartConf := PushChannelRestartConfig(pushchannelmonitor.Config{ + // Set up + restartConf := ChannelRestartConfig(channelmonitor.Config{ AcceptTimeout: 100 * time.Millisecond, Interval: 100 * time.Millisecond, - MinBytesSent: 1, + MinBytesTransferred: 1, ChecksPerInterval: 10, - RestartBackoff: 200 * time.Millisecond, + RestartBackoff: 500 * time.Millisecond, MaxConsecutiveRestarts: 5, CompleteTimeout: 100 * time.Millisecond, }) initiator, err := NewDataTransfer(gsData.DtDs1, gsData.TempDir1, gsData.DtNet1, initiatorGSTspt, gsData.StoredCounter1, restartConf) require.NoError(t, err) testutil.StartAndWaitForReady(ctx, t, initiator) + defer initiator.Stop(ctx) + responder, err := NewDataTransfer(gsData.DtDs2, gsData.TempDir2, gsData.DtNet2, responderGSTspt, gsData.StoredCounter2) require.NoError(t, err) testutil.StartAndWaitForReady(ctx, t, responder) + defer responder.Stop(ctx) //initiator.SubscribeToEvents(func(event datatransfer.Event, channelState datatransfer.ChannelState) { // t.Logf("clnt: evt %s / status %s", datatransfer.Events[event.Code], datatransfer.Statuses[channelState.Status()]) @@ -637,8 +697,14 @@ func TestPushRequestAutoRestart(t *testing.T) { voucher := testutil.FakeDTType{Data: "applesauce"} sv := testutil.NewStubbedValidator() - sourceDagService := gsData.DagService1 - destDagService := gsData.DagService2 + var sourceDagService, destDagService ipldformat.DAGService + if tc.isPush { + sourceDagService = gsData.DagService1 + destDagService = gsData.DagService2 + } else { + sourceDagService = gsData.DagService2 + destDagService = gsData.DagService1 + } root, origBytes := testutil.LoadUnixFSFile(ctx, t, sourceDagService, loremFile) rootCid := root.(cidlink.Link).Cid @@ -662,8 +728,14 @@ func TestPushRequestAutoRestart(t *testing.T) { }) } - // Open a push channel - chid, err := initiator.OpenPushDataChannel(ctx, host2.ID(), &voucher, rootCid, gsData.AllSelector) + var chid datatransfer.ChannelID + if tc.isPush { + // Open a push channel + chid, err = initiator.OpenPushDataChannel(ctx, responderHost.ID(), &voucher, rootCid, gsData.AllSelector) + } else { + // Open a pull channel + chid, err = initiator.OpenPullDataChannel(ctx, responderHost.ID(), &voucher, rootCid, gsData.AllSelector) + } require.NoError(t, err) // Wait for the moment at which the test case should experience a disconnect @@ -675,8 +747,8 @@ func TestPushRequestAutoRestart(t *testing.T) { // Break connection t.Logf("Breaking connection to peer") - require.NoError(t, gsData.Mn.UnlinkPeers(host1.ID(), host2.ID())) - require.NoError(t, gsData.Mn.DisconnectPeers(host1.ID(), host2.ID())) + require.NoError(t, gsData.Mn.UnlinkPeers(initiatorHost.ID(), responderHost.ID())) + require.NoError(t, gsData.Mn.DisconnectPeers(initiatorHost.ID(), responderHost.ID())) // Inform the coordinator that the disconnect has occurred dc.onDisconnect() @@ -1612,7 +1684,7 @@ func (r *receiver) ReceiveRestartExistingChannelRequest(ctx context.Context, //func SetDTLogLevelDebug() { // _ = logging.SetLogLevel("dt-impl", "debug") -// _ = logging.SetLogLevel("dt-pushchanmon", "debug") +// _ = logging.SetLogLevel("dt-chanmon", "debug") // _ = logging.SetLogLevel("dt_graphsync", "debug") // _ = logging.SetLogLevel("data_transfer", "debug") // _ = logging.SetLogLevel("data_transfer_network", "debug") diff --git a/transport.go b/transport.go index 9f9812c9..e271e7be 100644 --- a/transport.go +++ b/transport.go @@ -55,11 +55,14 @@ type EventsHandler interface { // OnRequestTimedOut is called when a request we opened (with the given channel Id) to receive data times out. // Error returns are logged but otherwise have no effect - OnRequestTimedOut(ctx context.Context, chid ChannelID) error + OnRequestTimedOut(chid ChannelID, err error) error - // OnRequestDisconnected is called when a network error occurs in a graphsync request - // or we appear to stall while receiving data - OnRequestDisconnected(ctx context.Context, chid ChannelID) error + // OnRequestDisconnected is called when a network error occurs trying to send a request + OnRequestDisconnected(chid ChannelID, err error) error + + // OnSendDataError is called when a network error occurs sending data + // at the transport layer + OnSendDataError(chid ChannelID, err error) error } /* diff --git a/transport/graphsync/graphsync.go b/transport/graphsync/graphsync.go index 469c2d5a..194047c5 100644 --- a/transport/graphsync/graphsync.go +++ b/transport/graphsync/graphsync.go @@ -45,23 +45,31 @@ func RegisterCompletedRequestListener(l func(channelID datatransfer.ChannelID)) } } +// RegisterCompletedResponseListener is used by the tests +func RegisterCompletedResponseListener(l func(channelID datatransfer.ChannelID)) Option { + return func(t *Transport) { + t.completedResponseListener = l + } +} + // Transport manages graphsync hooks for data transfer, translating from // graphsync hooks to semantic data transfer events type Transport struct { - events datatransfer.EventsHandler - gs graphsync.GraphExchange - peerID peer.ID - dataLock sync.RWMutex - graphsyncRequestMap map[graphsyncKey]datatransfer.ChannelID - channelIDMap map[datatransfer.ChannelID]graphsyncKey - contextCancelMap map[datatransfer.ChannelID]func() - pending map[datatransfer.ChannelID]chan struct{} - requestorCancelledMap map[datatransfer.ChannelID]struct{} - pendingExtensions map[datatransfer.ChannelID][]graphsync.ExtensionData - stores map[datatransfer.ChannelID]struct{} - supportedExtensions []graphsync.ExtensionName - unregisterFuncs []graphsync.UnregisterHookFunc - completedRequestListener func(channelID datatransfer.ChannelID) + events datatransfer.EventsHandler + gs graphsync.GraphExchange + peerID peer.ID + dataLock sync.RWMutex + graphsyncRequestMap map[graphsyncKey]datatransfer.ChannelID + channelIDMap map[datatransfer.ChannelID]graphsyncKey + contextCancelMap map[datatransfer.ChannelID]func() + pending map[datatransfer.ChannelID]chan struct{} + requestorCancelledMap map[datatransfer.ChannelID]struct{} + pendingExtensions map[datatransfer.ChannelID][]graphsync.ExtensionData + stores map[datatransfer.ChannelID]struct{} + supportedExtensions []graphsync.ExtensionName + unregisterFuncs []graphsync.UnregisterHookFunc + completedRequestListener func(channelID datatransfer.ChannelID) + completedResponseListener func(channelID datatransfer.ChannelID) } // NewTransport makes a new hooks manager with the given hook events interface @@ -130,7 +138,7 @@ func (t *Transport) OpenChannel(ctx context.Context, } responseChan, errChan := t.gs.Request(internalCtx, dataSender, root, stor, exts...) - go t.executeGsRequest(ctx, internalCtx, channelID, responseChan, errChan) + go t.executeGsRequest(internalCtx, channelID, responseChan, errChan) return nil } @@ -144,12 +152,13 @@ func (t *Transport) consumeResponses(responseChan <-chan graphsync.ResponseProgr return lastError } -func (t *Transport) executeGsRequest(ctx context.Context, internalCtx context.Context, channelID datatransfer.ChannelID, responseChan <-chan graphsync.ResponseProgress, errChan <-chan error) { +func (t *Transport) executeGsRequest(internalCtx context.Context, channelID datatransfer.ChannelID, responseChan <-chan graphsync.ResponseProgress, errChan <-chan error) { lastError := t.consumeResponses(responseChan, errChan) if _, ok := lastError.(graphsync.RequestContextCancelledErr); ok { - log.Warnf("graphsync request context cancelled, channel Id: %v", channelID) - if err := t.events.OnRequestTimedOut(ctx, channelID); err != nil { + terr := xerrors.Errorf("graphsync request context cancelled") + log.Warnf("channel id %v: %s", channelID, terr) + if err := t.events.OnRequestTimedOut(channelID, terr); err != nil { log.Error(err) } return @@ -321,7 +330,7 @@ func (t *Transport) SetEventHandler(events datatransfer.EventsHandler) error { t.unregisterFuncs = append(t.unregisterFuncs, t.gs.RegisterIncomingResponseHook(t.gsIncomingResponseHook)) t.unregisterFuncs = append(t.unregisterFuncs, t.gs.RegisterRequestUpdatedHook(t.gsRequestUpdatedHook)) t.unregisterFuncs = append(t.unregisterFuncs, t.gs.RegisterRequestorCancelledListener(t.gsRequestorCancelledListener)) - t.unregisterFuncs = append(t.unregisterFuncs, t.gs.RegisterNetworkErrorListener(t.gsNetworkErrorListener)) + t.unregisterFuncs = append(t.unregisterFuncs, t.gs.RegisterNetworkErrorListener(t.gsNetworkSendErrorListener)) return nil } @@ -561,6 +570,12 @@ func (t *Transport) gsCompletedResponseListener(p peer.ID, request graphsync.Req statusStr := gsResponseStatusCodeString(status) completeErr = xerrors.Errorf("graphsync response to peer %s did not complete: response status code %s", p, statusStr) } + + // Used by the tests to listen for when a response completes + if t.completedResponseListener != nil { + t.completedResponseListener(chid) + } + err := t.events.OnChannelCompleted(chid, completeErr) if err != nil { log.Error(err) @@ -709,15 +724,18 @@ func (t *Transport) gsRequestorCancelledListener(p peer.ID, request graphsync.Re } } -func (t *Transport) gsNetworkErrorListener(p peer.ID, request graphsync.RequestData, err error) { +// Called when there is a graphsync error sending data +func (t *Transport) gsNetworkSendErrorListener(p peer.ID, request graphsync.RequestData, gserr error) { t.dataLock.Lock() defer t.dataLock.Unlock() chid, ok := t.graphsyncRequestMap[graphsyncKey{request.ID(), p}] - if ok { - err := t.events.OnRequestDisconnected(context.TODO(), chid) - if err != nil { - log.Error(err) - } + if !ok { + return + } + + err := t.events.OnSendDataError(chid, gserr) + if err != nil { + log.Errorf("failed to fire transport send error %s: %s", gserr, err) } } diff --git a/transport/graphsync/graphsync_test.go b/transport/graphsync/graphsync_test.go index c477b2db..512ab78c 100644 --- a/transport/graphsync/graphsync_test.go +++ b/transport/graphsync/graphsync_test.go @@ -633,7 +633,7 @@ func TestManager(t *testing.T) { }, check: func(t *testing.T, events *fakeEvents, gsData *harness) { require.Equal(t, 1, events.OnRequestReceivedCallCount) - require.True(t, events.OnRequestDisconnectedCalled) + require.True(t, events.OnSendDataErrorCalled) }, }, "open channel adds doNotSendCids to the DoNotSend extension": { @@ -974,10 +974,10 @@ type fakeEvents struct { OnDataQueuedMessage datatransfer.Message OnDataQueuedError error - OnRequestTimedOutCalled bool - OnRequestTimedOutChannelId datatransfer.ChannelID - OnRequestDisconnectedCalled bool - OnRequestDisconnectedChannelID datatransfer.ChannelID + OnRequestTimedOutCalled bool + OnRequestTimedOutChannelId datatransfer.ChannelID + OnSendDataErrorCalled bool + OnSendDataErrorChannelID datatransfer.ChannelID ChannelCompletedSuccess bool RequestReceivedRequest datatransfer.Request @@ -991,16 +991,20 @@ func (fe *fakeEvents) OnDataQueued(chid datatransfer.ChannelID, link ipld.Link, return fe.OnDataQueuedMessage, fe.OnDataQueuedError } -func (fe *fakeEvents) OnRequestTimedOut(_ context.Context, chid datatransfer.ChannelID) error { +func (fe *fakeEvents) OnRequestTimedOut(chid datatransfer.ChannelID, err error) error { fe.OnRequestTimedOutCalled = true fe.OnRequestTimedOutChannelId = chid return nil } -func (fe *fakeEvents) OnRequestDisconnected(_ context.Context, chid datatransfer.ChannelID) error { - fe.OnRequestDisconnectedCalled = true - fe.OnRequestDisconnectedChannelID = chid +func (fe *fakeEvents) OnRequestDisconnected(chid datatransfer.ChannelID, err error) error { + return nil +} + +func (fe *fakeEvents) OnSendDataError(chid datatransfer.ChannelID, err error) error { + fe.OnSendDataErrorCalled = true + fe.OnSendDataErrorChannelID = chid return nil }