Skip to content

Commit

Permalink
less pointers and more memory-efficient Clear()
Browse files Browse the repository at this point in the history
`threadUnsafeSet` is already a map that is a reference-type, so passing it by-value still lets us mutate an origin object, but decreases a number of pointers, which in turn, decreases a GC-pressure, (potentially) makes less useless operations and (definitely) makes code better to read/understand.

Another point of the change is to a bit "upgrade" a `Clear()` method. In the old implementation, it simply constructed a new `threadUnsafeSet` and did mutate an origin pointer to point at our new set. This commits an extra allocation (I would expect a `Clear()` method to be allocations-free), and fallbacks already allocated map's size to the default one. So I replaced this with clearing itself using `mapclear()` function (that implicitly replaces `for key := range d { delete(d, key) }`). This is actually pretty expensive operation, but does no allocations, that is maybe even cheaper than vice-versa. Anyway user always can construct a new instance of the set by itself, but cannot clear underlying map manually in case he really needs it - to decrease amount of memory used in average, number of which is actually affected a lot by internally allocating a new map
  • Loading branch information
flrdv authored and deckarep committed Mar 5, 2023
1 parent 543b3d7 commit 3045cfb
Showing 1 changed file with 73 additions and 67 deletions.
140 changes: 73 additions & 67 deletions threadunsafe.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,138 +35,143 @@ import (
type threadUnsafeSet[T comparable] map[T]struct{}

// Assert concrete type:threadUnsafeSet adheres to Set interface.
var _ Set[string] = (*threadUnsafeSet[string])(nil)
var _ Set[string] = (threadUnsafeSet[string])(nil)

func newThreadUnsafeSet[T comparable]() threadUnsafeSet[T] {
return make(threadUnsafeSet[T])
}

func (s *threadUnsafeSet[T]) Add(v T) bool {
prevLen := len(*s)
(*s)[v] = struct{}{}
return prevLen != len(*s)
func (s threadUnsafeSet[T]) Add(v T) bool {
prevLen := len(s)
s[v] = struct{}{}
return prevLen != len(s)
}

// private version of Add which doesn't return a value
func (s *threadUnsafeSet[T]) add(v T) {
(*s)[v] = struct{}{}
func (s threadUnsafeSet[T]) add(v T) {
s[v] = struct{}{}
}

func (s *threadUnsafeSet[T]) Cardinality() int {
return len(*s)
func (s threadUnsafeSet[T]) Cardinality() int {
return len(s)
}

func (s *threadUnsafeSet[T]) Clear() {
*s = newThreadUnsafeSet[T]()
func (s threadUnsafeSet[T]) Clear() {
// Constructions like this are optimised by compiler, and replaced by
// mapclear() function, defined in
// https://github.com/golang/go/blob/29bbca5c2c1ad41b2a9747890d183b6dd3a4ace4/src/runtime/map.go#L993)
for key := range s {
delete(s, key)
}
}

func (s *threadUnsafeSet[T]) Clone() Set[T] {
func (s threadUnsafeSet[T]) Clone() Set[T] {
clonedSet := make(threadUnsafeSet[T], s.Cardinality())
for elem := range *s {
for elem := range s {
clonedSet.add(elem)
}
return &clonedSet
return clonedSet
}

func (s *threadUnsafeSet[T]) Contains(v ...T) bool {
func (s threadUnsafeSet[T]) Contains(v ...T) bool {
for _, val := range v {
if _, ok := (*s)[val]; !ok {
if _, ok := s[val]; !ok {
return false
}
}
return true
}

// private version of Contains for a single element v
func (s *threadUnsafeSet[T]) contains(v T) bool {
_, ok := (*s)[v]
func (s threadUnsafeSet[T]) contains(v T) (ok bool) {
_, ok = s[v]
return ok
}

func (s *threadUnsafeSet[T]) Difference(other Set[T]) Set[T] {
o := other.(*threadUnsafeSet[T])
func (s threadUnsafeSet[T]) Difference(other Set[T]) Set[T] {
o := other.(threadUnsafeSet[T])

diff := newThreadUnsafeSet[T]()
for elem := range *s {
for elem := range s {
if !o.contains(elem) {
diff.add(elem)
}
}
return &diff
return diff
}

func (s *threadUnsafeSet[T]) Each(cb func(T) bool) {
for elem := range *s {
func (s threadUnsafeSet[T]) Each(cb func(T) bool) {
for elem := range s {
if cb(elem) {
break
}
}
}

func (s *threadUnsafeSet[T]) Equal(other Set[T]) bool {
o := other.(*threadUnsafeSet[T])
func (s threadUnsafeSet[T]) Equal(other Set[T]) bool {
o := other.(threadUnsafeSet[T])

if s.Cardinality() != other.Cardinality() {
return false
}
for elem := range *s {
for elem := range s {
if !o.contains(elem) {
return false
}
}
return true
}

func (s *threadUnsafeSet[T]) Intersect(other Set[T]) Set[T] {
o := other.(*threadUnsafeSet[T])
func (s threadUnsafeSet[T]) Intersect(other Set[T]) Set[T] {
o := other.(threadUnsafeSet[T])

intersection := newThreadUnsafeSet[T]()
// loop over smaller set
if s.Cardinality() < other.Cardinality() {
for elem := range *s {
for elem := range s {
if o.contains(elem) {
intersection.add(elem)
}
}
} else {
for elem := range *o {
for elem := range o {
if s.contains(elem) {
intersection.add(elem)
}
}
}
return &intersection
return intersection
}

func (s *threadUnsafeSet[T]) IsProperSubset(other Set[T]) bool {
func (s threadUnsafeSet[T]) IsProperSubset(other Set[T]) bool {
return s.Cardinality() < other.Cardinality() && s.IsSubset(other)
}

func (s *threadUnsafeSet[T]) IsProperSuperset(other Set[T]) bool {
func (s threadUnsafeSet[T]) IsProperSuperset(other Set[T]) bool {
return s.Cardinality() > other.Cardinality() && s.IsSuperset(other)
}

func (s *threadUnsafeSet[T]) IsSubset(other Set[T]) bool {
o := other.(*threadUnsafeSet[T])
func (s threadUnsafeSet[T]) IsSubset(other Set[T]) bool {
o := other.(threadUnsafeSet[T])
if s.Cardinality() > other.Cardinality() {
return false
}
for elem := range *s {
for elem := range s {
if !o.contains(elem) {
return false
}
}
return true
}

func (s *threadUnsafeSet[T]) IsSuperset(other Set[T]) bool {
func (s threadUnsafeSet[T]) IsSuperset(other Set[T]) bool {
return other.IsSubset(s)
}

func (s *threadUnsafeSet[T]) Iter() <-chan T {
func (s threadUnsafeSet[T]) Iter() <-chan T {
ch := make(chan T)
go func() {
for elem := range *s {
for elem := range s {
ch <- elem
}
close(ch)
Expand All @@ -175,12 +180,12 @@ func (s *threadUnsafeSet[T]) Iter() <-chan T {
return ch
}

func (s *threadUnsafeSet[T]) Iterator() *Iterator[T] {
func (s threadUnsafeSet[T]) Iterator() *Iterator[T] {
iterator, ch, stopCh := newIterator[T]()

go func() {
L:
for elem := range *s {
for elem := range s {
select {
case <-stopCh:
break L
Expand All @@ -193,77 +198,78 @@ func (s *threadUnsafeSet[T]) Iterator() *Iterator[T] {
return iterator
}

// TODO: how can we make this properly , return T but can't return nil.
func (s *threadUnsafeSet[T]) Pop() (v T, ok bool) {
for item := range *s {
delete(*s, item)
// Pop returns a popped item in case set is not empty, or nil-value of T
// if set is already empty
func (s threadUnsafeSet[T]) Pop() (v T, ok bool) {
for item := range s {
delete(s, item)
return item, true
}
return
return v, false
}

func (s *threadUnsafeSet[T]) Remove(v T) {
delete(*s, v)
func (s threadUnsafeSet[T]) Remove(v T) {
delete(s, v)
}

func (s *threadUnsafeSet[T]) String() string {
items := make([]string, 0, len(*s))
func (s threadUnsafeSet[T]) String() string {
items := make([]string, 0, len(s))

for elem := range *s {
for elem := range s {
items = append(items, fmt.Sprintf("%v", elem))
}
return fmt.Sprintf("Set{%s}", strings.Join(items, ", "))
}

func (s *threadUnsafeSet[T]) SymmetricDifference(other Set[T]) Set[T] {
o := other.(*threadUnsafeSet[T])
func (s threadUnsafeSet[T]) SymmetricDifference(other Set[T]) Set[T] {
o := other.(threadUnsafeSet[T])

sd := newThreadUnsafeSet[T]()
for elem := range *s {
for elem := range s {
if !o.contains(elem) {
sd.add(elem)
}
}
for elem := range *o {
for elem := range o {
if !s.contains(elem) {
sd.add(elem)
}
}
return &sd
return sd
}

func (s *threadUnsafeSet[T]) ToSlice() []T {
func (s threadUnsafeSet[T]) ToSlice() []T {
keys := make([]T, 0, s.Cardinality())
for elem := range *s {
for elem := range s {
keys = append(keys, elem)
}

return keys
}

func (s *threadUnsafeSet[T]) Union(other Set[T]) Set[T] {
o := other.(*threadUnsafeSet[T])
func (s threadUnsafeSet[T]) Union(other Set[T]) Set[T] {
o := other.(threadUnsafeSet[T])

n := s.Cardinality()
if o.Cardinality() > n {
n = o.Cardinality()
}
unionedSet := make(threadUnsafeSet[T], n)

for elem := range *s {
for elem := range s {
unionedSet.add(elem)
}
for elem := range *o {
for elem := range o {
unionedSet.add(elem)
}
return &unionedSet
return unionedSet
}

// MarshalJSON creates a JSON array from the set, it marshals all elements
func (s *threadUnsafeSet[T]) MarshalJSON() ([]byte, error) {
func (s threadUnsafeSet[T]) MarshalJSON() ([]byte, error) {
items := make([]string, 0, s.Cardinality())

for elem := range *s {
for elem := range s {
b, err := json.Marshal(elem)
if err != nil {
return nil, err
Expand All @@ -277,7 +283,7 @@ func (s *threadUnsafeSet[T]) MarshalJSON() ([]byte, error) {

// UnmarshalJSON recreates a set from a JSON array, it only decodes
// primitive types. Numbers are decoded as json.Number.
func (s *threadUnsafeSet[T]) UnmarshalJSON(b []byte) error {
func (s threadUnsafeSet[T]) UnmarshalJSON(b []byte) error {
var i []any

d := json.NewDecoder(bytes.NewReader(b))
Expand Down

0 comments on commit 3045cfb

Please sign in to comment.