From 3045cfb32f3754a7f9e6d9bab57dab766477b382 Mon Sep 17 00:00:00 2001 From: a very fake floordiv Date: Tue, 8 Nov 2022 22:15:42 +0100 Subject: [PATCH] less pointers and more memory-efficient Clear() `threadUnsafeSet` is already a map that is a reference-type, so passing it by-value still lets us mutate an origin object, but decreases a number of pointers, which in turn, decreases a GC-pressure, (potentially) makes less useless operations and (definitely) makes code better to read/understand. Another point of the change is to a bit "upgrade" a `Clear()` method. In the old implementation, it simply constructed a new `threadUnsafeSet` and did mutate an origin pointer to point at our new set. This commits an extra allocation (I would expect a `Clear()` method to be allocations-free), and fallbacks already allocated map's size to the default one. So I replaced this with clearing itself using `mapclear()` function (that implicitly replaces `for key := range d { delete(d, key) }`). This is actually pretty expensive operation, but does no allocations, that is maybe even cheaper than vice-versa. Anyway user always can construct a new instance of the set by itself, but cannot clear underlying map manually in case he really needs it - to decrease amount of memory used in average, number of which is actually affected a lot by internally allocating a new map --- threadunsafe.go | 140 +++++++++++++++++++++++++----------------------- 1 file changed, 73 insertions(+), 67 deletions(-) diff --git a/threadunsafe.go b/threadunsafe.go index dfc5c8f..36e2fd9 100644 --- a/threadunsafe.go +++ b/threadunsafe.go @@ -35,42 +35,47 @@ import ( type threadUnsafeSet[T comparable] map[T]struct{} // Assert concrete type:threadUnsafeSet adheres to Set interface. -var _ Set[string] = (*threadUnsafeSet[string])(nil) +var _ Set[string] = (threadUnsafeSet[string])(nil) func newThreadUnsafeSet[T comparable]() threadUnsafeSet[T] { return make(threadUnsafeSet[T]) } -func (s *threadUnsafeSet[T]) Add(v T) bool { - prevLen := len(*s) - (*s)[v] = struct{}{} - return prevLen != len(*s) +func (s threadUnsafeSet[T]) Add(v T) bool { + prevLen := len(s) + s[v] = struct{}{} + return prevLen != len(s) } // private version of Add which doesn't return a value -func (s *threadUnsafeSet[T]) add(v T) { - (*s)[v] = struct{}{} +func (s threadUnsafeSet[T]) add(v T) { + s[v] = struct{}{} } -func (s *threadUnsafeSet[T]) Cardinality() int { - return len(*s) +func (s threadUnsafeSet[T]) Cardinality() int { + return len(s) } -func (s *threadUnsafeSet[T]) Clear() { - *s = newThreadUnsafeSet[T]() +func (s threadUnsafeSet[T]) Clear() { + // Constructions like this are optimised by compiler, and replaced by + // mapclear() function, defined in + // https://github.com/golang/go/blob/29bbca5c2c1ad41b2a9747890d183b6dd3a4ace4/src/runtime/map.go#L993) + for key := range s { + delete(s, key) + } } -func (s *threadUnsafeSet[T]) Clone() Set[T] { +func (s threadUnsafeSet[T]) Clone() Set[T] { clonedSet := make(threadUnsafeSet[T], s.Cardinality()) - for elem := range *s { + for elem := range s { clonedSet.add(elem) } - return &clonedSet + return clonedSet } -func (s *threadUnsafeSet[T]) Contains(v ...T) bool { +func (s threadUnsafeSet[T]) Contains(v ...T) bool { for _, val := range v { - if _, ok := (*s)[val]; !ok { + if _, ok := s[val]; !ok { return false } } @@ -78,38 +83,38 @@ func (s *threadUnsafeSet[T]) Contains(v ...T) bool { } // private version of Contains for a single element v -func (s *threadUnsafeSet[T]) contains(v T) bool { - _, ok := (*s)[v] +func (s threadUnsafeSet[T]) contains(v T) (ok bool) { + _, ok = s[v] return ok } -func (s *threadUnsafeSet[T]) Difference(other Set[T]) Set[T] { - o := other.(*threadUnsafeSet[T]) +func (s threadUnsafeSet[T]) Difference(other Set[T]) Set[T] { + o := other.(threadUnsafeSet[T]) diff := newThreadUnsafeSet[T]() - for elem := range *s { + for elem := range s { if !o.contains(elem) { diff.add(elem) } } - return &diff + return diff } -func (s *threadUnsafeSet[T]) Each(cb func(T) bool) { - for elem := range *s { +func (s threadUnsafeSet[T]) Each(cb func(T) bool) { + for elem := range s { if cb(elem) { break } } } -func (s *threadUnsafeSet[T]) Equal(other Set[T]) bool { - o := other.(*threadUnsafeSet[T]) +func (s threadUnsafeSet[T]) Equal(other Set[T]) bool { + o := other.(threadUnsafeSet[T]) if s.Cardinality() != other.Cardinality() { return false } - for elem := range *s { + for elem := range s { if !o.contains(elem) { return false } @@ -117,41 +122,41 @@ func (s *threadUnsafeSet[T]) Equal(other Set[T]) bool { return true } -func (s *threadUnsafeSet[T]) Intersect(other Set[T]) Set[T] { - o := other.(*threadUnsafeSet[T]) +func (s threadUnsafeSet[T]) Intersect(other Set[T]) Set[T] { + o := other.(threadUnsafeSet[T]) intersection := newThreadUnsafeSet[T]() // loop over smaller set if s.Cardinality() < other.Cardinality() { - for elem := range *s { + for elem := range s { if o.contains(elem) { intersection.add(elem) } } } else { - for elem := range *o { + for elem := range o { if s.contains(elem) { intersection.add(elem) } } } - return &intersection + return intersection } -func (s *threadUnsafeSet[T]) IsProperSubset(other Set[T]) bool { +func (s threadUnsafeSet[T]) IsProperSubset(other Set[T]) bool { return s.Cardinality() < other.Cardinality() && s.IsSubset(other) } -func (s *threadUnsafeSet[T]) IsProperSuperset(other Set[T]) bool { +func (s threadUnsafeSet[T]) IsProperSuperset(other Set[T]) bool { return s.Cardinality() > other.Cardinality() && s.IsSuperset(other) } -func (s *threadUnsafeSet[T]) IsSubset(other Set[T]) bool { - o := other.(*threadUnsafeSet[T]) +func (s threadUnsafeSet[T]) IsSubset(other Set[T]) bool { + o := other.(threadUnsafeSet[T]) if s.Cardinality() > other.Cardinality() { return false } - for elem := range *s { + for elem := range s { if !o.contains(elem) { return false } @@ -159,14 +164,14 @@ func (s *threadUnsafeSet[T]) IsSubset(other Set[T]) bool { return true } -func (s *threadUnsafeSet[T]) IsSuperset(other Set[T]) bool { +func (s threadUnsafeSet[T]) IsSuperset(other Set[T]) bool { return other.IsSubset(s) } -func (s *threadUnsafeSet[T]) Iter() <-chan T { +func (s threadUnsafeSet[T]) Iter() <-chan T { ch := make(chan T) go func() { - for elem := range *s { + for elem := range s { ch <- elem } close(ch) @@ -175,12 +180,12 @@ func (s *threadUnsafeSet[T]) Iter() <-chan T { return ch } -func (s *threadUnsafeSet[T]) Iterator() *Iterator[T] { +func (s threadUnsafeSet[T]) Iterator() *Iterator[T] { iterator, ch, stopCh := newIterator[T]() go func() { L: - for elem := range *s { + for elem := range s { select { case <-stopCh: break L @@ -193,56 +198,57 @@ func (s *threadUnsafeSet[T]) Iterator() *Iterator[T] { return iterator } -// TODO: how can we make this properly , return T but can't return nil. -func (s *threadUnsafeSet[T]) Pop() (v T, ok bool) { - for item := range *s { - delete(*s, item) +// Pop returns a popped item in case set is not empty, or nil-value of T +// if set is already empty +func (s threadUnsafeSet[T]) Pop() (v T, ok bool) { + for item := range s { + delete(s, item) return item, true } - return + return v, false } -func (s *threadUnsafeSet[T]) Remove(v T) { - delete(*s, v) +func (s threadUnsafeSet[T]) Remove(v T) { + delete(s, v) } -func (s *threadUnsafeSet[T]) String() string { - items := make([]string, 0, len(*s)) +func (s threadUnsafeSet[T]) String() string { + items := make([]string, 0, len(s)) - for elem := range *s { + for elem := range s { items = append(items, fmt.Sprintf("%v", elem)) } return fmt.Sprintf("Set{%s}", strings.Join(items, ", ")) } -func (s *threadUnsafeSet[T]) SymmetricDifference(other Set[T]) Set[T] { - o := other.(*threadUnsafeSet[T]) +func (s threadUnsafeSet[T]) SymmetricDifference(other Set[T]) Set[T] { + o := other.(threadUnsafeSet[T]) sd := newThreadUnsafeSet[T]() - for elem := range *s { + for elem := range s { if !o.contains(elem) { sd.add(elem) } } - for elem := range *o { + for elem := range o { if !s.contains(elem) { sd.add(elem) } } - return &sd + return sd } -func (s *threadUnsafeSet[T]) ToSlice() []T { +func (s threadUnsafeSet[T]) ToSlice() []T { keys := make([]T, 0, s.Cardinality()) - for elem := range *s { + for elem := range s { keys = append(keys, elem) } return keys } -func (s *threadUnsafeSet[T]) Union(other Set[T]) Set[T] { - o := other.(*threadUnsafeSet[T]) +func (s threadUnsafeSet[T]) Union(other Set[T]) Set[T] { + o := other.(threadUnsafeSet[T]) n := s.Cardinality() if o.Cardinality() > n { @@ -250,20 +256,20 @@ func (s *threadUnsafeSet[T]) Union(other Set[T]) Set[T] { } unionedSet := make(threadUnsafeSet[T], n) - for elem := range *s { + for elem := range s { unionedSet.add(elem) } - for elem := range *o { + for elem := range o { unionedSet.add(elem) } - return &unionedSet + return unionedSet } // MarshalJSON creates a JSON array from the set, it marshals all elements -func (s *threadUnsafeSet[T]) MarshalJSON() ([]byte, error) { +func (s threadUnsafeSet[T]) MarshalJSON() ([]byte, error) { items := make([]string, 0, s.Cardinality()) - for elem := range *s { + for elem := range s { b, err := json.Marshal(elem) if err != nil { return nil, err @@ -277,7 +283,7 @@ func (s *threadUnsafeSet[T]) MarshalJSON() ([]byte, error) { // UnmarshalJSON recreates a set from a JSON array, it only decodes // primitive types. Numbers are decoded as json.Number. -func (s *threadUnsafeSet[T]) UnmarshalJSON(b []byte) error { +func (s threadUnsafeSet[T]) UnmarshalJSON(b []byte) error { var i []any d := json.NewDecoder(bytes.NewReader(b))