diff --git a/CHANGELOG.md b/CHANGELOG.md index 7e703f8..d2b298b 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -22,6 +22,7 @@ The format is based on [Keep a Changelog], and this project adheres to the keys of `iter.Seq2` sequences. - Added `sets.NewFromValues()` (and variants) which construct set types from the values of `iter.Seq2` sequences. +- Added `Intersection()` method to all set types. ### Changed diff --git a/collections/maps/map.go b/collections/maps/map.go index f2de2e6..d2cf051 100644 --- a/collections/maps/map.go +++ b/collections/maps/map.go @@ -110,13 +110,13 @@ func (m *Map[K, V]) TryGet(k K) (V, bool) { // Clone returns a shallow copy of the map. func (m *Map[K, V]) Clone() *Map[K, V] { - var x Map[K, V] + var out Map[K, V] if m != nil { - x.elements = maps.Clone(m.elements) + out.elements = maps.Clone(m.elements) } - return &x + return &out } // Merge returns a new map containing all key/value pairs from s and x. @@ -145,34 +145,34 @@ func (m *Map[K, V]) Merge(x *Map[K, V]) *Map[K, V] { // Select returns a new map containing all key/value pairs from m for which the // given predicate returns true. func (m *Map[K, V]) Select(pred func(K, V) bool) *Map[K, V] { - var x Map[K, V] + var out Map[K, V] if m != nil { for k, v := range m.elements { if pred(k, v) { - x.Set(k, v) + out.Set(k, v) } } } - return &x + return &out } // Project constructs a new map by applying the given transform function to each // key/value pair in the map. If the transform function returns false, the key // is omitted from the resulting map. func (m *Map[K, V]) Project(transform func(K, V) (K, V, bool)) *Map[K, V] { - var x Map[K, V] + var out Map[K, V] if m != nil { for k, v := range m.elements { if k, v, ok := transform(k, v); ok { - x.Set(k, v) + out.Set(k, v) } } } - return &x + return &out } // All returns a sequence that yields all key/value pairs in the map in no diff --git a/collections/maps/orderedimpl.go b/collections/maps/orderedimpl.go index 56cf34a..7ba1509 100644 --- a/collections/maps/orderedimpl.go +++ b/collections/maps/orderedimpl.go @@ -238,17 +238,17 @@ func orderedProject[K, V any, M ordered[K, V, I], I any]( m M, transform func(K, V) (K, V, bool), ) M { - var x M = m.new(nil) + var out M = m.new(nil) if m != nil { for _, pair := range *m.ptr() { if k, v, ok := transform(pair.Key, pair.Value); ok { - orderedSet(x, k, v) + orderedSet(out, k, v) } } } - return x + return out } func orderedAll[K, V any, M ordered[K, V, I], I any]( diff --git a/collections/maps/proto.go b/collections/maps/proto.go index b50c00b..0011e9c 100644 --- a/collections/maps/proto.go +++ b/collections/maps/proto.go @@ -114,13 +114,13 @@ func (m *Proto[K, V]) TryGet(k K) (V, bool) { // Clone returns a shallow copy of the map. func (m *Proto[K, V]) Clone() *Proto[K, V] { - var x Proto[K, V] + var out Proto[K, V] if m != nil { - x.elements = *m.elements.Clone() + out.elements = *m.elements.Clone() } - return &x + return &out } // Merge returns a new map containing all key/value pairs from s and x. @@ -143,27 +143,27 @@ func (m *Proto[K, V]) Merge(x *Proto[K, V]) *Proto[K, V] { // Select returns a new map containing all key/value pairs from m for which the // given predicate returns true. func (m *Proto[K, V]) Select(pred func(K, V) bool) *Proto[K, V] { - var x Proto[K, V] + var out Proto[K, V] if m != nil { - x.elements = *m.elements.Select( + out.elements = *m.elements.Select( func(s string, v V) bool { return pred(m.unmarshal(s), v) }, ) } - return &x + return &out } // Project constructs a new map by applying the given transform function to each // key/value pair in the map. If the transform function returns false, the key // is omitted from the resulting map. func (m *Proto[K, V]) Project(transform func(K, V) (K, V, bool)) *Proto[K, V] { - var x Proto[K, V] + var out Proto[K, V] if m != nil { - x.elements = *m.elements.Project( + out.elements = *m.elements.Project( func(k string, v V) (string, V, bool) { if k, v, ok := transform(m.unmarshal(k), v); ok { return m.marshal(k), v, true @@ -173,7 +173,7 @@ func (m *Proto[K, V]) Project(transform func(K, V) (K, V, bool)) *Proto[K, V] { ) } - return &x + return &out } // All returns a sequence that yields all key/value pairs in the map in no diff --git a/collections/sets/contract_test.go b/collections/sets/contract_test.go index 4a89265..145d3f1 100644 --- a/collections/sets/contract_test.go +++ b/collections/sets/contract_test.go @@ -24,6 +24,7 @@ type contract[T, I any] interface { Clone() *I Union(*I) *I + Intersection(*I) *I Select(func(T) bool) *I All() iter.Seq[T] @@ -334,6 +335,40 @@ func testSet[ t.Fatal("the result of a union should never be nil") } }, + "intersection with an overlapping set": func(t *rapid.T) { + s := newDisjointSet(t, 0, 3) + + m := drawMember(t) + s.Add(m) + + expected = []T{m} + subject = subject.Intersection(s) + if subject == nil { + t.Fatal("the result of an intersection should never be nil") + } + }, + "intersection with disjoint set": func(t *rapid.T) { + s := newDisjointSet(t, 0, 3) + + expected = nil + subject = subject.Intersection(s) + if subject == nil { + t.Fatal("the result of an intersection should never be nil") + } + }, + "interesection with itself": func(t *rapid.T) { + subject = subject.Intersection(subject) + if subject == nil { + t.Fatal("the result of an intersection should never be nil") + } + }, + "intersection with a nil set": func(t *rapid.T) { + expected = nil + subject = subject.Intersection(nil) + if subject == nil { + t.Fatal("the result of an intersection should never be nil") + } + }, "select a subset": func(t *rapid.T) { subject = subject.Select( func(m T) bool { diff --git a/collections/sets/ordered.go b/collections/sets/ordered.go index 5e08cb4..80d6e97 100644 --- a/collections/sets/ordered.go +++ b/collections/sets/ordered.go @@ -95,6 +95,11 @@ func (s *Ordered[T]) Union(x *Ordered[T]) *Ordered[T] { return orderedUnion[T](s, x) } +// Intersection returns a set containing members that are in both s and x. +func (s *Ordered[T]) Intersection(x *Ordered[T]) *Ordered[T] { + return orderedIntersection[T](s, x) +} + // Select returns the subset of s containing members for which the given // predicate function returns true. func (s *Ordered[T]) Select(pred func(T) bool) *Ordered[T] { diff --git a/collections/sets/orderedbycomparator.go b/collections/sets/orderedbycomparator.go index e58faff..6fa65a8 100644 --- a/collections/sets/orderedbycomparator.go +++ b/collections/sets/orderedbycomparator.go @@ -142,6 +142,11 @@ func (s *OrderedByComparator[T, C]) Union(x *OrderedByComparator[T, C]) *Ordered return orderedUnion[T](s, x) } +// Intersection returns a set containing members that are in both s and x. +func (s *OrderedByComparator[T, C]) Intersection(x *OrderedByComparator[T, C]) *OrderedByComparator[T, C] { + return orderedIntersection[T](s, x) +} + // Select returns the subset of s containing members for which the given // predicate function returns true. func (s *OrderedByComparator[T, C]) Select(pred func(T) bool) *OrderedByComparator[T, C] { diff --git a/collections/sets/orderedbymember.go b/collections/sets/orderedbymember.go index e09be20..b2c2931 100644 --- a/collections/sets/orderedbymember.go +++ b/collections/sets/orderedbymember.go @@ -97,6 +97,11 @@ func (s *OrderedByMember[T]) Union(x *OrderedByMember[T]) *OrderedByMember[T] { return orderedUnion[T](s, x) } +// Intersection returns a set containing members that are in both s and x. +func (s *OrderedByMember[T]) Intersection(x *OrderedByMember[T]) *OrderedByMember[T] { + return orderedIntersection[T](s, x) +} + // Select returns the subset of s containing members for which the given // predicate function returns true. func (s *OrderedByMember[T]) Select(pred func(T) bool) *OrderedByMember[T] { diff --git a/collections/sets/orderedimpl.go b/collections/sets/orderedimpl.go index cba9899..118a1a6 100644 --- a/collections/sets/orderedimpl.go +++ b/collections/sets/orderedimpl.go @@ -234,16 +234,15 @@ func orderedClone[T any, S ordered[T, I], I any]( func orderedUnion[T any, S ordered[T, I], I any]( x, y S, ) S { - var membersX, membersY []T - - if x != nil { - membersX = *x.ptr() + if x == nil { + return orderedClone[T](y) } - if y != nil { - membersY = *y.ptr() + if y == nil { + return orderedClone[T](x) } + membersX, membersY := *x.ptr(), *y.ptr() indexX, indexY := 0, 0 lenX, lenY := len(membersX), len(membersY) @@ -289,6 +288,41 @@ func orderedUnion[T any, S ordered[T, I], I any]( return x.new(members) } +func orderedIntersection[T any, S ordered[T, I], I any]( + x, y S, +) S { + if x == nil || y == nil { + return x.new(nil) + } + + big, small := *x.ptr(), *y.ptr() + if len(small) > len(big) { + big, small = small, big + } + + if len(small) == 0 { + return x.new(nil) + } + + members := make([]T, 0, len(small)) + + for _, m := range small { + if len(big) == 0 { + break + } + + i, ok := slices.BinarySearchFunc(big, m, x.cmp) + if ok { + members = append(members, m) + big = big[i+1:] + } else { + big = big[i:] + } + } + + return x.new(members) +} + func orderedSelect[T any, S ordered[T, I], I any]( s S, pred func(T) bool, diff --git a/collections/sets/proto.go b/collections/sets/proto.go index 1e36b69..0fd9d61 100644 --- a/collections/sets/proto.go +++ b/collections/sets/proto.go @@ -163,13 +163,13 @@ func (s *Proto[T]) IsStrictSubset(x *Proto[T]) bool { // Clone returns a shallow copy of the set. func (s *Proto[T]) Clone() *Proto[T] { - var x Proto[T] + var out Proto[T] if s != nil { - x.members = *s.members.Clone() + out.members = *s.members.Clone() } - return &x + return &out } // Union returns a set containing all members of s and x. @@ -187,18 +187,29 @@ func (s *Proto[T]) Union(x *Proto[T]) *Proto[T] { } } +// Intersection returns a set containing members that are in both s and x. +func (s *Proto[T]) Intersection(x *Proto[T]) *Proto[T] { + var out Proto[T] + + if s != nil && x != nil { + out.members = *s.members.Intersection(&x.members) + } + + return &out +} + // Select returns the subset of s containing members for which the given // predicate function returns true. func (s *Proto[T]) Select(pred func(T) bool) *Proto[T] { - var x Proto[T] + var out Proto[T] if s != nil { - x.members = *s.members.Select(func(m string) bool { + out.members = *s.members.Select(func(m string) bool { return pred(s.unmarshal(m)) }) } - return &x + return &out } // All returns a sequence that yields all members of the set in no particular diff --git a/collections/sets/set.go b/collections/sets/set.go index b727036..ae6e635 100644 --- a/collections/sets/set.go +++ b/collections/sets/set.go @@ -167,13 +167,13 @@ func (s *Set[T]) IsStrictSubset(x *Set[T]) bool { // Clone returns a shallow copy of the set. func (s *Set[T]) Clone() *Set[T] { - var x Set[T] + var out Set[T] if s != nil { - x.members = maps.Clone(s.members) + out.members = maps.Clone(s.members) } - return &x + return &out } // Union returns a set containing all members of s and x. @@ -202,20 +202,44 @@ func (s *Set[T]) Union(x *Set[T]) *Set[T] { } } +// Intersection returns a set containing members that are in both s and x. +func (s *Set[T]) Intersection(x *Set[T]) *Set[T] { + if s == nil || x == nil { + return &Set[T]{} + } + + big, small := s.members, x.members + if len(small) > len(big) { + big, small = small, big + } + + members := make(map[T]struct{}, len(small)) + + for m := range small { + if _, ok := big[m]; ok { + members[m] = struct{}{} + } + } + + return &Set[T]{ + members: members, + } +} + // Select returns the subset of s containing members for which the given // predicate function returns true. func (s *Set[T]) Select(pred func(T) bool) *Set[T] { - var x Set[T] + var out Set[T] if s != nil { for m := range s.members { if pred(m) { - x.Add(m) + out.Add(m) } } } - return &x + return &out } // All returns a sequence that yields all members of the set in no particular