diff --git a/swarm.go b/swarm.go index f7c57717..915fe2ac 100644 --- a/swarm.go +++ b/swarm.go @@ -40,6 +40,9 @@ var ErrSwarmClosed = errors.New("swarm closed") // transport is misbehaving. var ErrAddrFiltered = errors.New("address filtered") +// ErrPeerLimitExceeded is returned when we exceed the specified peer limit. +var ErrPeerLimitExceeded = errors.New("number of peers over the peer limit, rejecting connection") + // ErrDialTimeout is returned when one a dial times out due to the global timeout var ErrDialTimeout = errors.New("dial timed out") @@ -83,6 +86,9 @@ type Swarm struct { connh atomic.Value streamh atomic.Value + //peerLimiter + peerLimit int + // dialing helpers dsync *DialSync backf DialBackoff @@ -97,7 +103,7 @@ type Swarm struct { } // NewSwarm constructs a Swarm -func NewSwarm(ctx context.Context, local peer.ID, peers peerstore.Peerstore, bwc metrics.Reporter) *Swarm { +func NewSwarm(ctx context.Context, local peer.ID, peers peerstore.Peerstore, bwc metrics.Reporter, opts ...Option) *Swarm { s := &Swarm{ local: local, peers: peers, @@ -115,6 +121,8 @@ func NewSwarm(ctx context.Context, local peer.ID, peers peerstore.Peerstore, bwc s.proc = goprocessctx.WithContextAndTeardown(ctx, s.teardown) s.ctx = goprocessctx.OnClosingContext(s.proc) + s.ApplyOptions(opts...) + return s } @@ -181,6 +189,20 @@ func (s *Swarm) Process() goprocess.Process { } func (s *Swarm) addConn(tc transport.CapableConn, dir network.Direction) (*Conn, error) { + s.conns.RLock() + numOfPeers := len(s.conns.m) + s.conns.RUnlock() + + p := tc.RemotePeer() + nonZeroLimit := s.peerLimit > 0 + + // Check if the connection would exceed our specified peer limit. + if nonZeroLimit && numOfPeers >= int(s.peerLimit) && s.Connectedness(p) != network.Connected { + tc.Close() + log.Debugf("rejecting connection from peer %s", p) + return nil, ErrPeerLimitExceeded + } + // The underlying transport (or the dialer) *should* filter it's own // connections but we should double check anyways. raddr := tc.RemoteMultiaddr() @@ -189,8 +211,6 @@ func (s *Swarm) addConn(tc transport.CapableConn, dir network.Direction) (*Conn, return nil, ErrAddrFiltered } - p := tc.RemotePeer() - // Add the public key. if pk := tc.RemotePublicKey(); pk != nil { s.peers.AddPubKey(p, pk) diff --git a/swarm_options.go b/swarm_options.go new file mode 100644 index 00000000..f7c82940 --- /dev/null +++ b/swarm_options.go @@ -0,0 +1,16 @@ +package swarm + +// Option is a Swarm Option that can be given to a Swarm Constructor(`NewSwarm`). +type Option func(s *Swarm) + +func (s *Swarm) ApplyOptions(opts ...Option) { + for _, opt := range opts { + opt(s) + } +} + +func SwarmPeerLimit(limit int) Option { + return func(s *Swarm) { + s.peerLimit = limit + } +} diff --git a/swarm_test.go b/swarm_test.go index b155373f..44c659ab 100644 --- a/swarm_test.go +++ b/swarm_test.go @@ -61,7 +61,7 @@ func makeDialOnlySwarm(ctx context.Context, t *testing.T) *Swarm { return swarm } -func makeSwarms(ctx context.Context, t *testing.T, num int, opts ...Option) []*Swarm { +func makeSwarms(ctx context.Context, t *testing.T, num int, opts ...TestOption) []*Swarm { swarms := make([]*Swarm, 0, num) for i := 0; i < num; i++ { @@ -73,14 +73,20 @@ func makeSwarms(ctx context.Context, t *testing.T, num int, opts ...Option) []*S return swarms } +func dialAddress(ctx context.Context, s *Swarm, dst peer.ID, addr ma.Multiaddr) error { + s.Peerstore().AddAddr(dst, addr, peerstore.PermanentAddrTTL) + if _, err := s.DialPeer(ctx, dst); err != nil { + return fmt.Errorf("error swarm dialing to peer: %v", err) + } + return nil +} + func connectSwarms(t *testing.T, ctx context.Context, swarms []*Swarm) { var wg sync.WaitGroup connect := func(s *Swarm, dst peer.ID, addr ma.Multiaddr) { - // TODO: make a DialAddr func. - s.Peerstore().AddAddr(dst, addr, peerstore.PermanentAddrTTL) - if _, err := s.DialPeer(ctx, dst); err != nil { - t.Fatal("error swarm dialing to peer", err) + if err := dialAddress(ctx, s, dst, addr); err != nil { + t.Fatal(err) } wg.Done() } @@ -346,3 +352,46 @@ func TestNoDial(t *testing.T) { t.Fatal("should have failed with ErrNoConn") } } + +func TestPeerLimit(t *testing.T) { + t.Parallel() + + ctx := context.Background() + swarms := makeSwarms(ctx, t, 5) + peerLimit := 3 + limitOpt := SwarmPeerLimit(peerLimit) + swarms[0].ApplyOptions(limitOpt) + + gotconn := make(chan struct{}, 10) + swarms[0].SetConnHandler(func(conn network.Conn) { + gotconn <- struct{}{} + }) + peerRejected := false + for i := 1; i < len(swarms); i++ { + err := dialAddress(ctx, swarms[0], swarms[i].LocalPeer(), swarms[i].ListenAddresses()[0]) + if err != nil { + peerRejected = true + } + + } + <-time.After(time.Millisecond) + + if !peerRejected { + t.Error("No peer was rejected despite peer limit being exceeded") + } + + swarms[0].SetConnHandler(nil) + for i := 0; i < peerLimit; i++ { + select { + case <-time.After(time.Second): + t.Fatal("failed to get connections") + case <-gotconn: + } + } + + select { + case <-gotconn: + t.Fatalf("should have connected to %d swarms, got an extra.", peerLimit) + default: + } +} diff --git a/testing/testing.go b/testing/testing.go index 5e396024..ad2e33fe 100644 --- a/testing/testing.go +++ b/testing/testing.go @@ -7,7 +7,7 @@ import ( "github.com/libp2p/go-libp2p-core/metrics" "github.com/libp2p/go-libp2p-core/network" "github.com/libp2p/go-libp2p-core/peerstore" - "github.com/libp2p/go-libp2p-testing/net" + tnet "github.com/libp2p/go-libp2p-testing/net" "github.com/libp2p/go-tcp-transport" goprocess "github.com/jbenet/goprocess" @@ -26,16 +26,16 @@ type config struct { dialOnly bool } -// Option is an option that can be passed when constructing a test swarm. -type Option func(*testing.T, *config) +// TestOption is an option that can be passed when constructing a test swarm. +type TestOption func(*testing.T, *config) // OptDisableReuseport disables reuseport in this test swarm. -var OptDisableReuseport Option = func(_ *testing.T, c *config) { +var OptDisableReuseport TestOption = func(_ *testing.T, c *config) { c.disableReuseport = true } // OptDialOnly prevents the test swarm from listening. -var OptDialOnly Option = func(_ *testing.T, c *config) { +var OptDialOnly TestOption = func(_ *testing.T, c *config) { c.dialOnly = true } @@ -61,7 +61,7 @@ func GenUpgrader(n *swarm.Swarm) *tptu.Upgrader { } // GenSwarm generates a new test swarm. -func GenSwarm(t *testing.T, ctx context.Context, opts ...Option) *swarm.Swarm { +func GenSwarm(t *testing.T, ctx context.Context, opts ...TestOption) *swarm.Swarm { var cfg config for _, o := range opts { o(t, &cfg)