diff --git a/p2p/discover/table_util_test.go b/p2p/discover/table_util_test.go index 77e03ca9e7e4..217905eb737e 100644 --- a/p2p/discover/table_util_test.go +++ b/p2p/discover/table_util_test.go @@ -52,6 +52,7 @@ func newTestTable(t transport) (*Table, *enode.DB) { func nodeAtDistance(base enode.ID, ld int, ip net.IP) *node { var r enr.Record r.Set(enr.IP(ip)) + r.Set(enr.UDP(30303)) return wrapNode(enode.SignNull(&r, idAtDistance(base, ld))) } diff --git a/p2p/discover/v5_talk.go b/p2p/discover/v5_talk.go new file mode 100644 index 000000000000..c1f67879402c --- /dev/null +++ b/p2p/discover/v5_talk.go @@ -0,0 +1,113 @@ +// Copyright 2023 The go-ethereum Authors +// This file is part of the go-ethereum library. +// +// The go-ethereum library is free software: you can redistribute it and/or modify +// it under the terms of the GNU Lesser General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// The go-ethereum library is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Lesser General Public License for more details. +// +// You should have received a copy of the GNU Lesser General Public License +// along with the go-ethereum library. If not, see . + +package discover + +import ( + "net" + "sync" + "time" + + "github.com/ethereum/go-ethereum/log" + "github.com/ethereum/go-ethereum/p2p/discover/v5wire" + "github.com/ethereum/go-ethereum/p2p/enode" +) + +// This is a limit for the number of concurrent talk requests. +const maxActiveTalkRequests = 1024 + +// This is the timeout for acquiring a handler execution slot for a talk request. +// The timeout should be short enough to fit within the request timeout. +const talkHandlerLaunchTimeout = 400 * time.Millisecond + +// TalkRequestHandler callback processes a talk request and returns a response. +// +// Note that talk handlers are expected to come up with a response very quickly, within at +// most 200ms or so. If the handler takes longer than that, the remote end may time out +// and wont receive the response. +type TalkRequestHandler func(enode.ID, *net.UDPAddr, []byte) []byte + +type talkSystem struct { + transport *UDPv5 + + mutex sync.Mutex + handlers map[string]TalkRequestHandler + slots chan struct{} + lastLog time.Time + dropCount int +} + +func newTalkSystem(transport *UDPv5) *talkSystem { + t := &talkSystem{ + transport: transport, + handlers: make(map[string]TalkRequestHandler), + slots: make(chan struct{}, maxActiveTalkRequests), + } + for i := 0; i < cap(t.slots); i++ { + t.slots <- struct{}{} + } + return t +} + +// register adds a protocol handler. +func (t *talkSystem) register(protocol string, handler TalkRequestHandler) { + t.mutex.Lock() + t.handlers[protocol] = handler + t.mutex.Unlock() +} + +// handleRequest handles a talk request. +func (t *talkSystem) handleRequest(id enode.ID, addr *net.UDPAddr, req *v5wire.TalkRequest) { + t.mutex.Lock() + handler, ok := t.handlers[req.Protocol] + t.mutex.Unlock() + + if !ok { + resp := &v5wire.TalkResponse{ReqID: req.ReqID} + t.transport.sendResponse(id, addr, resp) + return + } + + // Wait for a slot to become available, then run the handler. + timeout := time.NewTimer(talkHandlerLaunchTimeout) + defer timeout.Stop() + select { + case <-t.slots: + go func() { + defer func() { t.slots <- struct{}{} }() + respMessage := handler(id, addr, req.Message) + resp := &v5wire.TalkResponse{ReqID: req.ReqID, Message: respMessage} + t.transport.sendFromAnotherThread(id, addr, resp) + }() + case <-timeout.C: + // Couldn't get it in time, drop the request. + if time.Since(t.lastLog) > 5*time.Second { + log.Warn("Dropping TALKREQ due to overload", "ndrop", t.dropCount) + t.lastLog = time.Now() + t.dropCount++ + } + case <-t.transport.closeCtx.Done(): + // Transport closed, drop the request. + } +} + +// wait blocks until all active requests have finished, and prevents new request +// handlers from being launched. +func (t *talkSystem) wait() { + for i := 0; i < cap(t.slots); i++ { + <-t.slots + } +} diff --git a/p2p/discover/v5_udp.go b/p2p/discover/v5_udp.go index 38f5b3b652cf..4b4bc9195acd 100644 --- a/p2p/discover/v5_udp.go +++ b/p2p/discover/v5_udp.go @@ -74,8 +74,7 @@ type UDPv5 struct { logcontext []interface{} // talkreq handler registry - trlock sync.Mutex - trhandlers map[string]TalkRequestHandler + talk *talkSystem // channels into dispatch packetInCh chan ReadPacket @@ -83,6 +82,7 @@ type UDPv5 struct { callCh chan *callV5 callDoneCh chan *callV5 respTimeoutCh chan *callTimeout + sendCh chan sendRequest unhandled chan<- ReadPacket // state of dispatch @@ -98,12 +98,18 @@ type UDPv5 struct { wg sync.WaitGroup } -// TalkRequestHandler callback processes a talk request and optionally returns a reply -type TalkRequestHandler func(enode.ID, *net.UDPAddr, []byte) []byte +type sendRequest struct { + destID enode.ID + destAddr *net.UDPAddr + msg v5wire.Packet +} // callV5 represents a remote procedure call against another node. type callV5 struct { - node *enode.Node + id enode.ID + addr *net.UDPAddr + node *enode.Node // This is required to perform handshakes. + packet v5wire.Packet responseType byte // expected packet type of response reqid []byte @@ -150,12 +156,12 @@ func newUDPv5(conn UDPConn, ln *enode.LocalNode, cfg Config) (*UDPv5, error) { log: cfg.Log, validSchemes: cfg.ValidSchemes, clock: cfg.Clock, - trhandlers: make(map[string]TalkRequestHandler), // channels into dispatch packetInCh: make(chan ReadPacket, 1), readNextCh: make(chan struct{}, 1), callCh: make(chan *callV5), callDoneCh: make(chan *callV5), + sendCh: make(chan sendRequest), respTimeoutCh: make(chan *callTimeout), unhandled: cfg.Unhandled, // state of dispatch @@ -163,11 +169,11 @@ func newUDPv5(conn UDPConn, ln *enode.LocalNode, cfg Config) (*UDPv5, error) { activeCallByNode: make(map[enode.ID]*callV5), activeCallByAuth: make(map[v5wire.Nonce]*callV5), callQueue: make(map[enode.ID][]*callV5), - // shutdown closeCtx: closeCtx, cancelCloseCtx: cancelCloseCtx, } + t.talk = newTalkSystem(t) tab, err := newTable(t, t.db, cfg.Bootnodes, cfg.Log) if err != nil { return nil, err @@ -186,6 +192,7 @@ func (t *UDPv5) Close() { t.closeOnce.Do(func() { t.cancelCloseCtx() t.conn.Close() + t.talk.wait() t.wg.Wait() t.tab.close() }) @@ -241,15 +248,26 @@ func (t *UDPv5) LocalNode() *enode.LocalNode { // whenever a request for the given protocol is received and should return the response // data or nil. func (t *UDPv5) RegisterTalkHandler(protocol string, handler TalkRequestHandler) { - t.trlock.Lock() - defer t.trlock.Unlock() - t.trhandlers[protocol] = handler + t.talk.register(protocol, handler) } -// TalkRequest sends a talk request to n and waits for a response. +// TalkRequest sends a talk request to a node and waits for a response. func (t *UDPv5) TalkRequest(n *enode.Node, protocol string, request []byte) ([]byte, error) { req := &v5wire.TalkRequest{Protocol: protocol, Message: request} - resp := t.call(n, v5wire.TalkResponseMsg, req) + resp := t.callToNode(n, v5wire.TalkResponseMsg, req) + defer t.callDone(resp) + select { + case respMsg := <-resp.ch: + return respMsg.(*v5wire.TalkResponse).Message, nil + case err := <-resp.err: + return nil, err + } +} + +// TalkRequest sends a talk request to a node and waits for a response. +func (t *UDPv5) TalkRequestToID(id enode.ID, addr *net.UDPAddr, protocol string, request []byte) ([]byte, error) { + req := &v5wire.TalkRequest{Protocol: protocol, Message: request} + resp := t.callToID(id, addr, v5wire.TalkResponseMsg, req) defer t.callDone(resp) select { case respMsg := <-resp.ch: @@ -340,7 +358,7 @@ func lookupDistances(target, dest enode.ID) (dists []uint) { // ping calls PING on a node and waits for a PONG response. func (t *UDPv5) ping(n *enode.Node) (uint64, error) { req := &v5wire.Ping{ENRSeq: t.localNode.Node().Seq()} - resp := t.call(n, v5wire.PongMsg, req) + resp := t.callToNode(n, v5wire.PongMsg, req) defer t.callDone(resp) select { @@ -365,7 +383,7 @@ func (t *UDPv5) RequestENR(n *enode.Node) (*enode.Node, error) { // findnode calls FINDNODE on a node and waits for responses. func (t *UDPv5) findnode(n *enode.Node, distances []uint) ([]*enode.Node, error) { - resp := t.call(n, v5wire.NodesMsg, &v5wire.Findnode{Distances: distances}) + resp := t.callToNode(n, v5wire.NodesMsg, &v5wire.Findnode{Distances: distances}) return t.waitForNodes(resp, distances) } @@ -408,17 +426,17 @@ func (t *UDPv5) verifyResponseNode(c *callV5, r *enr.Record, distances []uint, s if err != nil { return nil, err } - if err := netutil.CheckRelayIP(c.node.IP(), node.IP()); err != nil { + if err := netutil.CheckRelayIP(c.addr.IP, node.IP()); err != nil { return nil, err } if t.netrestrict != nil && !t.netrestrict.Contains(node.IP()) { return nil, errors.New("not contained in netrestrict list") } - if c.node.UDP() <= 1024 { + if node.UDP() <= 1024 { return nil, errLowPort } if distances != nil { - nd := enode.LogDist(c.node.ID(), node.ID()) + nd := enode.LogDist(c.id, node.ID()) if !containsUint(uint(nd), distances) { return nil, errors.New("does not match any requested distance") } @@ -439,17 +457,28 @@ func containsUint(x uint, xs []uint) bool { return false } -// call sends the given call and sets up a handler for response packets (of message type -// responseType). Responses are dispatched to the call's response channel. -func (t *UDPv5) call(node *enode.Node, responseType byte, packet v5wire.Packet) *callV5 { - c := &callV5{ - node: node, - packet: packet, - responseType: responseType, - reqid: make([]byte, 8), - ch: make(chan v5wire.Packet, 1), - err: make(chan error, 1), - } +// callToNode sends the given call and sets up a handler for response packets (of message +// type responseType). Responses are dispatched to the call's response channel. +func (t *UDPv5) callToNode(n *enode.Node, responseType byte, req v5wire.Packet) *callV5 { + addr := &net.UDPAddr{IP: n.IP(), Port: int(n.UDP())} + c := &callV5{id: n.ID(), addr: addr, node: n} + t.initCall(c, responseType, req) + return c +} + +// callToID is like callToNode, but for cases where the node record is not available. +func (t *UDPv5) callToID(id enode.ID, addr *net.UDPAddr, responseType byte, req v5wire.Packet) *callV5 { + c := &callV5{id: id, addr: addr} + t.initCall(c, responseType, req) + return c +} + +func (t *UDPv5) initCall(c *callV5, responseType byte, packet v5wire.Packet) { + c.packet = packet + c.responseType = responseType + c.reqid = make([]byte, 8) + c.ch = make(chan v5wire.Packet, 1) + c.err = make(chan error, 1) // Assign request ID. crand.Read(c.reqid) packet.SetRequestID(c.reqid) @@ -459,7 +488,6 @@ func (t *UDPv5) call(node *enode.Node, responseType byte, packet v5wire.Packet) case <-t.closeCtx.Done(): c.err <- errClosed } - return c } // callDone tells dispatch that the active call is done. @@ -501,26 +529,27 @@ func (t *UDPv5) dispatch() { for { select { case c := <-t.callCh: - id := c.node.ID() - t.callQueue[id] = append(t.callQueue[id], c) - t.sendNextCall(id) + t.callQueue[c.id] = append(t.callQueue[c.id], c) + t.sendNextCall(c.id) case ct := <-t.respTimeoutCh: - active := t.activeCallByNode[ct.c.node.ID()] + active := t.activeCallByNode[ct.c.id] if ct.c == active && ct.timer == active.timeout { ct.c.err <- errTimeout } case c := <-t.callDoneCh: - id := c.node.ID() - active := t.activeCallByNode[id] + active := t.activeCallByNode[c.id] if active != c { panic("BUG: callDone for inactive call") } c.timeout.Stop() delete(t.activeCallByAuth, c.nonce) - delete(t.activeCallByNode, id) - t.sendNextCall(id) + delete(t.activeCallByNode, c.id) + t.sendNextCall(c.id) + + case r := <-t.sendCh: + t.send(r.destID, r.destAddr, r.msg, nil) case p := <-t.packetInCh: t.handlePacket(p.Data, p.Addr) @@ -590,8 +619,7 @@ func (t *UDPv5) sendCall(c *callV5) { delete(t.activeCallByAuth, c.nonce) } - addr := &net.UDPAddr{IP: c.node.IP(), Port: c.node.UDP()} - newNonce, _ := t.send(c.node.ID(), addr, c.packet, c.challenge) + newNonce, _ := t.send(c.id, c.addr, c.packet, c.challenge) c.nonce = newNonce t.activeCallByAuth[newNonce] = c t.startResponseTimeout(c) @@ -604,6 +632,13 @@ func (t *UDPv5) sendResponse(toID enode.ID, toAddr *net.UDPAddr, packet v5wire.P return err } +func (t *UDPv5) sendFromAnotherThread(toID enode.ID, toAddr *net.UDPAddr, packet v5wire.Packet) { + select { + case t.sendCh <- sendRequest{toID, toAddr, packet}: + case <-t.closeCtx.Done(): + } +} + // send sends a packet to the given node. func (t *UDPv5) send(toID enode.ID, toAddr *net.UDPAddr, packet v5wire.Packet, c *v5wire.Whoareyou) (v5wire.Nonce, error) { addr := toAddr.String() @@ -691,7 +726,7 @@ func (t *UDPv5) handleCallResponse(fromID enode.ID, fromAddr *net.UDPAddr, p v5w t.log.Debug(fmt.Sprintf("Unsolicited/late %s response", p.Name()), "id", fromID, "addr", fromAddr) return false } - if !fromAddr.IP.Equal(ac.node.IP()) || fromAddr.Port != ac.node.UDP() { + if !fromAddr.IP.Equal(ac.addr.IP) || fromAddr.Port != ac.addr.Port { t.log.Debug(fmt.Sprintf("%s from wrong endpoint", p.Name()), "id", fromID, "addr", fromAddr) return false } @@ -733,7 +768,7 @@ func (t *UDPv5) handle(p v5wire.Packet, fromID enode.ID, fromAddr *net.UDPAddr) case *v5wire.Nodes: t.handleCallResponse(fromID, fromAddr, p) case *v5wire.TalkRequest: - t.handleTalkRequest(fromID, fromAddr, p) + t.talk.handleRequest(fromID, fromAddr, p) case *v5wire.TalkResponse: t.handleCallResponse(fromID, fromAddr, p) } @@ -763,6 +798,12 @@ func (t *UDPv5) handleWhoareyou(p *v5wire.Whoareyou, fromID enode.ID, fromAddr * return } + if c.node == nil { + // Can't perform handshake because we don't have the ENR. + t.log.Debug("Can't handle "+p.Name(), "addr", fromAddr, "err", "call has no ENR") + c.err <- errors.New("remote wants handshake, but call has no ENR") + return + } // Resend the call that was answered by WHOAREYOU. t.log.Trace("<< "+p.Name(), "id", c.node.ID(), "addr", fromAddr) c.handshakeCount++ @@ -876,17 +917,3 @@ func packNodes(reqid []byte, nodes []*enode.Node) []*v5wire.Nodes { } return resp } - -// handleTalkRequest runs the talk request handler of the requested protocol. -func (t *UDPv5) handleTalkRequest(fromID enode.ID, fromAddr *net.UDPAddr, p *v5wire.TalkRequest) { - t.trlock.Lock() - handler := t.trhandlers[p.Protocol] - t.trlock.Unlock() - - var response []byte - if handler != nil { - response = handler(fromID, fromAddr, p.Message) - } - resp := &v5wire.TalkResponse{ReqID: p.ReqID, Message: response} - t.sendResponse(fromID, fromAddr, resp) -} diff --git a/p2p/discover/v5_udp_test.go b/p2p/discover/v5_udp_test.go index 481bb1cdc389..9aa5f975e897 100644 --- a/p2p/discover/v5_udp_test.go +++ b/p2p/discover/v5_udp_test.go @@ -186,7 +186,7 @@ func TestUDPv5_findnodeHandling(t *testing.T) { // This request gets all the distance-253 nodes. test.packetIn(&v5wire.Findnode{ReqID: []byte{4}, Distances: []uint{253}}) - test.expectNodes([]byte{4}, 1, nodes253) + test.expectNodes([]byte{4}, 2, nodes253) // This request gets all the distance-249 nodes and some more at 248 because // the bucket at 249 is not full. @@ -515,6 +515,27 @@ func TestUDPv5_talkRequest(t *testing.T) { if err := <-done; err != nil { t.Fatal(err) } + + // Also check requesting without ENR. + go func() { + _, err := test.udp.TalkRequestToID(remote.ID(), test.remoteaddr, "test", []byte("test request 2")) + done <- err + }() + test.waitPacketOut(func(p *v5wire.TalkRequest, addr *net.UDPAddr, _ v5wire.Nonce) { + if p.Protocol != "test" { + t.Errorf("wrong protocol ID in talk request: %q", p.Protocol) + } + if string(p.Message) != "test request 2" { + t.Errorf("wrong message talk request: %q", p.Message) + } + test.packetInFrom(test.remotekey, test.remoteaddr, &v5wire.TalkResponse{ + ReqID: p.ReqID, + Message: []byte("test response 2"), + }) + }) + if err := <-done; err != nil { + t.Fatal(err) + } } // This test checks that lookupDistances works. diff --git a/p2p/discover/v5wire/msg.go b/p2p/discover/v5wire/msg.go index fb8e1e12c294..401db2f6c587 100644 --- a/p2p/discover/v5wire/msg.go +++ b/p2p/discover/v5wire/msg.go @@ -216,7 +216,7 @@ func (p *TalkRequest) RequestID() []byte { return p.ReqID } func (p *TalkRequest) SetRequestID(id []byte) { p.ReqID = id } func (p *TalkRequest) AppendLogInfo(ctx []interface{}) []interface{} { - return append(ctx, "proto", p.Protocol, "reqid", hexutil.Bytes(p.ReqID), "len", len(p.Message)) + return append(ctx, "proto", p.Protocol, "req", hexutil.Bytes(p.ReqID), "len", len(p.Message)) } func (*TalkResponse) Name() string { return "TALKRESP/v5" } @@ -225,5 +225,5 @@ func (p *TalkResponse) RequestID() []byte { return p.ReqID } func (p *TalkResponse) SetRequestID(id []byte) { p.ReqID = id } func (p *TalkResponse) AppendLogInfo(ctx []interface{}) []interface{} { - return append(ctx, "req", p.ReqID, "len", len(p.Message)) + return append(ctx, "req", hexutil.Bytes(p.ReqID), "len", len(p.Message)) }