Skip to content

Commit

Permalink
fix(dot/peerset): remove race conditions from peerset package (#2267)
Browse files Browse the repository at this point in the history
* chore: remove race conditions from `peerset` package

Co-authored-by: Quentin McGaw <[email protected]>
  • Loading branch information
EclesioMeloJunior and qdm12 authored May 5, 2022
1 parent ea95ffd commit df09d45
Show file tree
Hide file tree
Showing 19 changed files with 681 additions and 322 deletions.
10 changes: 5 additions & 5 deletions dot/network/connmgr_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -82,12 +82,12 @@ func TestMaxPeers(t *testing.T) {
continue
}

n.host.h.Peerstore().AddAddrs(ainfo.ID, ainfo.Addrs, peerstore.PermanentAddrTTL)
n.host.p2pHost.Peerstore().AddAddrs(ainfo.ID, ainfo.Addrs, peerstore.PermanentAddrTTL)
n.host.cm.peerSetHandler.AddPeer(0, ainfo.ID)
}

time.Sleep(200 * time.Millisecond)
p := nodes[0].host.h.Peerstore().Peers()
p := nodes[0].host.p2pHost.Peerstore().Peers()
require.LessOrEqual(t, max, len(p))
}

Expand Down Expand Up @@ -152,15 +152,15 @@ func TestPersistentPeers(t *testing.T) {
time.Sleep(time.Millisecond * 600)

// B should have connected to A during bootstrap
conns := nodeB.host.h.Network().ConnsToPeer(nodeA.host.id())
conns := nodeB.host.p2pHost.Network().ConnsToPeer(nodeA.host.id())
require.NotEqual(t, 0, len(conns))

// if A disconnects from B, B should reconnect
nodeA.host.cm.peerSetHandler.DisconnectPeer(0, nodeB.host.id())

time.Sleep(time.Millisecond * 500)

conns = nodeB.host.h.Network().ConnsToPeer(nodeA.host.id())
conns = nodeB.host.p2pHost.Network().ConnsToPeer(nodeA.host.id())
require.NotEqual(t, 0, len(conns))
}

