From 3514225d5fb5e1c33bbdeea0cd6749e7563c3657 Mon Sep 17 00:00:00 2001 From: wojo Date: Wed, 15 Jan 2025 17:49:44 +0700 Subject: [PATCH] Ensure p2p protocol matches new Starknet spec --- p2p/p2p.go | 26 ++++++++--------- p2p/p2p_test.go | 36 ++++++++++++++++++++++++ p2p/sync/client.go | 10 +++---- p2p/sync/ids.go | 26 ++++++++++------- p2p/sync/ids_test.go | 67 ++++++++++++++++++++++++++++++++++++++++++++ 5 files changed, 137 insertions(+), 28 deletions(-) create mode 100644 p2p/sync/ids_test.go diff --git a/p2p/p2p.go b/p2p/p2p.go index ddb1ed955..e072814d8 100644 --- a/p2p/p2p.go +++ b/p2p/p2p.go @@ -52,7 +52,7 @@ type Service struct { database db.DB } -func New(addr, publicAddr, version, peers, privKeyStr string, feederNode bool, bc *blockchain.Blockchain, snNetwork *utils.Network, +func New(addr, publicAddr, version, peers, privKeyStr string, feederNode bool, bc *blockchain.Blockchain, net *utils.Network, log utils.SimpleLogger, database db.DB, ) (*Service, error) { if addr == "" { @@ -110,10 +110,10 @@ func New(addr, publicAddr, version, peers, privKeyStr string, feederNode bool, b // Todo: try to understand what will happen if user passes a multiaddr with p2p public and a private key which doesn't match. // For example, a user passes the following multiaddr: --p2p-addr=/ip4/0.0.0.0/tcp/7778/p2p/(SomePublicKey) and also passes a // --p2p-private-key="SomePrivateKey". However, the private public key pair don't match, in this case what will happen? - return NewWithHost(p2pHost, peers, feederNode, bc, snNetwork, log, database) + return NewWithHost(p2pHost, peers, feederNode, bc, net, log, database) } -func NewWithHost(p2phost host.Host, peers string, feederNode bool, bc *blockchain.Blockchain, snNetwork *utils.Network, +func NewWithHost(p2phost host.Host, peers string, feederNode bool, bc *blockchain.Blockchain, net *utils.Network, log utils.SimpleLogger, database db.DB, ) (*Service, error) { var ( @@ -139,19 +139,19 @@ func NewWithHost(p2phost host.Host, peers string, feederNode bool, bc *blockchai } } - p2pdht, err := makeDHT(p2phost, peersAddrInfoS) + p2pdht, err := MakeDHT(p2phost, peersAddrInfoS, net) if err != nil { return nil, err } // todo: reconsider initialising synchroniser here because if node is a feedernode we shouldn't not create an instance of it. - synchroniser := p2pSync.New(bc, p2phost, snNetwork, log) + synchroniser := p2pSync.New(bc, p2phost, net, log) s := &Service{ synchroniser: synchroniser, log: log, host: p2phost, - network: snNetwork, + network: net, dht: p2pdht, feederNode: feederNode, handler: p2pPeers.NewHandler(bc, log), @@ -160,9 +160,9 @@ func NewWithHost(p2phost host.Host, peers string, feederNode bool, bc *blockchai return s, nil } -func makeDHT(p2phost host.Host, addrInfos []peer.AddrInfo) (*dht.IpfsDHT, error) { +func MakeDHT(p2phost host.Host, addrInfos []peer.AddrInfo, net *utils.Network) (*dht.IpfsDHT, error) { return dht.New(context.Background(), p2phost, - dht.ProtocolPrefix(p2pSync.Prefix), + dht.ProtocolPrefix(p2pSync.DHTPrefixPID(net)), dht.BootstrapPeers(addrInfos...), dht.RoutingTableRefreshPeriod(routingTableRefreshPeriod), dht.Mode(dht.ModeServer), @@ -250,11 +250,11 @@ func (s *Service) Run(ctx context.Context) error { } func (s *Service) setProtocolHandlers() { - s.SetProtocolHandler(p2pSync.HeadersPID(), s.handler.HeadersHandler) - s.SetProtocolHandler(p2pSync.EventsPID(), s.handler.EventsHandler) - s.SetProtocolHandler(p2pSync.TransactionsPID(), s.handler.TransactionsHandler) - s.SetProtocolHandler(p2pSync.ClassesPID(), s.handler.ClassesHandler) - s.SetProtocolHandler(p2pSync.StateDiffPID(), s.handler.StateDiffHandler) + s.SetProtocolHandler(p2pSync.HeadersPID(s.network), s.handler.HeadersHandler) + s.SetProtocolHandler(p2pSync.EventsPID(s.network), s.handler.EventsHandler) + s.SetProtocolHandler(p2pSync.TransactionsPID(s.network), s.handler.TransactionsHandler) + s.SetProtocolHandler(p2pSync.ClassesPID(s.network), s.handler.ClassesHandler) + s.SetProtocolHandler(p2pSync.StateDiffPID(s.network), s.handler.StateDiffHandler) } func (s *Service) callAndLogErr(f func() error, msg string) { diff --git a/p2p/p2p_test.go b/p2p/p2p_test.go index 54b19d590..0e10a4c04 100644 --- a/p2p/p2p_test.go +++ b/p2p/p2p_test.go @@ -8,7 +8,10 @@ import ( "github.com/NethermindEth/juno/p2p" "github.com/NethermindEth/juno/utils" "github.com/libp2p/go-libp2p/core/peer" + "github.com/libp2p/go-libp2p/core/protocol" + mocknet "github.com/libp2p/go-libp2p/p2p/net/mock" "github.com/multiformats/go-multiaddr" + "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) @@ -64,3 +67,36 @@ func TestLoadAndPersistPeers(t *testing.T) { ) require.NoError(t, err) } + +func TestMakeDHTProtocolName(t *testing.T) { + net, err := mocknet.FullMeshLinked(1) + require.NoError(t, err) + testHost := net.Hosts()[0] + + testCases := []struct { + name string + network *utils.Network + expected string + }{ + { + name: "sepolia network", + network: &utils.Sepolia, + expected: "/starknet/SN_SEPOLIA/sync/kad/1.0.0", + }, + { + name: "mainnet network", + network: &utils.Mainnet, + expected: "/starknet/SN_MAIN/sync/kad/1.0.0", + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + dht, err := p2p.MakeDHT(testHost, nil, tc.network) + require.NoError(t, err) + + protocols := dht.Host().Mux().Protocols() + assert.Contains(t, protocols, protocol.ID(tc.expected), "protocol list: %v", protocols) + }) + } +} diff --git a/p2p/sync/client.go b/p2p/sync/client.go index 5f688a037..5d4271df8 100644 --- a/p2p/sync/client.go +++ b/p2p/sync/client.go @@ -104,22 +104,22 @@ func (c *Client) RequestBlockHeaders( ctx context.Context, req *gen.BlockHeadersRequest, ) (iter.Seq[*gen.BlockHeadersResponse], error) { return requestAndReceiveStream[*gen.BlockHeadersRequest, *gen.BlockHeadersResponse]( - ctx, c.newStream, HeadersPID(), req, c.log) + ctx, c.newStream, HeadersPID(c.network), req, c.log) } func (c *Client) RequestEvents(ctx context.Context, req *gen.EventsRequest) (iter.Seq[*gen.EventsResponse], error) { - return requestAndReceiveStream[*gen.EventsRequest, *gen.EventsResponse](ctx, c.newStream, EventsPID(), req, c.log) + return requestAndReceiveStream[*gen.EventsRequest, *gen.EventsResponse](ctx, c.newStream, EventsPID(c.network), req, c.log) } func (c *Client) RequestClasses(ctx context.Context, req *gen.ClassesRequest) (iter.Seq[*gen.ClassesResponse], error) { - return requestAndReceiveStream[*gen.ClassesRequest, *gen.ClassesResponse](ctx, c.newStream, ClassesPID(), req, c.log) + return requestAndReceiveStream[*gen.ClassesRequest, *gen.ClassesResponse](ctx, c.newStream, ClassesPID(c.network), req, c.log) } func (c *Client) RequestStateDiffs(ctx context.Context, req *gen.StateDiffsRequest) (iter.Seq[*gen.StateDiffsResponse], error) { - return requestAndReceiveStream[*gen.StateDiffsRequest, *gen.StateDiffsResponse](ctx, c.newStream, StateDiffPID(), req, c.log) + return requestAndReceiveStream[*gen.StateDiffsRequest, *gen.StateDiffsResponse](ctx, c.newStream, StateDiffPID(c.network), req, c.log) } func (c *Client) RequestTransactions(ctx context.Context, req *gen.TransactionsRequest) (iter.Seq[*gen.TransactionsResponse], error) { return requestAndReceiveStream[*gen.TransactionsRequest, *gen.TransactionsResponse]( - ctx, c.newStream, TransactionsPID(), req, c.log) + ctx, c.newStream, TransactionsPID(c.network), req, c.log) } diff --git a/p2p/sync/ids.go b/p2p/sync/ids.go index 284875e36..cfdfedb8f 100644 --- a/p2p/sync/ids.go +++ b/p2p/sync/ids.go @@ -2,26 +2,32 @@ package sync import ( "github.com/libp2p/go-libp2p/core/protocol" + + "github.com/NethermindEth/juno/utils" ) const Prefix = "/starknet" -func HeadersPID() protocol.ID { - return Prefix + "/headers/0.1.0-rc.0" +func HeadersPID(network *utils.Network) protocol.ID { + return protocol.ID(Prefix + "/" + network.L2ChainID + "/sync/headers/0.1.0-rc.0") +} + +func EventsPID(network *utils.Network) protocol.ID { + return protocol.ID(Prefix + "/" + network.L2ChainID + "/sync/events/0.1.0-rc.0") } -func EventsPID() protocol.ID { - return Prefix + "/events/0.1.0-rc.0" +func TransactionsPID(network *utils.Network) protocol.ID { + return protocol.ID(Prefix + "/" + network.L2ChainID + "/sync/transactions/0.1.0-rc.0") } -func TransactionsPID() protocol.ID { - return Prefix + "/transactions/0.1.0-rc.0" +func ClassesPID(network *utils.Network) protocol.ID { + return protocol.ID(Prefix + "/" + network.L2ChainID + "/sync/classes/0.1.0-rc.0") } -func ClassesPID() protocol.ID { - return Prefix + "/classes/0.1.0-rc.0" +func StateDiffPID(network *utils.Network) protocol.ID { + return protocol.ID(Prefix + "/" + network.L2ChainID + "/sync/state_diffs/0.1.0-rc.0") } -func StateDiffPID() protocol.ID { - return Prefix + "/state_diffs/0.1.0-rc.0" +func DHTPrefixPID(network *utils.Network) protocol.ID { + return protocol.ID(Prefix + "/" + network.L2ChainID + "/sync") } diff --git a/p2p/sync/ids_test.go b/p2p/sync/ids_test.go new file mode 100644 index 000000000..8ad6ae6d3 --- /dev/null +++ b/p2p/sync/ids_test.go @@ -0,0 +1,67 @@ +package sync + +import ( + "testing" + + "github.com/NethermindEth/juno/utils" + "github.com/stretchr/testify/assert" +) + +func TestProtocolIDs(t *testing.T) { + testCases := []struct { + name string + network *utils.Network + pidFunc func(*utils.Network) string + expected string + }{ + { + name: "HeadersPID with SN_MAIN", + network: &utils.Mainnet, + pidFunc: func(n *utils.Network) string { return string(HeadersPID(n)) }, + expected: "/starknet/SN_MAIN/sync/headers/0.1.0-rc.0", + }, + { + name: "EventsPID with SN_MAIN", + network: &utils.Mainnet, + pidFunc: func(n *utils.Network) string { return string(EventsPID(n)) }, + expected: "/starknet/SN_MAIN/sync/events/0.1.0-rc.0", + }, + { + name: "TransactionsPID with SN_MAIN", + network: &utils.Mainnet, + pidFunc: func(n *utils.Network) string { return string(TransactionsPID(n)) }, + expected: "/starknet/SN_MAIN/sync/transactions/0.1.0-rc.0", + }, + { + name: "ClassesPID with SN_MAIN", + network: &utils.Mainnet, + pidFunc: func(n *utils.Network) string { return string(ClassesPID(n)) }, + expected: "/starknet/SN_MAIN/sync/classes/0.1.0-rc.0", + }, + { + name: "StateDiffPID with SN_MAIN", + network: &utils.Mainnet, + pidFunc: func(n *utils.Network) string { return string(StateDiffPID(n)) }, + expected: "/starknet/SN_MAIN/sync/state_diffs/0.1.0-rc.0", + }, + { + name: "DHTPrefixPID with SN_MAIN", + network: &utils.Mainnet, + pidFunc: func(n *utils.Network) string { return string(DHTPrefixPID(n)) }, + expected: "/starknet/SN_MAIN/sync", + }, + { + name: "HeadersPID with SN_SEPOLIA", + network: &utils.Sepolia, + pidFunc: func(n *utils.Network) string { return string(HeadersPID(n)) }, + expected: "/starknet/SN_SEPOLIA/sync/headers/0.1.0-rc.0", + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + result := tc.pidFunc(tc.network) + assert.Equal(t, tc.expected, result) + }) + } +}