Skip to content

Commit

Permalink
fix: race condition when adding new channel to NodeInfo (#735)
Browse files Browse the repository at this point in the history
* fix: race condition when adding new channel to NodeInfo

* chore: fix missing nodeInfo.Channels initialization

* fix(sync): concurrent slice marshal/unmarshal json

* fix: json marshal nodeInfo channels fails
  • Loading branch information
lklimek authored Feb 9, 2024
1 parent a8f1ae6 commit fce9e90
Show file tree
Hide file tree
Showing 13 changed files with 482 additions and 184 deletions.
116 changes: 116 additions & 0 deletions internal/libs/sync/concurrent_slice.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,116 @@
package sync

import (
"encoding/json"
"sync"
)

// ConcurrentSlice is a thread-safe slice.
//
// It is safe to use from multiple goroutines without additional locking.
// It should be referenced by pointer.
//
// Initialize using NewConcurrentSlice().
type ConcurrentSlice[T any] struct {
mtx sync.RWMutex
items []T
}

// NewConcurrentSlice creates a new thread-safe slice.
func NewConcurrentSlice[T any](initial ...T) *ConcurrentSlice[T] {
return &ConcurrentSlice[T]{
items: initial,
}
}

// Append adds an element to the slice
func (s *ConcurrentSlice[T]) Append(val ...T) {
s.mtx.Lock()
defer s.mtx.Unlock()

s.items = append(s.items, val...)
}

// Reset removes all elements from the slice
func (s *ConcurrentSlice[T]) Reset() {
s.mtx.Lock()
defer s.mtx.Unlock()

s.items = []T{}
}

// Get returns the value at the given index
func (s *ConcurrentSlice[T]) Get(index int) T {
s.mtx.RLock()
defer s.mtx.RUnlock()

return s.items[index]
}

// Set updates the value at the given index.
// If the index is greater than the length of the slice, it panics.
// If the index is equal to the length of the slice, the value is appended.
// Otherwise, the value at the index is updated.
func (s *ConcurrentSlice[T]) Set(index int, val T) {
s.mtx.Lock()
defer s.mtx.Unlock()

if index > len(s.items) {
panic("index out of range")
} else if index == len(s.items) {
s.items = append(s.items, val)
return
}

s.items[index] = val
}

// ToSlice returns a copy of the underlying slice
func (s *ConcurrentSlice[T]) ToSlice() []T {
s.mtx.RLock()
defer s.mtx.RUnlock()

slice := make([]T, len(s.items))
copy(slice, s.items)
return slice
}

// Len returns the length of the slice
func (s *ConcurrentSlice[T]) Len() int {
s.mtx.RLock()
defer s.mtx.RUnlock()

return len(s.items)
}

// Copy returns a new deep copy of concurrentSlice with the same elements
func (s *ConcurrentSlice[T]) Copy() ConcurrentSlice[T] {
s.mtx.RLock()
defer s.mtx.RUnlock()

return ConcurrentSlice[T]{
items: s.ToSlice(),
}
}

// MarshalJSON implements the json.Marshaler interface.
func (cs *ConcurrentSlice[T]) MarshalJSON() ([]byte, error) {
cs.mtx.RLock()
defer cs.mtx.RUnlock()

return json.Marshal(cs.items)
}

// UnmarshalJSON implements the json.Unmarshaler interface.
func (cs *ConcurrentSlice[T]) UnmarshalJSON(data []byte) error {
var items []T
if err := json.Unmarshal(data, &items); err != nil {
return err
}

cs.mtx.Lock()
defer cs.mtx.Unlock()

cs.items = items
return nil
}
96 changes: 96 additions & 0 deletions internal/libs/sync/concurrent_slice_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
package sync

import (
"encoding/json"
"sync"
"testing"

"github.com/stretchr/testify/assert"
)

func TestConcurrentSlice(t *testing.T) {
s := NewConcurrentSlice[int](1, 2, 3)

// Test Append
s.Append(4)
if s.Len() != 4 {
t.Errorf("Expected length of slice to be 4, got %d", s.Len())
}

// Test Get
if s.Get(3) != 4 {
t.Errorf("Expected element at index 3 to be 4, got %d", s.Get(3))
}

// Test Set
s.Set(1, 5)

// Test ToSlice
slice := s.ToSlice()
if len(slice) != 4 || slice[3] != 4 || slice[1] != 5 {
t.Errorf("Expected ToSlice to return [1 5 3 4], got %v", slice)
}

// Test Reset
s.Reset()
if s.Len() != 0 {
t.Errorf("Expected length of slice to be 0 after Reset, got %d", s.Len())
}

// Test Copy
s.Append(5)
copy := s.Copy()
if copy.Len() != 1 || copy.Get(0) != 5 {
t.Errorf("Expected Copy to return a new slice with [5], got %v", copy.ToSlice())
}
}

func TestConcurrentSlice_Concurrency(t *testing.T) {
s := NewConcurrentSlice[int]()

var wg sync.WaitGroup
for i := 0; i < 100; i++ {
wg.Add(1)
go func(val int) {
defer wg.Done()
s.Append(val)
}(i)
}

wg.Wait()

assert.Equal(t, 100, s.Len())

if s.Len() != 100 {
t.Errorf("Expected length of slice to be 100, got %d", s.Len())
}

for i := 0; i < 100; i++ {
assert.Contains(t, s.ToSlice(), i)
}
}

func TestConcurrentSlice_MarshalUnmarshalJSON(t *testing.T) {
type node struct {
Channels *ConcurrentSlice[uint16]
}
cs := NewConcurrentSlice[uint16](1, 2, 3)

node1 := node{
Channels: cs,
}

// Marshal to JSON
data, err := json.Marshal(node1)
assert.NoError(t, err, "Failed to marshal concurrentSlice")

// Unmarshal from JSON
node2 := node{
// Channels: NewConcurrentSlice[uint16](),
}

err = json.Unmarshal(data, &node2)
assert.NoError(t, err, "Failed to unmarshal concurrentSlice")

assert.EqualValues(t, node1.Channels.ToSlice(), node2.Channels.ToSlice())
}
5 changes: 3 additions & 2 deletions internal/p2p/p2p_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package p2p_test
import (
"github.com/dashpay/tenderdash/crypto"
"github.com/dashpay/tenderdash/crypto/ed25519"
tmsync "github.com/dashpay/tenderdash/internal/libs/sync"
"github.com/dashpay/tenderdash/internal/p2p"
"github.com/dashpay/tenderdash/types"
)
Expand All @@ -25,7 +26,7 @@ var (
ListenAddr: "0.0.0.0:0",
Network: "test",
Moniker: string(selfID),
Channels: []byte{0x01, 0x02},
Channels: tmsync.NewConcurrentSlice[uint16](0x01, 0x02),
}

peerKey crypto.PrivKey = ed25519.GenPrivKeyFromSecret([]byte{0x84, 0xd7, 0x01, 0xbf, 0x83, 0x20, 0x1c, 0xfe})
Expand All @@ -35,6 +36,6 @@ var (
ListenAddr: "0.0.0.0:0",
Network: "test",
Moniker: string(peerID),
Channels: []byte{0x01, 0x02},
Channels: tmsync.NewConcurrentSlice[uint16](0x01, 0x02),
}
)
2 changes: 2 additions & 0 deletions internal/p2p/p2ptest/network.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ import (
"github.com/dashpay/tenderdash/config"
"github.com/dashpay/tenderdash/crypto"
"github.com/dashpay/tenderdash/crypto/ed25519"
tmsync "github.com/dashpay/tenderdash/internal/libs/sync"
"github.com/dashpay/tenderdash/internal/p2p"
p2pclient "github.com/dashpay/tenderdash/internal/p2p/client"
"github.com/dashpay/tenderdash/libs/log"
Expand Down Expand Up @@ -272,6 +273,7 @@ func (n *Network) MakeNode(ctx context.Context, t *testing.T, proTxHash crypto.P
ListenAddr: "0.0.0.0:0", // FIXME: We have to fake this for now.
Moniker: string(nodeID),
ProTxHash: proTxHash.Copy(),
Channels: tmsync.NewConcurrentSlice[uint16](),
}

transport := n.memoryNetwork.CreateTransport(nodeID)
Expand Down
10 changes: 5 additions & 5 deletions internal/p2p/router.go
Original file line number Diff line number Diff line change
Expand Up @@ -501,7 +501,7 @@ func (r *Router) openConnection(ctx context.Context, conn Connection) {
return
}

r.routePeer(ctx, peerInfo.NodeID, conn, toChannelIDs(peerInfo.Channels))
r.routePeer(ctx, peerInfo.NodeID, conn, toChannelIDs(peerInfo.Channels.ToSlice()))
}

