Skip to content

Commit

Permalink
Set Insertion Rework (#4999)
Browse files Browse the repository at this point in the history
This commit introduces lazy key slice sorting for the Set type, similar to what was done for Object types in #4830. After this change, sorting of the Set type's key slice will be delayed until just-before-use, identically to how lazy key slice sorting is done for the Object type.

This will move the sorting overhead from construction-time for Sets over to evaluation-time, allowing much more efficient construction and use of enormous (500k+ item) Sets. This appears to be a performance-neutral change overall, while dramatically improving performance for the "large set" edge case.

Signed-off-by: Philip Conrad <[email protected]>
  • Loading branch information
philipaconrad authored Aug 15, 2022
1 parent eb27eda commit 9d2b1ad
Show file tree
Hide file tree
Showing 3 changed files with 72 additions and 58 deletions.
51 changes: 27 additions & 24 deletions ast/term.go
Original file line number Diff line number Diff line change
Expand Up @@ -1352,10 +1352,11 @@ func newset(n int) *set {
keys = make([]*Term, 0, n)
}
return &set{
elems: make(map[int]*Term, n),
keys: keys,
hash: 0,
ground: true,
elems: make(map[int]*Term, n),
keys: keys,
hash: 0,
ground: true,
numInserts: 0,
}
}

Expand All @@ -1368,10 +1369,11 @@ func SetTerm(t ...*Term) *Term {
}

type set struct {
elems map[int]*Term
keys []*Term
hash int
ground bool
elems map[int]*Term
keys []*Term
hash int
ground bool
numInserts int // number of inserts since last sorting.
}

