diff --git a/dkg/sync/client.go b/dkg/sync/client.go index a3371812e..fab19532f 100644 --- a/dkg/sync/client.go +++ b/dkg/sync/client.go @@ -35,7 +35,7 @@ import ( type result struct { rtt time.Duration timestamp string - error string + error error } type Client struct { @@ -49,14 +49,16 @@ type Client struct { // AwaitConnected blocks until the connection with the server has been established or returns an error. func (c *Client) AwaitConnected() error { for res := range c.results { - if res.error == InvalidSig { + if errors.Is(res.error, errors.New(InvalidSig)) { return errors.New("invalid cluster definition") - } else if res.error == "" { + } else if res.error == nil { // We are connected break } } + log.Info(c.ctx, "Client connected to Server 🎉", z.Any("client", p2p.PeerName(c.tcpNode.ID()))) + return nil } @@ -79,12 +81,12 @@ func sendHashSignature(ctx context.Context, hashSig []byte, s network.Stream) re wb, err := proto.Marshal(msg) if err != nil { log.Error(ctx, "Marshal msg", err) - return result{error: err.Error()} + return result{error: err} } if _, err = s.Write(wb); err != nil { log.Error(ctx, "Write msg to stream", err) - return result{error: err.Error()} + return result{error: err} } buf := bufio.NewReader(s) @@ -93,7 +95,7 @@ func sendHashSignature(ctx context.Context, hashSig []byte, s network.Stream) re n, err := buf.Read(rb) if err != nil { log.Error(ctx, "Read server response from stream", err) - return result{error: err.Error()} + return result{error: err} } // The first `n` bytes that are read are the most important @@ -102,7 +104,11 @@ func sendHashSignature(ctx context.Context, hashSig []byte, s network.Stream) re resp := new(pb.MsgSyncResponse) if err = proto.Unmarshal(rb, resp); err != nil { log.Error(ctx, "Unmarshal server response", err) - return result{error: err.Error()} + return result{error: err} + } + + if resp.Error != "" { + return result{error: errors.New(resp.Error)} } log.Debug(ctx, "Server response", z.Any("response", resp.SyncTimestamp)) @@ -110,7 +116,6 @@ func sendHashSignature(ctx context.Context, hashSig []byte, s network.Stream) re return result{ rtt: time.Since(before), timestamp: resp.SyncTimestamp.String(), - error: resp.Error, } } @@ -121,7 +126,7 @@ func NewClient(ctx context.Context, tcpNode host.Host, server p2p.Peer, hashSig if err != nil { log.Error(ctx, "Open new stream with server", err) ch := make(chan result, 1) - ch <- result{error: err.Error()} + ch <- result{error: err} close(ch) return Client{ @@ -147,7 +152,7 @@ func NewClient(ctx context.Context, tcpNode host.Host, server p2p.Peer, hashSig return } - if res.error == "" { + if res.error == nil { tcpNode.Peerstore().RecordLatency(server.ID, res.rtt) } diff --git a/dkg/sync/server.go b/dkg/sync/server.go index c0cbd27a3..2354b2730 100644 --- a/dkg/sync/server.go +++ b/dkg/sync/server.go @@ -20,6 +20,8 @@ package sync import ( "bufio" "context" + "sync" + "time" "github.com/libp2p/go-libp2p-core/host" "github.com/libp2p/go-libp2p-core/network" @@ -39,15 +41,29 @@ const ( ) type Server struct { + mu sync.Mutex ctx context.Context onFailure func() tcpNode host.Host peers []p2p.Peer dedupResponse map[peer.ID]bool + receiveChan chan result } // AwaitAllConnected blocks until all peers have established a connection with this server or returns an error. -func (*Server) AwaitAllConnected() error { +func (s *Server) AwaitAllConnected() error { + var msgs []result + for len(msgs) < len(s.peers) { + select { + case <-s.ctx.Done(): + return s.ctx.Err() + case msg := <-s.receiveChan: + msgs = append(msgs, msg) + } + } + + log.Info(s.ctx, "All Clients Connected 🎉", z.Any("clients", len(msgs))) + return nil } @@ -65,6 +81,7 @@ func NewServer(ctx context.Context, tcpNode host.Host, peers []p2p.Peer, defHash peers: peers, onFailure: onFailure, dedupResponse: make(map[peer.ID]bool), + receiveChan: make(chan result, len(peers)), } knownPeers := make(map[peer.ID]bool) @@ -77,10 +94,11 @@ func NewServer(ctx context.Context, tcpNode host.Host, peers []p2p.Peer, defHash // TODO(dhruv): introduce timeout to break the loop for { + before := time.Now() pID := s.Conn().RemotePeer() if !knownPeers[pID] { // Ignoring unknown peer - log.Warn(ctx, "Ignoring unknown peer", nil, z.Any("peer", p2p.PeerName(pID))) + log.Warn(ctx, "Ignoring unknown client", nil, z.Any("client", p2p.PeerName(pID))) return } @@ -89,7 +107,7 @@ func NewServer(ctx context.Context, tcpNode host.Host, peers []p2p.Peer, defHash // n is the number of bytes read from buffer, if n < MsgSize the other bytes will be 0 n, err := buf.Read(b) if err != nil { - log.Error(ctx, "Read client msg from stream", err) + log.Error(ctx, "Read client msg from stream", err, z.Any("client", p2p.PeerName(pID))) return } @@ -102,7 +120,7 @@ func NewServer(ctx context.Context, tcpNode host.Host, peers []p2p.Peer, defHash return } - log.Debug(ctx, "Message received from client", z.Any("server", p2p.PeerName(pID))) + log.Debug(ctx, "Message received from client", z.Any("client", p2p.PeerName(pID))) pubkey, err := pID.ExtractPublicKey() if err != nil { @@ -133,18 +151,28 @@ func NewServer(ctx context.Context, tcpNode host.Host, peers []p2p.Peer, defHash _, err = s.Write(resBytes) if err != nil { - log.Error(ctx, "Send response to client", err) + log.Error(ctx, "Send response to client", err, z.Any("client", p2p.PeerName(pID))) return } if server.dedupResponse[pID] { - log.Debug(ctx, "Ignoring duplicate message", z.Any("peer", pID)) + log.Debug(ctx, "Ignoring duplicate message", z.Any("client", p2p.PeerName(pID))) continue } - server.dedupResponse[pID] = true + if resp.Error == "" && !server.dedupResponse[pID] { + // TODO(dhruv): This is temporary solution to avoid race condition of concurrent writes to map, figure out something permanent. + server.mu.Lock() + server.dedupResponse[pID] = true + server.mu.Unlock() + + server.receiveChan <- result{ + rtt: time.Since(before), + timestamp: msg.Timestamp.String(), + } + } - log.Debug(ctx, "Server sent response to client") + log.Debug(ctx, "Send response to client", z.Any("client", p2p.PeerName(pID))) } }) diff --git a/dkg/sync/sync_internal_test.go b/dkg/sync/sync_internal_test.go index 4438e56ec..605dd41a6 100644 --- a/dkg/sync/sync_internal_test.go +++ b/dkg/sync/sync_internal_test.go @@ -65,7 +65,7 @@ func TestNaiveServerClient(t *testing.T) { for i := 0; i < 5; i++ { actual, ok := <-client.results require.True(t, ok) - require.Equal(t, "", actual.error) + require.NoError(t, actual.error) t.Log("rtt is: ", actual.rtt) } } diff --git a/dkg/sync/sync_test.go b/dkg/sync/sync_test.go index f0de4b298..8f1e89e35 100644 --- a/dkg/sync/sync_test.go +++ b/dkg/sync/sync_test.go @@ -36,6 +36,8 @@ import ( "github.com/obolnetwork/charon/testutil" ) +//go:generate go test . -run=TestAwaitAllConnected -race + func TestAwaitConnected(t *testing.T) { ctx, cancel := context.WithCancel(context.Background()) defer cancel() @@ -60,12 +62,59 @@ func TestAwaitConnected(t *testing.T) { serverCtx := log.WithTopic(ctx, "server") _ = sync.NewServer(serverCtx, serverHost, []p2p.Peer{{ID: clientHost.ID()}}, hash, nil) - clientCtx := log.WithTopic(ctx, "client") + clientCtx := log.WithTopic(context.Background(), "client") client := sync.NewClient(clientCtx, clientHost, p2p.Peer{ID: serverHost.ID()}, hashSig, nil) require.NoError(t, client.AwaitConnected()) } +func TestAwaitAllConnected(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + const numClients = 3 + seed := 0 + serverHost, _ := newSyncHost(t, int64(seed)) + var ( + peers []p2p.Peer + keys []libp2pcrypto.PrivKey + clientHosts []host.Host + ) + for i := 0; i < numClients; i++ { + seed++ + clientHost, key := newSyncHost(t, int64(seed)) + require.NotEqual(t, clientHost.ID().String(), serverHost.ID().String()) + + err := serverHost.Connect(ctx, peer.AddrInfo{ + ID: clientHost.ID(), + Addrs: clientHost.Addrs(), + }) + require.NoError(t, err) + + clientHosts = append(clientHosts, clientHost) + keys = append(keys, key) + peers = append(peers, p2p.Peer{ID: clientHost.ID()}) + } + + hash := testutil.RandomBytes32() + server := sync.NewServer(log.WithTopic(ctx, "server"), serverHost, peers, hash, nil) + + var clients []sync.Client + for i := 0; i < numClients; i++ { + hashSig, err := keys[i].Sign(hash) + require.NoError(t, err) + + client := sync.NewClient(log.WithTopic(context.Background(), "client"), clientHosts[i], p2p.Peer{ID: serverHost.ID()}, hashSig, nil) + clients = append(clients, client) + } + + for i := 0; i < numClients; i++ { + require.NoError(t, clients[i].AwaitConnected()) + } + + require.NoError(t, server.AwaitAllConnected()) +} + func newSyncHost(t *testing.T, seed int64) (host.Host, libp2pcrypto.PrivKey) { t.Helper()