-
Notifications
You must be signed in to change notification settings - Fork 14
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
fix: race condition when adding new channel to NodeInfo (#735)
* fix: race condition when adding new channel to NodeInfo * chore: fix missing nodeInfo.Channels initialization * fix(sync): concurrent slice marshal/unmarshal json * fix: json marshal nodeInfo channels fails
- Loading branch information
Showing
13 changed files
with
482 additions
and
184 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,116 @@ | ||
package sync | ||
|
||
import ( | ||
"encoding/json" | ||
"sync" | ||
) | ||
|
||
// ConcurrentSlice is a thread-safe slice. | ||
// | ||
// It is safe to use from multiple goroutines without additional locking. | ||
// It should be referenced by pointer. | ||
// | ||
// Initialize using NewConcurrentSlice(). | ||
type ConcurrentSlice[T any] struct { | ||
mtx sync.RWMutex | ||
items []T | ||
} | ||
|
||
// NewConcurrentSlice creates a new thread-safe slice. | ||
func NewConcurrentSlice[T any](initial ...T) *ConcurrentSlice[T] { | ||
return &ConcurrentSlice[T]{ | ||
items: initial, | ||
} | ||
} | ||
|
||
// Append adds an element to the slice | ||
func (s *ConcurrentSlice[T]) Append(val ...T) { | ||
s.mtx.Lock() | ||
defer s.mtx.Unlock() | ||
|
||
s.items = append(s.items, val...) | ||
} | ||
|
||
// Reset removes all elements from the slice | ||
func (s *ConcurrentSlice[T]) Reset() { | ||
s.mtx.Lock() | ||
defer s.mtx.Unlock() | ||
|
||
s.items = []T{} | ||
} | ||
|
||
// Get returns the value at the given index | ||
func (s *ConcurrentSlice[T]) Get(index int) T { | ||
s.mtx.RLock() | ||
defer s.mtx.RUnlock() | ||
|
||
return s.items[index] | ||
} | ||
|
||
// Set updates the value at the given index. | ||
// If the index is greater than the length of the slice, it panics. | ||
// If the index is equal to the length of the slice, the value is appended. | ||
// Otherwise, the value at the index is updated. | ||
func (s *ConcurrentSlice[T]) Set(index int, val T) { | ||
s.mtx.Lock() | ||
defer s.mtx.Unlock() | ||
|
||
if index > len(s.items) { | ||
panic("index out of range") | ||
} else if index == len(s.items) { | ||
s.items = append(s.items, val) | ||
return | ||
} | ||
|
||
s.items[index] = val | ||
} | ||
|
||
// ToSlice returns a copy of the underlying slice | ||
func (s *ConcurrentSlice[T]) ToSlice() []T { | ||
s.mtx.RLock() | ||
defer s.mtx.RUnlock() | ||
|
||
slice := make([]T, len(s.items)) | ||
copy(slice, s.items) | ||
return slice | ||
} | ||
|
||
// Len returns the length of the slice | ||
func (s *ConcurrentSlice[T]) Len() int { | ||
s.mtx.RLock() | ||
defer s.mtx.RUnlock() | ||
|
||
return len(s.items) | ||
} | ||
|
||
// Copy returns a new deep copy of concurrentSlice with the same elements | ||
func (s *ConcurrentSlice[T]) Copy() ConcurrentSlice[T] { | ||
s.mtx.RLock() | ||
defer s.mtx.RUnlock() | ||
|
||
return ConcurrentSlice[T]{ | ||
items: s.ToSlice(), | ||
} | ||
} | ||
|
||
// MarshalJSON implements the json.Marshaler interface. | ||
func (cs *ConcurrentSlice[T]) MarshalJSON() ([]byte, error) { | ||
cs.mtx.RLock() | ||
defer cs.mtx.RUnlock() | ||
|
||
return json.Marshal(cs.items) | ||
} | ||
|
||
// UnmarshalJSON implements the json.Unmarshaler interface. | ||
func (cs *ConcurrentSlice[T]) UnmarshalJSON(data []byte) error { | ||
var items []T | ||
if err := json.Unmarshal(data, &items); err != nil { | ||
return err | ||
} | ||
|
||
cs.mtx.Lock() | ||
defer cs.mtx.Unlock() | ||
|
||
cs.items = items | ||
return nil | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,96 @@ | ||
package sync | ||
|
||
import ( | ||
"encoding/json" | ||
"sync" | ||
"testing" | ||
|
||
"github.com/stretchr/testify/assert" | ||
) | ||
|
||
func TestConcurrentSlice(t *testing.T) { | ||
s := NewConcurrentSlice[int](1, 2, 3) | ||
|
||
// Test Append | ||
s.Append(4) | ||
if s.Len() != 4 { | ||
t.Errorf("Expected length of slice to be 4, got %d", s.Len()) | ||
} | ||
|
||
// Test Get | ||
if s.Get(3) != 4 { | ||
t.Errorf("Expected element at index 3 to be 4, got %d", s.Get(3)) | ||
} | ||
|
||
// Test Set | ||
s.Set(1, 5) | ||
|
||
// Test ToSlice | ||
slice := s.ToSlice() | ||
if len(slice) != 4 || slice[3] != 4 || slice[1] != 5 { | ||
t.Errorf("Expected ToSlice to return [1 5 3 4], got %v", slice) | ||
} | ||
|
||
// Test Reset | ||
s.Reset() | ||
if s.Len() != 0 { | ||
t.Errorf("Expected length of slice to be 0 after Reset, got %d", s.Len()) | ||
} | ||
|
||
// Test Copy | ||
s.Append(5) | ||
copy := s.Copy() | ||
if copy.Len() != 1 || copy.Get(0) != 5 { | ||
t.Errorf("Expected Copy to return a new slice with [5], got %v", copy.ToSlice()) | ||
} | ||
} | ||
|
||
func TestConcurrentSlice_Concurrency(t *testing.T) { | ||
s := NewConcurrentSlice[int]() | ||
|
||
var wg sync.WaitGroup | ||
for i := 0; i < 100; i++ { | ||
wg.Add(1) | ||
go func(val int) { | ||
defer wg.Done() | ||
s.Append(val) | ||
}(i) | ||
} | ||
|
||
wg.Wait() | ||
|
||
assert.Equal(t, 100, s.Len()) | ||
|
||
if s.Len() != 100 { | ||
t.Errorf("Expected length of slice to be 100, got %d", s.Len()) | ||
} | ||
|
||
for i := 0; i < 100; i++ { | ||
assert.Contains(t, s.ToSlice(), i) | ||
} | ||
} | ||
|
||
func TestConcurrentSlice_MarshalUnmarshalJSON(t *testing.T) { | ||
type node struct { | ||
Channels *ConcurrentSlice[uint16] | ||
} | ||
cs := NewConcurrentSlice[uint16](1, 2, 3) | ||
|
||
node1 := node{ | ||
Channels: cs, | ||
} | ||
|
||
// Marshal to JSON | ||
data, err := json.Marshal(node1) | ||
assert.NoError(t, err, "Failed to marshal concurrentSlice") | ||
|
||
// Unmarshal from JSON | ||
node2 := node{ | ||
// Channels: NewConcurrentSlice[uint16](), | ||
} | ||
|
||
err = json.Unmarshal(data, &node2) | ||
assert.NoError(t, err, "Failed to unmarshal concurrentSlice") | ||
|
||
assert.EqualValues(t, node1.Channels.ToSlice(), node2.Channels.ToSlice()) | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.