Skip to content
This repository has been archived by the owner on May 26, 2022. It is now read-only.

Add Peer Limit to Swarm #146

Closed
wants to merge 15 commits into from
26 changes: 23 additions & 3 deletions swarm.go
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,9 @@ var ErrSwarmClosed = errors.New("swarm closed")
// transport is misbehaving.
var ErrAddrFiltered = errors.New("address filtered")

// ErrPeerLimitExceeded is returned when we exceed the specified peer limit.
var ErrPeerLimitExceeded = errors.New("number of peers over the peer limit, rejecting connection")

// ErrDialTimeout is returned when one a dial times out due to the global timeout
var ErrDialTimeout = errors.New("dial timed out")

Expand Down Expand Up @@ -83,6 +86,9 @@ type Swarm struct {
connh atomic.Value
streamh atomic.Value

//peerLimiter
peerLimit int

// dialing helpers
dsync *DialSync
backf DialBackoff
Expand All @@ -97,7 +103,7 @@ type Swarm struct {
}

// NewSwarm constructs a Swarm
func NewSwarm(ctx context.Context, local peer.ID, peers peerstore.Peerstore, bwc metrics.Reporter) *Swarm {
func NewSwarm(ctx context.Context, local peer.ID, peers peerstore.Peerstore, bwc metrics.Reporter, opts ...Option) *Swarm {
s := &Swarm{
local: local,
peers: peers,
Expand All @@ -115,6 +121,8 @@ func NewSwarm(ctx context.Context, local peer.ID, peers peerstore.Peerstore, bwc
s.proc = goprocessctx.WithContextAndTeardown(ctx, s.teardown)
s.ctx = goprocessctx.OnClosingContext(s.proc)

s.ApplyOptions(opts...)

return s
}

Expand Down Expand Up @@ -181,6 +189,20 @@ func (s *Swarm) Process() goprocess.Process {
}

func (s *Swarm) addConn(tc transport.CapableConn, dir network.Direction) (*Conn, error) {
s.conns.RLock()
numOfPeers := len(s.conns.m)
s.conns.RUnlock()

p := tc.RemotePeer()
nonZeroLimit := s.peerLimit > 0

// Check if the connection would exceed our specified peer limit.
if nonZeroLimit && numOfPeers >= int(s.peerLimit) && s.Connectedness(p) != network.Connected {
tc.Close()
vyzo marked this conversation as resolved.
Show resolved Hide resolved
log.Debugf("rejecting connection from peer %s", p)
return nil, ErrPeerLimitExceeded
}

// The underlying transport (or the dialer) *should* filter it's own
// connections but we should double check anyways.
raddr := tc.RemoteMultiaddr()
Expand All @@ -189,8 +211,6 @@ func (s *Swarm) addConn(tc transport.CapableConn, dir network.Direction) (*Conn,
return nil, ErrAddrFiltered
}

p := tc.RemotePeer()

// Add the public key.
if pk := tc.RemotePublicKey(); pk != nil {
s.peers.AddPubKey(p, pk)
Expand Down
16 changes: 16 additions & 0 deletions swarm_options.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
package swarm

// Option is a Swarm Option that can be given to a Swarm Constructor(`NewSwarm`).
type Option func(s *Swarm)
vyzo marked this conversation as resolved.
Show resolved Hide resolved

func (s *Swarm) ApplyOptions(opts ...Option) {
for _, opt := range opts {
opt(s)
}
}

func SwarmPeerLimit(limit int) Option {
return func(s *Swarm) {
s.peerLimit = limit
}
}
59 changes: 54 additions & 5 deletions swarm_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ func makeDialOnlySwarm(ctx context.Context, t *testing.T) *Swarm {
return swarm
}

func makeSwarms(ctx context.Context, t *testing.T, num int, opts ...Option) []*Swarm {
func makeSwarms(ctx context.Context, t *testing.T, num int, opts ...TestOption) []*Swarm {
vyzo marked this conversation as resolved.
Show resolved Hide resolved
swarms := make([]*Swarm, 0, num)

for i := 0; i < num; i++ {
Expand All @@ -73,14 +73,20 @@ func makeSwarms(ctx context.Context, t *testing.T, num int, opts ...Option) []*S
return swarms
}

func dialAddress(ctx context.Context, s *Swarm, dst peer.ID, addr ma.Multiaddr) error {
s.Peerstore().AddAddr(dst, addr, peerstore.PermanentAddrTTL)
if _, err := s.DialPeer(ctx, dst); err != nil {
return fmt.Errorf("error swarm dialing to peer: %v", err)
}
return nil
}

func connectSwarms(t *testing.T, ctx context.Context, swarms []*Swarm) {

var wg sync.WaitGroup
connect := func(s *Swarm, dst peer.ID, addr ma.Multiaddr) {
// TODO: make a DialAddr func.
s.Peerstore().AddAddr(dst, addr, peerstore.PermanentAddrTTL)
if _, err := s.DialPeer(ctx, dst); err != nil {
t.Fatal("error swarm dialing to peer", err)
if err := dialAddress(ctx, s, dst, addr); err != nil {
t.Fatal(err)
}
wg.Done()
}
Expand Down Expand Up @@ -346,3 +352,46 @@ func TestNoDial(t *testing.T) {
t.Fatal("should have failed with ErrNoConn")
}
}

func TestPeerLimit(t *testing.T) {
t.Parallel()

ctx := context.Background()
swarms := makeSwarms(ctx, t, 5)
peerLimit := 3
limitOpt := SwarmPeerLimit(peerLimit)
swarms[0].ApplyOptions(limitOpt)

gotconn := make(chan struct{}, 10)
swarms[0].SetConnHandler(func(conn network.Conn) {
gotconn <- struct{}{}
})
peerRejected := false
for i := 1; i < len(swarms); i++ {
err := dialAddress(ctx, swarms[0], swarms[i].LocalPeer(), swarms[i].ListenAddresses()[0])
if err != nil {
peerRejected = true
}

}
<-time.After(time.Millisecond)

if !peerRejected {
t.Error("No peer was rejected despite peer limit being exceeded")
}

swarms[0].SetConnHandler(nil)
for i := 0; i < peerLimit; i++ {
select {
case <-time.After(time.Second):
t.Fatal("failed to get connections")
case <-gotconn:
}
}

select {
case <-gotconn:
t.Fatalf("should have connected to %d swarms, got an extra.", peerLimit)
default:
}
}
12 changes: 6 additions & 6 deletions testing/testing.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ import (
"github.com/libp2p/go-libp2p-core/metrics"
"github.com/libp2p/go-libp2p-core/network"
"github.com/libp2p/go-libp2p-core/peerstore"
"github.com/libp2p/go-libp2p-testing/net"
tnet "github.com/libp2p/go-libp2p-testing/net"
"github.com/libp2p/go-tcp-transport"

goprocess "github.com/jbenet/goprocess"
Expand All @@ -26,16 +26,16 @@ type config struct {
dialOnly bool
}

// Option is an option that can be passed when constructing a test swarm.
type Option func(*testing.T, *config)
// TestOption is an option that can be passed when constructing a test swarm.
type TestOption func(*testing.T, *config)

// OptDisableReuseport disables reuseport in this test swarm.
var OptDisableReuseport Option = func(_ *testing.T, c *config) {
var OptDisableReuseport TestOption = func(_ *testing.T, c *config) {
c.disableReuseport = true
}

// OptDialOnly prevents the test swarm from listening.
var OptDialOnly Option = func(_ *testing.T, c *config) {
var OptDialOnly TestOption = func(_ *testing.T, c *config) {
c.dialOnly = true
}

Expand All @@ -61,7 +61,7 @@ func GenUpgrader(n *swarm.Swarm) *tptu.Upgrader {
}

// GenSwarm generates a new test swarm.
func GenSwarm(t *testing.T, ctx context.Context, opts ...Option) *swarm.Swarm {
func GenSwarm(t *testing.T, ctx context.Context, opts ...TestOption) *swarm.Swarm {
var cfg config
for _, o := range opts {
o(t, &cfg)
Expand Down