Skip to content

Commit

Permalink
Add Intersection().
Browse files Browse the repository at this point in the history
  • Loading branch information
jmalloc committed Oct 3, 2024
1 parent 6edcf0e commit 6a346f6
Show file tree
Hide file tree
Showing 11 changed files with 159 additions and 39 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ The format is based on [Keep a Changelog], and this project adheres to
the keys of `iter.Seq2` sequences.
- Added `sets.NewFromValues()` (and variants) which construct set types from
the values of `iter.Seq2` sequences.
- Added `Intersection()` method to all set types.

### Changed

Expand Down
18 changes: 9 additions & 9 deletions collections/maps/map.go
Original file line number Diff line number Diff line change
Expand Up @@ -110,13 +110,13 @@ func (m *Map[K, V]) TryGet(k K) (V, bool) {

// Clone returns a shallow copy of the map.
func (m *Map[K, V]) Clone() *Map[K, V] {
var x Map[K, V]
var out Map[K, V]

if m != nil {
x.elements = maps.Clone(m.elements)
out.elements = maps.Clone(m.elements)
}

return &x
return &out
}

// Merge returns a new map containing all key/value pairs from s and x.
Expand Down Expand Up @@ -145,34 +145,34 @@ func (m *Map[K, V]) Merge(x *Map[K, V]) *Map[K, V] {
// Select returns a new map containing all key/value pairs from m for which the
// given predicate returns true.
func (m *Map[K, V]) Select(pred func(K, V) bool) *Map[K, V] {
var x Map[K, V]
var out Map[K, V]

if m != nil {
for k, v := range m.elements {
if pred(k, v) {
x.Set(k, v)
out.Set(k, v)
}
}
}

return &x
return &out
}

// Project constructs a new map by applying the given transform function to each
// key/value pair in the map. If the transform function returns false, the key
// is omitted from the resulting map.
func (m *Map[K, V]) Project(transform func(K, V) (K, V, bool)) *Map[K, V] {
var x Map[K, V]
var out Map[K, V]

if m != nil {
for k, v := range m.elements {
if k, v, ok := transform(k, v); ok {
x.Set(k, v)
out.Set(k, v)
}
}
}

return &x
return &out
}

// All returns a sequence that yields all key/value pairs in the map in no
Expand Down
6 changes: 3 additions & 3 deletions collections/maps/orderedimpl.go
Original file line number Diff line number Diff line change
Expand Up @@ -238,17 +238,17 @@ func orderedProject[K, V any, M ordered[K, V, I], I any](
m M,
transform func(K, V) (K, V, bool),
) M {
var x M = m.new(nil)
var out M = m.new(nil)

if m != nil {
for _, pair := range *m.ptr() {
if k, v, ok := transform(pair.Key, pair.Value); ok {
orderedSet(x, k, v)
orderedSet(out, k, v)
}
}
}

return x
return out
}

func orderedAll[K, V any, M ordered[K, V, I], I any](
Expand Down
18 changes: 9 additions & 9 deletions collections/maps/proto.go
Original file line number Diff line number Diff line change
Expand Up @@ -114,13 +114,13 @@ func (m *Proto[K, V]) TryGet(k K) (V, bool) {

// Clone returns a shallow copy of the map.
func (m *Proto[K, V]) Clone() *Proto[K, V] {
var x Proto[K, V]
var out Proto[K, V]

if m != nil {
x.elements = *m.elements.Clone()
out.elements = *m.elements.Clone()
}

return &x
return &out
}

// Merge returns a new map containing all key/value pairs from s and x.
Expand All @@ -143,27 +143,27 @@ func (m *Proto[K, V]) Merge(x *Proto[K, V]) *Proto[K, V] {
// Select returns a new map containing all key/value pairs from m for which the
// given predicate returns true.
func (m *Proto[K, V]) Select(pred func(K, V) bool) *Proto[K, V] {
var x Proto[K, V]
var out Proto[K, V]

if m != nil {
x.elements = *m.elements.Select(
out.elements = *m.elements.Select(
func(s string, v V) bool {
return pred(m.unmarshal(s), v)
},
)
}

return &x
return &out
}

// Project constructs a new map by applying the given transform function to each
// key/value pair in the map. If the transform function returns false, the key
// is omitted from the resulting map.
func (m *Proto[K, V]) Project(transform func(K, V) (K, V, bool)) *Proto[K, V] {
var x Proto[K, V]
var out Proto[K, V]

if m != nil {
x.elements = *m.elements.Project(
out.elements = *m.elements.Project(
func(k string, v V) (string, V, bool) {
if k, v, ok := transform(m.unmarshal(k), v); ok {
return m.marshal(k), v, true
Expand All @@ -173,7 +173,7 @@ func (m *Proto[K, V]) Project(transform func(K, V) (K, V, bool)) *Proto[K, V] {
)
}

return &x
return &out
}

// All returns a sequence that yields all key/value pairs in the map in no
Expand Down
35 changes: 35 additions & 0 deletions collections/sets/contract_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ type contract[T, I any] interface {

Clone() *I
Union(*I) *I
Intersection(*I) *I
Select(func(T) bool) *I

All() iter.Seq[T]
Expand Down Expand Up @@ -334,6 +335,40 @@ func testSet[
t.Fatal("the result of a union should never be nil")
}
},
"intersection with an overlapping set": func(t *rapid.T) {
s := newDisjointSet(t, 0, 3)

m := drawMember(t)
s.Add(m)

expected = []T{m}
subject = subject.Intersection(s)
if subject == nil {
t.Fatal("the result of an intersection should never be nil")
}
},
"intersection with disjoint set": func(t *rapid.T) {
s := newDisjointSet(t, 0, 3)

expected = nil
subject = subject.Intersection(s)
if subject == nil {
t.Fatal("the result of an intersection should never be nil")
}
},
"interesection with itself": func(t *rapid.T) {
subject = subject.Intersection(subject)
if subject == nil {
t.Fatal("the result of an intersection should never be nil")
}
},
"intersection with a nil set": func(t *rapid.T) {
expected = nil
subject = subject.Intersection(nil)
if subject == nil {
t.Fatal("the result of an intersection should never be nil")
}
},
"select a subset": func(t *rapid.T) {
subject = subject.Select(
func(m T) bool {
Expand Down
5 changes: 5 additions & 0 deletions collections/sets/ordered.go
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,11 @@ func (s *Ordered[T]) Union(x *Ordered[T]) *Ordered[T] {
return orderedUnion[T](s, x)
}

// Intersection returns a set containing members that are in both s and x.
func (s *Ordered[T]) Intersection(x *Ordered[T]) *Ordered[T] {
return orderedIntersection[T](s, x)
}

// Select returns the subset of s containing members for which the given
// predicate function returns true.
func (s *Ordered[T]) Select(pred func(T) bool) *Ordered[T] {
Expand Down
5 changes: 5 additions & 0 deletions collections/sets/orderedbycomparator.go
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,11 @@ func (s *OrderedByComparator[T, C]) Union(x *OrderedByComparator[T, C]) *Ordered
return orderedUnion[T](s, x)
}

// Intersection returns a set containing members that are in both s and x.
func (s *OrderedByComparator[T, C]) Intersection(x *OrderedByComparator[T, C]) *OrderedByComparator[T, C] {
return orderedIntersection[T](s, x)
}

// Select returns the subset of s containing members for which the given
// predicate function returns true.
func (s *OrderedByComparator[T, C]) Select(pred func(T) bool) *OrderedByComparator[T, C] {
Expand Down
5 changes: 5 additions & 0 deletions collections/sets/orderedbymember.go
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,11 @@ func (s *OrderedByMember[T]) Union(x *OrderedByMember[T]) *OrderedByMember[T] {
return orderedUnion[T](s, x)
}

// Intersection returns a set containing members that are in both s and x.
func (s *OrderedByMember[T]) Intersection(x *OrderedByMember[T]) *OrderedByMember[T] {
return orderedIntersection[T](s, x)
}

// Select returns the subset of s containing members for which the given
// predicate function returns true.
func (s *OrderedByMember[T]) Select(pred func(T) bool) *OrderedByMember[T] {
Expand Down
46 changes: 40 additions & 6 deletions collections/sets/orderedimpl.go
Original file line number Diff line number Diff line change
Expand Up @@ -234,16 +234,15 @@ func orderedClone[T any, S ordered[T, I], I any](
func orderedUnion[T any, S ordered[T, I], I any](
x, y S,
) S {
var membersX, membersY []T

if x != nil {
membersX = *x.ptr()
if x == nil {
return orderedClone[T](y)
}

if y != nil {
membersY = *y.ptr()
if y == nil {
return orderedClone[T](x)
}

membersX, membersY := *x.ptr(), *y.ptr()
indexX, indexY := 0, 0
lenX, lenY := len(membersX), len(membersY)

Expand Down Expand Up @@ -289,6 +288,41 @@ func orderedUnion[T any, S ordered[T, I], I any](
return x.new(members)
}

func orderedIntersection[T any, S ordered[T, I], I any](
x, y S,
) S {
if x == nil || y == nil {
return x.new(nil)
}

big, small := *x.ptr(), *y.ptr()
if len(small) > len(big) {
big, small = small, big
}

if len(small) == 0 {
return x.new(nil)
}

members := make([]T, 0, len(small))

for _, m := range small {
if len(big) == 0 {
break
}

i, ok := slices.BinarySearchFunc(big, m, x.cmp)
if ok {
members = append(members, m)
big = big[i+1:]
} else {
big = big[i:]
}
}

return x.new(members)
}

func orderedSelect[T any, S ordered[T, I], I any](
s S,
pred func(T) bool,
Expand Down
23 changes: 17 additions & 6 deletions collections/sets/proto.go
Original file line number Diff line number Diff line change
Expand Up @@ -163,13 +163,13 @@ func (s *Proto[T]) IsStrictSubset(x *Proto[T]) bool {

// Clone returns a shallow copy of the set.
func (s *Proto[T]) Clone() *Proto[T] {
var x Proto[T]
var out Proto[T]

if s != nil {
x.members = *s.members.Clone()
out.members = *s.members.Clone()
}

return &x
return &out
}

// Union returns a set containing all members of s and x.
Expand All @@ -187,18 +187,29 @@ func (s *Proto[T]) Union(x *Proto[T]) *Proto[T] {
}
}

// Intersection returns a set containing members that are in both s and x.
func (s *Proto[T]) Intersection(x *Proto[T]) *Proto[T] {
var out Proto[T]

if s != nil && x != nil {
out.members = *s.members.Intersection(&x.members)
}

return &out
}

// Select returns the subset of s containing members for which the given
// predicate function returns true.
func (s *Proto[T]) Select(pred func(T) bool) *Proto[T] {
var x Proto[T]
var out Proto[T]

if s != nil {
x.members = *s.members.Select(func(m string) bool {
out.members = *s.members.Select(func(m string) bool {
return pred(s.unmarshal(m))
})
}

return &x
return &out
}

// All returns a sequence that yields all members of the set in no particular
Expand Down
Loading

0 comments on commit 6a346f6

Please sign in to comment.