From e00dc23653148bdbe32f47fb9aeea1f75b0d2553 Mon Sep 17 00:00:00 2001 From: Ryan Clarke Date: Tue, 3 Dec 2024 17:33:36 -0500 Subject: [PATCH] Use pointer receivers for unsafeset (#143) * Revert "less pointers and more memory-efficient Clear()" This reverts commit 3045cfb32f3754a7f9e6d9bab57dab766477b382. * Revert "less pointers" This reverts commit 7aad8e9df6335a8711a6561eb1705a481ad35ec1. * Switch to pointer receivers, fixes UnmarshalJSON * Expand test coverage for JSON marshaler interface --- threadsafe.go | 24 +++---- threadsafe_test.go | 12 +--- threadunsafe.go | 122 ++++++++++++++++---------------- threadunsafe_test.go | 164 +++++++++++++++++++++++++++++++++++++++++++ 4 files changed, 239 insertions(+), 83 deletions(-) create mode 100644 threadunsafe_test.go diff --git a/threadsafe.go b/threadsafe.go index ad7a834..93f20c8 100644 --- a/threadsafe.go +++ b/threadsafe.go @@ -29,7 +29,7 @@ import "sync" type threadSafeSet[T comparable] struct { sync.RWMutex - uss threadUnsafeSet[T] + uss *threadUnsafeSet[T] } func newThreadSafeSet[T comparable]() *threadSafeSet[T] { @@ -123,7 +123,7 @@ func (t *threadSafeSet[T]) Union(other Set[T]) Set[T] { t.RLock() o.RLock() - unsafeUnion := t.uss.Union(o.uss).(threadUnsafeSet[T]) + unsafeUnion := t.uss.Union(o.uss).(*threadUnsafeSet[T]) ret := &threadSafeSet[T]{uss: unsafeUnion} t.RUnlock() o.RUnlock() @@ -136,7 +136,7 @@ func (t *threadSafeSet[T]) Intersect(other Set[T]) Set[T] { t.RLock() o.RLock() - unsafeIntersection := t.uss.Intersect(o.uss).(threadUnsafeSet[T]) + unsafeIntersection := t.uss.Intersect(o.uss).(*threadUnsafeSet[T]) ret := &threadSafeSet[T]{uss: unsafeIntersection} t.RUnlock() o.RUnlock() @@ -149,7 +149,7 @@ func (t *threadSafeSet[T]) Difference(other Set[T]) Set[T] { t.RLock() o.RLock() - unsafeDifference := t.uss.Difference(o.uss).(threadUnsafeSet[T]) + unsafeDifference := t.uss.Difference(o.uss).(*threadUnsafeSet[T]) ret := &threadSafeSet[T]{uss: unsafeDifference} t.RUnlock() o.RUnlock() @@ -162,7 +162,7 @@ func (t *threadSafeSet[T]) SymmetricDifference(other Set[T]) Set[T] { t.RLock() o.RLock() - unsafeDifference := t.uss.SymmetricDifference(o.uss).(threadUnsafeSet[T]) + unsafeDifference := t.uss.SymmetricDifference(o.uss).(*threadUnsafeSet[T]) ret := &threadSafeSet[T]{uss: unsafeDifference} t.RUnlock() o.RUnlock() @@ -177,7 +177,7 @@ func (t *threadSafeSet[T]) Clear() { func (t *threadSafeSet[T]) Remove(v T) { t.Lock() - delete(t.uss, v) + delete(*t.uss, v) t.Unlock() } @@ -190,12 +190,12 @@ func (t *threadSafeSet[T]) RemoveAll(i ...T) { func (t *threadSafeSet[T]) Cardinality() int { t.RLock() defer t.RUnlock() - return len(t.uss) + return len(*t.uss) } func (t *threadSafeSet[T]) Each(cb func(T) bool) { t.RLock() - for elem := range t.uss { + for elem := range *t.uss { if cb(elem) { break } @@ -208,7 +208,7 @@ func (t *threadSafeSet[T]) Iter() <-chan T { go func() { t.RLock() - for elem := range t.uss { + for elem := range *t.uss { ch <- elem } close(ch) @@ -224,7 +224,7 @@ func (t *threadSafeSet[T]) Iterator() *Iterator[T] { go func() { t.RLock() L: - for elem := range t.uss { + for elem := range *t.uss { select { case <-stopCh: break L @@ -253,7 +253,7 @@ func (t *threadSafeSet[T]) Equal(other Set[T]) bool { func (t *threadSafeSet[T]) Clone() Set[T] { t.RLock() - unsafeClone := t.uss.Clone().(threadUnsafeSet[T]) + unsafeClone := t.uss.Clone().(*threadUnsafeSet[T]) ret := &threadSafeSet[T]{uss: unsafeClone} t.RUnlock() return ret @@ -275,7 +275,7 @@ func (t *threadSafeSet[T]) Pop() (T, bool) { func (t *threadSafeSet[T]) ToSlice() []T { keys := make([]T, 0, t.Cardinality()) t.RLock() - for elem := range t.uss { + for elem := range *t.uss { keys = append(keys, elem) } t.RUnlock() diff --git a/threadsafe_test.go b/threadsafe_test.go index 071cdb5..ca998c9 100644 --- a/threadsafe_test.go +++ b/threadsafe_test.go @@ -584,17 +584,7 @@ func Test_UnmarshalJSON(t *testing.T) { t.Errorf("Expected no difference, got: %v", expected.Difference(actual)) } } -func TestThreadUnsafeSet_UnmarshalJSON(t *testing.T) { - expected := NewThreadUnsafeSet[int64](1, 2, 3) - actual := NewThreadUnsafeSet[int64]() - err := actual.UnmarshalJSON([]byte(`[1, 2, 3]`)) - if err != nil { - t.Errorf("Error should be nil: %v", err) - } - if !expected.Equal(actual) { - t.Errorf("Expected no difference, got: %v", expected.Difference(actual)) - } -} + func Test_MarshalJSON(t *testing.T) { expected := NewSet( []string{ diff --git a/threadunsafe.go b/threadunsafe.go index 8b17b01..7e3243b 100644 --- a/threadunsafe.go +++ b/threadunsafe.go @@ -34,14 +34,16 @@ import ( type threadUnsafeSet[T comparable] map[T]struct{} // Assert concrete type:threadUnsafeSet adheres to Set interface. -var _ Set[string] = (threadUnsafeSet[string])(nil) +var _ Set[string] = (*threadUnsafeSet[string])(nil) -func newThreadUnsafeSet[T comparable]() threadUnsafeSet[T] { - return make(threadUnsafeSet[T]) +func newThreadUnsafeSet[T comparable]() *threadUnsafeSet[T] { + t := make(threadUnsafeSet[T]) + return &t } -func newThreadUnsafeSetWithSize[T comparable](cardinality int) threadUnsafeSet[T] { - return make(threadUnsafeSet[T], cardinality) +func newThreadUnsafeSetWithSize[T comparable](cardinality int) *threadUnsafeSet[T] { + t := make(threadUnsafeSet[T], cardinality) + return &t } func (s threadUnsafeSet[T]) Add(v T) bool { @@ -50,57 +52,57 @@ func (s threadUnsafeSet[T]) Add(v T) bool { return prevLen != len(s) } -func (s threadUnsafeSet[T]) Append(v ...T) int { - prevLen := len(s) +func (s *threadUnsafeSet[T]) Append(v ...T) int { + prevLen := len(*s) for _, val := range v { - (s)[val] = struct{}{} + (*s)[val] = struct{}{} } - return len(s) - prevLen + return len(*s) - prevLen } // private version of Add which doesn't return a value -func (s threadUnsafeSet[T]) add(v T) { - s[v] = struct{}{} +func (s *threadUnsafeSet[T]) add(v T) { + (*s)[v] = struct{}{} } -func (s threadUnsafeSet[T]) Cardinality() int { - return len(s) +func (s *threadUnsafeSet[T]) Cardinality() int { + return len(*s) } -func (s threadUnsafeSet[T]) Clear() { +func (s *threadUnsafeSet[T]) Clear() { // Constructions like this are optimised by compiler, and replaced by // mapclear() function, defined in // https://github.com/golang/go/blob/29bbca5c2c1ad41b2a9747890d183b6dd3a4ace4/src/runtime/map.go#L993) - for key := range s { - delete(s, key) + for key := range *s { + delete(*s, key) } } -func (s threadUnsafeSet[T]) Clone() Set[T] { +func (s *threadUnsafeSet[T]) Clone() Set[T] { clonedSet := newThreadUnsafeSetWithSize[T](s.Cardinality()) - for elem := range s { + for elem := range *s { clonedSet.add(elem) } return clonedSet } -func (s threadUnsafeSet[T]) Contains(v ...T) bool { +func (s *threadUnsafeSet[T]) Contains(v ...T) bool { for _, val := range v { - if _, ok := s[val]; !ok { + if _, ok := (*s)[val]; !ok { return false } } return true } -func (s threadUnsafeSet[T]) ContainsOne(v T) bool { - _, ok := s[v] +func (s *threadUnsafeSet[T]) ContainsOne(v T) bool { + _, ok := (*s)[v] return ok } -func (s threadUnsafeSet[T]) ContainsAny(v ...T) bool { +func (s *threadUnsafeSet[T]) ContainsAny(v ...T) bool { for _, val := range v { - if _, ok := s[val]; ok { + if _, ok := (*s)[val]; ok { return true } } @@ -108,16 +110,16 @@ func (s threadUnsafeSet[T]) ContainsAny(v ...T) bool { } // private version of Contains for a single element v -func (s threadUnsafeSet[T]) contains(v T) (ok bool) { - _, ok = s[v] +func (s *threadUnsafeSet[T]) contains(v T) (ok bool) { + _, ok = (*s)[v] return ok } -func (s threadUnsafeSet[T]) Difference(other Set[T]) Set[T] { - o := other.(threadUnsafeSet[T]) +func (s *threadUnsafeSet[T]) Difference(other Set[T]) Set[T] { + o := other.(*threadUnsafeSet[T]) diff := newThreadUnsafeSet[T]() - for elem := range s { + for elem := range *s { if !o.contains(elem) { diff.add(elem) } @@ -125,21 +127,21 @@ func (s threadUnsafeSet[T]) Difference(other Set[T]) Set[T] { return diff } -func (s threadUnsafeSet[T]) Each(cb func(T) bool) { - for elem := range s { +func (s *threadUnsafeSet[T]) Each(cb func(T) bool) { + for elem := range *s { if cb(elem) { break } } } -func (s threadUnsafeSet[T]) Equal(other Set[T]) bool { - o := other.(threadUnsafeSet[T]) +func (s *threadUnsafeSet[T]) Equal(other Set[T]) bool { + o := other.(*threadUnsafeSet[T]) if s.Cardinality() != other.Cardinality() { return false } - for elem := range s { + for elem := range *s { if !o.contains(elem) { return false } @@ -147,19 +149,19 @@ func (s threadUnsafeSet[T]) Equal(other Set[T]) bool { return true } -func (s threadUnsafeSet[T]) Intersect(other Set[T]) Set[T] { - o := other.(threadUnsafeSet[T]) +func (s *threadUnsafeSet[T]) Intersect(other Set[T]) Set[T] { + o := other.(*threadUnsafeSet[T]) intersection := newThreadUnsafeSet[T]() // loop over smaller set if s.Cardinality() < other.Cardinality() { - for elem := range s { + for elem := range *s { if o.contains(elem) { intersection.add(elem) } } } else { - for elem := range o { + for elem := range *o { if s.contains(elem) { intersection.add(elem) } @@ -168,24 +170,24 @@ func (s threadUnsafeSet[T]) Intersect(other Set[T]) Set[T] { return intersection } -func (s threadUnsafeSet[T]) IsEmpty() bool { +func (s *threadUnsafeSet[T]) IsEmpty() bool { return s.Cardinality() == 0 } -func (s threadUnsafeSet[T]) IsProperSubset(other Set[T]) bool { +func (s *threadUnsafeSet[T]) IsProperSubset(other Set[T]) bool { return s.Cardinality() < other.Cardinality() && s.IsSubset(other) } -func (s threadUnsafeSet[T]) IsProperSuperset(other Set[T]) bool { +func (s *threadUnsafeSet[T]) IsProperSuperset(other Set[T]) bool { return s.Cardinality() > other.Cardinality() && s.IsSuperset(other) } -func (s threadUnsafeSet[T]) IsSubset(other Set[T]) bool { - o := other.(threadUnsafeSet[T]) +func (s *threadUnsafeSet[T]) IsSubset(other Set[T]) bool { + o := other.(*threadUnsafeSet[T]) if s.Cardinality() > other.Cardinality() { return false } - for elem := range s { + for elem := range *s { if !o.contains(elem) { return false } @@ -193,14 +195,14 @@ func (s threadUnsafeSet[T]) IsSubset(other Set[T]) bool { return true } -func (s threadUnsafeSet[T]) IsSuperset(other Set[T]) bool { +func (s *threadUnsafeSet[T]) IsSuperset(other Set[T]) bool { return other.IsSubset(s) } -func (s threadUnsafeSet[T]) Iter() <-chan T { +func (s *threadUnsafeSet[T]) Iter() <-chan T { ch := make(chan T) go func() { - for elem := range s { + for elem := range *s { ch <- elem } close(ch) @@ -209,12 +211,12 @@ func (s threadUnsafeSet[T]) Iter() <-chan T { return ch } -func (s threadUnsafeSet[T]) Iterator() *Iterator[T] { +func (s *threadUnsafeSet[T]) Iterator() *Iterator[T] { iterator, ch, stopCh := newIterator[T]() go func() { L: - for elem := range s { + for elem := range *s { select { case <-stopCh: break L @@ -229,9 +231,9 @@ func (s threadUnsafeSet[T]) Iterator() *Iterator[T] { // Pop returns a popped item in case set is not empty, or nil-value of T // if set is already empty -func (s threadUnsafeSet[T]) Pop() (v T, ok bool) { - for item := range s { - delete(s, item) +func (s *threadUnsafeSet[T]) Pop() (v T, ok bool) { + for item := range *s { + delete(*s, item) return item, true } return v, false @@ -256,16 +258,16 @@ func (s threadUnsafeSet[T]) String() string { return fmt.Sprintf("Set{%s}", strings.Join(items, ", ")) } -func (s threadUnsafeSet[T]) SymmetricDifference(other Set[T]) Set[T] { - o := other.(threadUnsafeSet[T]) +func (s *threadUnsafeSet[T]) SymmetricDifference(other Set[T]) Set[T] { + o := other.(*threadUnsafeSet[T]) sd := newThreadUnsafeSet[T]() - for elem := range s { + for elem := range *s { if !o.contains(elem) { sd.add(elem) } } - for elem := range o { + for elem := range *o { if !s.contains(elem) { sd.add(elem) } @@ -283,7 +285,7 @@ func (s threadUnsafeSet[T]) ToSlice() []T { } func (s threadUnsafeSet[T]) Union(other Set[T]) Set[T] { - o := other.(threadUnsafeSet[T]) + o := other.(*threadUnsafeSet[T]) n := s.Cardinality() if o.Cardinality() > n { @@ -294,10 +296,10 @@ func (s threadUnsafeSet[T]) Union(other Set[T]) Set[T] { for elem := range s { unionedSet.add(elem) } - for elem := range o { + for elem := range *o { unionedSet.add(elem) } - return unionedSet + return &unionedSet } // MarshalJSON creates a JSON array from the set, it marshals all elements @@ -318,7 +320,7 @@ func (s threadUnsafeSet[T]) MarshalJSON() ([]byte, error) { // UnmarshalJSON recreates a set from a JSON array, it only decodes // primitive types. Numbers are decoded as json.Number. -func (s threadUnsafeSet[T]) UnmarshalJSON(b []byte) error { +func (s *threadUnsafeSet[T]) UnmarshalJSON(b []byte) error { var i []T err := json.Unmarshal(b, &i) if err != nil { diff --git a/threadunsafe_test.go b/threadunsafe_test.go new file mode 100644 index 0000000..c670305 --- /dev/null +++ b/threadunsafe_test.go @@ -0,0 +1,164 @@ +/* +Open Source Initiative OSI - The MIT License (MIT):Licensing + +The MIT License (MIT) +Copyright (c) 2013 - 2022 Ralph Caraveo (deckarep@gmail.com) + +Permission is hereby granted, free of charge, to any person obtaining a copy of +this software and associated documentation files (the "Software"), to deal in +the Software without restriction, including without limitation the rights to +use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies +of the Software, and to permit persons to whom the Software is furnished to do +so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. +*/ + +package mapset + +import ( + "encoding/json" + "testing" +) + +func TestThreadUnsafeSet_MarshalJSON(t *testing.T) { + expected := NewThreadUnsafeSet[int64](1, 2, 3) + actual := newThreadUnsafeSet[int64]() + + // test Marshal from Set method + b, err := expected.MarshalJSON() + if err != nil { + t.Errorf("Error should be nil: %v", err) + } + + err = json.Unmarshal(b, actual) + if err != nil { + t.Errorf("Error should be nil: %v", err) + } + + if !expected.Equal(actual) { + t.Errorf("Expected no difference, got: %v", expected.Difference(actual)) + } + + // test Marshal from json package + b, err = json.Marshal(expected) + if err != nil { + t.Errorf("Error should be nil: %v", err) + } + + err = json.Unmarshal(b, actual) + if err != nil { + t.Errorf("Error should be nil: %v", err) + } + + if !expected.Equal(actual) { + t.Errorf("Expected no difference, got: %v", expected.Difference(actual)) + } +} + +func TestThreadUnsafeSet_UnmarshalJSON(t *testing.T) { + expected := NewThreadUnsafeSet[int64](1, 2, 3) + actual := NewThreadUnsafeSet[int64]() + + // test Unmarshal from Set method + err := actual.UnmarshalJSON([]byte(`[1, 2, 3]`)) + if err != nil { + t.Errorf("Error should be nil: %v", err) + } + if !expected.Equal(actual) { + t.Errorf("Expected no difference, got: %v", expected.Difference(actual)) + } + + // test Unmarshal from json package + actual = NewThreadUnsafeSet[int64]() + err = json.Unmarshal([]byte(`[1, 2, 3]`), actual) + if err != nil { + t.Errorf("Error should be nil: %v", err) + } + if !expected.Equal(actual) { + t.Errorf("Expected no difference, got: %v", expected.Difference(actual)) + } +} + +func TestThreadUnsafeSet_MarshalJSON_Struct(t *testing.T) { + expected := &testStruct{"test", NewThreadUnsafeSet("a")} + + b, err := json.Marshal(&testStruct{"test", NewThreadUnsafeSet("a")}) + if err != nil { + t.Errorf("Error should be nil: %v", err) + } + + actual := &testStruct{} + err = json.Unmarshal(b, actual) + if err != nil { + t.Errorf("Error should be nil: %v", err) + } + + if !expected.Set.Equal(actual.Set) { + t.Errorf("Expected no difference, got: %v", expected.Set.Difference(actual.Set)) + } +} +func TestThreadUnsafeSet_UnmarshalJSON_Struct(t *testing.T) { + expected := &testStruct{"test", NewThreadUnsafeSet("a", "b", "c")} + actual := &testStruct{} + + err := json.Unmarshal([]byte(`{"other":"test", "set":["a", "b", "c"]}`), actual) + if err != nil { + t.Errorf("Error should be nil: %v", err) + } + if !expected.Set.Equal(actual.Set) { + t.Errorf("Expected no difference, got: %v", expected.Set.Difference(actual.Set)) + } + + expectedComplex := NewThreadUnsafeSet(struct{ Val string }{Val: "a"}, struct{ Val string }{Val: "b"}) + actualComplex := NewThreadUnsafeSet[struct{ Val string }]() + + err = actualComplex.UnmarshalJSON([]byte(`[{"Val": "a"}, {"Val": "b"}]`)) + if err != nil { + t.Errorf("Error should be nil: %v", err) + } + if !expectedComplex.Equal(actualComplex) { + t.Errorf("Expected no difference, got: %v", expectedComplex.Difference(actualComplex)) + } + + actualComplex = NewThreadUnsafeSet[struct{ Val string }]() + err = json.Unmarshal([]byte(`[{"Val": "a"}, {"Val": "b"}]`), actualComplex) + if err != nil { + t.Errorf("Error should be nil: %v", err) + } + if !expectedComplex.Equal(actualComplex) { + t.Errorf("Expected no difference, got: %v", expectedComplex.Difference(actualComplex)) + } +} + +// this serves as an example of how to correctly unmarshal a struct with a Set property +type testStruct struct { + Other string + Set Set[string] +} + +func (t *testStruct) UnmarshalJSON(b []byte) error { + raw := struct { + Other string + Set []string + }{} + + err := json.Unmarshal(b, &raw) + if err != nil { + return err + } + + t.Other = raw.Other + t.Set = NewThreadUnsafeSet(raw.Set...) + + return nil +}