-
Notifications
You must be signed in to change notification settings - Fork 36
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #417 from projectdiscovery/feat-sync-slice
Add Sync Slice
- Loading branch information
Showing
2 changed files
with
194 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,77 @@ | ||
package sliceutil | ||
|
||
import "sync" | ||
|
||
// SyncSlice provides a thread-safe slice for elements of any comparable type. | ||
type SyncSlice[K comparable] struct { | ||
Slice []K | ||
mu *sync.RWMutex | ||
} | ||
|
||
// NewSyncSlice initializes a new instance of SyncSlice. | ||
func NewSyncSlice[K comparable]() *SyncSlice[K] { | ||
return &SyncSlice[K]{mu: &sync.RWMutex{}} | ||
} | ||
|
||
// Append adds elements to the end of the slice in a thread-safe manner. | ||
func (ss *SyncSlice[K]) Append(items ...K) { | ||
ss.mu.Lock() | ||
defer ss.mu.Unlock() | ||
|
||
ss.Slice = append(ss.Slice, items...) | ||
} | ||
|
||
// Each iterates over all elements in the slice and applies the function f to each element. | ||
// Iteration is done in a read-locked context to prevent data race. | ||
func (ss *SyncSlice[K]) Each(f func(i int, k K) error) { | ||
ss.mu.RLock() | ||
defer ss.mu.RUnlock() | ||
|
||
for i, k := range ss.Slice { | ||
if err := f(i, k); err != nil { | ||
break | ||
} | ||
} | ||
} | ||
|
||
// Empty clears the slice by reinitializing it in a thread-safe manner. | ||
func (ss *SyncSlice[K]) Empty() { | ||
ss.mu.Lock() | ||
defer ss.mu.Unlock() | ||
|
||
ss.Slice = make([]K, 0) | ||
} | ||
|
||
// Len returns the number of elements in the slice in a thread-safe manner. | ||
func (ss *SyncSlice[K]) Len() int { | ||
ss.mu.RLock() | ||
defer ss.mu.RUnlock() | ||
|
||
return len(ss.Slice) | ||
} | ||
|
||
// Get retrieves an element by index from the slice safely. | ||
// Returns the element and true if index is within bounds, otherwise returns zero value and false. | ||
func (ss *SyncSlice[K]) Get(index int) (K, bool) { | ||
ss.mu.RLock() | ||
defer ss.mu.RUnlock() | ||
|
||
if index < 0 || index >= len(ss.Slice) { | ||
var zero K | ||
return zero, false | ||
} | ||
return ss.Slice[index], true | ||
} | ||
|
||
// Put updates the element at the specified index in the slice in a thread-safe manner. | ||
// Returns true if the index is within bounds, otherwise false. | ||
func (ss *SyncSlice[K]) Put(index int, value K) bool { | ||
ss.mu.Lock() | ||
defer ss.mu.Unlock() | ||
|
||
if index < 0 || index >= len(ss.Slice) { | ||
return false | ||
} | ||
ss.Slice[index] = value | ||
return true | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,117 @@ | ||
package sliceutil | ||
|
||
import ( | ||
"sync" | ||
"testing" | ||
"time" | ||
) | ||
|
||
func TestSimpleUsage(t *testing.T) { | ||
ss := NewSyncSlice[int]() | ||
expected := 10 | ||
for i := 0; i < expected; i++ { | ||
ss.Append(i) | ||
} | ||
value, ok := ss.Get(5) | ||
if !ok { | ||
t.Errorf("Failed to get value at index 5") | ||
} else if value != 5 { | ||
t.Errorf("Expected value 5 at index 5, got %d", value) | ||
} | ||
|
||
success := ss.Put(5, 20) | ||
if !success { | ||
t.Errorf("Failed to put value at index 5") | ||
} | ||
|
||
value, ok = ss.Get(5) | ||
if !ok { | ||
t.Errorf("Failed to get value at index 5 after put") | ||
} else if value != 20 { | ||
t.Errorf("Expected value 20 at index 5 after put, got %d", value) | ||
} | ||
if ss.Len() != expected { | ||
t.Errorf("Expected slice length %d, got %d", expected, ss.Len()) | ||
} | ||
ss.Empty() | ||
if ss.Len() != 0 { | ||
t.Errorf("Expected slice length 0 after emptying, got %d", ss.Len()) | ||
} | ||
} | ||
|
||
func TestConcurrentAppend(t *testing.T) { | ||
ss := NewSyncSlice[int]() | ||
var wg sync.WaitGroup | ||
count := 1000 | ||
|
||
for i := 0; i < count; i++ { | ||
wg.Add(1) | ||
go func(val int) { | ||
defer wg.Done() | ||
ss.Append(val) | ||
|
||
if val%10 == 0 { | ||
ss.Put(val, val*2) // Double the value at positions that are multiples of 10 | ||
} | ||
if val%5 == 0 { | ||
retrievedVal, _ := ss.Get(val) // Attempt to get the value at positions that are multiples of 5 | ||
_ = retrievedVal // Use the retrieved value to ensure it's not optimized away | ||
} | ||
}(i) | ||
} | ||
wg.Wait() | ||
|
||
if ss.Len() != count { | ||
t.Errorf("Expected slice length %d after concurrent append, got %d", count, ss.Len()) | ||
} | ||
} | ||
|
||
func TestConcurrentReadWriteAndIteration(t *testing.T) { | ||
ss := NewSyncSlice[int]() | ||
var wg sync.WaitGroup | ||
readWriteCount := 1000 | ||
|
||
wg.Add(3) // Adding three groups: writer, reader, iterator | ||
|
||
// Writer goroutine | ||
go func() { | ||
defer wg.Done() | ||
for i := 0; i < readWriteCount; i++ { | ||
ss.Append(i) // Write | ||
} | ||
}() | ||
|
||
// Reader goroutine | ||
go func() { | ||
defer wg.Done() | ||
|
||
time.Sleep(250 * time.Millisecond) | ||
|
||
for i := 0; i < readWriteCount; i++ { | ||
if value, ok := ss.Get(i % ss.Len()); !ok { | ||
t.Errorf("Failed to get value at index %d", i%ss.Len()) | ||
} else { | ||
_ = value // Use the value to ensure it's not optimized away | ||
} | ||
} | ||
}() | ||
|
||
// Iterator goroutine | ||
go func() { | ||
defer wg.Done() | ||
for repeat := 0; repeat < 1000; repeat++ { // Repeat the iteration 1000 times | ||
ss.Each(func(index int, value int) error { | ||
// Simulate some processing | ||
_ = index | ||
_ = value | ||
return nil | ||
}) | ||
} | ||
}() | ||
|
||
wg.Wait() | ||
|
||
if ss.Len() != readWriteCount { | ||
t.Errorf("Expected slice length %d after concurrent read/write, got %d", readWriteCount, ss.Len()) | ||
} | ||
} |