diff --git a/app/peerinfo/adhoc_test.go b/app/peerinfo/adhoc_test.go new file mode 100644 index 000000000..3dfb00fa8 --- /dev/null +++ b/app/peerinfo/adhoc_test.go @@ -0,0 +1,36 @@ +// Copyright © 2022-2023 Obol Labs Inc. Licensed under the terms of a Business Source License 1.1 + +package peerinfo_test + +import ( + "context" + "testing" + + "github.com/libp2p/go-libp2p/core/peer" + "github.com/libp2p/go-libp2p/core/peerstore" + "github.com/stretchr/testify/require" + + "github.com/obolnetwork/charon/app/peerinfo" + "github.com/obolnetwork/charon/p2p" + "github.com/obolnetwork/charon/testutil" +) + +func TestDoOnce(t *testing.T) { + server := testutil.CreateHost(t, testutil.AvailableAddr(t)) + client := testutil.CreateHost(t, testutil.AvailableAddr(t)) + + client.Peerstore().AddAddrs(server.ID(), server.Addrs(), peerstore.PermanentAddrTTL) + + version := "v0" + lockHash := []byte("123") + gitHash := "abc" + // Register the server handler that either + _ = peerinfo.New(server, []peer.ID{server.ID(), client.ID()}, version, lockHash, gitHash, p2p.SendReceive) + + info, _, ok, err := peerinfo.DoOnce(context.Background(), client, server.ID()) + require.NoError(t, err) + require.True(t, ok) + require.Equal(t, version, info.CharonVersion) + require.Equal(t, gitHash, info.GitHash) + require.Equal(t, lockHash, info.LockHash) +} diff --git a/p2p/receive_test.go b/p2p/receive_test.go index 1251f4cef..e30e405b1 100644 --- a/p2p/receive_test.go +++ b/p2p/receive_test.go @@ -110,7 +110,7 @@ func testSendReceive(t *testing.T, delimitedClient, delimitedServer bool) { t.Run("server error", func(t *testing.T) { _, err := sendReceive(-1) - require.ErrorContains(t, err, "no response") + require.ErrorContains(t, err, "no or zero response received") }) t.Run("ok", func(t *testing.T) { @@ -122,6 +122,6 @@ func testSendReceive(t *testing.T, delimitedClient, delimitedServer bool) { t.Run("empty response", func(t *testing.T) { _, err := sendReceive(101) - require.ErrorContains(t, err, "no response") + require.ErrorContains(t, err, "no or zero response received") }) } diff --git a/p2p/sender.go b/p2p/sender.go index c010b4811..df751574d 100644 --- a/p2p/sender.go +++ b/p2p/sender.go @@ -191,6 +191,10 @@ func defaultSendRecvOpts(pID protocol.ID) sendRecvOpts { func SendReceive(ctx context.Context, tcpNode host.Host, peerID peer.ID, req, resp proto.Message, pID protocol.ID, opts ...SendRecvOption, ) error { + if !isZeroProto(resp) { + return errors.New("bug: response proto must be zero value") + } + o := defaultSendRecvOpts(pID) for _, opt := range opts { opt(&o) @@ -223,10 +227,19 @@ func SendReceive(ctx context.Context, tcpNode host.Host, peerID peer.ID, return errors.Wrap(err, "close write", z.Any("protocol", s.Protocol())) } + zeroResp := proto.Clone(resp) + if err = reader.ReadMsg(resp); err != nil { return errors.Wrap(err, "read response", z.Any("protocol", s.Protocol())) } + // TODO(corver): Remove this once we use length-delimited protocols. + // This was added since legacy stream delimited readers couldn't distinguish between + // no response and a zero response. + if proto.Equal(resp, zeroResp) { + return errors.New("no or zero response received", z.Any("protocol", s.Protocol())) + } + if err = s.Close(); err != nil { return errors.Wrap(err, "close stream", z.Any("protocol", s.Protocol())) } @@ -288,13 +301,11 @@ func (w legacyReadWriter) WriteMsg(m proto.Message) error { func (w legacyReadWriter) ReadMsg(m proto.Message) error { b, err := io.ReadAll(w.stream) if err != nil { - return errors.Wrap(err, "read response") - } else if len(b) == 0 { - return errors.New("peer errored, no response") + return errors.Wrap(err, "read proto") } if err = proto.Unmarshal(b, m); err != nil { - return errors.Wrap(err, "unmarshal response") + return errors.Wrap(err, "unmarshal proto") } return nil @@ -325,3 +336,18 @@ func protocolPrefix(pIDs ...protocol.ID) protocol.ID { return prefix } + +// isZeroProto returns true if the provided proto message is zero. +// +// Note this function is inefficient for the negative case (i.e. when the message is not zero) +// as it copies the input argument. +func isZeroProto(m proto.Message) bool { + if m == nil { + return false + } + + clone := proto.Clone(m) + proto.Reset(clone) + + return proto.Equal(m, clone) +} diff --git a/p2p/sender_internal_test.go b/p2p/sender_internal_test.go index 620aec28c..bf5fa4a93 100644 --- a/p2p/sender_internal_test.go +++ b/p2p/sender_internal_test.go @@ -8,13 +8,18 @@ import ( "testing" "time" + fuzz "github.com/google/gofuzz" "github.com/libp2p/go-libp2p/core/host" "github.com/libp2p/go-libp2p/core/network" "github.com/libp2p/go-libp2p/core/peer" "github.com/libp2p/go-libp2p/core/protocol" "github.com/stretchr/testify/require" + "google.golang.org/protobuf/proto" + "google.golang.org/protobuf/types/known/anypb" + "google.golang.org/protobuf/types/known/timestamppb" "github.com/obolnetwork/charon/app/errors" + pbv1 "github.com/obolnetwork/charon/core/corepb/v1" ) func TestSenderAddResult(t *testing.T) { @@ -65,7 +70,7 @@ func TestSenderRetry(t *testing.T) { ctx := context.Background() h := new(testHost) - err := sender.SendReceive(ctx, h, "", nil, nil, "") + err := sender.SendReceive(ctx, h, "", nil, new(pbv1.Duty), "") require.ErrorIs(t, err, network.ErrReset) require.Equal(t, 2, h.Count()) @@ -103,3 +108,26 @@ func TestProtocolPrefix(b *testing.T) { require.EqualValues(b, "charon/peer_info/1.*", protocolPrefix("charon/peer_info/1.0.0", "charon/peer_info/1.1.0")) require.EqualValues(b, "charon/peer_info/*", protocolPrefix("charon/peer_info/1.0.0", "charon/peer_info/2.0.0", "charon/peer_info/3.0.0")) } + +func TestIsZeroProto(t *testing.T) { + for _, msg := range []proto.Message{ + new(pbv1.Duty), + new(pbv1.ConsensusMsg), + new(timestamppb.Timestamp), + } { + require.False(t, isZeroProto(nil)) + require.True(t, isZeroProto(msg)) + fuzz.New().NilChance(0).Fuzz(msg) + require.False(t, isZeroProto(msg)) + + anyMsg, err := anypb.New(msg) + require.NoError(t, err) + require.False(t, isZeroProto(anyMsg)) + } +} + +func TestNonZeroResponse(t *testing.T) { + ctx := context.Background() + err := SendReceive(ctx, nil, "", nil, &pbv1.Duty{Slot: 1}, "") + require.ErrorContains(t, err, "bug: response proto must be zero value") +}