Skip to content

Commit

Permalink
implement connection gating support: intercept peer, address dials, u…
Browse files Browse the repository at this point in the history
…pgraded conns (#201)
  • Loading branch information
aarshkshah1992 authored May 15, 2020
1 parent 47faf65 commit dc499b7
Show file tree
Hide file tree
Showing 6 changed files with 242 additions and 100 deletions.
4 changes: 2 additions & 2 deletions p2p/net/swarm/addrs.go
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
package swarm

import (
mafilter "github.com/libp2p/go-maddr-filter"
ma "github.com/multiformats/go-multiaddr"
mamask "github.com/whyrusleeping/multiaddr-filter"
)

// http://www.iana.org/assignments/iana-ipv4-special-registry/iana-ipv4-special-registry.xhtml
var lowTimeoutFilters = mafilter.NewFilters()
var lowTimeoutFilters = ma.NewFilters()

func init() {
for _, p := range []string{
Expand Down
87 changes: 51 additions & 36 deletions p2p/net/swarm/swarm.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import (
"sync/atomic"
"time"

"github.com/libp2p/go-libp2p-core/connmgr"
"github.com/libp2p/go-libp2p-core/metrics"
"github.com/libp2p/go-libp2p-core/network"
"github.com/libp2p/go-libp2p-core/peer"
Expand All @@ -19,9 +20,7 @@ import (
"github.com/jbenet/goprocess"
goprocessctx "github.com/jbenet/goprocess/context"

filter "github.com/libp2p/go-maddr-filter"
ma "github.com/multiformats/go-multiaddr"
mafilter "github.com/whyrusleeping/multiaddr-filter"
)

// DialTimeoutLocal is the maximum duration a Dial to local network address
Expand Down Expand Up @@ -87,29 +86,38 @@ type Swarm struct {
dsync *DialSync
backf DialBackoff
limiter *dialLimiter

// filters for addresses that shouldnt be dialed (or accepted)
Filters *filter.Filters
gater connmgr.ConnectionGater

proc goprocess.Process
ctx context.Context
bwc metrics.Reporter
}

// NewSwarm constructs a Swarm
func NewSwarm(ctx context.Context, local peer.ID, peers peerstore.Peerstore, bwc metrics.Reporter) *Swarm {
// NewSwarm constructs a Swarm.
//
// NOTE: go-libp2p will be moving to dependency injection soon. The variadic
// `extra` interface{} parameter facilitates the future migration. Supported
// elements are:
// - connmgr.ConnectionGater
func NewSwarm(ctx context.Context, local peer.ID, peers peerstore.Peerstore, bwc metrics.Reporter, extra ...interface{}) *Swarm {
s := &Swarm{
local: local,
peers: peers,
bwc: bwc,
Filters: filter.NewFilters(),
local: local,
peers: peers,
bwc: bwc,
}

s.conns.m = make(map[peer.ID][]*Conn)
s.listeners.m = make(map[transport.Listener]struct{})
s.transports.m = make(map[int]transport.Transport)
s.notifs.m = make(map[network.Notifiee]struct{})

for _, i := range extra {
switch v := i.(type) {
case connmgr.ConnectionGater:
s.gater = v
}
}

s.dsync = NewDialSync(s.doDial)
s.limiter = newDialLimiter(s.dialAddr)
s.proc = goprocessctx.WithContext(ctx)
Expand Down Expand Up @@ -168,33 +176,46 @@ func (s *Swarm) teardown() error {
return nil
}

// AddAddrFilter adds a multiaddr filter to the set of filters the swarm will use to determine which
// addresses not to dial to.
func (s *Swarm) AddAddrFilter(f string) error {
m, err := mafilter.NewMask(f)
if err != nil {
return err
}

s.Filters.AddDialFilter(m)
return nil
}

// Process returns the Process of the swarm
func (s *Swarm) Process() goprocess.Process {
return s.proc
}

func (s *Swarm) addConn(tc transport.CapableConn, dir network.Direction) (*Conn, error) {
// The underlying transport (or the dialer) *should* filter it's own
// connections but we should double check anyways.
raddr := tc.RemoteMultiaddr()
if s.Filters.AddrBlocked(raddr) {
tc.Close()
return nil, ErrAddrFiltered
var (
p = tc.RemotePeer()
addr = tc.RemoteMultiaddr()
)

if s.gater != nil {
if allow := s.gater.InterceptAddrDial(p, addr); !allow {
err := tc.Close()
if err != nil {
log.Warnf("failed to close connection with peer %s and addr %s; err: %s", p.Pretty(), addr, err)
}
return nil, ErrAddrFiltered
}
}

p := tc.RemotePeer()
stat := network.Stat{Direction: dir}
c := &Conn{
conn: tc,
swarm: s,
stat: stat,
}

// we ONLY check upgraded connections here so we can send them a Disconnect message.
// If we do this in the Upgrader, we will not be able to do this.
if s.gater != nil {
if allow, _ := s.gater.InterceptUpgraded(c); !allow {
// TODO Send disconnect with reason here
err := tc.Close()
if err != nil {
log.Warnf("failed to close connection with peer %s and addr %s; err: %s", p.Pretty(), addr, err)
}
return nil, ErrGaterDisallowedConnection
}
}

// Add the public key.
if pk := tc.RemotePublicKey(); pk != nil {
Expand All @@ -214,12 +235,6 @@ func (s *Swarm) addConn(tc transport.CapableConn, dir network.Direction) (*Conn,
}

// Wrap and register the connection.
stat := network.Stat{Direction: dir}
c := &Conn{
conn: tc,
swarm: s,
stat: stat,
}
c.streams.m = make(map[*Stream]struct{})
s.conns.m[p] = append(s.conns.m[p], c)

Expand Down
17 changes: 14 additions & 3 deletions p2p/net/swarm/swarm_dial.go
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,10 @@ var (
// ErrNoGoodAddresses is returned when we find addresses for a peer but
// can't use any of them.
ErrNoGoodAddresses = errors.New("no good addresses")

// ErrGaterDisallowedConnection is returned when the gater prevents us from
// forming a connection with a peer.
ErrGaterDisallowedConnection = errors.New("gater disallows connection to peer")
)

// DialAttempts governs how many times a goroutine will try to dial a given peer.
Expand Down Expand Up @@ -218,6 +222,11 @@ func (db *DialBackoff) cleanup() {
// This allows us to use various transport protocols, do NAT traversal/relay,
// etc. to achieve connection.
func (s *Swarm) DialPeer(ctx context.Context, p peer.ID) (network.Conn, error) {
if s.gater != nil && !s.gater.InterceptPeerDial(p) {
log.Debugf("gater disallowed outbound connection to peer %s", p.Pretty())
return nil, &DialError{Peer: p, Cause: ErrGaterDisallowedConnection}
}

return s.dialPeer(ctx, p)
}

Expand Down Expand Up @@ -339,7 +348,7 @@ func (s *Swarm) dial(ctx context.Context, p peer.ID) (*Conn, error) {
if len(peerAddrs) == 0 {
return nil, &DialError{Peer: p, Cause: ErrNoAddresses}
}
goodAddrs := s.filterKnownUndialables(peerAddrs)
goodAddrs := s.filterKnownUndialables(p, peerAddrs)
if len(goodAddrs) == 0 {
return nil, &DialError{Peer: p, Cause: ErrNoGoodAddresses}
}
Expand Down Expand Up @@ -393,7 +402,7 @@ func (s *Swarm) dial(ctx context.Context, p peer.ID) (*Conn, error) {
// IPv6 link-local addresses, addresses without a dial-capable transport,
// and addresses that we know to be our own.
// This is an optimization to avoid wasting time on dials that we know are going to fail.
func (s *Swarm) filterKnownUndialables(addrs []ma.Multiaddr) []ma.Multiaddr {
func (s *Swarm) filterKnownUndialables(p peer.ID, addrs []ma.Multiaddr) []ma.Multiaddr {
lisAddrs, _ := s.InterfaceListenAddresses()
var ourAddrs []ma.Multiaddr
for _, addr := range lisAddrs {
Expand All @@ -409,7 +418,9 @@ func (s *Swarm) filterKnownUndialables(addrs []ma.Multiaddr) []ma.Multiaddr {
s.canDial,
// TODO: Consider allowing link-local addresses
addrutil.AddrOverNonLocalIP,
addrutil.FilterNeg(s.Filters.AddrBlocked),
func(addr ma.Multiaddr) bool {
return s.gater == nil || s.gater.InterceptAddrDial(p, addr)
},
)
}

Expand Down
1 change: 1 addition & 0 deletions p2p/net/swarm/swarm_listen.go
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,7 @@ func (s *Swarm) AddListenAddr(a ma.Multiaddr) error {
}
return
}

log.Debugf("swarm listener accepted connection: %s", c)
s.refs.Add(1)
go func() {
Expand Down
146 changes: 96 additions & 50 deletions p2p/net/swarm/swarm_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,20 +5,21 @@ import (
"context"
"fmt"
"io"
"net"
"sync"
"testing"
"time"

logging "github.com/ipfs/go-log"
"github.com/libp2p/go-libp2p-core/control"
"github.com/libp2p/go-libp2p-core/network"
"github.com/libp2p/go-libp2p-core/peer"
"github.com/libp2p/go-libp2p-core/peerstore"

ma "github.com/multiformats/go-multiaddr"

. "github.com/libp2p/go-libp2p-swarm"
. "github.com/libp2p/go-libp2p-swarm/testing"

logging "github.com/ipfs/go-log"
ma "github.com/multiformats/go-multiaddr"
"github.com/stretchr/testify/require"
)

var log = logging.Logger("swarm_test")
Expand Down Expand Up @@ -280,60 +281,105 @@ func TestConnHandler(t *testing.T) {
}
}

func TestAddrBlocking(t *testing.T) {
func TestConnectionGating(t *testing.T) {
ctx := context.Background()
swarms := makeSwarms(ctx, t, 2)

swarms[0].SetConnHandler(func(conn network.Conn) {
t.Errorf("no connections should happen! -- %s", conn)
})

_, block, err := net.ParseCIDR("127.0.0.1/8")
if err != nil {
t.Fatal(err)
}

swarms[1].Filters.AddDialFilter(block)

swarms[1].Peerstore().AddAddr(swarms[0].LocalPeer(), swarms[0].ListenAddresses()[0], peerstore.PermanentAddrTTL)
_, err = swarms[1].DialPeer(ctx, swarms[0].LocalPeer())
if err == nil {
t.Fatal("dial should have failed")
}

swarms[0].Peerstore().AddAddr(swarms[1].LocalPeer(), swarms[1].ListenAddresses()[0], peerstore.PermanentAddrTTL)
_, err = swarms[0].DialPeer(ctx, swarms[1].LocalPeer())
if err == nil {
t.Fatal("dial should have failed")
tcs := map[string]struct {
p1Gater func(gater *MockConnectionGater) *MockConnectionGater
p2Gater func(gater *MockConnectionGater) *MockConnectionGater

p1ConnectednessToP2 network.Connectedness
p2ConnectednessToP1 network.Connectedness
isP1OutboundErr bool
}{
"no gating": {
p1ConnectednessToP2: network.Connected,
p2ConnectednessToP1: network.Connected,
isP1OutboundErr: false,
},
"p1 gates outbound peer dial": {
p1Gater: func(c *MockConnectionGater) *MockConnectionGater {
c.PeerDial = func(p peer.ID) bool { return false }
return c
},
p1ConnectednessToP2: network.NotConnected,
p2ConnectednessToP1: network.NotConnected,
isP1OutboundErr: true,
},
"p1 gates outbound addr dialing": {
p1Gater: func(c *MockConnectionGater) *MockConnectionGater {
c.Dial = func(p peer.ID, addr ma.Multiaddr) bool { return false }
return c
},
p1ConnectednessToP2: network.NotConnected,
p2ConnectednessToP1: network.NotConnected,
isP1OutboundErr: true,
},
"p2 gates inbound peer dial before securing": {
p2Gater: func(c *MockConnectionGater) *MockConnectionGater {
c.Accept = func(c network.ConnMultiaddrs) bool { return false }
return c
},
p1ConnectednessToP2: network.NotConnected,
p2ConnectednessToP1: network.NotConnected,
isP1OutboundErr: true,
},
"p2 gates inbound peer dial before multiplexing": {
p1Gater: func(c *MockConnectionGater) *MockConnectionGater {
c.Secured = func(network.Direction, peer.ID, network.ConnMultiaddrs) bool { return false }
return c
},
p1ConnectednessToP2: network.NotConnected,
p2ConnectednessToP1: network.NotConnected,
isP1OutboundErr: true,
},
"p2 gates inbound peer dial after upgrading": {
p1Gater: func(c *MockConnectionGater) *MockConnectionGater {
c.Upgraded = func(c network.Conn) (bool, control.DisconnectReason) { return false, 0 }
return c
},
p1ConnectednessToP2: network.NotConnected,
p2ConnectednessToP1: network.NotConnected,
isP1OutboundErr: true,
},
"p2 gates outbound dials": {
p2Gater: func(c *MockConnectionGater) *MockConnectionGater {
c.PeerDial = func(p peer.ID) bool { return false }
return c
},
p1ConnectednessToP2: network.Connected,
p2ConnectednessToP1: network.Connected,
isP1OutboundErr: false,
},
}
}

func TestFilterBounds(t *testing.T) {
ctx := context.Background()
swarms := makeSwarms(ctx, t, 2)
for n, tc := range tcs {
t.Run(n, func(t *testing.T) {
p1Gater := DefaultMockConnectionGater()
p2Gater := DefaultMockConnectionGater()
if tc.p1Gater != nil {
p1Gater = tc.p1Gater(p1Gater)
}
if tc.p2Gater != nil {
p2Gater = tc.p2Gater(p2Gater)
}

conns := make(chan struct{}, 8)
swarms[0].SetConnHandler(func(conn network.Conn) {
conns <- struct{}{}
})
sw1 := GenSwarm(t, ctx, OptConnGater(p1Gater))
sw2 := GenSwarm(t, ctx, OptConnGater(p2Gater))

// Address that we wont be dialing from
_, block, err := net.ParseCIDR("192.0.0.1/8")
if err != nil {
t.Fatal(err)
}
p1 := sw1.LocalPeer()
p2 := sw2.LocalPeer()
sw1.Peerstore().AddAddr(p2, sw2.ListenAddresses()[0], peerstore.PermanentAddrTTL)
// 1 -> 2
_, err := sw1.DialPeer(ctx, p2)

// set filter on both sides, shouldnt matter
swarms[1].Filters.AddDialFilter(block)
swarms[0].Filters.AddDialFilter(block)
require.Equal(t, tc.isP1OutboundErr, err != nil, n)
require.Equal(t, tc.p1ConnectednessToP2, sw1.Connectedness(p2), n)

connectSwarms(t, ctx, swarms)
require.Eventually(t, func() bool {
return tc.p2ConnectednessToP1 == sw2.Connectedness(p1)
}, 2*time.Second, 100*time.Millisecond, n)
})

select {
case <-time.After(time.Second):
t.Fatal("should have gotten connection")
case <-conns:
t.Log("got connect")
}
}

Expand Down
Loading

0 comments on commit dc499b7

Please sign in to comment.