diff --git a/iterator_test.go b/iterator_test.go index 6a4a26f19..22af1f6a1 100644 --- a/iterator_test.go +++ b/iterator_test.go @@ -1,6 +1,7 @@ package iavl import ( + "math/rand" "sort" "testing" @@ -32,6 +33,56 @@ func TestIterator_NewIterator_NilTree_Failure(t *testing.T) { performTest(t, itr) require.ErrorIs(t, errFastIteratorNilNdbGiven, itr.Error()) }) + + t.Run("Unsaved Fast Iterator", func(t *testing.T) { + itr := NewUnsavedFastIterator(start, end, ascending, nil, map[string]*FastNode{}, map[string]interface{}{}) + performTest(t, itr) + require.ErrorIs(t, errFastIteratorNilNdbGiven, itr.Error()) + }) +} + +func TestUnsavedFastIterator_NewIterator_NilAdditions_Failure(t *testing.T) { + var start, end []byte = []byte{'a'}, []byte{'c'} + ascending := true + + performTest := func(t *testing.T, itr dbm.Iterator) { + require.NotNil(t, itr) + require.False(t, itr.Valid()) + actualsStart, actualEnd := itr.Domain() + require.Equal(t, start, actualsStart) + require.Equal(t, end, actualEnd) + require.Error(t, itr.Error()) + } + + t.Run("Nil additions given", func(t *testing.T) { + tree, err := NewMutableTree(dbm.NewMemDB(), 0) + require.NoError(t, err) + itr := NewUnsavedFastIterator(start, end, ascending, tree.ndb, nil, tree.unsavedFastNodeRemovals) + performTest(t, itr) + require.ErrorIs(t, errUnsavedFastIteratorNilAdditionsGiven, itr.Error()) + }) + + t.Run("Nil removals given", func(t *testing.T) { + tree, err := NewMutableTree(dbm.NewMemDB(), 0) + require.NoError(t, err) + itr := NewUnsavedFastIterator(start, end, ascending, tree.ndb, tree.unsavedFastNodeAdditions, nil) + performTest(t, itr) + require.ErrorIs(t, errUnsavedFastIteratorNilRemovalsGiven, itr.Error()) + }) + + t.Run("All nil", func(t *testing.T) { + itr := NewUnsavedFastIterator(start, end, ascending, nil, nil, nil) + performTest(t, itr) + require.ErrorIs(t, errFastIteratorNilNdbGiven, itr.Error()) + }) + + t.Run("Additions and removals are nil", func(t *testing.T) { + tree, err := NewMutableTree(dbm.NewMemDB(), 0) + require.NoError(t, err) + itr := NewUnsavedFastIterator(start, end, ascending, tree.ndb, nil, nil) + performTest(t, itr) + require.ErrorIs(t, errUnsavedFastIteratorNilAdditionsGiven, itr.Error()) + }) } func TestIterator_Empty_Invalid(t *testing.T) { @@ -57,6 +108,11 @@ func TestIterator_Empty_Invalid(t *testing.T) { itr, mirror := setupFastIteratorAndMirror(t, config) performTest(t, itr, mirror) }) + + t.Run("Unsaved Fast Iterator", func(t *testing.T) { + itr, mirror := setupUnsavedFastIterator(t, config) + performTest(t, itr, mirror) + }) } func TestIterator_Basic_Ranged_Ascending_Success(t *testing.T) { @@ -89,6 +145,12 @@ func TestIterator_Basic_Ranged_Ascending_Success(t *testing.T) { require.True(t, itr.Valid()) performTest(t, itr, mirror) }) + + t.Run("Unsaved Fast Iterator", func(t *testing.T) { + itr, mirror := setupUnsavedFastIterator(t, config) + require.True(t, itr.Valid()) + performTest(t, itr, mirror) + }) } func TestIterator_Basic_Ranged_Descending_Success(t *testing.T) { @@ -121,6 +183,12 @@ func TestIterator_Basic_Ranged_Descending_Success(t *testing.T) { require.True(t, itr.Valid()) performTest(t, itr, mirror) }) + + t.Run("Unsaved Fast Iterator", func(t *testing.T) { + itr, mirror := setupUnsavedFastIterator(t, config) + require.True(t, itr.Valid()) + performTest(t, itr, mirror) + }) } func TestIterator_Basic_Full_Ascending_Success(t *testing.T) { @@ -133,9 +201,6 @@ func TestIterator_Basic_Full_Ascending_Success(t *testing.T) { } performTest := func(t *testing.T, itr dbm.Iterator, mirror [][]string) { - - require.Equal(t, 25, len(mirror)) - actualStart, actualEnd := itr.Domain() require.Equal(t, config.startIterate, actualStart) require.Equal(t, config.endIterate, actualEnd) @@ -148,12 +213,21 @@ func TestIterator_Basic_Full_Ascending_Success(t *testing.T) { t.Run("Iterator", func(t *testing.T) { itr, mirror := setupIteratorAndMirror(t, config) require.True(t, itr.Valid()) + require.Equal(t, 25, len(mirror)) performTest(t, itr, mirror) }) t.Run("Fast Iterator", func(t *testing.T) { itr, mirror := setupFastIteratorAndMirror(t, config) require.True(t, itr.Valid()) + require.Equal(t, 25, len(mirror)) + performTest(t, itr, mirror) + }) + + t.Run("Unsaved Fast Iterator", func(t *testing.T) { + itr, mirror := setupUnsavedFastIterator(t, config) + require.True(t, itr.Valid()) + require.Equal(t, 25 - 25 / 4 + 1, len(mirror)) // to account for removals performTest(t, itr, mirror) }) } @@ -168,8 +242,6 @@ func TestIterator_Basic_Full_Descending_Success(t *testing.T) { } performTest := func(t *testing.T, itr dbm.Iterator, mirror [][]string) { - require.Equal(t, 25, len(mirror)) - actualStart, actualEnd := itr.Domain() require.Equal(t, config.startIterate, actualStart) require.Equal(t, config.endIterate, actualEnd) @@ -181,12 +253,21 @@ func TestIterator_Basic_Full_Descending_Success(t *testing.T) { t.Run("Iterator", func(t *testing.T) { itr, mirror := setupIteratorAndMirror(t, config) + require.Equal(t, 25, len(mirror)) require.True(t, itr.Valid()) performTest(t, itr, mirror) }) t.Run("Fast Iterator", func(t *testing.T) { itr, mirror := setupFastIteratorAndMirror(t, config) + require.Equal(t, 25, len(mirror)) + require.True(t, itr.Valid()) + performTest(t, itr, mirror) + }) + + t.Run("Unsaved Fast Iterator", func(t *testing.T) { + itr, mirror := setupUnsavedFastIterator(t, config) + require.Equal(t, 25 - 25 / 4 + 1, len(mirror)) // to account for removals require.True(t, itr.Valid()) performTest(t, itr, mirror) }) @@ -238,6 +319,12 @@ func TestIterator_WithDelete_Full_Ascending_Success(t *testing.T) { require.True(t, itr.Valid()) assertIterator(t, itr, sortedMirror, config.ascending) }) + + t.Run("Unsaved Fast Iterator", func(t *testing.T) { + itr := NewUnsavedFastIterator(config.startIterate, config.endIterate, config.ascending, immutableTree.ndb, tree.unsavedFastNodeAdditions, tree.unsavedFastNodeRemovals) + require.True(t, itr.Valid()) + assertIterator(t, itr, sortedMirror, config.ascending) + }) } func setupIteratorAndMirror(t *testing.T, config *iteratorTestConfig) (dbm.Iterator, [][]string) { @@ -245,6 +332,8 @@ func setupIteratorAndMirror(t *testing.T, config *iteratorTestConfig) (dbm.Itera require.NoError(t, err) mirror := setupMirrorForIterator(t, config, tree) + _, _, err = tree.SaveVersion() + require.NoError(t, err) immutableTree, err := tree.GetImmutable(tree.ndb.getLatestVersion()) require.NoError(t, err) @@ -258,7 +347,63 @@ func setupFastIteratorAndMirror(t *testing.T, config *iteratorTestConfig) (dbm.I require.NoError(t, err) mirror := setupMirrorForIterator(t, config, tree) + _, _, err = tree.SaveVersion() + require.NoError(t, err) itr := NewFastIterator(config.startIterate, config.endIterate, config.ascending, tree.ndb) return itr, mirror } + +func setupUnsavedFastIterator(t *testing.T, config *iteratorTestConfig) (dbm.Iterator, [][]string) { + tree, err := NewMutableTree(dbm.NewMemDB(), 0) + require.NoError(t, err) + + // For unsaved fast iterator, we would like to test the state where + // there are saved fast nodes as well as some unsaved additions and removals. + // So, we split the byte range in half where the first half is saved and the second half is unsaved. + breakpointByte := (config.endByteToSet + config.startByteToSet) / 2 + + firstHalfConfig := *config + firstHalfConfig.endByteToSet = breakpointByte // exclusive + + secondHalfConfig := *config + secondHalfConfig.startByteToSet = breakpointByte + + firstHalfMirror := setupMirrorForIterator(t, &firstHalfConfig, tree) + _, _, err = tree.SaveVersion() + require.NoError(t, err) + + // No unsaved additions or removals should be present after saving + require.Equal(t, 0, len(tree.unsavedFastNodeAdditions)) + require.Equal(t, 0, len(tree.unsavedFastNodeRemovals)) + + // Ensure that there are unsaved additions and removals present + secondHalfMirror := setupMirrorForIterator(t, &secondHalfConfig, tree) + + require.True(t, len(tree.unsavedFastNodeAdditions) >= len(secondHalfMirror)) + require.Equal(t, 0, len(tree.unsavedFastNodeRemovals)) + + // Merge the two halves + var mergedMirror [][]string + if config.ascending { + mergedMirror = append(firstHalfMirror, secondHalfMirror...) + } else { + mergedMirror = append(secondHalfMirror, firstHalfMirror...) + } + + if len(mergedMirror) > 0 { + // Remove random keys + for i := 0; i < len(mergedMirror) / 4; i++ { + randIndex := rand.Intn(len(mergedMirror)) + keyToRemove := mergedMirror[randIndex][0] + + _, removed := tree.Remove([]byte(keyToRemove)) + require.True(t, removed) + + mergedMirror = append(mergedMirror[:randIndex], mergedMirror[randIndex+1:]...) + } + } + + itr := NewUnsavedFastIterator(config.startIterate, config.endIterate, config.ascending, tree.ndb, tree.unsavedFastNodeAdditions, tree.unsavedFastNodeRemovals) + return itr, mergedMirror +} diff --git a/mutable_tree.go b/mutable_tree.go index 5a61a5a63..babe33117 100644 --- a/mutable_tree.go +++ b/mutable_tree.go @@ -167,79 +167,20 @@ func (t *MutableTree) Iterate(fn func(key []byte, value []byte) bool) (stopped b return t.ImmutableTree.Iterate(fn) } - // We need to ensure that we iterate over saved and unsaved state in order. - // The strategy is to sort unsaved nodes, the fast node on disk are already sorted. - // Then, we keep a pointer to both the unsaved and saved nodes, and iterate over them in order efficiently. - unsavedFastNodesToSort := make([]string, 0, len(t.unsavedFastNodeAdditions)) - - for _, fastNode := range t.unsavedFastNodeAdditions { - unsavedFastNodesToSort = append(unsavedFastNodesToSort, string(fastNode.key)) - } - - sort.Strings(unsavedFastNodesToSort) - - itr := t.ImmutableTree.Iterator(nil, nil, true) - defer itr.Close() - nextUnsavedIdx := 0 - for itr.Valid() && nextUnsavedIdx < len(unsavedFastNodesToSort) { - diskKeyStr := string(itr.Key()) - - if t.unsavedFastNodeRemovals[string(diskKeyStr)] != nil { - // If next fast node from disk is to be removed, skip it. - itr.Next() - continue - } - - nextUnsavedKey := unsavedFastNodesToSort[nextUnsavedIdx] - nextUnsavedNode := t.unsavedFastNodeAdditions[nextUnsavedKey] - - if diskKeyStr >= nextUnsavedKey { - // Unsaved node is next - - if diskKeyStr == nextUnsavedKey { - // Unsaved update prevails over saved copy so we skip the copy from disk - itr.Next() - } - - if fn(nextUnsavedNode.key, nextUnsavedNode.value) { - return true - } - - nextUnsavedIdx++ - } else { - // Disk node is next - if fn(itr.Key(), itr.Value()) { - return true - } - - itr.Next() - } - } - - // if only nodes on disk are left, we can just iterate - for itr.Valid() { + itr := NewUnsavedFastIterator(nil, nil, true, t.ndb, t.unsavedFastNodeAdditions, t.unsavedFastNodeRemovals) + for ; itr.Valid(); itr.Next() { if fn(itr.Key(), itr.Value()) { return true } - itr.Next() - } - - // if only unsaved nodes are left, we can just iterate - for ; nextUnsavedIdx < len(unsavedFastNodesToSort); nextUnsavedIdx++ { - nextUnsavedKey := unsavedFastNodesToSort[nextUnsavedIdx] - nextUnsavedNode := t.unsavedFastNodeAdditions[nextUnsavedKey] - - if fn(nextUnsavedNode.key, nextUnsavedNode.value) { - return true - } } return false } -// Iterator is not supported and is therefore invalid for MutableTree. Get an ImmutableTree instead for a valid iterator. +// Iterator returns an iterator over the mutable tree. +// CONTRACT: no updates are made to the tree while an iterator is active. func (t *MutableTree) Iterator(start, end []byte, ascending bool) dbm.Iterator { - return NewIterator(start, end, ascending, nil) // this is an invalid iterator + return NewUnsavedFastIterator(start, end, ascending, t.ndb, t.unsavedFastNodeAdditions, t.unsavedFastNodeRemovals) } func (tree *MutableTree) set(key []byte, value []byte) (orphans []*Node, updated bool) { diff --git a/testutils_test.go b/testutils_test.go index b34f6dc68..4a9efa306 100644 --- a/testutils_test.go +++ b/testutils_test.go @@ -271,12 +271,10 @@ func setupMirrorForIterator(t *testing.T, config *iteratorTestConfig, tree *Muta curByte-- } } - _, _, err := tree.SaveVersion() - require.NoError(t, err) return mirror } -// assertIterator confirms that the iterato returns the expected values desribed by mirror in the same order. +// assertIterator confirms that the iterator returns the expected values desribed by mirror in the same order. // mirror is a slice containing slices of the form [key, value]. In other words, key at index 0 and value at index 1. // mirror should be sorted in either ascending or descending order depending on the value of ascending parameter. func assertIterator(t *testing.T, itr dbm.Iterator, mirror [][]string, ascending bool) { diff --git a/unsaved_fast_iterator.go b/unsaved_fast_iterator.go new file mode 100644 index 000000000..667b568cd --- /dev/null +++ b/unsaved_fast_iterator.go @@ -0,0 +1,231 @@ +package iavl + +import ( + "bytes" + "errors" + "sort" + + dbm "github.com/tendermint/tm-db" +) + +var ( + errUnsavedFastIteratorNilAdditionsGiven = errors.New("unsaved fast iterator must be created with unsaved additions but they were nil") + + errUnsavedFastIteratorNilRemovalsGiven = errors.New("unsaved fast iterator must be created with unsaved removals but they were nil") +) + +// UnsavedFastIterator is a dbm.Iterator for ImmutableTree +// it iterates over the latest state via fast nodes, +// taking advantage of keys being located in sequence in the underlying database. +type UnsavedFastIterator struct { + start, end []byte + + valid bool + + ascending bool + + err error + + ndb *nodeDB + + unsavedFastNodeAdditions map[string]*FastNode + + unsavedFastNodeRemovals map[string]interface{} + + unsavedFastNodesToSort []string + + nextKey []byte + + nextVal []byte + + nextUnsavedNodeIdx int + + fastIterator dbm.Iterator +} + +var _ dbm.Iterator = &UnsavedFastIterator{} + +func NewUnsavedFastIterator(start, end []byte, ascending bool, ndb *nodeDB, unsavedFastNodeAdditions map[string]*FastNode, unsavedFastNodeRemovals map[string]interface{}) *UnsavedFastIterator { + + iter := &UnsavedFastIterator{ + start: start, + end: end, + err: nil, + ascending: ascending, + ndb: ndb, + unsavedFastNodeAdditions: unsavedFastNodeAdditions, + unsavedFastNodeRemovals: unsavedFastNodeRemovals, + unsavedFastNodesToSort: make([]string, 0), + nextKey: nil, + nextVal: nil, + nextUnsavedNodeIdx: 0, + fastIterator: NewFastIterator(start, end, ascending, ndb), + } + + // We need to ensure that we iterate over saved and unsaved state in order. + // The strategy is to sort unsaved nodes, the fast node on disk are already sorted. + // Then, we keep a pointer to both the unsaved and saved nodes, and iterate over them in order efficiently. + for _, fastNode := range unsavedFastNodeAdditions { + if start != nil && bytes.Compare(fastNode.key, start) < 0 { + continue + } + + if end != nil && bytes.Compare(fastNode.key, end) >= 0 { + continue + } + + iter.unsavedFastNodesToSort = append(iter.unsavedFastNodesToSort, string(fastNode.key)) + } + + sort.Slice(iter.unsavedFastNodesToSort, func(i, j int) bool { + if ascending{ + return iter.unsavedFastNodesToSort[i] < iter.unsavedFastNodesToSort[j] + } else { + return iter.unsavedFastNodesToSort[i] > iter.unsavedFastNodesToSort[j] + } + }) + + if iter.ndb == nil { + iter.err = errFastIteratorNilNdbGiven + iter.valid = false + return iter + } + + if iter.unsavedFastNodeAdditions == nil { + iter.err = errUnsavedFastIteratorNilAdditionsGiven + iter.valid = false + return iter + } + + if iter.unsavedFastNodeRemovals == nil { + iter.err = errUnsavedFastIteratorNilRemovalsGiven + iter.valid = false + return iter + } + + // Move to the first elemenet + iter.Next() + + return iter +} + +// Domain implements dbm.Iterator. +// Maps the underlying nodedb iterator domain, to the 'logical' keys involved. +func (iter *UnsavedFastIterator) Domain() ([]byte, []byte) { + return iter.start, iter.end +} + +// Valid implements dbm.Iterator. +func (iter *UnsavedFastIterator) Valid() bool { + if iter.start != nil && iter.end != nil { + if bytes.Compare(iter.end, iter.start) != 1 { + return false + } + } + + return iter.fastIterator.Valid() || iter.nextUnsavedNodeIdx < len(iter.unsavedFastNodesToSort) || iter.nextKey != nil || iter.nextVal != nil +} + +// Key implements dbm.Iterator +func (iter *UnsavedFastIterator) Key() []byte { + return iter.nextKey +} + +// Value implements dbm.Iterator +func (iter *UnsavedFastIterator) Value() []byte { + return iter.nextVal +} + +// Next implements dbm.Iterator +func (iter *UnsavedFastIterator) Next() { + if iter.ndb == nil { + iter.err = errFastIteratorNilNdbGiven + iter.valid = false + return + } + + if iter.fastIterator.Valid() && iter.nextUnsavedNodeIdx < len(iter.unsavedFastNodesToSort) { + diskKeyStr := string(iter.fastIterator.Key()) + + if iter.unsavedFastNodeRemovals[diskKeyStr] != nil { + // If next fast node from disk is to be removed, skip it. + iter.fastIterator.Next() + iter.Next() + return + } + + nextUnsavedKey := iter.unsavedFastNodesToSort[iter.nextUnsavedNodeIdx] + nextUnsavedNode := iter.unsavedFastNodeAdditions[nextUnsavedKey] + + var isUnsavedNext bool + if iter.ascending { + isUnsavedNext = bytes.Compare([]byte(diskKeyStr), []byte(nextUnsavedKey)) >= 0 + } else { + isUnsavedNext = bytes.Compare([]byte(diskKeyStr), []byte(nextUnsavedKey)) <= 0 + } + + if isUnsavedNext { + // Unsaved node is next + + if diskKeyStr == nextUnsavedKey { + // Unsaved update prevails over saved copy so we skip the copy from disk + iter.fastIterator.Next() + } + + iter.nextKey = nextUnsavedNode.key + iter.nextVal = nextUnsavedNode.value + + iter.nextUnsavedNodeIdx++ + return + } else { + // Disk node is next + iter.nextKey = iter.fastIterator.Key() + iter.nextVal = iter.fastIterator.Value() + + iter.fastIterator.Next() + return + } + } + + // if only nodes on disk are left, we return them + if iter.fastIterator.Valid() { + if iter.unsavedFastNodeRemovals[string(iter.fastIterator.Key())] != nil { + // If next fast node from disk is to be removed, skip it. + iter.fastIterator.Next() + iter.Next() + return + } + + iter.nextKey = iter.fastIterator.Key() + iter.nextVal = iter.fastIterator.Value() + + iter.fastIterator.Next() + return + } + + // if only unsaved nodes are left, we can just iterate + if iter.nextUnsavedNodeIdx < len(iter.unsavedFastNodesToSort) { + nextUnsavedKey := iter.unsavedFastNodesToSort[iter.nextUnsavedNodeIdx] + nextUnsavedNode := iter.unsavedFastNodeAdditions[nextUnsavedKey] + + iter.nextKey = nextUnsavedNode.key + iter.nextVal = nextUnsavedNode.value + + iter.nextUnsavedNodeIdx++ + return + } + + iter.nextKey = nil + iter.nextVal = nil +} + +// Close implements dbm.Iterator +func (iter *UnsavedFastIterator) Close() error { + iter.valid = false + return iter.fastIterator.Close() +} + +// Error implements dbm.Iterator +func (iter *UnsavedFastIterator) Error() error { + return iter.err +}