diff --git a/app/peerinfo/adhoc.go b/app/peerinfo/adhoc.go index 1d9bd1a66..057c0d102 100644 --- a/app/peerinfo/adhoc.go +++ b/app/peerinfo/adhoc.go @@ -24,8 +24,8 @@ func DoOnce(ctx context.Context, tcpNode host.Host, peerID peer.ID) (*pbv1.PeerI req := new(pbv1.PeerInfo) // TODO(corver): Populate request fields and make them required. resp := new(pbv1.PeerInfo) - err := p2p.SendReceive(ctx, tcpNode, peerID, req, resp, protocolID2, - p2p.WithSendReceiveRTT(rttCallback)) + err := p2p.SendReceive(ctx, tcpNode, peerID, req, resp, protocolID1, + p2p.WithSendReceiveRTT(rttCallback), p2p.WithDelimitedProtocol(protocolID2)) if err != nil { return nil, 0, false, err } diff --git a/app/peerinfo/peerinfo.go b/app/peerinfo/peerinfo.go index 93e865de6..4e1fda820 100644 --- a/app/peerinfo/peerinfo.go +++ b/app/peerinfo/peerinfo.go @@ -25,13 +25,15 @@ import ( ) const ( - period = time.Minute + period = time.Minute + + protocolID1 protocol.ID = "/charon/peerinfo/1.0.0" protocolID2 protocol.ID = "/charon/peerinfo/2.0.0" ) // Protocols returns the supported protocols of this package in order of precedence. func Protocols() []protocol.ID { - return []protocol.ID{protocolID2} + return []protocol.ID{protocolID2, protocolID1} } type ( @@ -80,7 +82,7 @@ func newInternal(tcpNode host.Host, peers []peer.ID, version version.SemVer, loc startTime := timestamppb.New(nowFunc()) // Register a simple handler that returns our info and ignores the request. - registerHandler("peerinfo", tcpNode, protocolID2, + registerHandler("peerinfo", tcpNode, protocolID1, func() proto.Message { return new(pbv1.PeerInfo) }, func(context.Context, peer.ID, proto.Message) (proto.Message, bool, error) { return &pbv1.PeerInfo{ @@ -91,6 +93,7 @@ func newInternal(tcpNode host.Host, peers []peer.ID, version version.SemVer, loc StartedAt: startTime, }, true, nil }, + p2p.WithDelimitedProtocol(protocolID2), ) // Create log filters @@ -170,8 +173,8 @@ func (p *PeerInfo) sendOnce(ctx context.Context, now time.Time) { } resp := new(pbv1.PeerInfo) - err := p.sendFunc(ctx, p.tcpNode, peerID, req, resp, protocolID2, - p2p.WithSendReceiveRTT(rttCallback)) + err := p.sendFunc(ctx, p.tcpNode, peerID, req, resp, protocolID1, + p2p.WithSendReceiveRTT(rttCallback), p2p.WithDelimitedProtocol(protocolID2)) if err != nil { return // Logging handled by send func. } else if resp.SentAt == nil || resp.StartedAt == nil { diff --git a/core/consensus/component.go b/core/consensus/component.go index 6d2d841d5..c0d0814c0 100644 --- a/core/consensus/component.go +++ b/core/consensus/component.go @@ -28,12 +28,13 @@ import ( const ( recvBuffer = 100 // Allow buffering some initial messages when this node is late to start an instance. + protocolID1 = "/charon/consensus/qbft/1.0.0" protocolID2 = "/charon/consensus/qbft/2.0.0" ) // Protocols returns the supported protocols of this package in order of precedence. func Protocols() []protocol.ID { - return []protocol.ID{protocolID2} + return []protocol.ID{protocolID2, protocolID1} } type subscriber func(ctx context.Context, duty core.Duty, value proto.Message) error @@ -226,9 +227,9 @@ func (c *Component) SubscribePriority(fn func(ctx context.Context, duty core.Dut // Start registers the libp2p receive handler and starts a goroutine that cleans state. This should only be called once. func (c *Component) Start(ctx context.Context) { - p2p.RegisterHandler("qbft", c.tcpNode, protocolID2, + p2p.RegisterHandler("qbft", c.tcpNode, protocolID1, func() proto.Message { return new(pbv1.ConsensusMsg) }, - c.handle) + c.handle, p2p.WithDelimitedProtocol(protocolID2)) go func() { for { diff --git a/core/consensus/transport.go b/core/consensus/transport.go index efcaebdc8..bed829b78 100644 --- a/core/consensus/transport.go +++ b/core/consensus/transport.go @@ -17,6 +17,7 @@ import ( "github.com/obolnetwork/charon/core" pbv1 "github.com/obolnetwork/charon/core/corepb/v1" "github.com/obolnetwork/charon/core/qbft" + "github.com/obolnetwork/charon/p2p" ) // transport encapsulates receiving and broadcasting for a consensus instance/duty. @@ -128,7 +129,8 @@ func (t *transport) Broadcast(ctx context.Context, typ qbft.MsgType, duty core.D continue } - err = t.component.sender.SendAsync(ctx, t.component.tcpNode, protocolID2, p.ID, msg.ToConsensusMsg()) + err = t.component.sender.SendAsync(ctx, t.component.tcpNode, protocolID1, p.ID, msg.ToConsensusMsg(), + p2p.WithDelimitedProtocol(protocolID2)) if err != nil { return err } diff --git a/core/parsigex/parsigex.go b/core/parsigex/parsigex.go index cac32738a..c4be18ba6 100644 --- a/core/parsigex/parsigex.go +++ b/core/parsigex/parsigex.go @@ -21,12 +21,13 @@ import ( ) const ( + protocolID1 = "/charon/parsigex/1.0.0" protocolID2 = "/charon/parsigex/2.0.0" ) // Protocols returns the supported protocols of this package in order of precedence. func Protocols() []protocol.ID { - return []protocol.ID{protocolID2} + return []protocol.ID{protocolID2, protocolID1} } func NewParSigEx(tcpNode host.Host, sendFunc p2p.SendFunc, peerIdx int, peers []peer.ID, verifyFunc func(context.Context, core.Duty, core.PubKey, core.ParSignedData) error) *ParSigEx { @@ -39,7 +40,7 @@ func NewParSigEx(tcpNode host.Host, sendFunc p2p.SendFunc, peerIdx int, peers [] } newReq := func() proto.Message { return new(pbv1.ParSigExMsg) } - p2p.RegisterHandler("parsigex", tcpNode, protocolID2, newReq, parSigEx.handle) + p2p.RegisterHandler("parsigex", tcpNode, protocolID1, newReq, parSigEx.handle, p2p.WithDelimitedProtocol(protocolID2)) return parSigEx } @@ -114,7 +115,7 @@ func (m *ParSigEx) Broadcast(ctx context.Context, duty core.Duty, set core.ParSi continue } - if err := m.sendFunc(ctx, m.tcpNode, protocolID2, p, &msg); err != nil { + if err := m.sendFunc(ctx, m.tcpNode, protocolID1, p, &msg, p2p.WithDelimitedProtocol(protocolID2)); err != nil { return err } } diff --git a/core/priority/prioritiser.go b/core/priority/prioritiser.go index 934646b67..180b24295 100644 --- a/core/priority/prioritiser.go +++ b/core/priority/prioritiser.go @@ -37,12 +37,13 @@ import ( ) const ( + protocolID1 = "charon/priority/1.1.0" protocolID2 = "charon/priority/2.0.0" ) // Protocols returns the supported protocols of this package in order of precedence. func Protocols() []protocol.ID { - return []protocol.ID{protocolID2} + return []protocol.ID{protocolID2, protocolID1} } // Topic groups priorities in an instance. @@ -117,7 +118,7 @@ func newInternal(tcpNode host.Host, peers []peer.ID, minRequired int, }) // Register prioritiser protocol handler. - registerHandlerFunc("priority", tcpNode, protocolID2, + registerHandlerFunc("priority", tcpNode, protocolID1, func() proto.Message { return new(pbv1.PriorityMsg) }, func(ctx context.Context, pID peer.ID, msg proto.Message) (proto.Message, bool, error) { prioMsg, ok := msg.(*pbv1.PriorityMsg) @@ -132,7 +133,8 @@ func newInternal(tcpNode host.Host, peers []peer.ID, minRequired int, } return resp, true, nil - }) + }, + p2p.WithDelimitedProtocol(protocolID2)) return p } @@ -331,7 +333,7 @@ func exchange(ctx context.Context, tcpNode host.Host, peers []peer.ID, msgValida go func(pID peer.ID) { response := new(pbv1.PriorityMsg) - err := sendFunc(ctx, tcpNode, pID, own, response, protocolID2) + err := sendFunc(ctx, tcpNode, pID, own, response, protocolID1, p2p.WithDelimitedProtocol(protocolID2)) if err != nil { // No need to log, since transport will do it. return diff --git a/dkg/bcast/client.go b/dkg/bcast/client.go index 073337ac5..853648995 100644 --- a/dkg/bcast/client.go +++ b/dkg/bcast/client.go @@ -66,7 +66,7 @@ func (c *client) Broadcast(ctx context.Context, msgID string, msg proto.Message) fork, join, cancel := forkjoin.New(ctx, func(ctx context.Context, pID peer.ID) (*pb.BCastSigResponse, error) { sigResp := new(pb.BCastSigResponse) - err := c.sendRecvFunc(ctx, c.tcpNode, pID, sigReq, sigResp, protocolIDSig) + err := c.sendRecvFunc(ctx, c.tcpNode, pID, sigReq, sigResp, protocolIDSig, p2p.WithDelimitedProtocol(protocolIDSig)) return sigResp, err }) @@ -133,7 +133,7 @@ func (c *client) Broadcast(ctx context.Context, msgID string, msg proto.Message) continue // Skip self. } - err := c.sendFunc(ctx, c.tcpNode, protocolIDMsg, pID, bcastMsg) + err := c.sendFunc(ctx, c.tcpNode, protocolIDMsg, pID, bcastMsg, p2p.WithDelimitedProtocol(protocolIDMsg)) if err != nil { return errors.Wrap(err, "send message") } diff --git a/dkg/bcast/server.go b/dkg/bcast/server.go index 4f16ceaf3..a7492f204 100644 --- a/dkg/bcast/server.go +++ b/dkg/bcast/server.go @@ -29,11 +29,13 @@ func newServer(tcpNode host.Host, signFunc signFunc, verifyFunc verifyFunc) *ser p2p.RegisterHandler("bcast", tcpNode, protocolIDSig, func() proto.Message { return new(pb.BCastSigRequest) }, s.handleSigRequest, + p2p.WithDelimitedProtocol(protocolIDSig), ) p2p.RegisterHandler("bcast", tcpNode, protocolIDMsg, func() proto.Message { return new(pb.BCastMessage) }, s.handleMessage, + p2p.WithDelimitedProtocol(protocolIDMsg), ) return s diff --git a/dkg/frostp2p.go b/dkg/frostp2p.go index 60e74a416..4f56a2035 100644 --- a/dkg/frostp2p.go +++ b/dkg/frostp2p.go @@ -55,6 +55,7 @@ func newFrostP2P(tcpNode host.Host, peers map[peer.ID]cluster.NodeIdx, bcastComp p2p.RegisterHandler("frost", tcpNode, round1P2PID, func() proto.Message { return new(pb.FrostRound1P2P) }, newP2PCallback(tcpNode, peers, round1P2PRecv, numVals), + p2p.WithDelimitedProtocol(round1P2PID), ) bcastCallback := newBcastCallback(peers, round1CastsRecv, round2CastsRecv, threshold, numVals) @@ -237,7 +238,7 @@ func (f *frostP2P) Round1(ctx context.Context, castR1 map[msgKey]frost.Round1Bca return nil, nil, errors.New("bug: unexpected p2p message to self") } - err := p2p.Send(ctx, f.tcpNode, round1P2PID, pID, p2pMsg) + err := p2p.Send(ctx, f.tcpNode, round1P2PID, pID, p2pMsg, p2p.WithDelimitedProtocol(round1P2PID)) if err != nil { return nil, nil, err } diff --git a/p2p/receive_test.go b/p2p/receive_test.go index 9649483fb..a31da47d5 100644 --- a/p2p/receive_test.go +++ b/p2p/receive_test.go @@ -20,21 +20,68 @@ import ( ) func TestSendReceive(t *testing.T) { + tests := []struct { + name string + delimitedClient bool + delimitedServer bool + }{ + { + name: "non-delimited client and server", + delimitedClient: false, + delimitedServer: false, + }, + { + name: "delimited client and server", + delimitedClient: true, + delimitedServer: true, + }, + { + name: "delimited client and non-delimited server", + delimitedClient: true, + delimitedServer: false, + }, + { + name: "non-delimited client and delimited server", + delimitedClient: false, + delimitedServer: true, + }, + } + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + testSendReceive(t, test.delimitedClient, test.delimitedServer) + }) + } +} + +func testSendReceive(t *testing.T, delimitedClient, delimitedServer bool) { + t.Helper() + var ( - pID = protocol.ID("delimited") + pID1 = protocol.ID("undelimited") + pID2 = protocol.ID("delimited") errNegative = errors.New("negative slot") ctx = context.Background() server = testutil.CreateHost(t, testutil.AvailableAddr(t)) client = testutil.CreateHost(t, testutil.AvailableAddr(t)) ) + var serverOpt []p2p.SendRecvOption + if delimitedServer { + serverOpt = append(serverOpt, p2p.WithDelimitedProtocol(pID2)) + } + + var clientOpt []p2p.SendRecvOption + if delimitedClient { + clientOpt = append(clientOpt, p2p.WithDelimitedProtocol(pID2)) + } + client.Peerstore().AddAddrs(server.ID(), server.Addrs(), peerstore.PermanentAddrTTL) // Register the server handler that either: // - Errors is slot is negative // - Echos the duty request if slot is even // - Returns nothing is slot is odd - p2p.RegisterHandler("server", server, pID, + p2p.RegisterHandler("server", server, pID1, func() proto.Message { return new(pbv1.Duty) }, func(ctx context.Context, peerID peer.ID, req proto.Message) (proto.Message, bool, error) { log.Info(ctx, "See protocol logging field") @@ -51,18 +98,23 @@ func TestSendReceive(t *testing.T) { return nil, false, nil } }, + serverOpt..., ) sendReceive := func(slot int64) (*pbv1.Duty, error) { resp := new(pbv1.Duty) - err := p2p.SendReceive(ctx, client, server.ID(), &pbv1.Duty{Slot: slot}, resp, pID) + err := p2p.SendReceive(ctx, client, server.ID(), &pbv1.Duty{Slot: slot}, resp, pID1, clientOpt...) return resp, err } t.Run("server error", func(t *testing.T) { _, err := sendReceive(-1) - require.ErrorContains(t, err, "read response: EOF") + if delimitedClient && delimitedServer { + require.ErrorContains(t, err, "read response: EOF") + } else { + require.ErrorContains(t, err, "no or zero response received") + } }) t.Run("ok", func(t *testing.T) { @@ -74,6 +126,10 @@ func TestSendReceive(t *testing.T) { t.Run("empty response", func(t *testing.T) { _, err := sendReceive(101) - require.ErrorContains(t, err, "read response: EOF") + if delimitedClient && delimitedServer { + require.ErrorContains(t, err, "read response: EOF") + } else { + require.ErrorContains(t, err, "no or zero response received") + } }) } diff --git a/p2p/sender.go b/p2p/sender.go index 16363346e..845b86951 100644 --- a/p2p/sender.go +++ b/p2p/sender.go @@ -4,6 +4,7 @@ package p2p import ( "context" + "io" "sync" "time" @@ -155,15 +156,24 @@ func WithSendReceiveRTT(callback func(time.Duration)) func(*sendRecvOpts) { } } -// defaultSendRecvOpts returns the default sendRecvOpts, it uses the delimited read-writers and noop rtt callback. +// WithDelimitedProtocol returns an option that adds a length delimited read/writer for the provide protocol. +func WithDelimitedProtocol(pID protocol.ID) func(*sendRecvOpts) { + return func(opts *sendRecvOpts) { + opts.protocols = append([]protocol.ID{pID}, opts.protocols...) // Add to front + opts.writersByProtocol[pID] = func(s network.Stream) pbio.Writer { return pbio.NewDelimitedWriter(s) } + opts.readersByProtocol[pID] = func(s network.Stream) pbio.Reader { return pbio.NewDelimitedReader(s, maxMsgSize) } + } +} + +// defaultSendRecvOpts returns the default sendRecvOpts, it uses the legacy writers and noop rtt callback. func defaultSendRecvOpts(pID protocol.ID) sendRecvOpts { return sendRecvOpts{ protocols: []protocol.ID{pID}, writersByProtocol: map[protocol.ID]func(s network.Stream) pbio.Writer{ - pID: func(s network.Stream) pbio.Writer { return pbio.NewDelimitedWriter(s) }, + pID: func(s network.Stream) pbio.Writer { return legacyReadWriter{s} }, }, readersByProtocol: map[protocol.ID]func(s network.Stream) pbio.Reader{ - pID: func(s network.Stream) pbio.Reader { return pbio.NewDelimitedReader(s, maxMsgSize) }, + pID: func(s network.Stream) pbio.Reader { return legacyReadWriter{s} }, }, rttCallback: func(time.Duration) {}, } @@ -212,10 +222,19 @@ func SendReceive(ctx context.Context, tcpNode host.Host, peerID peer.ID, return errors.Wrap(err, "close write", z.Any("protocol", s.Protocol())) } + zeroResp := proto.Clone(resp) + if err = reader.ReadMsg(resp); err != nil { return errors.Wrap(err, "read response", z.Any("protocol", s.Protocol())) } + // TODO(corver): Remove this once we only use length-delimited protocols. + // This was added since legacy stream delimited readers couldn't distinguish between + // no response and a zero response. + if proto.Equal(resp, zeroResp) { + return errors.New("no or zero response received", z.Any("protocol", s.Protocol())) + } + if err = s.Close(); err != nil { return errors.Wrap(err, "close stream", z.Any("protocol", s.Protocol())) } @@ -255,6 +274,38 @@ func Send(ctx context.Context, tcpNode host.Host, protoID protocol.ID, peerID pe return nil } +// legacyReadWriter implements pbio.Reader and pbio.Writer without length delimited encoding. +type legacyReadWriter struct { + stream network.Stream +} + +// WriteMsg writes a protobuf message to the stream. +func (w legacyReadWriter) WriteMsg(m proto.Message) error { + b, err := proto.Marshal(m) + if err != nil { + return errors.Wrap(err, "marshal proto") + } + + _, err = w.stream.Write(b) + + return err +} + +// ReadMsg reads a single protobuf message from the whole stream. +// The stream must be closed after the message was sent. +func (w legacyReadWriter) ReadMsg(m proto.Message) error { + b, err := io.ReadAll(w.stream) + if err != nil { + return errors.Wrap(err, "read proto") + } + + if err = proto.Unmarshal(b, m); err != nil { + return errors.Wrap(err, "unmarshal proto") + } + + return nil +} + // protocolPrefix returns the common prefix of the provided protocol IDs. func protocolPrefix(pIDs ...protocol.ID) protocol.ID { if len(pIDs) == 0 {