Skip to content

Commit

Permalink
Ensure p2p protocol matches new Starknet spec
Browse files Browse the repository at this point in the history
  • Loading branch information
wojciechos committed Jan 16, 2025
1 parent 0c6508c commit 3514225
Show file tree
Hide file tree
Showing 5 changed files with 137 additions and 28 deletions.
26 changes: 13 additions & 13 deletions p2p/p2p.go
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ type Service struct {
database db.DB
}

func New(addr, publicAddr, version, peers, privKeyStr string, feederNode bool, bc *blockchain.Blockchain, snNetwork *utils.Network,
func New(addr, publicAddr, version, peers, privKeyStr string, feederNode bool, bc *blockchain.Blockchain, net *utils.Network,
log utils.SimpleLogger, database db.DB,
) (*Service, error) {
if addr == "" {
Expand Down Expand Up @@ -110,10 +110,10 @@ func New(addr, publicAddr, version, peers, privKeyStr string, feederNode bool, b
// Todo: try to understand what will happen if user passes a multiaddr with p2p public and a private key which doesn't match.
// For example, a user passes the following multiaddr: --p2p-addr=/ip4/0.0.0.0/tcp/7778/p2p/(SomePublicKey) and also passes a
// --p2p-private-key="SomePrivateKey". However, the private public key pair don't match, in this case what will happen?
return NewWithHost(p2pHost, peers, feederNode, bc, snNetwork, log, database)
return NewWithHost(p2pHost, peers, feederNode, bc, net, log, database)
}

func NewWithHost(p2phost host.Host, peers string, feederNode bool, bc *blockchain.Blockchain, snNetwork *utils.Network,
func NewWithHost(p2phost host.Host, peers string, feederNode bool, bc *blockchain.Blockchain, net *utils.Network,
log utils.SimpleLogger, database db.DB,
) (*Service, error) {
var (
Expand All @@ -139,19 +139,19 @@ func NewWithHost(p2phost host.Host, peers string, feederNode bool, bc *blockchai
}
}

p2pdht, err := makeDHT(p2phost, peersAddrInfoS)
p2pdht, err := MakeDHT(p2phost, peersAddrInfoS, net)
if err != nil {
return nil, err
}

// todo: reconsider initialising synchroniser here because if node is a feedernode we shouldn't not create an instance of it.

synchroniser := p2pSync.New(bc, p2phost, snNetwork, log)
synchroniser := p2pSync.New(bc, p2phost, net, log)
s := &Service{
synchroniser: synchroniser,
log: log,
host: p2phost,
network: snNetwork,
network: net,
dht: p2pdht,
feederNode: feederNode,
handler: p2pPeers.NewHandler(bc, log),
Expand All @@ -160,9 +160,9 @@ func NewWithHost(p2phost host.Host, peers string, feederNode bool, bc *blockchai
return s, nil
}

func makeDHT(p2phost host.Host, addrInfos []peer.AddrInfo) (*dht.IpfsDHT, error) {
func MakeDHT(p2phost host.Host, addrInfos []peer.AddrInfo, net *utils.Network) (*dht.IpfsDHT, error) {
return dht.New(context.Background(), p2phost,
dht.ProtocolPrefix(p2pSync.Prefix),
dht.ProtocolPrefix(p2pSync.DHTPrefixPID(net)),
dht.BootstrapPeers(addrInfos...),
dht.RoutingTableRefreshPeriod(routingTableRefreshPeriod),
dht.Mode(dht.ModeServer),
Expand Down Expand Up @@ -250,11 +250,11 @@ func (s *Service) Run(ctx context.Context) error {
}

func (s *Service) setProtocolHandlers() {
s.SetProtocolHandler(p2pSync.HeadersPID(), s.handler.HeadersHandler)
s.SetProtocolHandler(p2pSync.EventsPID(), s.handler.EventsHandler)
s.SetProtocolHandler(p2pSync.TransactionsPID(), s.handler.TransactionsHandler)
s.SetProtocolHandler(p2pSync.ClassesPID(), s.handler.ClassesHandler)
s.SetProtocolHandler(p2pSync.StateDiffPID(), s.handler.StateDiffHandler)
s.SetProtocolHandler(p2pSync.HeadersPID(s.network), s.handler.HeadersHandler)
s.SetProtocolHandler(p2pSync.EventsPID(s.network), s.handler.EventsHandler)
s.SetProtocolHandler(p2pSync.TransactionsPID(s.network), s.handler.TransactionsHandler)
s.SetProtocolHandler(p2pSync.ClassesPID(s.network), s.handler.ClassesHandler)
s.SetProtocolHandler(p2pSync.StateDiffPID(s.network), s.handler.StateDiffHandler)
}

func (s *Service) callAndLogErr(f func() error, msg string) {
Expand Down
36 changes: 36 additions & 0 deletions p2p/p2p_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,10 @@ import (
"github.com/NethermindEth/juno/p2p"
"github.com/NethermindEth/juno/utils"
"github.com/libp2p/go-libp2p/core/peer"
"github.com/libp2p/go-libp2p/core/protocol"
mocknet "github.com/libp2p/go-libp2p/p2p/net/mock"
"github.com/multiformats/go-multiaddr"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)

Expand Down Expand Up @@ -64,3 +67,36 @@ func TestLoadAndPersistPeers(t *testing.T) {
)
require.NoError(t, err)
}

func TestMakeDHTProtocolName(t *testing.T) {
net, err := mocknet.FullMeshLinked(1)
require.NoError(t, err)
testHost := net.Hosts()[0]

testCases := []struct {
name string
network *utils.Network
expected string
}{
{
name: "sepolia network",
network: &utils.Sepolia,
expected: "/starknet/SN_SEPOLIA/sync/kad/1.0.0",
},
{
name: "mainnet network",
network: &utils.Mainnet,
expected: "/starknet/SN_MAIN/sync/kad/1.0.0",
},
}

for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
dht, err := p2p.MakeDHT(testHost, nil, tc.network)
require.NoError(t, err)

protocols := dht.Host().Mux().Protocols()
assert.Contains(t, protocols, protocol.ID(tc.expected), "protocol list: %v", protocols)
})
}
}
10 changes: 5 additions & 5 deletions p2p/sync/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -104,22 +104,22 @@ func (c *Client) RequestBlockHeaders(
ctx context.Context, req *gen.BlockHeadersRequest,
) (iter.Seq[*gen.BlockHeadersResponse], error) {
return requestAndReceiveStream[*gen.BlockHeadersRequest, *gen.BlockHeadersResponse](
ctx, c.newStream, HeadersPID(), req, c.log)
ctx, c.newStream, HeadersPID(c.network), req, c.log)
}

func (c *Client) RequestEvents(ctx context.Context, req *gen.EventsRequest) (iter.Seq[*gen.EventsResponse], error) {
return requestAndReceiveStream[*gen.EventsRequest, *gen.EventsResponse](ctx, c.newStream, EventsPID(), req, c.log)
return requestAndReceiveStream[*gen.EventsRequest, *gen.EventsResponse](ctx, c.newStream, EventsPID(c.network), req, c.log)
}

func (c *Client) RequestClasses(ctx context.Context, req *gen.ClassesRequest) (iter.Seq[*gen.ClassesResponse], error) {
return requestAndReceiveStream[*gen.ClassesRequest, *gen.ClassesResponse](ctx, c.newStream, ClassesPID(), req, c.log)
return requestAndReceiveStream[*gen.ClassesRequest, *gen.ClassesResponse](ctx, c.newStream, ClassesPID(c.network), req, c.log)
}

func (c *Client) RequestStateDiffs(ctx context.Context, req *gen.StateDiffsRequest) (iter.Seq[*gen.StateDiffsResponse], error) {
return requestAndReceiveStream[*gen.StateDiffsRequest, *gen.StateDiffsResponse](ctx, c.newStream, StateDiffPID(), req, c.log)
return requestAndReceiveStream[*gen.StateDiffsRequest, *gen.StateDiffsResponse](ctx, c.newStream, StateDiffPID(c.network), req, c.log)
}

func (c *Client) RequestTransactions(ctx context.Context, req *gen.TransactionsRequest) (iter.Seq[*gen.TransactionsResponse], error) {
return requestAndReceiveStream[*gen.TransactionsRequest, *gen.TransactionsResponse](
ctx, c.newStream, TransactionsPID(), req, c.log)
ctx, c.newStream, TransactionsPID(c.network), req, c.log)
}
26 changes: 16 additions & 10 deletions p2p/sync/ids.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,26 +2,32 @@ package sync

import (
"github.com/libp2p/go-libp2p/core/protocol"

"github.com/NethermindEth/juno/utils"
)

const Prefix = "/starknet"

func HeadersPID() protocol.ID {
return Prefix + "/headers/0.1.0-rc.0"
func HeadersPID(network *utils.Network) protocol.ID {
return protocol.ID(Prefix + "/" + network.L2ChainID + "/sync/headers/0.1.0-rc.0")
}

func EventsPID(network *utils.Network) protocol.ID {
return protocol.ID(Prefix + "/" + network.L2ChainID + "/sync/events/0.1.0-rc.0")
}

func EventsPID() protocol.ID {
return Prefix + "/events/0.1.0-rc.0"
func TransactionsPID(network *utils.Network) protocol.ID {
return protocol.ID(Prefix + "/" + network.L2ChainID + "/sync/transactions/0.1.0-rc.0")
}

func TransactionsPID() protocol.ID {
return Prefix + "/transactions/0.1.0-rc.0"
func ClassesPID(network *utils.Network) protocol.ID {
return protocol.ID(Prefix + "/" + network.L2ChainID + "/sync/classes/0.1.0-rc.0")
}

func ClassesPID() protocol.ID {
return Prefix + "/classes/0.1.0-rc.0"
func StateDiffPID(network *utils.Network) protocol.ID {
return protocol.ID(Prefix + "/" + network.L2ChainID + "/sync/state_diffs/0.1.0-rc.0")
}

func StateDiffPID() protocol.ID {
return Prefix + "/state_diffs/0.1.0-rc.0"
func DHTPrefixPID(network *utils.Network) protocol.ID {
return protocol.ID(Prefix + "/" + network.L2ChainID + "/sync")
}
67 changes: 67 additions & 0 deletions p2p/sync/ids_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
package sync

import (
"testing"

"github.com/NethermindEth/juno/utils"
"github.com/stretchr/testify/assert"
)

func TestProtocolIDs(t *testing.T) {
testCases := []struct {
name string
network *utils.Network
pidFunc func(*utils.Network) string
expected string
}{
{
name: "HeadersPID with SN_MAIN",
network: &utils.Mainnet,
pidFunc: func(n *utils.Network) string { return string(HeadersPID(n)) },
expected: "/starknet/SN_MAIN/sync/headers/0.1.0-rc.0",
},
{
name: "EventsPID with SN_MAIN",
network: &utils.Mainnet,
pidFunc: func(n *utils.Network) string { return string(EventsPID(n)) },
expected: "/starknet/SN_MAIN/sync/events/0.1.0-rc.0",
},
{
name: "TransactionsPID with SN_MAIN",
network: &utils.Mainnet,
pidFunc: func(n *utils.Network) string { return string(TransactionsPID(n)) },
expected: "/starknet/SN_MAIN/sync/transactions/0.1.0-rc.0",
},
{
name: "ClassesPID with SN_MAIN",
network: &utils.Mainnet,
pidFunc: func(n *utils.Network) string { return string(ClassesPID(n)) },
expected: "/starknet/SN_MAIN/sync/classes/0.1.0-rc.0",
},
{
name: "StateDiffPID with SN_MAIN",
network: &utils.Mainnet,
pidFunc: func(n *utils.Network) string { return string(StateDiffPID(n)) },
expected: "/starknet/SN_MAIN/sync/state_diffs/0.1.0-rc.0",
},
{
name: "DHTPrefixPID with SN_MAIN",
network: &utils.Mainnet,
pidFunc: func(n *utils.Network) string { return string(DHTPrefixPID(n)) },
expected: "/starknet/SN_MAIN/sync",
},
{
name: "HeadersPID with SN_SEPOLIA",
network: &utils.Sepolia,
pidFunc: func(n *utils.Network) string { return string(HeadersPID(n)) },
expected: "/starknet/SN_SEPOLIA/sync/headers/0.1.0-rc.0",
},
}

for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
result := tc.pidFunc(tc.network)
assert.Equal(t, tc.expected, result)
})
}
}

0 comments on commit 3514225

Please sign in to comment.