Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

p2p: use netip.Addr where possible #29891

Merged
merged 11 commits into from
Jun 5, 2024
3 changes: 2 additions & 1 deletion cmd/devp2p/internal/ethtest/conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,8 @@ func (s *Suite) dial() (*Conn, error) {
// dialAs attempts to dial a given node and perform a handshake using the given
// private key.
func (s *Suite) dialAs(key *ecdsa.PrivateKey) (*Conn, error) {
fd, err := net.Dial("tcp", fmt.Sprintf("%v:%d", s.Dest.IP(), s.Dest.TCP()))
tcpEndpoint, _ := s.Dest.TCPEndpoint()
fd, err := net.Dial("tcp", tcpEndpoint.String())
if err != nil {
return nil, err
}
Expand Down
6 changes: 4 additions & 2 deletions cmd/devp2p/internal/v4test/framework.go
Original file line number Diff line number Diff line change
Expand Up @@ -53,10 +53,12 @@ func newTestEnv(remote string, listen1, listen2 string) *testenv {
if err != nil {
panic(err)
}
if node.IP() == nil || node.UDP() == 0 {
if !node.IPAddr().IsValid() || node.UDP() == 0 {
var ip net.IP
var tcpPort, udpPort int
if ip = node.IP(); ip == nil {
if node.IPAddr().IsValid() {
ip = node.IPAddr().AsSlice()
} else {
ip = net.ParseIP("127.0.0.1")
}
if tcpPort = node.TCP(); tcpPort == 0 {
Expand Down
6 changes: 3 additions & 3 deletions cmd/devp2p/nodesetcmd.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ package main
import (
"errors"
"fmt"
"net"
"net/netip"
"sort"
"strconv"
"strings"
Expand Down Expand Up @@ -205,11 +205,11 @@ func trueFilter(args []string) (nodeFilter, error) {
}

func ipFilter(args []string) (nodeFilter, error) {
_, cidr, err := net.ParseCIDR(args[0])
prefix, err := netip.ParsePrefix(args[0])
if err != nil {
return nil, err
}
f := func(n nodeJSON) bool { return cidr.Contains(n.N.IP()) }
f := func(n nodeJSON) bool { return prefix.Contains(n.N.IPAddr()) }
return f, nil
}

Expand Down
6 changes: 5 additions & 1 deletion cmd/devp2p/rlpxcmd.go
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,11 @@ var (

func rlpxPing(ctx *cli.Context) error {
n := getNodeArg(ctx)
fd, err := net.Dial("tcp", fmt.Sprintf("%v:%d", n.IP(), n.TCP()))
tcpEndpoint, ok := n.TCPEndpoint()
if !ok {
return fmt.Errorf("node has no TCP endpoint")
}
fd, err := net.Dial("tcp", tcpEndpoint.String())
if err != nil {
return err
}
Expand Down
27 changes: 13 additions & 14 deletions p2p/dial.go
Original file line number Diff line number Diff line change
Expand Up @@ -65,11 +65,8 @@ type tcpDialer struct {
}

func (t tcpDialer) Dial(ctx context.Context, dest *enode.Node) (net.Conn, error) {
return t.d.DialContext(ctx, "tcp", nodeAddr(dest).String())
}

func nodeAddr(n *enode.Node) net.Addr {
return &net.TCPAddr{IP: n.IP(), Port: n.TCP()}
addr, _ := dest.TCPEndpoint()
return t.d.DialContext(ctx, "tcp", addr.String())
}

// checkDial errors:
Expand Down Expand Up @@ -243,7 +240,7 @@ loop:
select {
case node := <-nodesCh:
if err := d.checkDial(node); err != nil {
d.log.Trace("Discarding dial candidate", "id", node.ID(), "ip", node.IP(), "reason", err)
d.log.Trace("Discarding dial candidate", "id", node.ID(), "ip", node.IPAddr(), "reason", err)
} else {
d.startDial(newDialTask(node, dynDialedConn))
}
Expand Down Expand Up @@ -277,7 +274,7 @@ loop:
case node := <-d.addStaticCh:
id := node.ID()
_, exists := d.static[id]
d.log.Trace("Adding static node", "id", id, "ip", node.IP(), "added", !exists)
d.log.Trace("Adding static node", "id", id, "ip", node.IPAddr(), "added", !exists)
if exists {
continue loop
}
Expand Down Expand Up @@ -376,7 +373,7 @@ func (d *dialScheduler) checkDial(n *enode.Node) error {
if n.ID() == d.self {
return errSelf
}
if n.IP() != nil && n.TCP() == 0 {
if n.IPAddr().IsValid() && n.TCP() == 0 {
// This check can trigger if a non-TCP node is found
// by discovery. If there is no IP, the node is a static
// node and the actual endpoint will be resolved later in dialTask.
Expand All @@ -388,7 +385,7 @@ func (d *dialScheduler) checkDial(n *enode.Node) error {
if _, ok := d.peers[n.ID()]; ok {
return errAlreadyConnected
}
if d.netRestrict != nil && !d.netRestrict.Contains(n.IP()) {
if d.netRestrict != nil && !d.netRestrict.ContainsAddr(n.IPAddr()) {
return errNetRestrict
}
if d.history.contains(string(n.ID().Bytes())) {
Expand Down Expand Up @@ -439,7 +436,7 @@ func (d *dialScheduler) removeFromStaticPool(idx int) {
// startDial runs the given dial task in a separate goroutine.
func (d *dialScheduler) startDial(task *dialTask) {
node := task.dest()
d.log.Trace("Starting p2p dial", "id", node.ID(), "ip", node.IP(), "flag", task.flags)
d.log.Trace("Starting p2p dial", "id", node.ID(), "ip", node.IPAddr(), "flag", task.flags)
hkey := string(node.ID().Bytes())
d.history.add(hkey, d.clock.Now().Add(dialHistoryExpiration))
d.dialing[node.ID()] = task
Expand Down Expand Up @@ -492,7 +489,7 @@ func (t *dialTask) run(d *dialScheduler) {
}

func (t *dialTask) needResolve() bool {
return t.flags&staticDialedConn != 0 && t.dest().IP() == nil
return t.flags&staticDialedConn != 0 && !t.dest().IPAddr().IsValid()
}

// resolve attempts to find the current endpoint for the destination
Expand Down Expand Up @@ -526,7 +523,8 @@ func (t *dialTask) resolve(d *dialScheduler) bool {
// The node was found.
t.resolveDelay = initialResolveDelay
t.destPtr.Store(resolved)
d.log.Debug("Resolved node", "id", resolved.ID(), "addr", &net.TCPAddr{IP: resolved.IP(), Port: resolved.TCP()})
resAddr, _ := resolved.TCPEndpoint()
d.log.Debug("Resolved node", "id", resolved.ID(), "addr", resAddr)
return true
}

Expand All @@ -535,7 +533,8 @@ func (t *dialTask) dial(d *dialScheduler, dest *enode.Node) error {
dialMeter.Mark(1)
fd, err := d.dialer.Dial(d.ctx, dest)
if err != nil {
d.log.Trace("Dial error", "id", dest.ID(), "addr", nodeAddr(dest), "conn", t.flags, "err", cleanupDialErr(err))
addr, _ := dest.TCPEndpoint()
d.log.Trace("Dial error", "id", dest.ID(), "addr", addr, "conn", t.flags, "err", cleanupDialErr(err))
dialConnectionError.Mark(1)
return &dialError{err}
}
Expand All @@ -545,7 +544,7 @@ func (t *dialTask) dial(d *dialScheduler, dest *enode.Node) error {
func (t *dialTask) String() string {
node := t.dest()
id := node.ID()
return fmt.Sprintf("%v %x %v:%d", t.flags, id[:8], node.IP(), node.TCP())
return fmt.Sprintf("%v %x %v:%d", t.flags, id[:8], node.IPAddr(), node.TCP())
}

func cleanupDialErr(err error) error {
Expand Down
52 changes: 26 additions & 26 deletions p2p/discover/table.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ package discover
import (
"context"
"fmt"
"net"
"net/netip"
"slices"
"sync"
"time"
Expand Down Expand Up @@ -207,8 +207,8 @@ func (tab *Table) setFallbackNodes(nodes []*enode.Node) error {
if err := n.ValidateComplete(); err != nil {
return fmt.Errorf("bad bootstrap node %q: %v", n, err)
}
if tab.cfg.NetRestrict != nil && !tab.cfg.NetRestrict.Contains(n.IP()) {
tab.log.Error("Bootstrap node filtered by netrestrict", "id", n.ID(), "ip", n.IP())
if tab.cfg.NetRestrict != nil && !tab.cfg.NetRestrict.ContainsAddr(n.IPAddr()) {
tab.log.Error("Bootstrap node filtered by netrestrict", "id", n.ID(), "ip", n.IPAddr())
continue
}
nursery = append(nursery, n)
Expand Down Expand Up @@ -448,7 +448,7 @@ func (tab *Table) loadSeedNodes() {
for i := range seeds {
seed := seeds[i]
if tab.log.Enabled(context.Background(), log.LevelTrace) {
age := time.Since(tab.db.LastPongReceived(seed.ID(), seed.IP()))
age := time.Since(tab.db.LastPongReceived(seed.ID(), seed.IPAddr()))
addr, _ := seed.UDPEndpoint()
tab.log.Trace("Found seed node in database", "id", seed.ID(), "addr", addr, "age", age)
}
Expand All @@ -474,31 +474,31 @@ func (tab *Table) bucketAtDistance(d int) *bucket {
return tab.buckets[d-bucketMinDistance-1]
}

func (tab *Table) addIP(b *bucket, ip net.IP) bool {
if len(ip) == 0 {
func (tab *Table) addIP(b *bucket, ip netip.Addr) bool {
if !ip.IsValid() || ip.IsUnspecified() {
return false // Nodes without IP cannot be added.
}
if netutil.IsLAN(ip) {
if netutil.AddrIsLAN(ip) {
return true
}
if !tab.ips.Add(ip) {
if !tab.ips.AddAddr(ip) {
tab.log.Debug("IP exceeds table limit", "ip", ip)
return false
}
if !b.ips.Add(ip) {
if !b.ips.AddAddr(ip) {
tab.log.Debug("IP exceeds bucket limit", "ip", ip)
tab.ips.Remove(ip)
tab.ips.RemoveAddr(ip)
return false
}
return true
}

func (tab *Table) removeIP(b *bucket, ip net.IP) {
if netutil.IsLAN(ip) {
func (tab *Table) removeIP(b *bucket, ip netip.Addr) {
if netutil.AddrIsLAN(ip) {
return
}
tab.ips.Remove(ip)
b.ips.Remove(ip)
tab.ips.RemoveAddr(ip)
b.ips.RemoveAddr(ip)
}

// handleAddNode adds the node in the request to the table, if there is space.
Expand All @@ -524,7 +524,7 @@ func (tab *Table) handleAddNode(req addNodeOp) bool {
tab.addReplacement(b, req.node)
return false
}
if !tab.addIP(b, req.node.IP()) {
if !tab.addIP(b, req.node.IPAddr()) {
// Can't add: IP limit reached.
return false
}
Expand All @@ -547,15 +547,15 @@ func (tab *Table) addReplacement(b *bucket, n *enode.Node) {
// TODO: update ENR
return
}
if !tab.addIP(b, n.IP()) {
if !tab.addIP(b, n.IPAddr()) {
return
}

wn := &tableNode{Node: n, addedToTable: time.Now()}
var removed *tableNode
b.replacements, removed = pushNode(b.replacements, wn, maxReplacements)
if removed != nil {
tab.removeIP(b, removed.IP())
tab.removeIP(b, removed.IPAddr())
}
}

Expand Down Expand Up @@ -595,20 +595,20 @@ func (tab *Table) deleteInBucket(b *bucket, id enode.ID) *tableNode {
// Remove the node.
n := b.entries[index]
b.entries = slices.Delete(b.entries, index, index+1)
tab.removeIP(b, n.IP())
tab.removeIP(b, n.IPAddr())
tab.nodeRemoved(b, n)

// Add replacement.
if len(b.replacements) == 0 {
tab.log.Debug("Removed dead node", "b", b.index, "id", n.ID(), "ip", n.IP())
tab.log.Debug("Removed dead node", "b", b.index, "id", n.ID(), "ip", n.IPAddr())
return nil
}
rindex := tab.rand.Intn(len(b.replacements))
rep := b.replacements[rindex]
b.replacements = slices.Delete(b.replacements, rindex, rindex+1)
b.entries = append(b.entries, rep)
tab.nodeAdded(b, rep)
tab.log.Debug("Replaced dead node", "b", b.index, "id", n.ID(), "ip", n.IP(), "r", rep.ID(), "rip", rep.IP())
tab.log.Debug("Replaced dead node", "b", b.index, "id", n.ID(), "ip", n.IPAddr(), "r", rep.ID(), "rip", rep.IPAddr())
return rep
}

Expand All @@ -635,10 +635,10 @@ func (tab *Table) bumpInBucket(b *bucket, newRecord *enode.Node, isInbound bool)
ipchanged := newRecord.IPAddr() != n.IPAddr()
portchanged := newRecord.UDP() != n.UDP()
if ipchanged {
tab.removeIP(b, n.IP())
if !tab.addIP(b, newRecord.IP()) {
tab.removeIP(b, n.IPAddr())
if !tab.addIP(b, newRecord.IPAddr()) {
// It doesn't fit with the limit, put the previous record back.
tab.addIP(b, n.IP())
tab.addIP(b, n.IPAddr())
return n, false
}
}
Expand All @@ -657,11 +657,11 @@ func (tab *Table) handleTrackRequest(op trackRequestOp) {
var fails int
if op.success {
// Reset failure counter because it counts _consecutive_ failures.
tab.db.UpdateFindFails(op.node.ID(), op.node.IP(), 0)
tab.db.UpdateFindFails(op.node.ID(), op.node.IPAddr(), 0)
} else {
fails = tab.db.FindFails(op.node.ID(), op.node.IP())
fails = tab.db.FindFails(op.node.ID(), op.node.IPAddr())
fails++
tab.db.UpdateFindFails(op.node.ID(), op.node.IP(), fails)
tab.db.UpdateFindFails(op.node.ID(), op.node.IPAddr(), fails)
}

tab.mutex.Lock()
Expand Down
14 changes: 4 additions & 10 deletions p2p/discover/table_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -188,7 +188,7 @@ func checkIPLimitInvariant(t *testing.T, tab *Table) {
tabset := netutil.DistinctNetSet{Subnet: tableSubnet, Limit: tableIPLimit}
for _, b := range tab.buckets {
for _, n := range b.entries {
tabset.Add(n.IP())
tabset.AddAddr(n.IPAddr())
}
}
if tabset.String() != tab.ips.String() {
Expand Down Expand Up @@ -268,7 +268,7 @@ func (*closeTest) Generate(rand *rand.Rand, size int) reflect.Value {
}
for _, id := range gen([]enode.ID{}, rand).([]enode.ID) {
r := new(enr.Record)
r.Set(enr.IP(genIP(rand)))
r.Set(enr.IPv4Addr(netutil.RandomAddr(rand, true)))
n := enode.SignNull(r, id)
t.All = append(t.All, n)
}
Expand Down Expand Up @@ -385,11 +385,11 @@ func checkBucketContent(t *testing.T, tab *Table, nodes []*enode.Node) {
}
t.Log("wrong bucket content. have nodes:")
for _, n := range b.entries {
t.Logf(" %v (seq=%v, ip=%v)", n.ID(), n.Seq(), n.IP())
t.Logf(" %v (seq=%v, ip=%v)", n.ID(), n.Seq(), n.IPAddr())
}
t.Log("want nodes:")
for _, n := range nodes {
t.Logf(" %v (seq=%v, ip=%v)", n.ID(), n.Seq(), n.IP())
t.Logf(" %v (seq=%v, ip=%v)", n.ID(), n.Seq(), n.IPAddr())
}
t.FailNow()

Expand Down Expand Up @@ -483,12 +483,6 @@ func gen(typ interface{}, rand *rand.Rand) interface{} {
return v.Interface()
}

func genIP(rand *rand.Rand) net.IP {
ip := make(net.IP, 4)
rand.Read(ip)
return ip
}

func quickcfg() *quick.Config {
return &quick.Config{
MaxCount: 5000,
Expand Down
5 changes: 3 additions & 2 deletions p2p/discover/table_util_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -100,8 +100,9 @@ func idAtDistance(a enode.ID, n int) (b enode.ID) {
return b
}

// intIP returns a LAN IP address based on i.
func intIP(i int) net.IP {
return net.IP{byte(i), 0, 2, byte(i)}
return net.IP{10, 0, byte(i >> 8), byte(i & 0xFF)}
}

// fillBucket inserts nodes into the given bucket until it is full.
Expand Down Expand Up @@ -254,7 +255,7 @@ NotEqual:
}

func nodeEqual(n1 *enode.Node, n2 *enode.Node) bool {
return n1.ID() == n2.ID() && n1.IP().Equal(n2.IP())
return n1.ID() == n2.ID() && n1.IPAddr() == n2.IPAddr()
}

func sortByID[N nodeType](nodes []N) {
Expand Down
Loading
Loading