Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

dkg/sync: add AwaitAllConnected for server #743

Merged
merged 3 commits into from
Jun 24, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 15 additions & 10 deletions dkg/sync/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ import (
type result struct {
rtt time.Duration
timestamp string
error string
error error
}

type Client struct {
Expand All @@ -49,14 +49,16 @@ type Client struct {
// AwaitConnected blocks until the connection with the server has been established or returns an error.
func (c *Client) AwaitConnected() error {
for res := range c.results {
if res.error == InvalidSig {
if errors.Is(res.error, errors.New(InvalidSig)) {
return errors.New("invalid cluster definition")
} else if res.error == "" {
} else if res.error == nil {
// We are connected
break
}
}

log.Info(c.ctx, "Client connected to Server 🎉", z.Any("client", p2p.PeerName(c.tcpNode.ID())))

return nil
}

Expand All @@ -79,12 +81,12 @@ func sendHashSignature(ctx context.Context, hashSig []byte, s network.Stream) re
wb, err := proto.Marshal(msg)
if err != nil {
log.Error(ctx, "Marshal msg", err)
return result{error: err.Error()}
return result{error: err}
}

if _, err = s.Write(wb); err != nil {
log.Error(ctx, "Write msg to stream", err)
return result{error: err.Error()}
return result{error: err}
}

buf := bufio.NewReader(s)
Expand All @@ -93,7 +95,7 @@ func sendHashSignature(ctx context.Context, hashSig []byte, s network.Stream) re
n, err := buf.Read(rb)
if err != nil {
log.Error(ctx, "Read server response from stream", err)
return result{error: err.Error()}
return result{error: err}
}

// The first `n` bytes that are read are the most important
Expand All @@ -102,15 +104,18 @@ func sendHashSignature(ctx context.Context, hashSig []byte, s network.Stream) re
resp := new(pb.MsgSyncResponse)
if err = proto.Unmarshal(rb, resp); err != nil {
log.Error(ctx, "Unmarshal server response", err)
return result{error: err.Error()}
return result{error: err}
}

if resp.Error != "" {
return result{error: errors.New(resp.Error)}
}

log.Debug(ctx, "Server response", z.Any("response", resp.SyncTimestamp))

return result{
rtt: time.Since(before),
timestamp: resp.SyncTimestamp.String(),
error: resp.Error,
}
}

Expand All @@ -121,7 +126,7 @@ func NewClient(ctx context.Context, tcpNode host.Host, server p2p.Peer, hashSig
if err != nil {
log.Error(ctx, "Open new stream with server", err)
ch := make(chan result, 1)
ch <- result{error: err.Error()}
ch <- result{error: err}
close(ch)

return Client{
Expand All @@ -147,7 +152,7 @@ func NewClient(ctx context.Context, tcpNode host.Host, server p2p.Peer, hashSig
return
}

if res.error == "" {
if res.error == nil {
tcpNode.Peerstore().RecordLatency(server.ID, res.rtt)
}

Expand Down
44 changes: 36 additions & 8 deletions dkg/sync/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@ package sync
import (
"bufio"
"context"
"sync"
"time"

"github.com/libp2p/go-libp2p-core/host"
"github.com/libp2p/go-libp2p-core/network"
Expand All @@ -39,15 +41,29 @@ const (
)

type Server struct {
mu sync.Mutex
ctx context.Context
onFailure func()
tcpNode host.Host
peers []p2p.Peer
dedupResponse map[peer.ID]bool
receiveChan chan result
}

// AwaitAllConnected blocks until all peers have established a connection with this server or returns an error.
func (*Server) AwaitAllConnected() error {
func (s *Server) AwaitAllConnected() error {
var msgs []result
for len(msgs) < len(s.peers) {
Copy link
Contributor

@xenowits xenowits Jun 23, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

not a good practice imo to use value changing variables as loop vars. can do this instead:

for {
  if len(msgs) >= len(s.peers) { break }
    select {
       case <-s.ctx.Done():
	return s.ctx.Err()
       case msg := <-s.receiveChan:
	msgs = append(msgs, msg)
  }
}

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hmm looks fair then

select {
case <-s.ctx.Done():
return s.ctx.Err()
case msg := <-s.receiveChan:
msgs = append(msgs, msg)
}
}

log.Info(s.ctx, "All Clients Connected 🎉", z.Any("clients", len(msgs)))

return nil
}

Expand All @@ -65,6 +81,7 @@ func NewServer(ctx context.Context, tcpNode host.Host, peers []p2p.Peer, defHash
peers: peers,
onFailure: onFailure,
dedupResponse: make(map[peer.ID]bool),
receiveChan: make(chan result, len(peers)),
}

knownPeers := make(map[peer.ID]bool)
Expand All @@ -77,10 +94,11 @@ func NewServer(ctx context.Context, tcpNode host.Host, peers []p2p.Peer, defHash

// TODO(dhruv): introduce timeout to break the loop
for {
before := time.Now()
pID := s.Conn().RemotePeer()
if !knownPeers[pID] {
// Ignoring unknown peer
log.Warn(ctx, "Ignoring unknown peer", nil, z.Any("peer", p2p.PeerName(pID)))
log.Warn(ctx, "Ignoring unknown client", nil, z.Any("client", p2p.PeerName(pID)))
return
}

Expand All @@ -89,7 +107,7 @@ func NewServer(ctx context.Context, tcpNode host.Host, peers []p2p.Peer, defHash
// n is the number of bytes read from buffer, if n < MsgSize the other bytes will be 0
n, err := buf.Read(b)
if err != nil {
log.Error(ctx, "Read client msg from stream", err)
log.Error(ctx, "Read client msg from stream", err, z.Any("client", p2p.PeerName(pID)))
return
}

Expand All @@ -102,7 +120,7 @@ func NewServer(ctx context.Context, tcpNode host.Host, peers []p2p.Peer, defHash
return
}

log.Debug(ctx, "Message received from client", z.Any("server", p2p.PeerName(pID)))
log.Debug(ctx, "Message received from client", z.Any("client", p2p.PeerName(pID)))

pubkey, err := pID.ExtractPublicKey()
if err != nil {
Expand Down Expand Up @@ -133,18 +151,28 @@ func NewServer(ctx context.Context, tcpNode host.Host, peers []p2p.Peer, defHash

_, err = s.Write(resBytes)
if err != nil {
log.Error(ctx, "Send response to client", err)
log.Error(ctx, "Send response to client", err, z.Any("client", p2p.PeerName(pID)))
return
}

if server.dedupResponse[pID] {
log.Debug(ctx, "Ignoring duplicate message", z.Any("peer", pID))
log.Debug(ctx, "Ignoring duplicate message", z.Any("client", p2p.PeerName(pID)))
continue
}

server.dedupResponse[pID] = true
if resp.Error == "" && !server.dedupResponse[pID] {
// TODO(dhruv): This is temporary solution to avoid race condition of concurrent writes to map, figure out something permanent.
server.mu.Lock()
server.dedupResponse[pID] = true
server.mu.Unlock()

server.receiveChan <- result{
rtt: time.Since(before),
timestamp: msg.Timestamp.String(),
}
}

log.Debug(ctx, "Server sent response to client")
log.Debug(ctx, "Send response to client", z.Any("client", p2p.PeerName(pID)))
}
})

Expand Down
2 changes: 1 addition & 1 deletion dkg/sync/sync_internal_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ func TestNaiveServerClient(t *testing.T) {
for i := 0; i < 5; i++ {
actual, ok := <-client.results
require.True(t, ok)
require.Equal(t, "", actual.error)
require.NoError(t, actual.error)
t.Log("rtt is: ", actual.rtt)
}
}
Expand Down
51 changes: 50 additions & 1 deletion dkg/sync/sync_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,8 @@ import (
"github.com/obolnetwork/charon/testutil"
)

//go:generate go test . -run=TestAwaitAllConnected -race

func TestAwaitConnected(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
Expand All @@ -60,12 +62,59 @@ func TestAwaitConnected(t *testing.T) {
serverCtx := log.WithTopic(ctx, "server")
_ = sync.NewServer(serverCtx, serverHost, []p2p.Peer{{ID: clientHost.ID()}}, hash, nil)

clientCtx := log.WithTopic(ctx, "client")
clientCtx := log.WithTopic(context.Background(), "client")
client := sync.NewClient(clientCtx, clientHost, p2p.Peer{ID: serverHost.ID()}, hashSig, nil)

require.NoError(t, client.AwaitConnected())
}

func TestAwaitAllConnected(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()

const numClients = 3
seed := 0
serverHost, _ := newSyncHost(t, int64(seed))
var (
peers []p2p.Peer
keys []libp2pcrypto.PrivKey
clientHosts []host.Host
)
for i := 0; i < numClients; i++ {
seed++
clientHost, key := newSyncHost(t, int64(seed))
require.NotEqual(t, clientHost.ID().String(), serverHost.ID().String())

err := serverHost.Connect(ctx, peer.AddrInfo{
ID: clientHost.ID(),
Addrs: clientHost.Addrs(),
})
require.NoError(t, err)

clientHosts = append(clientHosts, clientHost)
keys = append(keys, key)
peers = append(peers, p2p.Peer{ID: clientHost.ID()})
}

hash := testutil.RandomBytes32()
server := sync.NewServer(log.WithTopic(ctx, "server"), serverHost, peers, hash, nil)

var clients []sync.Client
for i := 0; i < numClients; i++ {
hashSig, err := keys[i].Sign(hash)
require.NoError(t, err)

client := sync.NewClient(log.WithTopic(context.Background(), "client"), clientHosts[i], p2p.Peer{ID: serverHost.ID()}, hashSig, nil)
clients = append(clients, client)
}

for i := 0; i < numClients; i++ {
require.NoError(t, clients[i].AwaitConnected())
}

require.NoError(t, server.AwaitAllConnected())
}

func newSyncHost(t *testing.T, seed int64) (host.Host, libp2pcrypto.PrivKey) {
t.Helper()

Expand Down