diff --git a/peer/network.go b/peer/network.go index 8c658a402f..6aaa962b3b 100644 --- a/peer/network.go +++ b/peer/network.go @@ -16,6 +16,7 @@ import ( "github.com/ava-labs/avalanchego/codec" "github.com/ava-labs/avalanchego/ids" + "github.com/ava-labs/avalanchego/network/p2p" "github.com/ava-labs/avalanchego/snow/engine/common" "github.com/ava-labs/avalanchego/snow/validators" "github.com/ava-labs/avalanchego/utils" @@ -87,6 +88,7 @@ type network struct { outstandingRequestHandlers map[uint32]message.ResponseHandler // maps avalanchego requestID => message.ResponseHandler activeAppRequests *semaphore.Weighted // controls maximum number of active outbound requests activeCrossChainRequests *semaphore.Weighted // controls maximum number of active outbound cross chain requests + router *p2p.Router // handles messages being sent to the generic networking SDK appSender common.AppSender // avalanchego AppSender for sending messages codec codec.Manager // Codec used for parsing messages crossChainCodec codec.Manager // Codec used for parsing cross chain messages @@ -99,11 +101,18 @@ type network struct { // Set to true when Shutdown is called, after which all operations on this // struct are no-ops. + // + // Invariant: Even though `closed` is an atomic, `lock` is required to be + // held when sending requests to guarantee that the network isn't closed + // during these calls. This is because closing the network cancels all + // outstanding requests, which means we must guarantee never to register a + // request that will never be fulfilled or cancelled. closed utils.Atomic[bool] } -func NewNetwork(appSender common.AppSender, codec codec.Manager, crossChainCodec codec.Manager, self ids.NodeID, maxActiveAppRequests int64, maxActiveCrossChainRequests int64) Network { +func NewNetwork(router *p2p.Router, appSender common.AppSender, codec codec.Manager, crossChainCodec codec.Manager, self ids.NodeID, maxActiveAppRequests int64, maxActiveCrossChainRequests int64) Network { return &network{ + router: router, appSender: appSender, codec: codec, crossChainCodec: crossChainCodec, @@ -172,10 +181,7 @@ func (n *network) sendAppRequest(nodeID ids.NodeID, request []byte, responseHand log.Debug("sending request to peer", "nodeID", nodeID, "requestLen", len(request)) n.peers.TrackPeer(nodeID) - // generate requestID - requestID := n.requestIDGen - n.requestIDGen++ - + requestID := n.nextRequestID() n.outstandingRequestHandlers[requestID] = responseHandler nodeIDs := set.NewSet[ids.NodeID](1) @@ -209,10 +215,7 @@ func (n *network) SendCrossChainRequest(chainID ids.ID, request []byte, handler return nil } - // generate requestID - requestID := n.requestIDGen - n.requestIDGen++ - + requestID := n.nextRequestID() n.outstandingRequestHandlers[requestID] = handler // Send cross chain request to [chainID]. @@ -272,19 +275,12 @@ func (n *network) CrossChainAppRequest(ctx context.Context, requestingChainID id // If [requestID] is not known, this function will emit a log and return a nil error. // If the response handler returns an error it is propagated as a fatal error. func (n *network) CrossChainAppRequestFailed(ctx context.Context, respondingChainID ids.ID, requestID uint32) error { - n.lock.Lock() - defer n.lock.Unlock() - - if n.closed.Get() { - return nil - } - log.Debug("received CrossChainAppRequestFailed from chain", "respondingChainID", respondingChainID, "requestID", requestID) handler, exists := n.markRequestFulfilled(requestID) if !exists { - // Should never happen since the engine should be managing outstanding requests - log.Error("received CrossChainAppRequestFailed to unknown request", "respondingChainID", respondingChainID, "requestID", requestID) + // Can happen after the network has been closed. + log.Debug("received CrossChainAppRequestFailed to unknown request", "respondingChainID", respondingChainID, "requestID", requestID) return nil } @@ -299,19 +295,12 @@ func (n *network) CrossChainAppRequestFailed(ctx context.Context, respondingChai // If [requestID] is not known, this function will emit a log and return a nil error. // If the response handler returns an error it is propagated as a fatal error. func (n *network) CrossChainAppResponse(ctx context.Context, respondingChainID ids.ID, requestID uint32, response []byte) error { - n.lock.Lock() - defer n.lock.Unlock() - - if n.closed.Get() { - return nil - } - log.Debug("received CrossChainAppResponse from responding chain", "respondingChainID", respondingChainID, "requestID", requestID) handler, exists := n.markRequestFulfilled(requestID) if !exists { - // Should never happen since the engine should be managing outstanding requests - log.Error("received CrossChainAppResponse to unknown request", "respondingChainID", respondingChainID, "requestID", requestID, "responseLen", len(response)) + // Can happen after the network has been closed. + log.Debug("received CrossChainAppResponse to unknown request", "respondingChainID", respondingChainID, "requestID", requestID, "responseLen", len(response)) return nil } @@ -335,8 +324,8 @@ func (n *network) AppRequest(ctx context.Context, nodeID ids.NodeID, requestID u var req message.Request if _, err := n.codec.Unmarshal(request, &req); err != nil { - log.Debug("failed to unmarshal app request", "nodeID", nodeID, "requestID", requestID, "requestLen", len(request), "err", err) - return nil + log.Debug("forwarding AppRequest to SDK router", "nodeID", nodeID, "requestID", requestID, "requestLen", len(request), "err", err) + return n.router.AppRequest(ctx, nodeID, requestID, deadline, request) } bufferedDeadline, err := calculateTimeUntilDeadline(deadline, n.appStats) @@ -366,21 +355,13 @@ func (n *network) AppRequest(ctx context.Context, nodeID ids.NodeID, requestID u // Error returned by this function is expected to be treated as fatal by the engine // If [requestID] is not known, this function will emit a log and return a nil error. // If the response handler returns an error it is propagated as a fatal error. -func (n *network) AppResponse(_ context.Context, nodeID ids.NodeID, requestID uint32, response []byte) error { - n.lock.Lock() - defer n.lock.Unlock() - - if n.closed.Get() { - return nil - } - +func (n *network) AppResponse(ctx context.Context, nodeID ids.NodeID, requestID uint32, response []byte) error { log.Debug("received AppResponse from peer", "nodeID", nodeID, "requestID", requestID) handler, exists := n.markRequestFulfilled(requestID) if !exists { - // Should never happen since the engine should be managing outstanding requests - log.Error("received AppResponse to unknown request", "nodeID", nodeID, "requestID", requestID, "responseLen", len(response)) - return nil + log.Debug("forwarding AppResponse to SDK router", "nodeID", nodeID, "requestID", requestID, "responseLen", len(response)) + return n.router.AppResponse(ctx, nodeID, requestID, response) } // We must release the slot @@ -395,21 +376,13 @@ func (n *network) AppResponse(_ context.Context, nodeID ids.NodeID, requestID ui // - request times out before a response is provided // error returned by this function is expected to be treated as fatal by the engine // returns error only when the response handler returns an error -func (n *network) AppRequestFailed(_ context.Context, nodeID ids.NodeID, requestID uint32) error { - n.lock.Lock() - defer n.lock.Unlock() - - if n.closed.Get() { - return nil - } - +func (n *network) AppRequestFailed(ctx context.Context, nodeID ids.NodeID, requestID uint32) error { log.Debug("received AppRequestFailed from peer", "nodeID", nodeID, "requestID", requestID) handler, exists := n.markRequestFulfilled(requestID) if !exists { - // Should never happen since the engine should be managing outstanding requests - log.Error("received AppRequestFailed to unknown request", "nodeID", nodeID, "requestID", requestID) - return nil + log.Debug("forwarding AppRequestFailed to SDK router", "nodeID", nodeID, "requestID", requestID) + return n.router.AppRequestFailed(ctx, nodeID, requestID) } // We must release the slot @@ -442,8 +415,11 @@ func calculateTimeUntilDeadline(deadline time.Time, stats stats.RequestHandlerSt // markRequestFulfilled fetches the handler for [requestID] and marks the request with [requestID] as having been fulfilled. // This is called by either [AppResponse] or [AppRequestFailed]. -// Assumes that the write lock is held. +// Assumes that the write lock is not held. func (n *network) markRequestFulfilled(requestID uint32) (message.ResponseHandler, bool) { + n.lock.Lock() + defer n.lock.Unlock() + handler, exists := n.outstandingRequestHandlers[requestID] if !exists { return nil, false @@ -467,10 +443,6 @@ func (n *network) Gossip(gossip []byte) error { // error returned by this function is expected to be treated as fatal by the engine // returns error if request could not be parsed as message.Request or when the requestHandler returns an error func (n *network) AppGossip(_ context.Context, nodeID ids.NodeID, gossipBytes []byte) error { - if n.closed.Get() { - return nil - } - var gossipMsg message.GossipMessage if _, err := n.codec.Unmarshal(gossipBytes, &gossipMsg); err != nil { log.Debug("could not parse app gossip", "nodeID", nodeID, "gossipLen", len(gossipBytes), "err", err) @@ -564,3 +536,15 @@ func (n *network) TrackBandwidth(nodeID ids.NodeID, bandwidth float64) { n.peers.TrackBandwidth(nodeID, bandwidth) } + +// invariant: peer/network must use explicitly even request ids. +// for this reason, [n.requestID] is initialized as zero and incremented by 2. +// This is for backwards-compatibility while the SDK router exists with the +// legacy coreth handlers to avoid a (very) narrow edge case where request ids +// can overlap, resulting in a dropped timeout. +func (n *network) nextRequestID() uint32 { + next := n.requestIDGen + n.requestIDGen += 2 + + return next +} diff --git a/peer/network_test.go b/peer/network_test.go index 3e1c32f492..add57616de 100644 --- a/peer/network_test.go +++ b/peer/network_test.go @@ -12,7 +12,9 @@ import ( "testing" "time" + "github.com/ava-labs/avalanchego/network/p2p" "github.com/ava-labs/avalanchego/snow/engine/common" + "github.com/ava-labs/avalanchego/utils/logging" "github.com/ava-labs/avalanchego/utils/set" ethcommon "github.com/ethereum/go-ethereum/common" @@ -49,11 +51,13 @@ var ( _ message.CrossChainRequest = &ExampleCrossChainRequest{} _ message.CrossChainRequestHandler = &testCrossChainHandler{} + + _ p2p.Handler = &testSDKHandler{} ) func TestNetworkDoesNotConnectToItself(t *testing.T) { selfNodeID := ids.GenerateTestNodeID() - n := NewNetwork(nil, nil, nil, selfNodeID, 1, 1) + n := NewNetwork(p2p.NewRouter(logging.NoLog{}, nil), nil, nil, nil, selfNodeID, 1, 1) assert.NoError(t, n.Connected(context.Background(), selfNodeID, defaultPeerVersion)) assert.EqualValues(t, 0, n.Size()) } @@ -89,7 +93,7 @@ func TestRequestAnyRequestsRoutingAndResponse(t *testing.T) { codecManager := buildCodec(t, HelloRequest{}, HelloResponse{}) crossChainCodecManager := buildCodec(t, ExampleCrossChainRequest{}, ExampleCrossChainResponse{}) - net = NewNetwork(sender, codecManager, crossChainCodecManager, ids.EmptyNodeID, 16, 16) + net = NewNetwork(p2p.NewRouter(logging.NoLog{}, nil), sender, codecManager, crossChainCodecManager, ids.EmptyNodeID, 16, 16) net.SetRequestHandler(&HelloGreetingRequestHandler{codec: codecManager}) client := NewNetworkClient(net) nodeID := ids.GenerateTestNodeID() @@ -164,7 +168,7 @@ func TestRequestRequestsRoutingAndResponse(t *testing.T) { codecManager := buildCodec(t, HelloRequest{}, HelloResponse{}) crossChainCodecManager := buildCodec(t, ExampleCrossChainRequest{}, ExampleCrossChainResponse{}) - net = NewNetwork(sender, codecManager, crossChainCodecManager, ids.EmptyNodeID, 16, 16) + net = NewNetwork(p2p.NewRouter(logging.NoLog{}, nil), sender, codecManager, crossChainCodecManager, ids.EmptyNodeID, 16, 16) net.SetRequestHandler(&HelloGreetingRequestHandler{codec: codecManager}) client := NewNetworkClient(net) @@ -244,7 +248,7 @@ func TestAppRequestOnShutdown(t *testing.T) { codecManager := buildCodec(t, HelloRequest{}, HelloResponse{}) crossChainCodecManager := buildCodec(t, ExampleCrossChainRequest{}, ExampleCrossChainResponse{}) - net = NewNetwork(sender, codecManager, crossChainCodecManager, ids.EmptyNodeID, 1, 1) + net = NewNetwork(p2p.NewRouter(logging.NoLog{}, nil), sender, codecManager, crossChainCodecManager, ids.EmptyNodeID, 1, 1) client := NewNetworkClient(net) nodeID := ids.GenerateTestNodeID() require.NoError(t, net.Connected(context.Background(), nodeID, defaultPeerVersion)) @@ -293,7 +297,7 @@ func TestRequestMinVersion(t *testing.T) { } // passing nil as codec works because the net.AppRequest is never called - net = NewNetwork(sender, codecManager, crossChainCodecManager, ids.EmptyNodeID, 1, 16) + net = NewNetwork(p2p.NewRouter(logging.NoLog{}, nil), sender, codecManager, crossChainCodecManager, ids.EmptyNodeID, 1, 16) client := NewNetworkClient(net) requestMessage := TestMessage{Message: "this is a request"} requestBytes, err := message.RequestToBytes(codecManager, requestMessage) @@ -356,7 +360,7 @@ func TestOnRequestHonoursDeadline(t *testing.T) { processingDuration: 500 * time.Millisecond, } - net = NewNetwork(sender, codecManager, crossChainCodecManager, ids.EmptyNodeID, 1, 1) + net = NewNetwork(p2p.NewRouter(logging.NoLog{}, nil), sender, codecManager, crossChainCodecManager, ids.EmptyNodeID, 1, 1) net.SetRequestHandler(requestHandler) nodeID := ids.GenerateTestNodeID() @@ -396,7 +400,7 @@ func TestGossip(t *testing.T) { } gossipHandler := &testGossipHandler{} - clientNetwork = NewNetwork(sender, codecManager, crossChainCodecManager, ids.EmptyNodeID, 1, 1) + clientNetwork = NewNetwork(p2p.NewRouter(logging.NoLog{}, nil), sender, codecManager, crossChainCodecManager, ids.EmptyNodeID, 1, 1) clientNetwork.SetGossipHandler(gossipHandler) assert.NoError(t, clientNetwork.Connected(context.Background(), nodeID, defaultPeerVersion)) @@ -423,7 +427,7 @@ func TestHandleInvalidMessages(t *testing.T) { requestID := uint32(1) sender := testAppSender{} - clientNetwork := NewNetwork(sender, codecManager, crossChainCodecManager, ids.EmptyNodeID, 1, 1) + clientNetwork := NewNetwork(p2p.NewRouter(logging.NoLog{}, nil), sender, codecManager, crossChainCodecManager, ids.EmptyNodeID, 1, 1) clientNetwork.SetGossipHandler(message.NoopMempoolGossipHandler{}) clientNetwork.SetRequestHandler(&testRequestHandler{}) @@ -457,12 +461,11 @@ func TestHandleInvalidMessages(t *testing.T) { assert.NoError(t, clientNetwork.AppRequest(context.Background(), nodeID, requestID, time.Now().Add(time.Second), garbageResponse)) assert.NoError(t, clientNetwork.AppRequest(context.Background(), nodeID, requestID, time.Now().Add(time.Second), emptyResponse)) assert.NoError(t, clientNetwork.AppRequest(context.Background(), nodeID, requestID, time.Now().Add(time.Second), nilResponse)) - assert.NoError(t, clientNetwork.AppResponse(context.Background(), nodeID, requestID, gossipMsg)) - assert.NoError(t, clientNetwork.AppResponse(context.Background(), nodeID, requestID, requestMessage)) - assert.NoError(t, clientNetwork.AppResponse(context.Background(), nodeID, requestID, garbageResponse)) - assert.NoError(t, clientNetwork.AppResponse(context.Background(), nodeID, requestID, emptyResponse)) - assert.NoError(t, clientNetwork.AppResponse(context.Background(), nodeID, requestID, nilResponse)) - assert.NoError(t, clientNetwork.AppRequestFailed(context.Background(), nodeID, requestID)) + assert.ErrorIs(t, p2p.ErrUnrequestedResponse, clientNetwork.AppResponse(context.Background(), nodeID, requestID, gossipMsg)) + assert.ErrorIs(t, p2p.ErrUnrequestedResponse, clientNetwork.AppResponse(context.Background(), nodeID, requestID, requestMessage)) + assert.ErrorIs(t, p2p.ErrUnrequestedResponse, clientNetwork.AppResponse(context.Background(), nodeID, requestID, garbageResponse)) + assert.ErrorIs(t, p2p.ErrUnrequestedResponse, clientNetwork.AppResponse(context.Background(), nodeID, requestID, emptyResponse)) + assert.ErrorIs(t, p2p.ErrUnrequestedResponse, clientNetwork.AppResponse(context.Background(), nodeID, requestID, nilResponse)) } func TestNetworkPropagatesRequestHandlerError(t *testing.T) { @@ -473,7 +476,7 @@ func TestNetworkPropagatesRequestHandlerError(t *testing.T) { requestID := uint32(1) sender := testAppSender{} - clientNetwork := NewNetwork(sender, codecManager, crossChainCodecManager, ids.EmptyNodeID, 1, 1) + clientNetwork := NewNetwork(p2p.NewRouter(logging.NoLog{}, nil), sender, codecManager, crossChainCodecManager, ids.EmptyNodeID, 1, 1) clientNetwork.SetGossipHandler(message.NoopMempoolGossipHandler{}) clientNetwork.SetRequestHandler(&testRequestHandler{err: errors.New("fail")}) // Return an error from the request handler @@ -513,7 +516,7 @@ func TestCrossChainAppRequest(t *testing.T) { }, } - net = NewNetwork(sender, codecManager, crossChainCodecManager, ids.EmptyNodeID, 1, 1) + net = NewNetwork(p2p.NewRouter(logging.NoLog{}, nil), sender, codecManager, crossChainCodecManager, ids.EmptyNodeID, 1, 1) net.SetCrossChainRequestHandler(&testCrossChainHandler{codec: crossChainCodecManager}) client := NewNetworkClient(net) @@ -568,7 +571,7 @@ func TestCrossChainRequestRequestsRoutingAndResponse(t *testing.T) { codecManager := buildCodec(t, TestMessage{}) crossChainCodecManager := buildCodec(t, ExampleCrossChainRequest{}, ExampleCrossChainResponse{}) - net = NewNetwork(sender, codecManager, crossChainCodecManager, ids.EmptyNodeID, 1, 1) + net = NewNetwork(p2p.NewRouter(logging.NoLog{}, nil), sender, codecManager, crossChainCodecManager, ids.EmptyNodeID, 1, 1) net.SetCrossChainRequestHandler(&testCrossChainHandler{codec: crossChainCodecManager}) client := NewNetworkClient(net) @@ -628,7 +631,7 @@ func TestCrossChainRequestOnShutdown(t *testing.T) { } codecManager := buildCodec(t, TestMessage{}) crossChainCodecManager := buildCodec(t, ExampleCrossChainRequest{}, ExampleCrossChainResponse{}) - net = NewNetwork(sender, codecManager, crossChainCodecManager, ids.EmptyNodeID, 1, 1) + net = NewNetwork(p2p.NewRouter(logging.NoLog{}, nil), sender, codecManager, crossChainCodecManager, ids.EmptyNodeID, 1, 1) client := NewNetworkClient(net) exampleCrossChainRequest := ExampleCrossChainRequest{ @@ -649,6 +652,48 @@ func TestCrossChainRequestOnShutdown(t *testing.T) { require.True(t, called) } +func TestSDKRouting(t *testing.T) { + require := require.New(t) + sender := &testAppSender{ + sendAppRequestFn: func(s set.Set[ids.NodeID], u uint32, bytes []byte) error { + return nil + }, + sendAppResponseFn: func(id ids.NodeID, u uint32, bytes []byte) error { + return nil + }, + } + protocol := 0 + handler := &testSDKHandler{} + router := p2p.NewRouter(logging.NoLog{}, sender) + _, err := router.RegisterAppProtocol(uint64(protocol), handler) + require.NoError(err) + + networkCodec := codec.NewManager(0) + crossChainCodec := codec.NewManager(0) + + network := NewNetwork( + router, + nil, + networkCodec, + crossChainCodec, + ids.EmptyNodeID, + 1, + 1, + ) + + nodeID := ids.GenerateTestNodeID() + foobar := append([]byte{byte(protocol)}, []byte("foobar")...) + err = network.AppRequest(context.Background(), nodeID, 0, time.Time{}, foobar) + require.NoError(err) + require.True(handler.appRequested) + + err = network.AppResponse(context.Background(), ids.GenerateTestNodeID(), 0, foobar) + require.ErrorIs(err, p2p.ErrUnrequestedResponse) + + err = network.AppRequestFailed(context.Background(), nodeID, 0) + require.ErrorIs(err, p2p.ErrUnrequestedResponse) +} + func buildCodec(t *testing.T, types ...interface{}) codec.Manager { codecManager := codec.NewDefaultManager() c := linearcodec.NewDefault() @@ -850,3 +895,22 @@ type testCrossChainHandler struct { func (t *testCrossChainHandler) HandleCrossChainRequest(ctx context.Context, requestingChainID ids.ID, requestID uint32, exampleRequest message.CrossChainRequest) ([]byte, error) { return t.codec.Marshal(message.Version, ExampleCrossChainResponse{Response: "this is an example response"}) } + +type testSDKHandler struct { + appRequested bool +} + +func (t *testSDKHandler) AppGossip(ctx context.Context, nodeID ids.NodeID, gossipBytes []byte) error { + // TODO implement me + panic("implement me") +} + +func (t *testSDKHandler) AppRequest(ctx context.Context, nodeID ids.NodeID, deadline time.Time, requestBytes []byte) ([]byte, error) { + t.appRequested = true + return nil, nil +} + +func (t *testSDKHandler) CrossChainAppRequest(ctx context.Context, chainID ids.ID, deadline time.Time, requestBytes []byte) ([]byte, error) { + // TODO implement me + panic("implement me") +} diff --git a/plugin/evm/vm.go b/plugin/evm/vm.go index a62245587d..2e5a3f81b6 100644 --- a/plugin/evm/vm.go +++ b/plugin/evm/vm.go @@ -17,6 +17,7 @@ import ( "time" avalanchegoMetrics "github.com/ava-labs/avalanchego/api/metrics" + "github.com/ava-labs/avalanchego/network/p2p" "github.com/ava-labs/coreth/consensus/dummy" corethConstants "github.com/ava-labs/coreth/constants" @@ -276,6 +277,8 @@ type VM struct { client peer.NetworkClient networkCodec codec.Manager + router *p2p.Router + // Metrics multiGatherer avalanchegoMetrics.MultiGatherer @@ -506,8 +509,9 @@ func (vm *VM) Initialize( } // initialize peer network + vm.router = p2p.NewRouter(vm.ctx.Log, appSender) vm.networkCodec = message.Codec - vm.Network = peer.NewNetwork(appSender, vm.networkCodec, message.CrossChainCodec, chainCtx.NodeID, vm.config.MaxOutboundActiveRequests, vm.config.MaxOutboundActiveCrossChainRequests) + vm.Network = peer.NewNetwork(vm.router, appSender, vm.networkCodec, message.CrossChainCodec, chainCtx.NodeID, vm.config.MaxOutboundActiveRequests, vm.config.MaxOutboundActiveCrossChainRequests) vm.client = peer.NewNetworkClient(vm.Network) if err := vm.initializeChain(lastAcceptedHash); err != nil {