Skip to content

Commit

Permalink
Make network.dial honor context cancellation. (#2061)
Browse files Browse the repository at this point in the history
  • Loading branch information
Dan Laine authored Sep 26, 2023
1 parent a713afe commit d0655ae
Show file tree
Hide file tree
Showing 2 changed files with 71 additions and 11 deletions.
17 changes: 10 additions & 7 deletions network/network.go
Original file line number Diff line number Diff line change
Expand Up @@ -558,7 +558,7 @@ func (n *network) Track(peerID ids.NodeID, claimedIPPorts []*ips.ClaimedIPPort)
// Stop tracking the old IP and start tracking the new one.
tracked := tracked.trackNewIP(ip.IPPort)
n.trackedIPs[nodeID] = tracked
n.dial(n.onCloseCtx, nodeID, tracked)
n.dial(nodeID, tracked)
}
case verifiedIP && shouldDial:
// Invariant: [isTracked] is false here.
Expand All @@ -575,7 +575,7 @@ func (n *network) Track(peerID ids.NodeID, claimedIPPorts []*ips.ClaimedIPPort)

tracked := newTrackedIP(ip.IPPort)
n.trackedIPs[nodeID] = tracked
n.dial(n.onCloseCtx, nodeID, tracked)
n.dial(nodeID, tracked)
default:
// This IP isn't desired
n.metrics.numUselessPeerListBytes.Add(float64(ip.BytesLen()))
Expand Down Expand Up @@ -829,7 +829,7 @@ func (n *network) ManuallyTrack(nodeID ids.NodeID, ip ips.IPPort) {
if !isTracked {
tracked := newTrackedIP(ip)
n.trackedIPs[nodeID] = tracked
n.dial(n.onCloseCtx, nodeID, tracked)
n.dial(nodeID, tracked)
}
}

Expand Down Expand Up @@ -965,7 +965,7 @@ func (n *network) disconnectedFromConnecting(nodeID ids.NodeID) {
if n.wantsConnection(nodeID) {
tracked := tracked.trackNewIP(tracked.ip)
n.trackedIPs[nodeID] = tracked
n.dial(n.onCloseCtx, nodeID, tracked)
n.dial(nodeID, tracked)
} else {
tracked.stopTracking()
delete(n.peerIPs, nodeID)
Expand All @@ -989,7 +989,7 @@ func (n *network) disconnectedFromConnected(peer peer.Peer, nodeID ids.NodeID) {
prevIP := n.peerIPs[nodeID]
tracked := newTrackedIP(prevIP.IPPort)
n.trackedIPs[nodeID] = tracked
n.dial(n.onCloseCtx, nodeID, tracked)
n.dial(nodeID, tracked)
} else {
delete(n.peerIPs, nodeID)
}
Expand Down Expand Up @@ -1064,7 +1064,7 @@ func (n *network) peerIPStatus(nodeID ids.NodeID, ip *ips.ClaimedIPPort) (*ips.C
// If initiating a connection to [ip] fails, then dial will reattempt. However,
// there is a randomized exponential backoff to avoid spamming connection
// attempts.
func (n *network) dial(ctx context.Context, nodeID ids.NodeID, ip *trackedIP) {
func (n *network) dial(nodeID ids.NodeID, ip *trackedIP) {
go func() {
n.metrics.numTracked.Inc()
defer n.metrics.numTracked.Dec()
Expand All @@ -1073,6 +1073,9 @@ func (n *network) dial(ctx context.Context, nodeID ids.NodeID, ip *trackedIP) {
timer := time.NewTimer(ip.getDelay())

select {
case <-n.onCloseCtx.Done():
timer.Stop()
return
case <-ip.onStopTracking:
timer.Stop()
return
Expand Down Expand Up @@ -1141,7 +1144,7 @@ func (n *network) dial(ctx context.Context, nodeID ids.NodeID, ip *trackedIP) {
continue
}

conn, err := n.dialer.Dial(ctx, ip.ip)
conn, err := n.dialer.Dial(n.onCloseCtx, ip.ip)
if err != nil {
n.peerConfig.Log.Verbo(
"failed to reach peer, attempting again",
Expand Down
65 changes: 61 additions & 4 deletions network/network_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -196,13 +196,13 @@ func newMessageCreator(t *testing.T) message.Creator {
return mc
}

func newFullyConnectedTestNetwork(t *testing.T, handlers []router.InboundHandler) ([]ids.NodeID, []Network, *sync.WaitGroup) {
func newFullyConnectedTestNetwork(t *testing.T, handlers []router.InboundHandler) ([]ids.NodeID, []*network, *sync.WaitGroup) {
require := require.New(t)

dialer, listeners, nodeIDs, configs := newTestNetwork(t, len(handlers))

var (
networks = make([]Network, len(configs))
networks = make([]*network, len(configs))

globalLock sync.Mutex
numConnected int
Expand Down Expand Up @@ -278,7 +278,7 @@ func newFullyConnectedTestNetwork(t *testing.T, handlers []router.InboundHandler
},
)
require.NoError(err)
networks[i] = net
networks[i] = net.(*network)
}

wg := sync.WaitGroup{}
Expand Down Expand Up @@ -403,7 +403,7 @@ func TestTrackVerifiesSignatures(t *testing.T) {

_, networks, wg := newFullyConnectedTestNetwork(t, []router.InboundHandler{nil})

network := networks[0].(*network)
network := networks[0]
nodeID, tlsCert, _ := getTLS(t, 1)
require.NoError(validators.Add(network.config.Validators, constants.PrimaryNetworkID, nodeID, nil, ids.Empty, 1))

Expand Down Expand Up @@ -632,3 +632,60 @@ func TestDialDeletesNonValidators(t *testing.T) {
}
wg.Wait()
}

// Test that cancelling the context passed into dial
// causes dial to return immediately.
func TestDialContext(t *testing.T) {
_, networks, wg := newFullyConnectedTestNetwork(t, []router.InboundHandler{nil})
defer wg.Done()

dialer := newTestDialer()
network := networks[0]
network.dialer = dialer

var (
neverDialedNodeID = ids.GenerateTestNodeID()
dialedNodeID = ids.GenerateTestNodeID()

dynamicNeverDialedIP, neverDialedListener = dialer.NewListener()
dynamicDialedIP, dialedListener = dialer.NewListener()

neverDialedIP = &trackedIP{
ip: dynamicNeverDialedIP.IPPort(),
}
dialedIP = &trackedIP{
ip: dynamicDialedIP.IPPort(),
}
)

network.manuallyTrackedIDs.Add(neverDialedNodeID)
network.manuallyTrackedIDs.Add(dialedNodeID)

// Sanity check that when a non-cancelled context is given,
// we actually dial the peer.
network.dial(dialedNodeID, dialedIP)

gotDialedIPConn := make(chan struct{})
go func() {
_, _ = dialedListener.Accept()
close(gotDialedIPConn)
}()
<-gotDialedIPConn

// Asset that when [n.onCloseCtx] is cancelled, dial returns immediately.
// That is, [neverDialedListener] doesn't accept a connection.
network.onCloseCtxCancel()
network.dial(neverDialedNodeID, neverDialedIP)

gotNeverDialedIPConn := make(chan struct{})
go func() {
_, _ = neverDialedListener.Accept()
close(gotNeverDialedIPConn)
}()

select {
case <-gotNeverDialedIPConn:
require.FailNow(t, "unexpectedly connected to peer")
default:
}
}

0 comments on commit d0655ae

Please sign in to comment.