Skip to content

Commit

Permalink
Use pointer receivers for unsafeset (#143)
Browse files Browse the repository at this point in the history
* Revert "less pointers and more memory-efficient Clear()"

This reverts commit 3045cfb.

* Revert "less pointers"

This reverts commit 7aad8e9.

* Switch to pointer receivers, fixes UnmarshalJSON

* Expand test coverage for JSON marshaler interface
  • Loading branch information
ryclarke authored Dec 3, 2024
1 parent b710ba4 commit e00dc23
Show file tree
Hide file tree
Showing 4 changed files with 239 additions and 83 deletions.
24 changes: 12 additions & 12 deletions threadsafe.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ import "sync"

type threadSafeSet[T comparable] struct {
sync.RWMutex
uss threadUnsafeSet[T]
uss *threadUnsafeSet[T]
}

func newThreadSafeSet[T comparable]() *threadSafeSet[T] {
Expand Down Expand Up @@ -123,7 +123,7 @@ func (t *threadSafeSet[T]) Union(other Set[T]) Set[T] {
t.RLock()
o.RLock()

unsafeUnion := t.uss.Union(o.uss).(threadUnsafeSet[T])
unsafeUnion := t.uss.Union(o.uss).(*threadUnsafeSet[T])
ret := &threadSafeSet[T]{uss: unsafeUnion}
t.RUnlock()
o.RUnlock()
Expand All @@ -136,7 +136,7 @@ func (t *threadSafeSet[T]) Intersect(other Set[T]) Set[T] {
t.RLock()
o.RLock()

unsafeIntersection := t.uss.Intersect(o.uss).(threadUnsafeSet[T])
unsafeIntersection := t.uss.Intersect(o.uss).(*threadUnsafeSet[T])
ret := &threadSafeSet[T]{uss: unsafeIntersection}
t.RUnlock()
o.RUnlock()
Expand All @@ -149,7 +149,7 @@ func (t *threadSafeSet[T]) Difference(other Set[T]) Set[T] {
t.RLock()
o.RLock()

unsafeDifference := t.uss.Difference(o.uss).(threadUnsafeSet[T])
unsafeDifference := t.uss.Difference(o.uss).(*threadUnsafeSet[T])
ret := &threadSafeSet[T]{uss: unsafeDifference}
t.RUnlock()
o.RUnlock()
Expand All @@ -162,7 +162,7 @@ func (t *threadSafeSet[T]) SymmetricDifference(other Set[T]) Set[T] {
t.RLock()
o.RLock()

unsafeDifference := t.uss.SymmetricDifference(o.uss).(threadUnsafeSet[T])
unsafeDifference := t.uss.SymmetricDifference(o.uss).(*threadUnsafeSet[T])
ret := &threadSafeSet[T]{uss: unsafeDifference}
t.RUnlock()
o.RUnlock()
Expand All @@ -177,7 +177,7 @@ func (t *threadSafeSet[T]) Clear() {

func (t *threadSafeSet[T]) Remove(v T) {
t.Lock()
delete(t.uss, v)
delete(*t.uss, v)
t.Unlock()
}

Expand All @@ -190,12 +190,12 @@ func (t *threadSafeSet[T]) RemoveAll(i ...T) {
func (t *threadSafeSet[T]) Cardinality() int {
t.RLock()
defer t.RUnlock()
return len(t.uss)
return len(*t.uss)
}

func (t *threadSafeSet[T]) Each(cb func(T) bool) {
t.RLock()
for elem := range t.uss {
for elem := range *t.uss {
if cb(elem) {
break
}
Expand All @@ -208,7 +208,7 @@ func (t *threadSafeSet[T]) Iter() <-chan T {
go func() {
t.RLock()

for elem := range t.uss {
for elem := range *t.uss {
ch <- elem
}
close(ch)
Expand All @@ -224,7 +224,7 @@ func (t *threadSafeSet[T]) Iterator() *Iterator[T] {
go func() {
t.RLock()
L:
for elem := range t.uss {
for elem := range *t.uss {
select {
case <-stopCh:
break L
Expand Down Expand Up @@ -253,7 +253,7 @@ func (t *threadSafeSet[T]) Equal(other Set[T]) bool {
func (t *threadSafeSet[T]) Clone() Set[T] {
t.RLock()

unsafeClone := t.uss.Clone().(threadUnsafeSet[T])
unsafeClone := t.uss.Clone().(*threadUnsafeSet[T])
ret := &threadSafeSet[T]{uss: unsafeClone}
t.RUnlock()
return ret
Expand All @@ -275,7 +275,7 @@ func (t *threadSafeSet[T]) Pop() (T, bool) {
func (t *threadSafeSet[T]) ToSlice() []T {
keys := make([]T, 0, t.Cardinality())
t.RLock()
for elem := range t.uss {
for elem := range *t.uss {
keys = append(keys, elem)
}
t.RUnlock()
Expand Down
12 changes: 1 addition & 11 deletions threadsafe_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -584,17 +584,7 @@ func Test_UnmarshalJSON(t *testing.T) {
t.Errorf("Expected no difference, got: %v", expected.Difference(actual))
}
}
func TestThreadUnsafeSet_UnmarshalJSON(t *testing.T) {
expected := NewThreadUnsafeSet[int64](1, 2, 3)
actual := NewThreadUnsafeSet[int64]()
err := actual.UnmarshalJSON([]byte(`[1, 2, 3]`))
if err != nil {
t.Errorf("Error should be nil: %v", err)
}
if !expected.Equal(actual) {
t.Errorf("Expected no difference, got: %v", expected.Difference(actual))
}
}

func Test_MarshalJSON(t *testing.T) {
expected := NewSet(
[]string{
Expand Down
Loading

0 comments on commit e00dc23

Please sign in to comment.