// Copy returns a deep copy of s.
Expand Down Expand Up @@ -1401,7 +1403,7 @@ func (s *set) String() string {
}
var b strings.Builder
b.WriteRune('{')
for i := range s.keys {
for i := range s.sortedKeys() {
if i > 0 {
b.WriteString(", ")
}
Expand All @@ -1411,6 +1413,14 @@ func (s *set) String() string {
return b.String()
}

func (s *set) sortedKeys() []*Term {
if s.numInserts > 0 {
sort.Sort(termSlice(s.keys))
s.numInserts = 0
}
return s.keys
}

// Compare compares s to other, return <0, 0, or >0 if it is less than, equal to,
// or greater than other.
func (s *set) Compare(other Value) int {
Expand All @@ -1422,7 +1432,7 @@ func (s *set) Compare(other Value) int {
return 1
}
t := other.(*set)
return termSliceCompare(s.keys, t.keys)
return termSliceCompare(s.sortedKeys(), t.sortedKeys())
}

// Find returns the set or dereferences the element itself.
Expand Down Expand Up @@ -1488,7 +1498,7 @@ func (s *set) Add(t *Term) {
// Iter calls f on each element in s. If f returns an error, iteration stops
// and the return value is the error.
func (s *set) Iter(f func(*Term) error) error {
for i := range s.keys {
for i := range s.sortedKeys() {
if err := f(s.keys[i]); err != nil {
return err
}
Expand Down Expand Up @@ -1564,20 +1574,19 @@ func (s *set) MarshalJSON() ([]byte, error) {
if s.keys == nil {
return []byte(`[]`), nil
}
return json.Marshal(s.keys)
return json.Marshal(s.sortedKeys())
}

// Sorted returns an Array that contains the sorted elements of s.
func (s *set) Sorted() *Array {
cpy := make([]*Term, len(s.keys))
copy(cpy, s.keys)
sort.Sort(termSlice(cpy))
copy(cpy, s.sortedKeys())
return NewArray(cpy...)
}

// Slice returns a slice of terms contained in the set.
func (s *set) Slice() []*Term {
return s.keys
return s.sortedKeys()
}

func (s *set) insert(x *Term) {
Expand Down Expand Up @@ -1670,15 +1679,9 @@ func (s *set) insert(x *Term) {
}

s.elems[insertHash] = x
i := sort.Search(len(s.keys), func(i int) bool { return Compare(x, s.keys[i]) < 0 })
if i < len(s.keys) {
// insert at position `i`:
s.keys = append(s.keys, nil) // add some space
copy(s.keys[i+1:], s.keys[i:]) // move things over
s.keys[i] = x // drop it in position
} else {
s.keys = append(s.keys, x)
}
// O(1) insertion, but we'll have to re-sort the keys later.
s.keys = append(s.keys, x)
s.numInserts++ // Track insertions since the last re-sorting.

s.hash += hash
s.ground = s.ground && x.IsGround()
Expand Down
47 changes: 45 additions & 2 deletions ast/term_bench_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,25 @@ func BenchmarkObjectCreationAndLookup(b *testing.B) {
}
}

func BenchmarkSetCreationAndLookup(b *testing.B) {
sizes := []int{5, 50, 500, 5000, 50000, 500000}
for _, n := range sizes {
b.Run(fmt.Sprint(n), func(b *testing.B) {
set := NewSet()
for i := 0; i < n; i++ {
set.Add(StringTerm(fmt.Sprint(i)))
}
key := StringTerm(fmt.Sprint(n - 1))
for i := 0; i < b.N; i++ {
present := set.Contains(key)
if !present {
b.Fatal("expected hit")
}
}
})
}
}

func BenchmarkSetIntersection(b *testing.B) {
sizes := []int{5, 50, 500, 5000}
for _, n := range sizes {
Expand Down Expand Up @@ -288,11 +307,10 @@ func BenchmarkArrayString(b *testing.B) {
}

func BenchmarkSetString(b *testing.B) {
sizes := []int{5, 50, 500, 5000}
sizes := []int{5, 50, 500, 5000, 50000}

for _, n := range sizes {
b.Run(fmt.Sprint(n), func(b *testing.B) {

val := NewSet()
for i := 0; i < n; i++ {
val.Add(IntNumberTerm(i))
Expand All @@ -307,3 +325,28 @@ func BenchmarkSetString(b *testing.B) {
})
}
}

func BenchmarkSetMarshalJSON(b *testing.B) {
var err error
sizes := []int{5, 50, 500, 5000, 50000}

for _, n := range sizes {
b.Run(fmt.Sprint(n), func(b *testing.B) {
set := NewSet()
for i := 0; i < n; i++ {
set.Add(StringTerm(fmt.Sprint(i)))
}

b.Run("json.Marshal", func(b *testing.B) {
b.ResetTimer()
for i := 0; i < b.N; i++ {
bs, err = json.Marshal(set)
if err != nil {
b.Fatal(err)
}
}
})
})
}

}
32 changes: 0 additions & 32 deletions ast/term_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@ import (
"encoding/json"
"fmt"
"reflect"
"sort"
"strings"
"testing"

Expand Down Expand Up @@ -263,37 +262,6 @@ func TestObjectFilter(t *testing.T) {
}
}

func TestSetInsertKeepsKeysSorting(t *testing.T) {
keysSorted := func(s *set) func(int, int) bool {
return func(i, j int) bool {
return Compare(s.keys[i], s.keys[j]) < 0
}
}

s0 := NewSet(
StringTerm("d"),
StringTerm("b"),
StringTerm("a"),
)
s := s0.(*set)
act := sort.SliceIsSorted(s.keys, keysSorted(s))
if exp := true; act != exp {
t.Errorf("SliceIsSorted: expected %v, got %v", exp, act)
for i := range s.keys {
t.Logf("elem[%d]: %v", i, s.keys[i])
}
}

s0.Add(StringTerm("c"))
act = sort.SliceIsSorted(s.keys, keysSorted(s))
if exp := true; act != exp {
t.Errorf("SliceIsSorted: expected %v, got %v", exp, act)
for i := range s.keys {
t.Logf("elem[%d]: %v", i, s.keys[i])
}
}
}

func TestTermBadJSON(t *testing.T) {

input := `{
Expand Down

0 comments on commit 9d2b1ad

Please sign in to comment.