Expand Down Expand Up @@ -239,7 +239,7 @@ func TestSetReservedPeer(t *testing.T) {

require.Equal(t, 2, node3.host.peerCount())

node3.host.h.Peerstore().AddAddrs(addrC.ID, addrC.Addrs, peerstore.PermanentAddrTTL)
node3.host.p2pHost.Peerstore().AddAddrs(addrC.ID, addrC.Addrs, peerstore.PermanentAddrTTL)
node3.host.cm.peerSetHandler.SetReservedPeer(0, addrC.ID)
time.Sleep(200 * time.Millisecond)

Expand Down
4 changes: 2 additions & 2 deletions dot/network/discovery_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ func newTestDiscovery(t *testing.T, num int) []*discovery {
require.NoError(t, err)
disc := &discovery{
ctx: srvc.ctx,
h: srvc.host.h,
h: srvc.host.p2pHost,
ds: ds,
}

Expand Down Expand Up @@ -200,7 +200,7 @@ func TestBeginDiscovery_ThreeNodes(t *testing.T) {
time.Sleep(time.Millisecond * 500)

// assert B and C can discover each other
addrs := nodeB.host.h.Peerstore().Addrs(nodeC.host.id())
addrs := nodeB.host.p2pHost.Peerstore().Addrs(nodeC.host.id())
require.NotEqual(t, 0, len(addrs))

}
44 changes: 22 additions & 22 deletions dot/network/host.go
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ const (
// host wraps libp2p host with network host configuration and services
type host struct {
ctx context.Context
h libp2phost.Host
p2pHost libp2phost.Host
discovery *discovery
bootnodes []peer.AddrInfo
persistentPeers []peer.AddrInfo
Expand Down Expand Up @@ -211,7 +211,7 @@ func newHost(ctx context.Context, cfg *Config) (*host, error) {

host := &host{
ctx: ctx,
h: h,
p2pHost: h,
discovery: discovery,
bootnodes: bns,
protocolID: pid,
Expand All @@ -236,14 +236,14 @@ func (h *host) close() error {
}

// close libp2p host
err = h.h.Close()
err = h.p2pHost.Close()
if err != nil {
logger.Errorf("Failed to close libp2p host: %s", err)
return err
}

h.closeSync.Do(func() {
err = h.h.Peerstore().Close()
err = h.p2pHost.Peerstore().Close()
if err != nil {
logger.Errorf("Failed to close libp2p peerstore: %s", err)
return
Expand All @@ -260,28 +260,28 @@ func (h *host) close() error {

// registerStreamHandler registers the stream handler for the given protocol id.
func (h *host) registerStreamHandler(pid protocol.ID, handler func(libp2pnetwork.Stream)) {
h.h.SetStreamHandler(pid, handler)
h.p2pHost.SetStreamHandler(pid, handler)
}

// connect connects the host to a specific peer address
func (h *host) connect(p peer.AddrInfo) (err error) {
h.h.Peerstore().AddAddrs(p.ID, p.Addrs, peerstore.PermanentAddrTTL)
h.p2pHost.Peerstore().AddAddrs(p.ID, p.Addrs, peerstore.PermanentAddrTTL)
ctx, cancel := context.WithTimeout(h.ctx, connectTimeout)
defer cancel()
err = h.h.Connect(ctx, p)
err = h.p2pHost.Connect(ctx, p)
return err
}

// bootstrap connects the host to the configured bootnodes
func (h *host) bootstrap() {
for _, info := range h.persistentPeers {
h.h.Peerstore().AddAddrs(info.ID, info.Addrs, peerstore.PermanentAddrTTL)
h.p2pHost.Peerstore().AddAddrs(info.ID, info.Addrs, peerstore.PermanentAddrTTL)
h.cm.peerSetHandler.AddReservedPeer(0, info.ID)
}

for _, addrInfo := range h.bootnodes {
logger.Debugf("bootstrapping to peer %s", addrInfo.ID)
h.h.Peerstore().AddAddrs(addrInfo.ID, addrInfo.Addrs, peerstore.PermanentAddrTTL)
h.p2pHost.Peerstore().AddAddrs(addrInfo.ID, addrInfo.Addrs, peerstore.PermanentAddrTTL)
h.cm.peerSetHandler.AddPeer(0, addrInfo.ID)
}
}
Expand All @@ -290,7 +290,7 @@ func (h *host) bootstrap() {
// the newly created stream.
func (h *host) send(p peer.ID, pid protocol.ID, msg Message) (libp2pnetwork.Stream, error) {
// open outbound stream with host protocol id
stream, err := h.h.NewStream(h.ctx, p, pid)
stream, err := h.p2pHost.NewStream(h.ctx, p, pid)
if err != nil {
logger.Tracef("failed to open new stream with peer %s using protocol %s: %s", p, pid, err)
return nil, err
Expand Down Expand Up @@ -334,12 +334,12 @@ func (h *host) writeToStream(s libp2pnetwork.Stream, msg Message) error {

// id returns the host id
func (h *host) id() peer.ID {
return h.h.ID()
return h.p2pHost.ID()
}

// Peers returns connected peers
func (h *host) peers() []peer.ID {
return h.h.Network().Peers()
return h.p2pHost.Network().Peers()
}

// addReservedPeers adds the peers `addrs` to the protected peers list and connects to them
Expand All @@ -354,7 +354,7 @@ func (h *host) addReservedPeers(addrs ...string) error {
if err != nil {
return err
}
h.h.Peerstore().AddAddrs(addrInfo.ID, addrInfo.Addrs, peerstore.PermanentAddrTTL)
h.p2pHost.Peerstore().AddAddrs(addrInfo.ID, addrInfo.Addrs, peerstore.PermanentAddrTTL)
h.cm.peerSetHandler.AddReservedPeer(0, addrInfo.ID)
}

Expand All @@ -369,7 +369,7 @@ func (h *host) removeReservedPeers(ids ...string) error {
return err
}
h.cm.peerSetHandler.RemoveReservedPeer(0, peerID)
h.h.ConnManager().Unprotect(peerID, "")
h.p2pHost.ConnManager().Unprotect(peerID, "")
}

return nil
Expand All @@ -378,7 +378,7 @@ func (h *host) removeReservedPeers(ids ...string) error {
// supportsProtocol checks if the protocol is supported by peerID
// returns an error if could not get peer protocols
func (h *host) supportsProtocol(peerID peer.ID, protocol protocol.ID) (bool, error) {
peerProtocols, err := h.h.Peerstore().SupportsProtocols(peerID, string(protocol))
peerProtocols, err := h.p2pHost.Peerstore().SupportsProtocols(peerID, string(protocol))
if err != nil {
return false, err
}
Expand All @@ -388,21 +388,21 @@ func (h *host) supportsProtocol(peerID peer.ID, protocol protocol.ID) (bool, err

// peerCount returns the number of connected peers
func (h *host) peerCount() int {
peers := h.h.Network().Peers()
peers := h.p2pHost.Network().Peers()
return len(peers)
}

// addrInfo returns the libp2p peer.AddrInfo of the host
func (h *host) addrInfo() peer.AddrInfo {
return peer.AddrInfo{
ID: h.h.ID(),
Addrs: h.h.Addrs(),
ID: h.p2pHost.ID(),
Addrs: h.p2pHost.Addrs(),
}
}

// multiaddrs returns the multiaddresses of the host
func (h *host) multiaddrs() (multiaddrs []ma.Multiaddr) {
addrs := h.h.Addrs()
addrs := h.p2pHost.Addrs()
for _, addr := range addrs {
multiaddr, err := ma.NewMultiaddr(fmt.Sprintf("%s/p2p/%s", addr, h.id()))
if err != nil {
Expand All @@ -415,16 +415,16 @@ func (h *host) multiaddrs() (multiaddrs []ma.Multiaddr) {

// protocols returns all protocols currently supported by the node
func (h *host) protocols() []string {
return h.h.Mux().Protocols()
return h.p2pHost.Mux().Protocols()
}

// closePeer closes connection with peer.
func (h *host) closePeer(peer peer.ID) error {
return h.h.Network().ClosePeer(peer)
return h.p2pHost.Network().ClosePeer(peer)
}

func (h *host) closeProtocolStream(pID protocol.ID, p peer.ID) {
connToPeer := h.h.Network().ConnsToPeer(p)
connToPeer := h.p2pHost.Network().ConnsToPeer(p)
for _, c := range connToPeer {
for _, st := range c.GetStreams() {
if st.Protocol() != pID {
Expand Down
12 changes: 6 additions & 6 deletions dot/network/host_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -170,13 +170,13 @@ func TestBootstrap(t *testing.T) {

peerCountA := nodeA.host.peerCount()
if peerCountA == 0 {
peerCountA := len(nodeA.host.h.Peerstore().Peers())
peerCountA := len(nodeA.host.p2pHost.Peerstore().Peers())
require.NotZero(t, peerCountA)
}

peerCountB := nodeB.host.peerCount()
if peerCountB == 0 {
peerCountB := len(nodeB.host.h.Peerstore().Peers())
peerCountB := len(nodeB.host.p2pHost.Peerstore().Peers())
require.NotZero(t, peerCountB)
}
}
Expand Down Expand Up @@ -498,7 +498,7 @@ func Test_RemoveReservedPeers(t *testing.T) {
time.Sleep(100 * time.Millisecond)

require.Equal(t, 1, nodeA.host.peerCount())
isProtected := nodeA.host.h.ConnManager().IsProtected(nodeB.host.addrInfo().ID, "")
isProtected := nodeA.host.p2pHost.ConnManager().IsProtected(nodeB.host.addrInfo().ID, "")
require.False(t, isProtected)

err = nodeA.host.removeReservedPeers("unknown_perr_id")
Expand Down Expand Up @@ -583,7 +583,7 @@ func TestPeerConnect(t *testing.T) {
nodeB.noGossip = true

addrInfoB := nodeB.host.addrInfo()
nodeA.host.h.Peerstore().AddAddrs(addrInfoB.ID, addrInfoB.Addrs, peerstore.PermanentAddrTTL)
nodeA.host.p2pHost.Peerstore().AddAddrs(addrInfoB.ID, addrInfoB.Addrs, peerstore.PermanentAddrTTL)
nodeA.host.cm.peerSetHandler.AddPeer(0, addrInfoB.ID)

time.Sleep(100 * time.Millisecond)
Expand Down Expand Up @@ -621,7 +621,7 @@ func TestBannedPeer(t *testing.T) {
nodeB.noGossip = true

addrInfoB := nodeB.host.addrInfo()
nodeA.host.h.Peerstore().AddAddrs(addrInfoB.ID, addrInfoB.Addrs, peerstore.PermanentAddrTTL)
nodeA.host.p2pHost.Peerstore().AddAddrs(addrInfoB.ID, addrInfoB.Addrs, peerstore.PermanentAddrTTL)
nodeA.host.cm.peerSetHandler.AddPeer(0, addrInfoB.ID)

time.Sleep(100 * time.Millisecond)
Expand Down Expand Up @@ -674,7 +674,7 @@ func TestPeerReputation(t *testing.T) {
nodeB.noGossip = true

addrInfoB := nodeB.host.addrInfo()
nodeA.host.h.Peerstore().AddAddrs(addrInfoB.ID, addrInfoB.Addrs, peerstore.PermanentAddrTTL)
nodeA.host.p2pHost.Peerstore().AddAddrs(addrInfoB.ID, addrInfoB.Addrs, peerstore.PermanentAddrTTL)
nodeA.host.cm.peerSetHandler.AddPeer(0, addrInfoB.ID)

time.Sleep(100 * time.Millisecond)
Expand Down
2 changes: 1 addition & 1 deletion dot/network/light_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,7 @@ func TestHandleLightMessage_Response(t *testing.T) {
}
require.NoError(t, err)

stream, err := s.host.h.NewStream(s.ctx, b.host.id(), s.host.protocolID+lightID)
stream, err := s.host.p2pHost.NewStream(s.ctx, b.host.id(), s.host.protocolID+lightID)
require.NoError(t, err)

// Testing empty request
Expand Down
4 changes: 2 additions & 2 deletions dot/network/mdns.go
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ func (m *mdns) start() {
// create and start service
mdns, err := libp2pdiscovery.NewMdnsService(
m.host.ctx,
m.host.h,
m.host.p2pHost,
MDNSPeriod,
string(m.host.protocolID),
)
Expand Down Expand Up @@ -89,7 +89,7 @@ func (n Notifee) HandlePeerFound(p peer.AddrInfo) {
"Peer %s found using mDNS discovery, with host %s",
p.ID, n.host.id())

n.host.h.Peerstore().AddAddrs(p.ID, p.Addrs, peerstore.PermanentAddrTTL)
n.host.p2pHost.Peerstore().AddAddrs(p.ID, p.Addrs, peerstore.PermanentAddrTTL)
// connect to found peer
n.host.cm.peerSetHandler.AddPeer(0, p.ID)
}
4 changes: 2 additions & 2 deletions dot/network/mdns_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -42,13 +42,13 @@ func TestMDNS(t *testing.T) {

if peerCountA == 0 {
// check peerstore for disconnected peers
peerCountA := len(nodeA.host.h.Peerstore().Peers())
peerCountA := len(nodeA.host.p2pHost.Peerstore().Peers())
require.NotZero(t, peerCountA)
}

if peerCountB == 0 {
// check peerstore for disconnected peers
peerCountB := len(nodeB.host.h.Peerstore().Peers())
peerCountB := len(nodeB.host.p2pHost.Peerstore().Peers())
require.NotZero(t, peerCountB)
}
}
14 changes: 7 additions & 7 deletions dot/network/notifications_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,7 @@ func TestCreateNotificationsMessageHandler_BlockAnnounce(t *testing.T) {
}
require.NoError(t, err)

stream, err := s.host.h.NewStream(s.ctx, b.host.id(), s.host.protocolID+blockAnnounceID)
stream, err := s.host.p2pHost.NewStream(s.ctx, b.host.id(), s.host.protocolID+blockAnnounceID)
require.NoError(t, err)

// create info and handler
Expand Down Expand Up @@ -181,7 +181,7 @@ func TestCreateNotificationsMessageHandler_BlockAnnounceHandshake(t *testing.T)
}
require.NoError(t, err)

stream, err := s.host.h.NewStream(s.ctx, b.host.id(), s.host.protocolID+blockAnnounceID)
stream, err := s.host.p2pHost.NewStream(s.ctx, b.host.id(), s.host.protocolID+blockAnnounceID)
require.NoError(t, err)

// try invalid handshake
Expand Down Expand Up @@ -250,7 +250,7 @@ func Test_HandshakeTimeout(t *testing.T) {
info := newNotificationsProtocol(nodeA.host.protocolID+blockAnnounceID,
nodeA.getBlockAnnounceHandshake, testHandshakeDecoder, nodeA.validateBlockAnnounceHandshake)

nodeB.host.h.SetStreamHandler(info.protocolID, func(stream libp2pnetwork.Stream) {
nodeB.host.p2pHost.SetStreamHandler(info.protocolID, func(stream libp2pnetwork.Stream) {
// should not respond to a handshake message
})

Expand All @@ -267,7 +267,7 @@ func Test_HandshakeTimeout(t *testing.T) {
// clear handshake data from connection handler
time.Sleep(time.Millisecond * 100)
info.peersData.deleteOutboundHandshakeData(nodeB.host.id())
connAToB := nodeA.host.h.Network().ConnsToPeer(nodeB.host.id())
connAToB := nodeA.host.p2pHost.Network().ConnsToPeer(nodeB.host.id())
for _, stream := range connAToB[0].GetStreams() {
_ = stream.Close()
}
Expand All @@ -289,7 +289,7 @@ func Test_HandshakeTimeout(t *testing.T) {
require.Nil(t, data)

// a stream should be open until timeout
connAToB = nodeA.host.h.Network().ConnsToPeer(nodeB.host.id())
connAToB = nodeA.host.p2pHost.Network().ConnsToPeer(nodeB.host.id())
require.Len(t, connAToB, 1)
require.Len(t, connAToB[0].GetStreams(), 1)

Expand All @@ -301,7 +301,7 @@ func Test_HandshakeTimeout(t *testing.T) {
require.Nil(t, data)

// stream should be closed
connAToB = nodeA.host.h.Network().ConnsToPeer(nodeB.host.id())
connAToB = nodeA.host.p2pHost.Network().ConnsToPeer(nodeB.host.id())
require.Len(t, connAToB, 1)
require.Len(t, connAToB[0].GetStreams(), 0)
}
Expand Down Expand Up @@ -343,7 +343,7 @@ func TestCreateNotificationsMessageHandler_HandleTransaction(t *testing.T) {
require.NoError(t, err)

txnProtocolID := srvc1.host.protocolID + transactionsID
stream, err := srvc1.host.h.NewStream(srvc1.ctx, srvc2.host.id(), txnProtocolID)
stream, err := srvc1.host.p2pHost.NewStream(srvc1.ctx, srvc2.host.id(), txnProtocolID)
require.NoError(t, err)

// create info and handler
Expand Down
Loading

0 comments on commit df09d45

Please sign in to comment.