diff --git a/app/version/version.go b/app/version/version.go index ef8c0dea9..b8e4205f0 100644 --- a/app/version/version.go +++ b/app/version/version.go @@ -5,7 +5,9 @@ package version import ( "context" "runtime/debug" + "strings" + "github.com/obolnetwork/charon/app/errors" "github.com/obolnetwork/charon/app/log" "github.com/obolnetwork/charon/app/z" ) @@ -53,3 +55,13 @@ func LogInfo(ctx context.Context, msg string) { z.Str("git_commit_time", gitTimestamp), ) } + +// Minor returns the minor version of the provided version string. +func Minor(version string) (string, error) { + split := strings.Split(version, ".") + if len(split) < 2 { + return "", errors.New("invalid version string") + } + + return strings.Join(split[:2], "."), nil +} diff --git a/app/version/version_test.go b/app/version/version_test.go new file mode 100644 index 000000000..b910012e1 --- /dev/null +++ b/app/version/version_test.go @@ -0,0 +1,39 @@ +// Copyright © 2022-2023 Obol Labs Inc. Licensed under the terms of a Business Source License 1.1 + +package version_test + +import ( + "testing" + + "github.com/stretchr/testify/require" + + "github.com/obolnetwork/charon/app/version" +) + +func TestMinor(t *testing.T) { + minor, err := version.Minor("v0.1.2") + require.NoError(t, err) + require.Equal(t, "v0.1", minor) + + minor, err = version.Minor("1.2.3") + require.NoError(t, err) + require.Equal(t, "1.2", minor) + + minor, err = version.Minor("version 1000.2000.3000") + require.NoError(t, err) + require.Equal(t, "version 1000.2000", minor) + + minor, err = version.Minor("v0.1") + require.NoError(t, err) + require.Equal(t, "v0.1", minor) + + minor, err = version.Minor("v0.1.2.3") + require.NoError(t, err) + require.Equal(t, "v0.1", minor) + + _, err = version.Minor("0") + require.ErrorContains(t, err, "invalid version string") + + _, err = version.Minor("foo") + require.ErrorContains(t, err, "invalid version string") +} diff --git a/dkg/dkg.go b/dkg/dkg.go index 850f7f5ff..2adb776d1 100644 --- a/dkg/dkg.go +++ b/dkg/dkg.go @@ -270,17 +270,24 @@ func setupP2P(ctx context.Context, key *k1.PrivateKey, p2pConf p2p.Config, peers // startSyncProtocol sets up a sync protocol server and clients for each peer and returns a shutdown function // when all peers are connected. -func startSyncProtocol(ctx context.Context, tcpNode host.Host, key *k1.PrivateKey, defHash []byte, peerIDs []peer.ID, - onFailure func(), testCallback func(connected int, id peer.ID), +func startSyncProtocol(ctx context.Context, tcpNode host.Host, key *k1.PrivateKey, defHash []byte, + peerIDs []peer.ID, onFailure func(), testCallback func(connected int, id peer.ID), ) (func(context.Context) error, error) { // Sign definition hash with charon-enr-private-key // Note: libp2p signing does another hash of the defHash. + hashSig, err := ((*libp2pcrypto.Secp256k1PrivateKey)(key)).Sign(defHash) if err != nil { return nil, errors.Wrap(err, "sign definition hash") } - server := sync.NewServer(tcpNode, len(peerIDs)-1, defHash) + // DKG compatibility is minor version dependent. + minorVersion, err := version.Minor(version.Version) + if err != nil { + return nil, errors.Wrap(err, "get version") + } + + server := sync.NewServer(tcpNode, len(peerIDs)-1, defHash, minorVersion) server.Start(ctx) var clients []*sync.Client @@ -290,7 +297,7 @@ func startSyncProtocol(ctx context.Context, tcpNode host.Host, key *k1.PrivateKe } ctx := log.WithCtx(ctx, z.Str("peer", p2p.PeerName(pID))) - client := sync.NewClient(tcpNode, pID, hashSig) + client := sync.NewClient(tcpNode, pID, hashSig, minorVersion) clients = append(clients, client) go func() { diff --git a/dkg/dkgpb/v1/sync.pb.go b/dkg/dkgpb/v1/sync.pb.go index 2f9ea5b7f..733c5fd4f 100644 --- a/dkg/dkgpb/v1/sync.pb.go +++ b/dkg/dkgpb/v1/sync.pb.go @@ -29,6 +29,7 @@ type MsgSync struct { Timestamp *timestamppb.Timestamp `protobuf:"bytes,1,opt,name=timestamp,proto3" json:"timestamp,omitempty"` HashSignature []byte `protobuf:"bytes,2,opt,name=hash_signature,json=hashSignature,proto3" json:"hash_signature,omitempty"` Shutdown bool `protobuf:"varint,3,opt,name=shutdown,proto3" json:"shutdown,omitempty"` + Version string `protobuf:"bytes,4,opt,name=version,proto3" json:"version,omitempty"` } func (x *MsgSync) Reset() { @@ -84,6 +85,13 @@ func (x *MsgSync) GetShutdown() bool { return false } +func (x *MsgSync) GetVersion() string { + if x != nil { + return x.Version + } + return "" +} + type MsgSyncResponse struct { state protoimpl.MessageState sizeCache protoimpl.SizeCache @@ -146,7 +154,7 @@ var file_dkg_dkgpb_v1_sync_proto_rawDesc = []byte{ 0x79, 0x6e, 0x63, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x12, 0x0c, 0x64, 0x6b, 0x67, 0x2e, 0x64, 0x6b, 0x67, 0x70, 0x62, 0x2e, 0x76, 0x31, 0x1a, 0x1f, 0x67, 0x6f, 0x6f, 0x67, 0x6c, 0x65, 0x2f, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x62, 0x75, 0x66, 0x2f, 0x74, 0x69, 0x6d, 0x65, 0x73, 0x74, 0x61, - 0x6d, 0x70, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x22, 0x86, 0x01, 0x0a, 0x07, 0x4d, 0x73, 0x67, + 0x6d, 0x70, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x22, 0xa0, 0x01, 0x0a, 0x07, 0x4d, 0x73, 0x67, 0x53, 0x79, 0x6e, 0x63, 0x12, 0x38, 0x0a, 0x09, 0x74, 0x69, 0x6d, 0x65, 0x73, 0x74, 0x61, 0x6d, 0x70, 0x18, 0x01, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x1a, 0x2e, 0x67, 0x6f, 0x6f, 0x67, 0x6c, 0x65, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x62, 0x75, 0x66, 0x2e, 0x54, 0x69, 0x6d, 0x65, 0x73, 0x74, @@ -155,17 +163,18 @@ var file_dkg_dkgpb_v1_sync_proto_rawDesc = []byte{ 0x18, 0x02, 0x20, 0x01, 0x28, 0x0c, 0x52, 0x0d, 0x68, 0x61, 0x73, 0x68, 0x53, 0x69, 0x67, 0x6e, 0x61, 0x74, 0x75, 0x72, 0x65, 0x12, 0x1a, 0x0a, 0x08, 0x73, 0x68, 0x75, 0x74, 0x64, 0x6f, 0x77, 0x6e, 0x18, 0x03, 0x20, 0x01, 0x28, 0x08, 0x52, 0x08, 0x73, 0x68, 0x75, 0x74, 0x64, 0x6f, 0x77, - 0x6e, 0x22, 0x6a, 0x0a, 0x0f, 0x4d, 0x73, 0x67, 0x53, 0x79, 0x6e, 0x63, 0x52, 0x65, 0x73, 0x70, - 0x6f, 0x6e, 0x73, 0x65, 0x12, 0x41, 0x0a, 0x0e, 0x73, 0x79, 0x6e, 0x63, 0x5f, 0x74, 0x69, 0x6d, - 0x65, 0x73, 0x74, 0x61, 0x6d, 0x70, 0x18, 0x01, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x1a, 0x2e, 0x67, - 0x6f, 0x6f, 0x67, 0x6c, 0x65, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x62, 0x75, 0x66, 0x2e, 0x54, - 0x69, 0x6d, 0x65, 0x73, 0x74, 0x61, 0x6d, 0x70, 0x52, 0x0d, 0x73, 0x79, 0x6e, 0x63, 0x54, 0x69, - 0x6d, 0x65, 0x73, 0x74, 0x61, 0x6d, 0x70, 0x12, 0x14, 0x0a, 0x05, 0x65, 0x72, 0x72, 0x6f, 0x72, - 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x05, 0x65, 0x72, 0x72, 0x6f, 0x72, 0x42, 0x2c, 0x5a, - 0x2a, 0x67, 0x69, 0x74, 0x68, 0x75, 0x62, 0x2e, 0x63, 0x6f, 0x6d, 0x2f, 0x6f, 0x62, 0x6f, 0x6c, - 0x6e, 0x65, 0x74, 0x77, 0x6f, 0x72, 0x6b, 0x2f, 0x63, 0x68, 0x61, 0x72, 0x6f, 0x6e, 0x2f, 0x64, - 0x6b, 0x67, 0x2f, 0x64, 0x6b, 0x67, 0x70, 0x62, 0x2f, 0x76, 0x31, 0x62, 0x06, 0x70, 0x72, 0x6f, - 0x74, 0x6f, 0x33, + 0x6e, 0x12, 0x18, 0x0a, 0x07, 0x76, 0x65, 0x72, 0x73, 0x69, 0x6f, 0x6e, 0x18, 0x04, 0x20, 0x01, + 0x28, 0x09, 0x52, 0x07, 0x76, 0x65, 0x72, 0x73, 0x69, 0x6f, 0x6e, 0x22, 0x6a, 0x0a, 0x0f, 0x4d, + 0x73, 0x67, 0x53, 0x79, 0x6e, 0x63, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x12, 0x41, + 0x0a, 0x0e, 0x73, 0x79, 0x6e, 0x63, 0x5f, 0x74, 0x69, 0x6d, 0x65, 0x73, 0x74, 0x61, 0x6d, 0x70, + 0x18, 0x01, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x1a, 0x2e, 0x67, 0x6f, 0x6f, 0x67, 0x6c, 0x65, 0x2e, + 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x62, 0x75, 0x66, 0x2e, 0x54, 0x69, 0x6d, 0x65, 0x73, 0x74, 0x61, + 0x6d, 0x70, 0x52, 0x0d, 0x73, 0x79, 0x6e, 0x63, 0x54, 0x69, 0x6d, 0x65, 0x73, 0x74, 0x61, 0x6d, + 0x70, 0x12, 0x14, 0x0a, 0x05, 0x65, 0x72, 0x72, 0x6f, 0x72, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, + 0x52, 0x05, 0x65, 0x72, 0x72, 0x6f, 0x72, 0x42, 0x2c, 0x5a, 0x2a, 0x67, 0x69, 0x74, 0x68, 0x75, + 0x62, 0x2e, 0x63, 0x6f, 0x6d, 0x2f, 0x6f, 0x62, 0x6f, 0x6c, 0x6e, 0x65, 0x74, 0x77, 0x6f, 0x72, + 0x6b, 0x2f, 0x63, 0x68, 0x61, 0x72, 0x6f, 0x6e, 0x2f, 0x64, 0x6b, 0x67, 0x2f, 0x64, 0x6b, 0x67, + 0x70, 0x62, 0x2f, 0x76, 0x31, 0x62, 0x06, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x33, } var ( diff --git a/dkg/dkgpb/v1/sync.proto b/dkg/dkgpb/v1/sync.proto index 8f3155396..4d44e0309 100644 --- a/dkg/dkgpb/v1/sync.proto +++ b/dkg/dkgpb/v1/sync.proto @@ -10,6 +10,7 @@ message MsgSync { google.protobuf.Timestamp timestamp = 1; bytes hash_signature = 2; bool shutdown = 3; + string version = 4; } message MsgSyncResponse { diff --git a/dkg/sync/client.go b/dkg/sync/client.go index 990cb980d..fcc6d07d4 100644 --- a/dkg/sync/client.go +++ b/dkg/sync/client.go @@ -21,7 +21,7 @@ import ( ) // NewClient returns a new Client instance. -func NewClient(tcpNode host.Host, peer peer.ID, hashSig []byte) *Client { +func NewClient(tcpNode host.Host, peer peer.ID, hashSig []byte, version string) *Client { return &Client{ tcpNode: tcpNode, peer: peer, @@ -29,6 +29,7 @@ func NewClient(tcpNode host.Host, peer peer.ID, hashSig []byte) *Client { shutdown: make(chan struct{}), done: make(chan struct{}), reconnect: true, + version: version, } } @@ -45,6 +46,7 @@ type Client struct { // Immutable state hashSig []byte + version string tcpNode host.Host peer peer.ID } @@ -167,6 +169,7 @@ func (c *Client) sendMsg(stream network.Stream, shutdown bool) (*pb.MsgSyncRespo Timestamp: timestamppb.Now(), HashSignature: c.hashSig, Shutdown: shutdown, + Version: c.version, } if err := writeSizedProto(stream, msg); err != nil { diff --git a/dkg/sync/server.go b/dkg/sync/server.go index c2c012970..189d12805 100644 --- a/dkg/sync/server.go +++ b/dkg/sync/server.go @@ -12,6 +12,7 @@ import ( "sync" "time" + "github.com/libp2p/go-libp2p/core/crypto" "github.com/libp2p/go-libp2p/core/host" "github.com/libp2p/go-libp2p/core/network" "github.com/libp2p/go-libp2p/core/peer" @@ -25,30 +26,36 @@ import ( ) const ( - protocolID = "/charon/dkg/sync/1.0.0/" - errInvalidSig = "invalid signature" + protocolID = "/charon/dkg/sync/1.0.0/" + errInvalidSig = "invalid signature" + errInvalidVersion = "invalid version" ) // NewServer returns a new Server instance. -func NewServer(tcpNode host.Host, allCount int, defHash []byte) *Server { +func NewServer(tcpNode host.Host, allCount int, defHash []byte, version string) *Server { return &Server{ defHash: defHash, tcpNode: tcpNode, allCount: allCount, shutdown: make(map[peer.ID]struct{}), connected: make(map[peer.ID]struct{}), + version: version, } } // Server implements the server side of the sync protocol. It accepts connections from clients, verifies // definition hash signatures, and supports waiting for shutdown by all clients. type Server struct { + // Immutable state + tcpNode host.Host + defHash []byte + version string + allCount int // Excluding self + + // Mutable state mu sync.Mutex shutdown map[peer.ID]struct{} connected map[peer.ID]struct{} - defHash []byte - allCount int // Excluding self - tcpNode host.Host errResponse bool // To return error and exit anywhere in the server flow } @@ -73,6 +80,14 @@ func (s *Server) AwaitAllConnected(ctx context.Context) error { } } +// setError sets the shared error state for the server. +func (s *Server) setError() { + s.mu.Lock() + defer s.mu.Unlock() + + s.errResponse = true +} + // isError checks if there was any error in between the server flow. func (s *Server) isError() bool { s.mu.Lock() @@ -179,20 +194,13 @@ func (s *Server) handleStream(ctx context.Context, stream network.Stream) error SyncTimestamp: msg.Timestamp, } - // Verify definition hash - // Note: libp2p verify does another hash of defHash. - ok, err := pubkey.Verify(s.defHash, msg.HashSignature) + var ok bool + resp.Error, ok, err = s.validReq(ctx, pubkey, msg) if err != nil { - return errors.Wrap(err, "verify sig hash") + return err } else if !ok { - resp.Error = errInvalidSig - - s.mu.Lock() - s.errResponse = true - s.mu.Unlock() - - log.Error(ctx, "Received mismatching cluster definition hash from peer", nil) - } else if ok && !s.isConnected(pID) { + s.setError() + } else if !s.isConnected(pID) { count := s.setConnected(pID) log.Info(ctx, fmt.Sprintf("Connected to peer %d of %d", count, s.allCount)) } @@ -209,6 +217,29 @@ func (s *Server) handleStream(ctx context.Context, stream network.Stream) error } } +// validReq returns an error message and false if the request version or definition hash are invalid. +// Else it returns true or an error. +func (s *Server) validReq(ctx context.Context, pubkey crypto.PubKey, msg *pb.MsgSync) (string, bool, error) { + if msg.Version != s.version { + log.Error(ctx, "Received mismatching charon version from peer", nil, + z.Str("expect", s.version), + z.Str("got", msg.Version), + ) + + return errInvalidVersion, false, nil + } + + ok, err := pubkey.Verify(s.defHash, msg.HashSignature) + if err != nil { // Note: libp2p verify does another hash of defHash. + return "", false, errors.Wrap(err, "verify sig hash") + } else if !ok { + log.Error(ctx, "Received mismatching cluster definition hash from peer", nil) + return errInvalidSig, false, nil + } + + return "", true, nil +} + // Start registers sync protocol with the libp2p host. func (s *Server) Start(ctx context.Context) { s.tcpNode.SetStreamHandler(protocolID, func(stream network.Stream) { diff --git a/dkg/sync/sync_test.go b/dkg/sync/sync_test.go index 6adf6fee0..036a7e0c8 100644 --- a/dkg/sync/sync_test.go +++ b/dkg/sync/sync_test.go @@ -16,25 +16,38 @@ import ( "github.com/stretchr/testify/require" "github.com/obolnetwork/charon/app/log" + "github.com/obolnetwork/charon/app/version" "github.com/obolnetwork/charon/dkg/sync" "github.com/obolnetwork/charon/testutil" ) func TestSyncProtocol(t *testing.T) { + versions := make(map[int]string) + for i := 0; i < 5; i++ { + versions[i] = version.Version + } + t.Run("2", func(t *testing.T) { - testCluster(t, 2) + testCluster(t, 2, versions, false) }) t.Run("3", func(t *testing.T) { - testCluster(t, 3) + testCluster(t, 3, versions, false) }) t.Run("5", func(t *testing.T) { - testCluster(t, 5) + testCluster(t, 5, versions, false) }) } -func testCluster(t *testing.T, n int) { +func TestInvalidVersion(t *testing.T) { + testCluster(t, 2, map[int]string{ + 0: "1.0", + 1: "2.0", + }, true) +} + +func testCluster(t *testing.T, n int, versions map[int]string, expectErr bool) { t.Helper() ctx, cancel := context.WithCancel(context.Background()) @@ -53,7 +66,7 @@ func testCluster(t *testing.T, n int) { tcpNodes = append(tcpNodes, tcpNode) keys = append(keys, key) - server := sync.NewServer(tcpNode, n-1, hash) + server := sync.NewServer(tcpNode, n-1, hash, versions[i]) servers = append(servers, server) } @@ -71,12 +84,16 @@ func testCluster(t *testing.T, n int) { hashSig, err := keys[i].Sign(hash) require.NoError(t, err) - client := sync.NewClient(tcpNodes[i], tcpNodes[j].ID(), hashSig) + client := sync.NewClient(tcpNodes[i], tcpNodes[j].ID(), hashSig, versions[i]) clients = append(clients, client) ctx := log.WithTopic(ctx, fmt.Sprintf("client%d_%d", i, j)) go func() { err := client.Run(ctx) + if expectErr { + require.Error(t, err) + return + } require.NoError(t, err) }() } @@ -91,7 +108,15 @@ func testCluster(t *testing.T, n int) { t.Log("server.AwaitAllConnected") for _, server := range servers { err := server.AwaitAllConnected(ctx) - require.NoError(t, err) + if expectErr { + require.Error(t, err) + } else { + require.NoError(t, err) + } + } + + if expectErr { + return } t.Log("client.IsConnected")