diff --git a/network/network.go b/network/network.go index 1bdaac75065d..591987197824 100644 --- a/network/network.go +++ b/network/network.go @@ -558,7 +558,7 @@ func (n *network) Track(peerID ids.NodeID, claimedIPPorts []*ips.ClaimedIPPort) // Stop tracking the old IP and start tracking the new one. tracked := tracked.trackNewIP(ip.IPPort) n.trackedIPs[nodeID] = tracked - n.dial(n.onCloseCtx, nodeID, tracked) + n.dial(nodeID, tracked) } case verifiedIP && shouldDial: // Invariant: [isTracked] is false here. @@ -575,7 +575,7 @@ func (n *network) Track(peerID ids.NodeID, claimedIPPorts []*ips.ClaimedIPPort) tracked := newTrackedIP(ip.IPPort) n.trackedIPs[nodeID] = tracked - n.dial(n.onCloseCtx, nodeID, tracked) + n.dial(nodeID, tracked) default: // This IP isn't desired n.metrics.numUselessPeerListBytes.Add(float64(ip.BytesLen())) @@ -829,7 +829,7 @@ func (n *network) ManuallyTrack(nodeID ids.NodeID, ip ips.IPPort) { if !isTracked { tracked := newTrackedIP(ip) n.trackedIPs[nodeID] = tracked - n.dial(n.onCloseCtx, nodeID, tracked) + n.dial(nodeID, tracked) } } @@ -965,7 +965,7 @@ func (n *network) disconnectedFromConnecting(nodeID ids.NodeID) { if n.wantsConnection(nodeID) { tracked := tracked.trackNewIP(tracked.ip) n.trackedIPs[nodeID] = tracked - n.dial(n.onCloseCtx, nodeID, tracked) + n.dial(nodeID, tracked) } else { tracked.stopTracking() delete(n.peerIPs, nodeID) @@ -989,7 +989,7 @@ func (n *network) disconnectedFromConnected(peer peer.Peer, nodeID ids.NodeID) { prevIP := n.peerIPs[nodeID] tracked := newTrackedIP(prevIP.IPPort) n.trackedIPs[nodeID] = tracked - n.dial(n.onCloseCtx, nodeID, tracked) + n.dial(nodeID, tracked) } else { delete(n.peerIPs, nodeID) } @@ -1064,7 +1064,7 @@ func (n *network) peerIPStatus(nodeID ids.NodeID, ip *ips.ClaimedIPPort) (*ips.C // If initiating a connection to [ip] fails, then dial will reattempt. However, // there is a randomized exponential backoff to avoid spamming connection // attempts. -func (n *network) dial(ctx context.Context, nodeID ids.NodeID, ip *trackedIP) { +func (n *network) dial(nodeID ids.NodeID, ip *trackedIP) { go func() { n.metrics.numTracked.Inc() defer n.metrics.numTracked.Dec() @@ -1073,6 +1073,9 @@ func (n *network) dial(ctx context.Context, nodeID ids.NodeID, ip *trackedIP) { timer := time.NewTimer(ip.getDelay()) select { + case <-n.onCloseCtx.Done(): + timer.Stop() + return case <-ip.onStopTracking: timer.Stop() return @@ -1141,7 +1144,7 @@ func (n *network) dial(ctx context.Context, nodeID ids.NodeID, ip *trackedIP) { continue } - conn, err := n.dialer.Dial(ctx, ip.ip) + conn, err := n.dialer.Dial(n.onCloseCtx, ip.ip) if err != nil { n.peerConfig.Log.Verbo( "failed to reach peer, attempting again", diff --git a/network/network_test.go b/network/network_test.go index e8251c4788dc..0e8913082608 100644 --- a/network/network_test.go +++ b/network/network_test.go @@ -196,13 +196,13 @@ func newMessageCreator(t *testing.T) message.Creator { return mc } -func newFullyConnectedTestNetwork(t *testing.T, handlers []router.InboundHandler) ([]ids.NodeID, []Network, *sync.WaitGroup) { +func newFullyConnectedTestNetwork(t *testing.T, handlers []router.InboundHandler) ([]ids.NodeID, []*network, *sync.WaitGroup) { require := require.New(t) dialer, listeners, nodeIDs, configs := newTestNetwork(t, len(handlers)) var ( - networks = make([]Network, len(configs)) + networks = make([]*network, len(configs)) globalLock sync.Mutex numConnected int @@ -278,7 +278,7 @@ func newFullyConnectedTestNetwork(t *testing.T, handlers []router.InboundHandler }, ) require.NoError(err) - networks[i] = net + networks[i] = net.(*network) } wg := sync.WaitGroup{} @@ -403,7 +403,7 @@ func TestTrackVerifiesSignatures(t *testing.T) { _, networks, wg := newFullyConnectedTestNetwork(t, []router.InboundHandler{nil}) - network := networks[0].(*network) + network := networks[0] nodeID, tlsCert, _ := getTLS(t, 1) require.NoError(validators.Add(network.config.Validators, constants.PrimaryNetworkID, nodeID, nil, ids.Empty, 1)) @@ -632,3 +632,60 @@ func TestDialDeletesNonValidators(t *testing.T) { } wg.Wait() } + +// Test that cancelling the context passed into dial +// causes dial to return immediately. +func TestDialContext(t *testing.T) { + _, networks, wg := newFullyConnectedTestNetwork(t, []router.InboundHandler{nil}) + defer wg.Done() + + dialer := newTestDialer() + network := networks[0] + network.dialer = dialer + + var ( + neverDialedNodeID = ids.GenerateTestNodeID() + dialedNodeID = ids.GenerateTestNodeID() + + dynamicNeverDialedIP, neverDialedListener = dialer.NewListener() + dynamicDialedIP, dialedListener = dialer.NewListener() + + neverDialedIP = &trackedIP{ + ip: dynamicNeverDialedIP.IPPort(), + } + dialedIP = &trackedIP{ + ip: dynamicDialedIP.IPPort(), + } + ) + + network.manuallyTrackedIDs.Add(neverDialedNodeID) + network.manuallyTrackedIDs.Add(dialedNodeID) + + // Sanity check that when a non-cancelled context is given, + // we actually dial the peer. + network.dial(dialedNodeID, dialedIP) + + gotDialedIPConn := make(chan struct{}) + go func() { + _, _ = dialedListener.Accept() + close(gotDialedIPConn) + }() + <-gotDialedIPConn + + // Asset that when [n.onCloseCtx] is cancelled, dial returns immediately. + // That is, [neverDialedListener] doesn't accept a connection. + network.onCloseCtxCancel() + network.dial(neverDialedNodeID, neverDialedIP) + + gotNeverDialedIPConn := make(chan struct{}) + go func() { + _, _ = neverDialedListener.Accept() + close(gotNeverDialedIPConn) + }() + + select { + case <-gotNeverDialedIPConn: + require.FailNow(t, "unexpectedly connected to peer") + default: + } +}