diff --git a/dot/network/discovery.go b/dot/network/discovery.go index 2226cfd2b7..d4773c9c5e 100644 --- a/dot/network/discovery.go +++ b/dot/network/discovery.go @@ -35,13 +35,15 @@ var ( startDHTTimeout = time.Second * 10 initialAdvertisementTimeout = time.Millisecond tryAdvertiseTimeout = time.Second * 30 - connectToPeersTimeout = time.Minute + connectToPeersTimeout = time.Minute * 5 + findPeersTimeout = time.Minute ) // discovery handles discovery of new peers via the kademlia DHT type discovery struct { ctx context.Context dht *dual.DHT + rd *libp2pdiscovery.RoutingDiscovery h libp2phost.Host bootnodes []peer.AddrInfo ds *badger.Datastore @@ -117,7 +119,7 @@ func (d *discovery) stop() error { } func (d *discovery) discoverAndAdvertise() error { - rd := libp2pdiscovery.NewRoutingDiscovery(d.dht) + d.rd = libp2pdiscovery.NewRoutingDiscovery(d.dht) err := d.dht.Bootstrap(d.ctx) if err != nil { @@ -126,79 +128,83 @@ func (d *discovery) discoverAndAdvertise() error { // wait to connect to bootstrap peers time.Sleep(time.Second) + go d.advertise() + go d.checkPeerCount() - go func() { - ttl := initialAdvertisementTimeout + logger.Debug("DHT discovery started!") + return nil +} - for { - select { - case <-time.After(ttl): - logger.Debug("advertising ourselves in the DHT...") - err := d.dht.Bootstrap(d.ctx) - if err != nil { - logger.Warn("failed to bootstrap DHT", "error", err) - continue - } +func (d *discovery) advertise() { + ttl := initialAdvertisementTimeout + + for { + select { + case <-time.After(ttl): + logger.Debug("advertising ourselves in the DHT...") + err := d.dht.Bootstrap(d.ctx) + if err != nil { + logger.Warn("failed to bootstrap DHT", "error", err) + continue + } - ttl, err = rd.Advertise(d.ctx, string(d.pid)) - if err != nil { - logger.Debug("failed to advertise in the DHT", "error", err) - ttl = tryAdvertiseTimeout - } - case <-d.ctx.Done(): - return + ttl, err = d.rd.Advertise(d.ctx, string(d.pid)) + if err != nil { + logger.Debug("failed to advertise in the DHT", "error", err) + ttl = tryAdvertiseTimeout } + case <-d.ctx.Done(): + return } - }() + } +} - go func() { - logger.Debug("attempting to find DHT peers...") - peerCh, err := rd.FindPeers(d.ctx, string(d.pid)) - if err != nil { - logger.Warn("failed to begin finding peers via DHT", "err", err) +func (d *discovery) checkPeerCount() { + for { + select { + case <-d.ctx.Done(): return + case <-time.After(connectToPeersTimeout): + if len(d.h.Network().Peers()) > d.minPeers { + continue + } + + ctx, cancel := context.WithTimeout(d.ctx, findPeersTimeout) + defer cancel() + d.findPeers(ctx) } + } +} - peersToTry := make(map[*peer.AddrInfo]struct{}) +func (d *discovery) findPeers(ctx context.Context) { + logger.Debug("attempting to find DHT peers...") + peerCh, err := d.rd.FindPeers(d.ctx, string(d.pid)) + if err != nil { + logger.Warn("failed to begin finding peers via DHT", "err", err) + return + } - for { - select { - case <-d.ctx.Done(): - return - case <-time.After(connectToPeersTimeout): - if len(d.h.Network().Peers()) > d.minPeers { - continue - } + for { + select { + case <-ctx.Done(): + return + case peer := <-peerCh: + if peer.ID == d.h.ID() || peer.ID == "" { + continue + } - // reconnect to peers if peer count is low - for p := range peersToTry { - err = d.h.Connect(d.ctx, *p) - if err != nil { - logger.Trace("failed to connect to discovered peer", "peer", p.ID, "err", err) - delete(peersToTry, p) - } - } - case peer := <-peerCh: - if peer.ID == d.h.ID() || peer.ID == "" { - continue - } + logger.Trace("found new peer via DHT", "peer", peer.ID) - logger.Trace("found new peer via DHT", "peer", peer.ID) - - // found a peer, try to connect if we need more peers - if len(d.h.Network().Peers()) < d.maxPeers { - err = d.h.Connect(d.ctx, peer) - if err != nil { - logger.Trace("failed to connect to discovered peer", "peer", peer.ID, "err", err) - } - } else { - d.h.Peerstore().AddAddrs(peer.ID, peer.Addrs, peerstore.PermanentAddrTTL) - peersToTry[&peer] = struct{}{} + // found a peer, try to connect if we need more peers + if len(d.h.Network().Peers()) < d.maxPeers { + err = d.h.Connect(d.ctx, peer) + if err != nil { + logger.Trace("failed to connect to discovered peer", "peer", peer.ID, "err", err) } + } else { + d.h.Peerstore().AddAddrs(peer.ID, peer.Addrs, peerstore.PermanentAddrTTL) + return } } - }() - - logger.Debug("DHT discovery started!") - return nil + } } diff --git a/dot/network/service.go b/dot/network/service.go index 5c86e153af..effb2b7506 100644 --- a/dot/network/service.go +++ b/dot/network/service.go @@ -45,7 +45,7 @@ const ( blockAnnounceID = "/block-announces/1" transactionsID = "/transactions/1" - maxMessageSize = 1024 * 1024 // 1mb for now + maxMessageSize = 1024 * 63 // 63kb for now ) var ( @@ -139,10 +139,10 @@ func NewService(cfg *Config) (*Service, error) { var bufPool *sizedBufferPool if cfg.noPreAllocate { bufPool = &sizedBufferPool{ - c: make(chan *[maxMessageSize]byte, cfg.MaxPeers*3), + c: make(chan *[maxMessageSize]byte, cfg.MinPeers*3), } } else { - bufPool = newSizedBufferPool((cfg.MaxPeers-cfg.MinPeers)*3/2, (cfg.MaxPeers+1)*3) + bufPool = newSizedBufferPool(cfg.MinPeers*3, cfg.MaxPeers*3) } network := &Service{ @@ -474,20 +474,15 @@ func (s *Service) IsStopped() bool { // SendMessage implementation of interface to handle receiving messages func (s *Service) SendMessage(msg NotificationsMessage) { - if s.host == nil { - return - } - if s.IsStopped() { - return - } - if msg == nil { - logger.Debug("Received nil message from core service") + if s.host == nil || msg == nil || s.IsStopped() { return } + logger.Debug( - "Broadcasting message from core service", + "gossiping message", "host", s.host.id(), "type", msg.Type(), + "message", msg, ) // check if the message is part of a notifications protocol diff --git a/dot/network/utils.go b/dot/network/utils.go index 74935e3e47..771cabda08 100644 --- a/dot/network/utils.go +++ b/dot/network/utils.go @@ -17,7 +17,6 @@ package network import ( - "bufio" crand "crypto/rand" "encoding/hex" "errors" @@ -151,15 +150,20 @@ func uint64ToLEB128(in uint64) []byte { return out } -func readLEB128ToUint64(r *bufio.Reader) (uint64, error) { +func readLEB128ToUint64(r io.Reader, buf []byte) (uint64, error) { + if len(buf) == 0 { + return 0, errors.New("buffer has length 0") + } + var out uint64 var shift uint for { - b, err := r.ReadByte() + _, err := r.Read(buf) if err != nil { return 0, err } + b := buf[0] out |= uint64(0x7F&b) << shift if b&0x80 == 0 { break @@ -175,13 +179,11 @@ func readStream(stream libp2pnetwork.Stream, buf []byte) (int, error) { return 0, errors.New("stream is nil") } - r := bufio.NewReader(stream) - var ( tot int ) - length, err := readLEB128ToUint64(r) + length, err := readLEB128ToUint64(stream, buf[:1]) if err == io.EOF { return 0, err } else if err != nil { @@ -192,21 +194,21 @@ func readStream(stream libp2pnetwork.Stream, buf []byte) (int, error) { return 0, nil // msg length of 0 is allowed, for example transactions handshake } - // TODO: check if length > len(buf), if so probably log.Crit + if length > uint64(len(buf)) { + logger.Warn("received message with size greater than allocated message buffer", "length", length, "buffer size", len(buf)) + _ = stream.Close() + return 0, fmt.Errorf("message size greater than allocated message buffer: got %d", length) + } + if length > maxBlockResponseSize { - logger.Warn("received message with size greater than maxBlockResponseSize, discarding", "length", length) - for { - _, err = r.Discard(int(maxBlockResponseSize)) - if err != nil { - break - } - } + logger.Warn("received message with size greater than maxBlockResponseSize, closing stream", "length", length) + _ = stream.Close() return 0, fmt.Errorf("message size greater than maximum: got %d", length) } tot = 0 for i := 0; i < maxReads; i++ { - n, err := r.Read(buf[tot:]) + n, err := stream.Read(buf[tot:]) if err != nil { return n + tot, err } diff --git a/go.mod b/go.mod index 30a8b0fb26..84bfa0a2dd 100644 --- a/go.mod +++ b/go.mod @@ -50,7 +50,6 @@ require ( github.com/naoina/toml v0.1.2-0.20170918210437-9fafd6967416 github.com/perlin-network/life v0.0.0-20191203030451-05c0e0f7eaea github.com/rs/cors v1.7.0 // indirect - github.com/sirupsen/logrus v1.6.0 github.com/stretchr/testify v1.7.0 github.com/syndtr/goleveldb v1.0.1-0.20200815110645-5c35d600f0ca // indirect github.com/urfave/cli v1.20.0