Skip to content

Commit

Permalink
Merge pull request #417 from projectdiscovery/feat-sync-slice
Browse files Browse the repository at this point in the history
Add Sync Slice
  • Loading branch information
Mzack9999 authored May 20, 2024
2 parents 3908d4b + bbc5522 commit 53befed
Show file tree
Hide file tree
Showing 2 changed files with 194 additions and 0 deletions.
77 changes: 77 additions & 0 deletions slice/sync_slice.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
package sliceutil

import "sync"

// SyncSlice provides a thread-safe slice for elements of any comparable type.
type SyncSlice[K comparable] struct {
Slice []K
mu *sync.RWMutex
}

// NewSyncSlice initializes a new instance of SyncSlice.
func NewSyncSlice[K comparable]() *SyncSlice[K] {
return &SyncSlice[K]{mu: &sync.RWMutex{}}
}

// Append adds elements to the end of the slice in a thread-safe manner.
func (ss *SyncSlice[K]) Append(items ...K) {
ss.mu.Lock()
defer ss.mu.Unlock()

ss.Slice = append(ss.Slice, items...)
}

// Each iterates over all elements in the slice and applies the function f to each element.
// Iteration is done in a read-locked context to prevent data race.
func (ss *SyncSlice[K]) Each(f func(i int, k K) error) {
ss.mu.RLock()
defer ss.mu.RUnlock()

for i, k := range ss.Slice {
if err := f(i, k); err != nil {
break
}
}
}

// Empty clears the slice by reinitializing it in a thread-safe manner.
func (ss *SyncSlice[K]) Empty() {
ss.mu.Lock()
defer ss.mu.Unlock()

ss.Slice = make([]K, 0)
}

// Len returns the number of elements in the slice in a thread-safe manner.
func (ss *SyncSlice[K]) Len() int {
ss.mu.RLock()
defer ss.mu.RUnlock()

return len(ss.Slice)
}

// Get retrieves an element by index from the slice safely.
// Returns the element and true if index is within bounds, otherwise returns zero value and false.
func (ss *SyncSlice[K]) Get(index int) (K, bool) {
ss.mu.RLock()
defer ss.mu.RUnlock()

if index < 0 || index >= len(ss.Slice) {
var zero K
return zero, false
}
return ss.Slice[index], true
}

// Put updates the element at the specified index in the slice in a thread-safe manner.
// Returns true if the index is within bounds, otherwise false.
func (ss *SyncSlice[K]) Put(index int, value K) bool {
ss.mu.Lock()
defer ss.mu.Unlock()

if index < 0 || index >= len(ss.Slice) {
return false
}
ss.Slice[index] = value
return true
}
117 changes: 117 additions & 0 deletions slice/sync_slice_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,117 @@
package sliceutil

import (
"sync"
"testing"
"time"
)

func TestSimpleUsage(t *testing.T) {
ss := NewSyncSlice[int]()
expected := 10
for i := 0; i < expected; i++ {
ss.Append(i)
}
value, ok := ss.Get(5)
if !ok {
t.Errorf("Failed to get value at index 5")
} else if value != 5 {
t.Errorf("Expected value 5 at index 5, got %d", value)
}

success := ss.Put(5, 20)
if !success {
t.Errorf("Failed to put value at index 5")
}

value, ok = ss.Get(5)
if !ok {
t.Errorf("Failed to get value at index 5 after put")
} else if value != 20 {
t.Errorf("Expected value 20 at index 5 after put, got %d", value)
}
if ss.Len() != expected {
t.Errorf("Expected slice length %d, got %d", expected, ss.Len())
}
ss.Empty()
if ss.Len() != 0 {
t.Errorf("Expected slice length 0 after emptying, got %d", ss.Len())
}
}

func TestConcurrentAppend(t *testing.T) {
ss := NewSyncSlice[int]()
var wg sync.WaitGroup
count := 1000

for i := 0; i < count; i++ {
wg.Add(1)
go func(val int) {
defer wg.Done()
ss.Append(val)

if val%10 == 0 {
ss.Put(val, val*2) // Double the value at positions that are multiples of 10
}
if val%5 == 0 {
retrievedVal, _ := ss.Get(val) // Attempt to get the value at positions that are multiples of 5
_ = retrievedVal // Use the retrieved value to ensure it's not optimized away
}
}(i)
}
wg.Wait()

if ss.Len() != count {
t.Errorf("Expected slice length %d after concurrent append, got %d", count, ss.Len())
}
}

func TestConcurrentReadWriteAndIteration(t *testing.T) {
ss := NewSyncSlice[int]()
var wg sync.WaitGroup
readWriteCount := 1000

wg.Add(3) // Adding three groups: writer, reader, iterator

// Writer goroutine
go func() {
defer wg.Done()
for i := 0; i < readWriteCount; i++ {
ss.Append(i) // Write
}
}()

// Reader goroutine
go func() {
defer wg.Done()

time.Sleep(250 * time.Millisecond)

for i := 0; i < readWriteCount; i++ {
if value, ok := ss.Get(i % ss.Len()); !ok {
t.Errorf("Failed to get value at index %d", i%ss.Len())
} else {
_ = value // Use the value to ensure it's not optimized away
}
}
}()

// Iterator goroutine
go func() {
defer wg.Done()
for repeat := 0; repeat < 1000; repeat++ { // Repeat the iteration 1000 times
ss.Each(func(index int, value int) error {
// Simulate some processing
_ = index
_ = value
return nil
})
}
}()

wg.Wait()

if ss.Len() != readWriteCount {
t.Errorf("Expected slice length %d after concurrent read/write, got %d", readWriteCount, ss.Len())
}
}

0 comments on commit 53befed

Please sign in to comment.