Skip to content

Commit

Permalink
dkg/sync: add AwaitAllConnected for server (#743)
Browse files Browse the repository at this point in the history
Implements AwaitAllConnected method for sync.Server.

category: feature
ticket: #684
  • Loading branch information
dB2510 authored Jun 24, 2022
1 parent 81d333e commit 40ba42d
Show file tree
Hide file tree
Showing 4 changed files with 102 additions and 20 deletions.
25 changes: 15 additions & 10 deletions dkg/sync/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ import (
type result struct {
rtt time.Duration
timestamp string
error string
error error
}

type Client struct {
Expand All @@ -49,14 +49,16 @@ type Client struct {
// AwaitConnected blocks until the connection with the server has been established or returns an error.
func (c *Client) AwaitConnected() error {
for res := range c.results {
if res.error == InvalidSig {
if errors.Is(res.error, errors.New(InvalidSig)) {
return errors.New("invalid cluster definition")
} else if res.error == "" {
} else if res.error == nil {
// We are connected
break
}
}

log.Info(c.ctx, "Client connected to Server 🎉", z.Any("client", p2p.PeerName(c.tcpNode.ID())))

return nil
}

Expand All @@ -79,12 +81,12 @@ func sendHashSignature(ctx context.Context, hashSig []byte, s network.Stream) re
wb, err := proto.Marshal(msg)
if err != nil {
log.Error(ctx, "Marshal msg", err)
return result{error: err.Error()}
return result{error: err}
}

if _, err = s.Write(wb); err != nil {
log.Error(ctx, "Write msg to stream", err)
return result{error: err.Error()}
return result{error: err}
}

buf := bufio.NewReader(s)
Expand All @@ -93,7 +95,7 @@ func sendHashSignature(ctx context.Context, hashSig []byte, s network.Stream) re
n, err := buf.Read(rb)
if err != nil {
log.Error(ctx, "Read server response from stream", err)
return result{error: err.Error()}
return result{error: err}
}

// The first `n` bytes that are read are the most important
Expand All @@ -102,15 +104,18 @@ func sendHashSignature(ctx context.Context, hashSig []byte, s network.Stream) re
resp := new(pb.MsgSyncResponse)
if err = proto.Unmarshal(rb, resp); err != nil {
log.Error(ctx, "Unmarshal server response", err)
return result{error: err.Error()}
return result{error: err}
}

if resp.Error != "" {
return result{error: errors.New(resp.Error)}
}

log.Debug(ctx, "Server response", z.Any("response", resp.SyncTimestamp))

return result{
rtt: time.Since(before),
timestamp: resp.SyncTimestamp.String(),
error: resp.Error,
}
}

Expand All @@ -121,7 +126,7 @@ func NewClient(ctx context.Context, tcpNode host.Host, server p2p.Peer, hashSig
if err != nil {
log.Error(ctx, "Open new stream with server", err)
ch := make(chan result, 1)
ch <- result{error: err.Error()}
ch <- result{error: err}
close(ch)

return Client{
Expand All @@ -147,7 +152,7 @@ func NewClient(ctx context.Context, tcpNode host.Host, server p2p.Peer, hashSig
return
}

if res.error == "" {
if res.error == nil {
tcpNode.Peerstore().RecordLatency(server.ID, res.rtt)
}

Expand Down
44 changes: 36 additions & 8 deletions dkg/sync/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@ package sync
import (
"bufio"
"context"
"sync"
"time"

"github.com/libp2p/go-libp2p-core/host"
"github.com/libp2p/go-libp2p-core/network"
Expand All @@ -39,15 +41,29 @@ const (
)

type Server struct {
mu sync.Mutex
ctx context.Context
onFailure func()
tcpNode host.Host
peers []p2p.Peer
dedupResponse map[peer.ID]bool
receiveChan chan result
}

// AwaitAllConnected blocks until all peers have established a connection with this server or returns an error.
func (*Server) AwaitAllConnected() error {
func (s *Server) AwaitAllConnected() error {
var msgs []result
for len(msgs) < len(s.peers) {
select {
case <-s.ctx.Done():
return s.ctx.Err()
case msg := <-s.receiveChan:
msgs = append(msgs, msg)
}
}

log.Info(s.ctx, "All Clients Connected 🎉", z.Any("clients", len(msgs)))

return nil
}

Expand All @@ -65,6 +81,7 @@ func NewServer(ctx context.Context, tcpNode host.Host, peers []p2p.Peer, defHash
peers: peers,
onFailure: onFailure,
dedupResponse: make(map[peer.ID]bool),
receiveChan: make(chan result, len(peers)),
}

knownPeers := make(map[peer.ID]bool)
Expand All @@ -77,10 +94,11 @@ func NewServer(ctx context.Context, tcpNode host.Host, peers []p2p.Peer, defHash

// TODO(dhruv): introduce timeout to break the loop
for {
before := time.Now()
pID := s.Conn().RemotePeer()
if !knownPeers[pID] {
// Ignoring unknown peer
log.Warn(ctx, "Ignoring unknown peer", nil, z.Any("peer", p2p.PeerName(pID)))
log.Warn(ctx, "Ignoring unknown client", nil, z.Any("client", p2p.PeerName(pID)))
return
}

Expand All @@ -89,7 +107,7 @@ func NewServer(ctx context.Context, tcpNode host.Host, peers []p2p.Peer, defHash
// n is the number of bytes read from buffer, if n < MsgSize the other bytes will be 0
n, err := buf.Read(b)
if err != nil {
log.Error(ctx, "Read client msg from stream", err)
log.Error(ctx, "Read client msg from stream", err, z.Any("client", p2p.PeerName(pID)))
return
}

Expand All @@ -102,7 +120,7 @@ func NewServer(ctx context.Context, tcpNode host.Host, peers []p2p.Peer, defHash
return
}

log.Debug(ctx, "Message received from client", z.Any("server", p2p.PeerName(pID)))
log.Debug(ctx, "Message received from client", z.Any("client", p2p.PeerName(pID)))

pubkey, err := pID.ExtractPublicKey()
if err != nil {
Expand Down Expand Up @@ -133,18 +151,28 @@ func NewServer(ctx context.Context, tcpNode host.Host, peers []p2p.Peer, defHash

_, err = s.Write(resBytes)
if err != nil {
log.Error(ctx, "Send response to client", err)
log.Error(ctx, "Send response to client", err, z.Any("client", p2p.PeerName(pID)))
return
}

if server.dedupResponse[pID] {
log.Debug(ctx, "Ignoring duplicate message", z.Any("peer", pID))
log.Debug(ctx, "Ignoring duplicate message", z.Any("client", p2p.PeerName(pID)))
continue
}

server.dedupResponse[pID] = true
if resp.Error == "" && !server.dedupResponse[pID] {
// TODO(dhruv): This is temporary solution to avoid race condition of concurrent writes to map, figure out something permanent.
server.mu.Lock()
server.dedupResponse[pID] = true
server.mu.Unlock()

server.receiveChan <- result{
rtt: time.Since(before),
timestamp: msg.Timestamp.String(),
}
}

log.Debug(ctx, "Server sent response to client")
log.Debug(ctx, "Send response to client", z.Any("client", p2p.PeerName(pID)))
}
})

Expand Down
2 changes: 1 addition & 1 deletion dkg/sync/sync_internal_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ func TestNaiveServerClient(t *testing.T) {
for i := 0; i < 5; i++ {
actual, ok := <-client.results
require.True(t, ok)
require.Equal(t, "", actual.error)
require.NoError(t, actual.error)
t.Log("rtt is: ", actual.rtt)
}
}
Expand Down
51 changes: 50 additions & 1 deletion dkg/sync/sync_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,8 @@ import (
"github.com/obolnetwork/charon/testutil"
)

//go:generate go test . -run=TestAwaitAllConnected -race

func TestAwaitConnected(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
Expand All @@ -60,12 +62,59 @@ func TestAwaitConnected(t *testing.T) {
serverCtx := log.WithTopic(ctx, "server")
_ = sync.NewServer(serverCtx, serverHost, []p2p.Peer{{ID: clientHost.ID()}}, hash, nil)

clientCtx := log.WithTopic(ctx, "client")
clientCtx := log.WithTopic(context.Background(), "client")
client := sync.NewClient(clientCtx, clientHost, p2p.Peer{ID: serverHost.ID()}, hashSig, nil)

require.NoError(t, client.AwaitConnected())
}

func TestAwaitAllConnected(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()

const numClients = 3
seed := 0
serverHost, _ := newSyncHost(t, int64(seed))
var (
peers []p2p.Peer
keys []libp2pcrypto.PrivKey
clientHosts []host.Host
)
for i := 0; i < numClients; i++ {
seed++
clientHost, key := newSyncHost(t, int64(seed))
require.NotEqual(t, clientHost.ID().String(), serverHost.ID().String())

err := serverHost.Connect(ctx, peer.AddrInfo{
ID: clientHost.ID(),
Addrs: clientHost.Addrs(),
})
require.NoError(t, err)

clientHosts = append(clientHosts, clientHost)
keys = append(keys, key)
peers = append(peers, p2p.Peer{ID: clientHost.ID()})
}

hash := testutil.RandomBytes32()
server := sync.NewServer(log.WithTopic(ctx, "server"), serverHost, peers, hash, nil)

var clients []sync.Client
for i := 0; i < numClients; i++ {
hashSig, err := keys[i].Sign(hash)
require.NoError(t, err)

client := sync.NewClient(log.WithTopic(context.Background(), "client"), clientHosts[i], p2p.Peer{ID: serverHost.ID()}, hashSig, nil)
clients = append(clients, client)
}

for i := 0; i < numClients; i++ {
require.NoError(t, clients[i].AwaitConnected())
}

require.NoError(t, server.AwaitAllConnected())
}

func newSyncHost(t *testing.T, seed int64) (host.Host, libp2pcrypto.PrivKey) {
t.Helper()

Expand Down

0 comments on commit 40ba42d

Please sign in to comment.