Skip to content

Commit

Permalink
p2p: use netip.Addr where possible (ethereum#29891)
Browse files Browse the repository at this point in the history
enode.Node was recently changed to store a cache of endpoint information. The IP address in the cache is a netip.Addr. I chose that type over net.IP because it is just better. netip.Addr is meant to be used as a value type. Copying it does not allocate, it can be compared with ==, and can be used as a map key.

This PR changes most uses of Node.IP() into Node.IPAddr(), which returns the cached value directly without allocating.
While there are still some public APIs left where net.IP is used, I have converted all code used internally by p2p/discover to the new types. So this does change some public Go API, but hopefully not APIs any external code actually uses.

There weren't supposed to be any semantic differences resulting from this refactoring, however it does introduce one: In package p2p/netutil we treated the 0.0.0.0/8 network (addresses 0.x.y.z) as LAN, but netip.Addr.IsPrivate() doesn't. The treatment of this particular IP address range is controversial, with some software supporting it and others not. IANA lists it as special-purpose and invalid as a destination for a long time, so I don't know why I put it into the LAN list. It has now been marked as special in p2p/netutil as well.
  • Loading branch information
fjl authored and stwiname committed Sep 9, 2024
1 parent 9d6ed30 commit c0b26f8
Show file tree
Hide file tree
Showing 23 changed files with 392 additions and 313 deletions.
3 changes: 2 additions & 1 deletion cmd/devp2p/internal/ethtest/conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,8 @@ func (s *Suite) dial() (*Conn, error) {
// dialAs attempts to dial a given node and perform a handshake using the given
// private key.
func (s *Suite) dialAs(key *ecdsa.PrivateKey) (*Conn, error) {
fd, err := net.Dial("tcp", fmt.Sprintf("%v:%d", s.Dest.IP(), s.Dest.TCP()))
tcpEndpoint, _ := s.Dest.TCPEndpoint()
fd, err := net.Dial("tcp", tcpEndpoint.String())
if err != nil {
return nil, err
}
Expand Down
6 changes: 4 additions & 2 deletions cmd/devp2p/internal/v4test/framework.go
Original file line number Diff line number Diff line change
Expand Up @@ -53,10 +53,12 @@ func newTestEnv(remote string, listen1, listen2 string) *testenv {
if err != nil {
panic(err)
}
if node.IP() == nil || node.UDP() == 0 {
if !node.IPAddr().IsValid() || node.UDP() == 0 {
var ip net.IP
var tcpPort, udpPort int
if ip = node.IP(); ip == nil {
if node.IPAddr().IsValid() {
ip = node.IPAddr().AsSlice()
} else {
ip = net.ParseIP("127.0.0.1")
}
if tcpPort = node.TCP(); tcpPort == 0 {
Expand Down
6 changes: 3 additions & 3 deletions cmd/devp2p/nodesetcmd.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ package main
import (
"errors"
"fmt"
"net"
"net/netip"
"sort"
"strconv"
"strings"
Expand Down Expand Up @@ -205,11 +205,11 @@ func trueFilter(args []string) (nodeFilter, error) {
}

func ipFilter(args []string) (nodeFilter, error) {
_, cidr, err := net.ParseCIDR(args[0])
prefix, err := netip.ParsePrefix(args[0])
if err != nil {
return nil, err
}
f := func(n nodeJSON) bool { return cidr.Contains(n.N.IP()) }
f := func(n nodeJSON) bool { return prefix.Contains(n.N.IPAddr()) }
return f, nil
}

Expand Down
6 changes: 5 additions & 1 deletion cmd/devp2p/rlpxcmd.go
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,11 @@ var (

func rlpxPing(ctx *cli.Context) error {
n := getNodeArg(ctx)
fd, err := net.Dial("tcp", fmt.Sprintf("%v:%d", n.IP(), n.TCP()))
tcpEndpoint, ok := n.TCPEndpoint()
if !ok {
return fmt.Errorf("node has no TCP endpoint")
}
fd, err := net.Dial("tcp", tcpEndpoint.String())
if err != nil {
return err
}
Expand Down
27 changes: 13 additions & 14 deletions p2p/dial.go
Original file line number Diff line number Diff line change
Expand Up @@ -65,11 +65,8 @@ type tcpDialer struct {
}

func (t tcpDialer) Dial(ctx context.Context, dest *enode.Node) (net.Conn, error) {
return t.d.DialContext(ctx, "tcp", nodeAddr(dest).String())
}

func nodeAddr(n *enode.Node) net.Addr {
return &net.TCPAddr{IP: n.IP(), Port: n.TCP()}
addr, _ := dest.TCPEndpoint()
return t.d.DialContext(ctx, "tcp", addr.String())
}

// checkDial errors:
Expand Down Expand Up @@ -243,7 +240,7 @@ loop:
select {
case node := <-nodesCh:
if err := d.checkDial(node); err != nil {
d.log.Trace("Discarding dial candidate", "id", node.ID(), "ip", node.IP(), "reason", err)
d.log.Trace("Discarding dial candidate", "id", node.ID(), "ip", node.IPAddr(), "reason", err)
} else {
d.startDial(newDialTask(node, dynDialedConn))
}
Expand Down Expand Up @@ -277,7 +274,7 @@ loop:
case node := <-d.addStaticCh:
id := node.ID()
_, exists := d.static[id]
d.log.Trace("Adding static node", "id", id, "ip", node.IP(), "added", !exists)
d.log.Trace("Adding static node", "id", id, "ip", node.IPAddr(), "added", !exists)
if exists {
continue loop
}
Expand Down Expand Up @@ -376,7 +373,7 @@ func (d *dialScheduler) checkDial(n *enode.Node) error {
if n.ID() == d.self {
return errSelf
}
if n.IP() != nil && n.TCP() == 0 {
if n.IPAddr().IsValid() && n.TCP() == 0 {
// This check can trigger if a non-TCP node is found
// by discovery. If there is no IP, the node is a static
// node and the actual endpoint will be resolved later in dialTask.
Expand All @@ -388,7 +385,7 @@ func (d *dialScheduler) checkDial(n *enode.Node) error {
if _, ok := d.peers[n.ID()]; ok {
return errAlreadyConnected
}
if d.netRestrict != nil && !d.netRestrict.Contains(n.IP()) {
if d.netRestrict != nil && !d.netRestrict.ContainsAddr(n.IPAddr()) {
return errNetRestrict
}
if d.history.contains(string(n.ID().Bytes())) {
Expand Down Expand Up @@ -439,7 +436,7 @@ func (d *dialScheduler) removeFromStaticPool(idx int) {
// startDial runs the given dial task in a separate goroutine.
func (d *dialScheduler) startDial(task *dialTask) {
node := task.dest()
d.log.Trace("Starting p2p dial", "id", node.ID(), "ip", node.IP(), "flag", task.flags)
d.log.Trace("Starting p2p dial", "id", node.ID(), "ip", node.IPAddr(), "flag", task.flags)
hkey := string(node.ID().Bytes())
d.history.add(hkey, d.clock.Now().Add(dialHistoryExpiration))
d.dialing[node.ID()] = task
Expand Down Expand Up @@ -492,7 +489,7 @@ func (t *dialTask) run(d *dialScheduler) {
}

func (t *dialTask) needResolve() bool {
return t.flags&staticDialedConn != 0 && t.dest().IP() == nil
return t.flags&staticDialedConn != 0 && !t.dest().IPAddr().IsValid()
}

// resolve attempts to find the current endpoint for the destination
Expand Down Expand Up @@ -526,7 +523,8 @@ func (t *dialTask) resolve(d *dialScheduler) bool {
// The node was found.
t.resolveDelay = initialResolveDelay
t.destPtr.Store(resolved)
d.log.Debug("Resolved node", "id", resolved.ID(), "addr", &net.TCPAddr{IP: resolved.IP(), Port: resolved.TCP()})
resAddr, _ := resolved.TCPEndpoint()
d.log.Debug("Resolved node", "id", resolved.ID(), "addr", resAddr)
return true
}

Expand All @@ -535,7 +533,8 @@ func (t *dialTask) dial(d *dialScheduler, dest *enode.Node) error {
dialMeter.Mark(1)
fd, err := d.dialer.Dial(d.ctx, dest)
if err != nil {
d.log.Trace("Dial error", "id", dest.ID(), "addr", nodeAddr(dest), "conn", t.flags, "err", cleanupDialErr(err))
addr, _ := dest.TCPEndpoint()
d.log.Trace("Dial error", "id", dest.ID(), "addr", addr, "conn", t.flags, "err", cleanupDialErr(err))
dialConnectionError.Mark(1)
return &dialError{err}
}
Expand All @@ -545,7 +544,7 @@ func (t *dialTask) dial(d *dialScheduler, dest *enode.Node) error {
func (t *dialTask) String() string {
node := t.dest()
id := node.ID()
return fmt.Sprintf("%v %x %v:%d", t.flags, id[:8], node.IP(), node.TCP())
return fmt.Sprintf("%v %x %v:%d", t.flags, id[:8], node.IPAddr(), node.TCP())
}

func cleanupDialErr(err error) error {
Expand Down
52 changes: 26 additions & 26 deletions p2p/discover/table.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ package discover
import (
"context"
"fmt"
"net"
"net/netip"
"slices"
"sync"
"time"
Expand Down Expand Up @@ -207,8 +207,8 @@ func (tab *Table) setFallbackNodes(nodes []*enode.Node) error {
if err := n.ValidateComplete(); err != nil {
return fmt.Errorf("bad bootstrap node %q: %v", n, err)
}
if tab.cfg.NetRestrict != nil && !tab.cfg.NetRestrict.Contains(n.IP()) {
tab.log.Error("Bootstrap node filtered by netrestrict", "id", n.ID(), "ip", n.IP())
if tab.cfg.NetRestrict != nil && !tab.cfg.NetRestrict.ContainsAddr(n.IPAddr()) {
tab.log.Error("Bootstrap node filtered by netrestrict", "id", n.ID(), "ip", n.IPAddr())
continue
}
nursery = append(nursery, n)
Expand Down Expand Up @@ -448,7 +448,7 @@ func (tab *Table) loadSeedNodes() {
for i := range seeds {
seed := seeds[i]
if tab.log.Enabled(context.Background(), log.LevelTrace) {
age := time.Since(tab.db.LastPongReceived(seed.ID(), seed.IP()))
age := time.Since(tab.db.LastPongReceived(seed.ID(), seed.IPAddr()))
addr, _ := seed.UDPEndpoint()
tab.log.Trace("Found seed node in database", "id", seed.ID(), "addr", addr, "age", age)
}
Expand All @@ -474,31 +474,31 @@ func (tab *Table) bucketAtDistance(d int) *bucket {
return tab.buckets[d-bucketMinDistance-1]
}

func (tab *Table) addIP(b *bucket, ip net.IP) bool {
if len(ip) == 0 {
func (tab *Table) addIP(b *bucket, ip netip.Addr) bool {
if !ip.IsValid() || ip.IsUnspecified() {
return false // Nodes without IP cannot be added.
}
if netutil.IsLAN(ip) {
if netutil.AddrIsLAN(ip) {
return true
}
if !tab.ips.Add(ip) {
if !tab.ips.AddAddr(ip) {
tab.log.Debug("IP exceeds table limit", "ip", ip)
return false
}
if !b.ips.Add(ip) {
if !b.ips.AddAddr(ip) {
tab.log.Debug("IP exceeds bucket limit", "ip", ip)
tab.ips.Remove(ip)
tab.ips.RemoveAddr(ip)
return false
}
return true
}

func (tab *Table) removeIP(b *bucket, ip net.IP) {
if netutil.IsLAN(ip) {
func (tab *Table) removeIP(b *bucket, ip netip.Addr) {
if netutil.AddrIsLAN(ip) {
return
}
tab.ips.Remove(ip)
b.ips.Remove(ip)
tab.ips.RemoveAddr(ip)
b.ips.RemoveAddr(ip)
}

// handleAddNode adds the node in the request to the table, if there is space.
Expand All @@ -524,7 +524,7 @@ func (tab *Table) handleAddNode(req addNodeOp) bool {
tab.addReplacement(b, req.node)
return false
}
if !tab.addIP(b, req.node.IP()) {
if !tab.addIP(b, req.node.IPAddr()) {
// Can't add: IP limit reached.
return false
}
Expand All @@ -547,15 +547,15 @@ func (tab *Table) addReplacement(b *bucket, n *enode.Node) {
// TODO: update ENR
return
}
if !tab.addIP(b, n.IP()) {
if !tab.addIP(b, n.IPAddr()) {
return
}

wn := &tableNode{Node: n, addedToTable: time.Now()}
var removed *tableNode
b.replacements, removed = pushNode(b.replacements, wn, maxReplacements)
if removed != nil {
tab.removeIP(b, removed.IP())
tab.removeIP(b, removed.IPAddr())
}
}

Expand Down Expand Up @@ -595,20 +595,20 @@ func (tab *Table) deleteInBucket(b *bucket, id enode.ID) *tableNode {
// Remove the node.
n := b.entries[index]
b.entries = slices.Delete(b.entries, index, index+1)
tab.removeIP(b, n.IP())
tab.removeIP(b, n.IPAddr())
tab.nodeRemoved(b, n)

// Add replacement.
if len(b.replacements) == 0 {
tab.log.Debug("Removed dead node", "b", b.index, "id", n.ID(), "ip", n.IP())
tab.log.Debug("Removed dead node", "b", b.index, "id", n.ID(), "ip", n.IPAddr())
return nil
}
rindex := tab.rand.Intn(len(b.replacements))
rep := b.replacements[rindex]
b.replacements = slices.Delete(b.replacements, rindex, rindex+1)
b.entries = append(b.entries, rep)
tab.nodeAdded(b, rep)
tab.log.Debug("Replaced dead node", "b", b.index, "id", n.ID(), "ip", n.IP(), "r", rep.ID(), "rip", rep.IP())
tab.log.Debug("Replaced dead node", "b", b.index, "id", n.ID(), "ip", n.IPAddr(), "r", rep.ID(), "rip", rep.IPAddr())
return rep
}

Expand All @@ -635,10 +635,10 @@ func (tab *Table) bumpInBucket(b *bucket, newRecord *enode.Node, isInbound bool)
ipchanged := newRecord.IPAddr() != n.IPAddr()
portchanged := newRecord.UDP() != n.UDP()
if ipchanged {
tab.removeIP(b, n.IP())
if !tab.addIP(b, newRecord.IP()) {
tab.removeIP(b, n.IPAddr())
if !tab.addIP(b, newRecord.IPAddr()) {
// It doesn't fit with the limit, put the previous record back.
tab.addIP(b, n.IP())
tab.addIP(b, n.IPAddr())
return n, false
}
}
Expand All @@ -657,11 +657,11 @@ func (tab *Table) handleTrackRequest(op trackRequestOp) {
var fails int
if op.success {
// Reset failure counter because it counts _consecutive_ failures.
tab.db.UpdateFindFails(op.node.ID(), op.node.IP(), 0)
tab.db.UpdateFindFails(op.node.ID(), op.node.IPAddr(), 0)
} else {
fails = tab.db.FindFails(op.node.ID(), op.node.IP())
fails = tab.db.FindFails(op.node.ID(), op.node.IPAddr())
fails++
tab.db.UpdateFindFails(op.node.ID(), op.node.IP(), fails)
tab.db.UpdateFindFails(op.node.ID(), op.node.IPAddr(), fails)
}

tab.mutex.Lock()
Expand Down
14 changes: 4 additions & 10 deletions p2p/discover/table_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -188,7 +188,7 @@ func checkIPLimitInvariant(t *testing.T, tab *Table) {
tabset := netutil.DistinctNetSet{Subnet: tableSubnet, Limit: tableIPLimit}
for _, b := range tab.buckets {
for _, n := range b.entries {
tabset.Add(n.IP())
tabset.AddAddr(n.IPAddr())
}
}
if tabset.String() != tab.ips.String() {
Expand Down Expand Up @@ -268,7 +268,7 @@ func (*closeTest) Generate(rand *rand.Rand, size int) reflect.Value {
}
for _, id := range gen([]enode.ID{}, rand).([]enode.ID) {
r := new(enr.Record)
r.Set(enr.IP(genIP(rand)))
r.Set(enr.IPv4Addr(netutil.RandomAddr(rand, true)))
n := enode.SignNull(r, id)
t.All = append(t.All, n)
}
Expand Down Expand Up @@ -385,11 +385,11 @@ func checkBucketContent(t *testing.T, tab *Table, nodes []*enode.Node) {
}
t.Log("wrong bucket content. have nodes:")
for _, n := range b.entries {
t.Logf(" %v (seq=%v, ip=%v)", n.ID(), n.Seq(), n.IP())
t.Logf(" %v (seq=%v, ip=%v)", n.ID(), n.Seq(), n.IPAddr())
}
t.Log("want nodes:")
for _, n := range nodes {
t.Logf(" %v (seq=%v, ip=%v)", n.ID(), n.Seq(), n.IP())
t.Logf(" %v (seq=%v, ip=%v)", n.ID(), n.Seq(), n.IPAddr())
}
t.FailNow()

Expand Down Expand Up @@ -483,12 +483,6 @@ func gen(typ interface{}, rand *rand.Rand) interface{} {
return v.Interface()
}

func genIP(rand *rand.Rand) net.IP {
ip := make(net.IP, 4)
rand.Read(ip)
return ip
}

func quickcfg() *quick.Config {
return &quick.Config{
MaxCount: 5000,
Expand Down
5 changes: 3 additions & 2 deletions p2p/discover/table_util_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -100,8 +100,9 @@ func idAtDistance(a enode.ID, n int) (b enode.ID) {
return b
}

// intIP returns a LAN IP address based on i.
func intIP(i int) net.IP {
return net.IP{byte(i), 0, 2, byte(i)}
return net.IP{10, 0, byte(i >> 8), byte(i & 0xFF)}
}

// fillBucket inserts nodes into the given bucket until it is full.
Expand Down Expand Up @@ -254,7 +255,7 @@ NotEqual:
}

func nodeEqual(n1 *enode.Node, n2 *enode.Node) bool {
return n1.ID() == n2.ID() && n1.IP().Equal(n2.IP())
return n1.ID() == n2.ID() && n1.IPAddr() == n2.IPAddr()
}

func sortByID[N nodeType](nodes []N) {
Expand Down
Loading

0 comments on commit c0b26f8

Please sign in to comment.