// dialPeers maintains outbound connections to peers by dialing them.
Expand Down Expand Up @@ -589,7 +589,7 @@ func (r *Router) connectPeer(ctx context.Context, address NodeAddress) {
}

// routePeer (also) calls connection close
go r.routePeer(ctx, address.NodeID, conn, toChannelIDs(peerInfo.Channels))
go r.routePeer(ctx, address.NodeID, conn, toChannelIDs(peerInfo.Channels.ToSlice()))
}

func (r *Router) getOrMakeQueue(peerID types.NodeID, channels ChannelIDSet) queue {
Expand Down Expand Up @@ -943,9 +943,9 @@ func (cs ChannelIDSet) Contains(id ChannelID) bool {
return ok
}

func toChannelIDs(bytes []byte) ChannelIDSet {
c := make(map[ChannelID]struct{}, len(bytes))
for _, b := range bytes {
func toChannelIDs(ids []uint16) ChannelIDSet {
c := make(map[ChannelID]struct{}, len(ids))
for _, b := range ids {
c[ChannelID(b)] = struct{}{}
}
return c
Expand Down
7 changes: 5 additions & 2 deletions internal/p2p/router_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ import (
dbm "github.com/tendermint/tm-db"

"github.com/dashpay/tenderdash/crypto"
tmsync "github.com/dashpay/tenderdash/internal/libs/sync"
"github.com/dashpay/tenderdash/internal/p2p"
"github.com/dashpay/tenderdash/internal/p2p/mocks"
"github.com/dashpay/tenderdash/internal/p2p/p2ptest"
Expand Down Expand Up @@ -303,6 +304,7 @@ func TestRouter_AcceptPeers(t *testing.T) {
ListenAddr: "0.0.0.0:0",
Network: "other-network",
Moniker: string(peerID),
Channels: tmsync.NewConcurrentSlice[uint16](),
},
peerKey.PubKey(),
false,
Expand Down Expand Up @@ -504,6 +506,7 @@ func TestRouter_DialPeers(t *testing.T) {
ListenAddr: "0.0.0.0:0",
Network: "other-network",
Moniker: string(peerID),
Channels: tmsync.NewConcurrentSlice[uint16](),
},
peerKey.PubKey(),
nil,
Expand Down Expand Up @@ -766,7 +769,7 @@ func TestRouter_ChannelCompatability(t *testing.T) {
ListenAddr: "0.0.0.0:0",
Network: "test",
Moniker: string(peerID),
Channels: []byte{0x03},
Channels: tmsync.NewConcurrentSlice[uint16](0x03),
}

mockConnection := &mocks.Connection{}
Expand Down Expand Up @@ -817,7 +820,7 @@ func TestRouter_DontSendOnInvalidChannel(t *testing.T) {
ListenAddr: "0.0.0.0:0",
Network: "test",
Moniker: string(peerID),
Channels: []byte{0x02},
Channels: tmsync.NewConcurrentSlice[uint16](0x02),
}

mockConnection := &mocks.Connection{}
Expand Down
13 changes: 8 additions & 5 deletions internal/p2p/transport_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,8 @@ import (
"github.com/stretchr/testify/require"

"github.com/dashpay/tenderdash/crypto/ed25519"
tmsync "github.com/dashpay/tenderdash/internal/libs/sync"
"github.com/dashpay/tenderdash/internal/p2p"
"github.com/dashpay/tenderdash/libs/bytes"
"github.com/dashpay/tenderdash/types"
)

Expand Down Expand Up @@ -283,15 +283,18 @@ func TestConnection_Handshake(t *testing.T) {
ListenAddr: "listenaddr",
Network: "network",
Version: "1.2.3",
Channels: bytes.HexBytes([]byte{0xf0, 0x0f}),
Channels: tmsync.NewConcurrentSlice[uint16](0xf0, 0x0f),
Moniker: "moniker",
Other: types.NodeInfoOther{
TxIndex: "txindex",
RPCAddress: "rpc.domain.com",
},
}
bKey := ed25519.GenPrivKey()
bInfo := types.NodeInfo{NodeID: types.NodeIDFromPubKey(bKey.PubKey())}
bInfo := types.NodeInfo{
NodeID: types.NodeIDFromPubKey(bKey.PubKey()),
Channels: tmsync.NewConcurrentSlice[uint16](),
}

errCh := make(chan error, 1)
go func() {
Expand Down Expand Up @@ -641,13 +644,13 @@ func dialAcceptHandshake(ctx context.Context, t *testing.T, a, b p2p.Transport)
errCh := make(chan error, 1)
go func() {
privKey := ed25519.GenPrivKey()
nodeInfo := types.NodeInfo{NodeID: types.NodeIDFromPubKey(privKey.PubKey())}
nodeInfo := types.NodeInfo{NodeID: types.NodeIDFromPubKey(privKey.PubKey()), Channels: tmsync.NewConcurrentSlice[uint16]()}
_, _, err := ba.Handshake(ctx, 0, nodeInfo, privKey)
errCh <- err
}()

privKey := ed25519.GenPrivKey()
nodeInfo := types.NodeInfo{NodeID: types.NodeIDFromPubKey(privKey.PubKey())}
nodeInfo := types.NodeInfo{NodeID: types.NodeIDFromPubKey(privKey.PubKey()), Channels: tmsync.NewConcurrentSlice[uint16]()}
_, _, err := ab.Handshake(ctx, 0, nodeInfo, privKey)
require.NoError(t, err)

Expand Down
Loading

0 comments on commit fce9e90

Please sign in to comment.