diff --git a/.gitignore b/.gitignore index dff7e1bb1..8743c144c 100644 --- a/.gitignore +++ b/.gitignore @@ -12,3 +12,6 @@ cpu*.out mem*.out cpu*.pdf mem*.pdf + +# IDE files +.idea/* \ No newline at end of file diff --git a/CHANGELOG.md b/CHANGELOG.md index 436e07b1c..a6063291b 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -9,6 +9,12 @@ ### Improvements +- [\#468](https://github.com/cosmos/iavl/pull/468) Fast storage optimization for queries and iterations + +## 0.17.3 (December 1, 2021) + +### Improvements + - [\#445](https://github.com/cosmos/iavl/pull/445) Bump github.com/tendermint/tendermint to v0.35.0 - [\#452](https://github.com/cosmos/iavl/pull/452) Optimization: remove unnecessary (*bytes.Buffer).Reset right after creating buffer. - [\#453](https://github.com/cosmos/iavl/pull/453),[\#456](https://github.com/cosmos/iavl/pull/456) Optimization: buffer reuse diff --git a/basic_test.go b/basic_test.go index 9a3b1de64..8245dc12a 100644 --- a/basic_test.go +++ b/basic_test.go @@ -35,71 +35,126 @@ func TestBasic(t *testing.T) { // Test 0x00 { - idx, val := tree.Get([]byte{0x00}) + key := []byte{0x00} + expected := "" + + idx, val := tree.GetWithIndex(key) if val != nil { - t.Errorf("Expected no value to exist") + t.Error("Expected no value to exist") } if idx != 0 { t.Errorf("Unexpected idx %x", idx) } - if string(val) != "" { - t.Errorf("Unexpected value %v", string(val)) + if string(val) != expected { + t.Errorf("Unexpected value %s", val) + } + + val = tree.Get(key) + if val != nil { + t.Error("Fast method - expected no value to exist") + } + if string(val) != expected { + t.Errorf("Fast method - Unexpected value %s", val) } } // Test "1" { - idx, val := tree.Get([]byte("1")) + key := []byte("1") + expected := "one" + + idx, val := tree.GetWithIndex(key) if val == nil { - t.Errorf("Expected value to exist") + t.Error("Expected value to exist") } if idx != 0 { t.Errorf("Unexpected idx %x", idx) } - if string(val) != "one" { - t.Errorf("Unexpected value %v", string(val)) + if string(val) != expected { + t.Errorf("Unexpected value %s", val) + } + + val = tree.Get(key) + if val == nil { + t.Error("Fast method - expected value to exist") + } + if string(val) != expected { + t.Errorf("Fast method - Unexpected value %s", val) } } // Test "2" { - idx, val := tree.Get([]byte("2")) + key := []byte("2") + expected := "TWO" + + idx, val := tree.GetWithIndex(key) if val == nil { - t.Errorf("Expected value to exist") + t.Error("Expected value to exist") } if idx != 1 { t.Errorf("Unexpected idx %x", idx) } - if string(val) != "TWO" { - t.Errorf("Unexpected value %v", string(val)) + if string(val) != expected { + t.Errorf("Unexpected value %s", val) + } + + val = tree.Get(key) + if val == nil { + t.Error("Fast method - expected value to exist") + } + if string(val) != expected { + t.Errorf("Fast method - Unexpected value %s", val) } } // Test "4" { - idx, val := tree.Get([]byte("4")) + key := []byte("4") + expected := "" + + idx, val := tree.GetWithIndex(key) if val != nil { - t.Errorf("Expected no value to exist") + t.Error("Expected no value to exist") } if idx != 2 { t.Errorf("Unexpected idx %x", idx) } - if string(val) != "" { - t.Errorf("Unexpected value %v", string(val)) + if string(val) != expected { + t.Errorf("Unexpected value %s", val) + } + + val = tree.Get(key) + if val != nil { + t.Error("Fast method - expected no value to exist") + } + if string(val) != expected { + t.Errorf("Fast method - Unexpected value %s", val) } } // Test "6" { - idx, val := tree.Get([]byte("6")) + key := []byte("6") + expected := "" + + idx, val := tree.GetWithIndex(key) if val != nil { - t.Errorf("Expected no value to exist") + t.Error("Expected no value to exist") } if idx != 3 { t.Errorf("Unexpected idx %x", idx) } - if string(val) != "" { - t.Errorf("Unexpected value %v", string(val)) + if string(val) != expected { + t.Errorf("Unexpected value %s", val) + } + + val = tree.Get(key) + if val != nil { + t.Error("Fast method - expected no value to exist") + } + if string(val) != expected { + t.Errorf("Fast method - Unexpected value %s", val) } } } @@ -252,7 +307,7 @@ func TestIntegration(t *testing.T) { if has := tree.Has([]byte(randstr(12))); has { t.Error("Table has extra key") } - if _, val := tree.Get([]byte(r.key)); string(val) != r.value { + if val := tree.Get([]byte(r.key)); string(val) != r.value { t.Error("wrong value") } } @@ -270,7 +325,7 @@ func TestIntegration(t *testing.T) { if has := tree.Has([]byte(randstr(12))); has { t.Error("Table has extra key") } - _, val := tree.Get([]byte(r.key)) + val := tree.Get([]byte(r.key)) if string(val) != r.value { t.Error("wrong value") } @@ -388,7 +443,7 @@ func TestPersistence(t *testing.T) { require.NoError(t, err) t2.Load() for key, value := range records { - _, t2value := t2.Get([]byte(key)) + t2value := t2.Get([]byte(key)) if string(t2value) != value { t.Fatalf("Invalid value. Expected %v, got %v", value, t2value) } diff --git a/benchmarks/bench_test.go b/benchmarks/bench_test.go index 8d361e18a..b7cd6553b 100644 --- a/benchmarks/bench_test.go +++ b/benchmarks/bench_test.go @@ -19,7 +19,6 @@ func randBytes(length int) []byte { key := make([]byte, length) // math.rand.Read always returns err=nil // we do not need cryptographic randomness for this test: - //nolint:gosec rand.Read(key) return key } @@ -57,14 +56,18 @@ func commitTree(b *testing.B, t *iavl.MutableTree) { } } -func runQueries(b *testing.B, t *iavl.MutableTree, keyLen int) { +// queries random keys against live state. Keys are almost certainly not in the tree. +func runQueriesFast(b *testing.B, t *iavl.MutableTree, keyLen int) { + require.True(b, t.IsFastCacheEnabled()) for i := 0; i < b.N; i++ { q := randBytes(keyLen) t.Get(q) } } -func runKnownQueries(b *testing.B, t *iavl.MutableTree, keys [][]byte) { +// queries keys that are known to be in state +func runKnownQueriesFast(b *testing.B, t *iavl.MutableTree, keys [][]byte) { + require.True(b, t.IsFastCacheEnabled()) // to ensure fast storage is enabled l := int32(len(keys)) for i := 0; i < b.N; i++ { q := keys[rand.Int31n(l)] @@ -72,6 +75,76 @@ func runKnownQueries(b *testing.B, t *iavl.MutableTree, keys [][]byte) { } } +func runQueriesSlow(b *testing.B, t *iavl.MutableTree, keyLen int) { + b.StopTimer() + // Save version to get an old immutable tree to query against, + // Fast storage is not enabled on old tree versions, allowing us to bench the desired behavior. + _, version, err := t.SaveVersion() + require.NoError(b, err) + + itree, err := t.GetImmutable(version - 1) + require.NoError(b, err) + require.False(b, itree.IsFastCacheEnabled()) // to ensure fast storage is not enabled + + b.StartTimer() + for i := 0; i < b.N; i++ { + q := randBytes(keyLen) + itree.GetWithIndex(q) + } +} + +func runKnownQueriesSlow(b *testing.B, t *iavl.MutableTree, keys [][]byte) { + b.StopTimer() + // Save version to get an old immutable tree to query against, + // Fast storage is not enabled on old tree versions, allowing us to bench the desired behavior. + _, version, err := t.SaveVersion() + require.NoError(b, err) + + itree, err := t.GetImmutable(version - 1) + require.NoError(b, err) + require.False(b, itree.IsFastCacheEnabled()) // to ensure fast storage is not enabled + b.StartTimer() + l := int32(len(keys)) + for i := 0; i < b.N; i++ { + q := keys[rand.Int31n(l)] + index, value := itree.GetWithIndex(q) + require.True(b, index >= 0, "the index must not be negative") + require.NotNil(b, value, "the value should exist") + } +} + +func runIterationFast(b *testing.B, t *iavl.MutableTree, expectedSize int) { + require.True(b, t.IsFastCacheEnabled()) // to ensure fast storage is enabled + for i := 0; i < b.N; i++ { + itr := t.ImmutableTree.Iterator(nil, nil, false) + iterate(b, itr, expectedSize) + require.Nil(b, itr.Close(), ".Close should not error out") + } +} + +func runIterationSlow(b *testing.B, t *iavl.MutableTree, expectedSize int) { + for i := 0; i < b.N; i++ { + itr := iavl.NewIterator(nil, nil, false, t.ImmutableTree) // create slow iterator directly + iterate(b, itr, expectedSize) + require.Nil(b, itr.Close(), ".Close should not error out") + } +} + +func iterate(b *testing.B, itr db.Iterator, expectedSize int) { + b.StartTimer() + keyValuePairs := make([][][]byte, 0, expectedSize) + for i := 0; i < expectedSize && itr.Valid(); i++ { + itr.Next() + keyValuePairs = append(keyValuePairs, [][]byte{itr.Key(), itr.Value()}) + } + b.StopTimer() + if g, w := len(keyValuePairs), expectedSize; g != w { + b.Errorf("iteration count mismatch: got=%d, want=%d", g, w) + } else { + b.Logf("completed %d iterations", len(keyValuePairs)) + } +} + // func runInsert(b *testing.B, t *iavl.MutableTree, keyLen, dataLen, blockSize int) *iavl.MutableTree { // for i := 1; i <= b.N; i++ { // t.Set(randBytes(keyLen), randBytes(dataLen)) @@ -132,7 +205,7 @@ func runBlock(b *testing.B, t *iavl.MutableTree, keyLen, dataLen, blockSize int, data := randBytes(dataLen) // perform query and write on check and then real - // check.Get(key) + // check.GetFast(key) // check.Set(key, data) real.Get(key) real.Set(key, data) @@ -175,11 +248,11 @@ func BenchmarkMedium(b *testing.B) { benchmarks := []benchmark{ {"memdb", 100000, 100, 16, 40}, {"goleveldb", 100000, 100, 16, 40}, - {"cleveldb", 100000, 100, 16, 40}, + // {"cleveldb", 100000, 100, 16, 40}, // FIXME: idk why boltdb is too slow !? // {"boltdb", 100000, 100, 16, 40}, - {"rocksdb", 100000, 100, 16, 40}, - {"badgerdb", 100000, 100, 16, 40}, + // {"rocksdb", 100000, 100, 16, 40}, + // {"badgerdb", 100000, 100, 16, 40}, } runBenchmarks(b, benchmarks) } @@ -188,10 +261,10 @@ func BenchmarkSmall(b *testing.B) { benchmarks := []benchmark{ {"memdb", 1000, 100, 4, 10}, {"goleveldb", 1000, 100, 4, 10}, - {"cleveldb", 1000, 100, 4, 10}, - {"boltdb", 1000, 100, 4, 10}, - {"rocksdb", 1000, 100, 4, 10}, - {"badgerdb", 1000, 100, 4, 10}, + // {"cleveldb", 1000, 100, 4, 10}, + // {"boltdb", 1000, 100, 4, 10}, + // {"rocksdb", 1000, 100, 4, 10}, + // {"badgerdb", 1000, 100, 4, 10}, } runBenchmarks(b, benchmarks) } @@ -202,8 +275,8 @@ func BenchmarkLarge(b *testing.B) { {"goleveldb", 1000000, 100, 16, 40}, // FIXME: idk why boltdb is too slow !? // {"boltdb", 1000000, 100, 16, 40}, - {"rocksdb", 1000000, 100, 16, 40}, - {"badgerdb", 1000000, 100, 16, 40}, + // {"rocksdb", 1000000, 100, 16, 40}, + // {"badgerdb", 1000000, 100, 16, 40}, } runBenchmarks(b, benchmarks) } @@ -287,14 +360,38 @@ func runSuite(b *testing.B, d db.DB, initSize, blockSize, keyLen, dataLen int) { b.ResetTimer() - b.Run("query-miss", func(sub *testing.B) { + b.Run("query-no-in-tree-guarantee-fast", func(sub *testing.B) { + sub.ReportAllocs() + runQueriesFast(sub, t, keyLen) + }) + b.Run("query-no-in-tree-guarantee-slow", func(sub *testing.B) { + sub.ReportAllocs() + runQueriesSlow(sub, t, keyLen) + }) + // + b.Run("query-hits-fast", func(sub *testing.B) { sub.ReportAllocs() - runQueries(sub, t, keyLen) + runKnownQueriesFast(sub, t, keys) }) - b.Run("query-hits", func(sub *testing.B) { + b.Run("query-hits-slow", func(sub *testing.B) { sub.ReportAllocs() - runKnownQueries(sub, t, keys) + runKnownQueriesSlow(sub, t, keys) }) + // + // Iterations for BenchmarkLevelDBLargeData timeout bencher in CI so + // we must skip them. + if b.Name() != "BenchmarkLevelDBLargeData" { + b.Run("iteration-fast", func(sub *testing.B) { + sub.ReportAllocs() + runIterationFast(sub, t, initSize) + }) + b.Run("iteration-slow", func(sub *testing.B) { + sub.ReportAllocs() + runIterationSlow(sub, t, initSize) + }) + } + + // b.Run("update", func(sub *testing.B) { sub.ReportAllocs() t = runUpdate(sub, t, dataLen, blockSize, keys) diff --git a/benchmarks/hash_test.go b/benchmarks/hash_test.go index 4c2a4a4dc..4ed6c527d 100644 --- a/benchmarks/hash_test.go +++ b/benchmarks/hash_test.go @@ -1,4 +1,3 @@ -// nolint: errcheck,scopelint package benchmarks import ( @@ -8,6 +7,7 @@ import ( "testing" "github.com/cosmos/iavl" + "github.com/stretchr/testify/require" _ "crypto/sha256" @@ -18,9 +18,9 @@ import ( func BenchmarkHash(b *testing.B) { fmt.Printf("%s\n", iavl.GetVersionInfo()) hashers := []struct { - name string - size int - hasher hash.Hash + name string + size int + hash hash.Hash }{ {"ripemd160", 64, crypto.RIPEMD160.New()}, {"ripemd160", 512, crypto.RIPEMD160.New()}, @@ -32,20 +32,22 @@ func BenchmarkHash(b *testing.B) { for _, h := range hashers { prefix := fmt.Sprintf("%s-%d", h.name, h.size) + hasher := h b.Run(prefix, func(sub *testing.B) { - benchHasher(sub, h.hasher, h.size) + benchHasher(sub, hasher.hash, hasher.size) }) } } -func benchHasher(b *testing.B, hasher hash.Hash, size int) { +func benchHasher(b *testing.B, hash hash.Hash, size int) { // create all random bytes before to avoid timing this inputs := randBytes(b.N + size + 1) for i := 0; i < b.N; i++ { - hasher.Reset() + hash.Reset() // grab a slice of size bytes from random string - hasher.Write(inputs[i : i+size]) - hasher.Sum(nil) + _, err := hash.Write(inputs[i : i+size]) + require.NoError(b, err) + hash.Sum(nil) } } diff --git a/cmd/iavlserver/main.go b/cmd/iavlserver/main.go index 346276b49..e925f4f4b 100644 --- a/cmd/iavlserver/main.go +++ b/cmd/iavlserver/main.go @@ -19,6 +19,7 @@ import ( "github.com/pkg/errors" dbm "github.com/tendermint/tm-db" "google.golang.org/grpc" + "google.golang.org/grpc/credentials/insecure" "google.golang.org/grpc/grpclog" pb "github.com/cosmos/iavl/proto" @@ -110,7 +111,7 @@ func startRPCGateway() error { runtime.WithProtoErrorHandler(runtime.DefaultHTTPProtoErrorHandler), ) - dialOpts := []grpc.DialOption{grpc.WithInsecure(), grpc.WithBlock()} + dialOpts := []grpc.DialOption{grpc.WithTransportCredentials(insecure.NewCredentials()), grpc.WithBlock()} err := pb.RegisterIAVLServiceHandlerFromEndpoint( context.Background(), gatewayMux, *gRPCEndpoint, dialOpts, @@ -174,7 +175,7 @@ func openDB() (dbm.DB, error) { // trapSignal will listen for any OS signal and invokes a callback function to // perform any necessary cleanup. func trapSignal(cb func()) { - var sigCh = make(chan os.Signal) + var sigCh = make(chan os.Signal, 1) signal.Notify(sigCh, syscall.SIGTERM) signal.Notify(sigCh, syscall.SIGINT) diff --git a/export_test.go b/export_test.go index 2b27a7137..56d3ef818 100644 --- a/export_test.go +++ b/export_test.go @@ -52,8 +52,8 @@ func setupExportTreeRandom(t *testing.T) *ImmutableTree { keySize = 16 valueSize = 16 - versions = 32 // number of versions to generate - versionOps = 4096 // number of operations (create/update/delete) per version + versions = 8 // number of versions to generate + versionOps = 1024 // number of operations (create/update/delete) per version updateRatio = 0.4 // ratio of updates out of all operations deleteRatio = 0.2 // ratio of deletes out of all operations ) @@ -211,8 +211,8 @@ func TestExporter_Import(t *testing.T) { require.Equal(t, tree.Version(), newTree.Version(), "Tree version mismatch") tree.Iterate(func(key, value []byte) bool { - index, _ := tree.Get(key) - newIndex, newValue := newTree.Get(key) + index, _ := tree.GetWithIndex(key) + newIndex, newValue := newTree.GetWithIndex(key) require.Equal(t, index, newIndex, "Index mismatch for key %v", key) require.Equal(t, value, newValue, "Value mismatch for key %v", key) return false diff --git a/fast_iterator.go b/fast_iterator.go new file mode 100644 index 000000000..aa423cae1 --- /dev/null +++ b/fast_iterator.go @@ -0,0 +1,133 @@ +package iavl + +import ( + "errors" + + dbm "github.com/tendermint/tm-db" +) + +var errFastIteratorNilNdbGiven = errors.New("fast iterator must be created with a nodedb but it was nil") + +// FastIterator is a dbm.Iterator for ImmutableTree +// it iterates over the latest state via fast nodes, +// taking advantage of keys being located in sequence in the underlying database. +type FastIterator struct { + start, end []byte + + valid bool + + ascending bool + + err error + + ndb *nodeDB + + nextFastNode *FastNode + + fastIterator dbm.Iterator +} + +var _ dbm.Iterator = (*FastIterator)(nil) + +func NewFastIterator(start, end []byte, ascending bool, ndb *nodeDB) *FastIterator { + iter := &FastIterator{ + start: start, + end: end, + err: nil, + ascending: ascending, + ndb: ndb, + nextFastNode: nil, + fastIterator: nil, + } + // Move iterator before the first element + iter.Next() + return iter +} + +// Domain implements dbm.Iterator. +// Maps the underlying nodedb iterator domain, to the 'logical' keys involved. +func (iter *FastIterator) Domain() ([]byte, []byte) { + if iter.fastIterator == nil { + return iter.start, iter.end + } + + start, end := iter.fastIterator.Domain() + + if start != nil { + start = start[1:] + if len(start) == 0 { + start = nil + } + } + + if end != nil { + end = end[1:] + if len(end) == 0 { + end = nil + } + } + + return start, end +} + +// Valid implements dbm.Iterator. +func (iter *FastIterator) Valid() bool { + return iter.fastIterator != nil && iter.fastIterator.Valid() && iter.valid +} + +// Key implements dbm.Iterator +func (iter *FastIterator) Key() []byte { + if iter.valid { + return iter.nextFastNode.key + } + return nil +} + +// Value implements dbm.Iterator +func (iter *FastIterator) Value() []byte { + if iter.valid { + return iter.nextFastNode.value + } + return nil +} + +// Next implements dbm.Iterator +func (iter *FastIterator) Next() { + if iter.ndb == nil { + iter.err = errFastIteratorNilNdbGiven + iter.valid = false + return + } + + if iter.fastIterator == nil { + iter.fastIterator, iter.err = iter.ndb.getFastIterator(iter.start, iter.end, iter.ascending) + iter.valid = true + } else { + iter.fastIterator.Next() + } + + if iter.err == nil { + iter.err = iter.fastIterator.Error() + } + + iter.valid = iter.valid && iter.fastIterator.Valid() + if iter.valid { + iter.nextFastNode, iter.err = DeserializeFastNode(iter.fastIterator.Key()[1:], iter.fastIterator.Value()) + iter.valid = iter.err == nil + } +} + +// Close implements dbm.Iterator +func (iter *FastIterator) Close() error { + if iter.fastIterator != nil { + iter.err = iter.fastIterator.Close() + } + iter.valid = false + iter.fastIterator = nil + return iter.err +} + +// Error implements dbm.Iterator +func (iter *FastIterator) Error() error { + return iter.err +} diff --git a/fast_node.go b/fast_node.go new file mode 100644 index 000000000..a116b8efb --- /dev/null +++ b/fast_node.go @@ -0,0 +1,69 @@ +package iavl + +import ( + "io" + + "github.com/cosmos/iavl/internal/encoding" + "github.com/pkg/errors" +) + +// NOTE: This file favors int64 as opposed to int for size/counts. +// The Tree on the other hand favors int. This is intentional. + +type FastNode struct { + key []byte + versionLastUpdatedAt int64 + value []byte +} + +// NewFastNode returns a new fast node from a value and version. +func NewFastNode(key []byte, value []byte, version int64) *FastNode { + return &FastNode{ + key: key, + versionLastUpdatedAt: version, + value: value, + } +} + +// DeserializeFastNode constructs an *FastNode from an encoded byte slice. +func DeserializeFastNode(key []byte, buf []byte) (*FastNode, error) { + ver, n, cause := encoding.DecodeVarint(buf) + if cause != nil { + return nil, errors.Wrap(cause, "decoding fastnode.version") + } + buf = buf[n:] + + val, _, cause := encoding.DecodeBytes(buf) + if cause != nil { + return nil, errors.Wrap(cause, "decoding fastnode.value") + } + + fastNode := &FastNode{ + key: key, + versionLastUpdatedAt: ver, + value: val, + } + + return fastNode, nil +} + +func (node *FastNode) encodedSize() int { + n := encoding.EncodeVarintSize(node.versionLastUpdatedAt) + encoding.EncodeBytesSize(node.value) + return n +} + +// writeBytes writes the FastNode as a serialized byte slice to the supplied io.Writer. +func (node *FastNode) writeBytes(w io.Writer) error { + if node == nil { + return errors.New("cannot write nil node") + } + cause := encoding.EncodeVarint(w, node.versionLastUpdatedAt) + if cause != nil { + return errors.Wrap(cause, "writing version last updated at") + } + cause = encoding.EncodeBytes(w, node.value) + if cause != nil { + return errors.Wrap(cause, "writing value") + } + return nil +} diff --git a/fast_node_test.go b/fast_node_test.go new file mode 100644 index 000000000..b6e1ffd98 --- /dev/null +++ b/fast_node_test.go @@ -0,0 +1,58 @@ +package iavl + +import ( + "bytes" + "encoding/hex" + "testing" + + "github.com/stretchr/testify/require" +) + +func TestFastNode_encodedSize(t *testing.T) { + fastNode := &FastNode{ + key: randBytes(10), + versionLastUpdatedAt: 1, + value: randBytes(20), + } + + expectedSize := 1 + len(fastNode.value) + 1 + + require.Equal(t, expectedSize, fastNode.encodedSize()) +} + +func TestFastNode_encode_decode(t *testing.T) { + testcases := map[string]struct { + node *FastNode + expectHex string + expectError bool + }{ + "nil": {nil, "", true}, + "empty": {&FastNode{}, "0000", false}, + "inner": {&FastNode{ + key: []byte{0x4}, + versionLastUpdatedAt: 1, + value: []byte{0x2}, + }, "020102", false}, + } + for name, tc := range testcases { + tc := tc + t.Run(name, func(t *testing.T) { + var buf bytes.Buffer + err := tc.node.writeBytes(&buf) + if tc.expectError { + require.Error(t, err) + return + } + require.NoError(t, err) + require.Equal(t, tc.expectHex, hex.EncodeToString(buf.Bytes())) + + node, err := DeserializeFastNode(tc.node.key, buf.Bytes()) + require.NoError(t, err) + // since value and leafHash are always decoded to []byte{} we augment the expected struct here + if tc.node.value == nil { + tc.node.value = []byte{} + } + require.Equal(t, tc.node, node) + }) + } +} diff --git a/go.mod b/go.mod index 4d5b5051f..1c7b640a8 100644 --- a/go.mod +++ b/go.mod @@ -6,6 +6,7 @@ require ( github.com/confio/ics23/go v0.7.0 github.com/gogo/gateway v1.1.0 github.com/gogo/protobuf v1.3.2 + github.com/golang/mock v1.6.0 github.com/golang/protobuf v1.5.2 github.com/grpc-ecosystem/go-grpc-middleware v1.3.0 github.com/grpc-ecosystem/grpc-gateway v1.16.0 diff --git a/go.sum b/go.sum index a1ee309cd..ca0f2edd3 100644 --- a/go.sum +++ b/go.sum @@ -59,6 +59,7 @@ contrib.go.opencensus.io/exporter/stackdriver v0.13.4/go.mod h1:aXENhDJ1Y4lIg4EU dmitri.shuralyov.com/gpu/mtl v0.0.0-20190408044501-666a987793e9/go.mod h1:H6x//7gZCb22OMCxBHrMx7a5I7Hp++hsVxbQ4BYO7hU= github.com/Antonboom/errname v0.1.5/go.mod h1:DugbBstvPFQbv/5uLcRRzfrNqKE9tVdVCqWCLp6Cifo= github.com/Antonboom/nilnil v0.1.0/go.mod h1:PhHLvRPSghY5Y7mX4TW+BHZQYo1A8flE5H20D3IPZBo= +github.com/Azure/go-ansiterm v0.0.0-20170929234023-d6e3b3328b78/go.mod h1:LmzpDX56iTiv29bbRTIsUNlaFfuhWRQBWjQdVyAevI8= github.com/Azure/go-ansiterm v0.0.0-20210617225240-d185dfc1b5a1/go.mod h1:xomTg63KZ2rFqZQzSB4Vz2SUXa1BpHTVz9L5PTmPC4E= github.com/BurntSushi/toml v0.3.1/go.mod h1:xHWCNGjB5oqiDr8zfno3MHue2Ht5sIBksp03qcyfWMU= github.com/BurntSushi/toml v1.0.0/go.mod h1:CxXYINrC8qIiEnFrOxCa7Jy5BFHlXnUU2pbicEuybxQ= @@ -78,7 +79,6 @@ github.com/Masterminds/sprig v2.15.0+incompatible/go.mod h1:y6hNFY5UBTIWBxnzTeuN github.com/Masterminds/sprig v2.22.0+incompatible/go.mod h1:y6hNFY5UBTIWBxnzTeuNhlNS5hqE0NB0E6fgfo2Br3o= github.com/Microsoft/go-winio v0.5.1/go.mod h1:JPGBdM1cNvN/6ISo+n8V5iA4v8pBzdOpzfwIujj1a84= github.com/Nvveen/Gotty v0.0.0-20120604004816-cd527374f1e5/go.mod h1:lmUJ/7eu/Q8D7ML55dXQrVaamCz2vxCfdQBasLZfHKk= -github.com/OneOfOne/xxhash v1.2.2 h1:KMrpdQIwFcEqXDklaen+P1axHaj9BSKzvpUUfnHldSE= github.com/OneOfOne/xxhash v1.2.2/go.mod h1:HSdplMjZKSmBqAxg5vPj2TmRDmfkzw+cTzAElWljhcU= github.com/OpenPeeDeeP/depguard v1.1.0/go.mod h1:JtAMzWkmFEzDPyAd+W0NHl1lvpQKTvT9jnRVsohBKpc= github.com/Shopify/sarama v1.19.0/go.mod h1:FVkBWblsNy7DGZRfXLU0O9RCGt5g3g3yEuWXgklEdEo= @@ -232,11 +232,8 @@ github.com/envoyproxy/protoc-gen-validate v0.1.0/go.mod h1:iSmxcyjqTsJpI2R4NaDN7 github.com/envoyproxy/protoc-gen-validate v0.6.2/go.mod h1:2t7qjJNvHPx8IjnBOzl9E9/baC+qXE/TeeyBRzgJDws= github.com/esimonov/ifshort v1.0.4/go.mod h1:Pe8zjlRrJ80+q2CxHLfEOfTwxCZ4O+MuhcHcfgNWTk0= github.com/ettle/strcase v0.1.1/go.mod h1:hzDLsPC7/lwKyBOywSHEP89nt2pDgdy+No1NBA9o9VY= -github.com/facebookgo/ensure v0.0.0-20160127193407-b4ab57deab51 h1:0JZ+dUmQeA8IIVUMzysrX4/AKuQwWhV2dYQuPZdvdSQ= github.com/facebookgo/ensure v0.0.0-20160127193407-b4ab57deab51/go.mod h1:Yg+htXGokKKdzcwhuNDwVvN+uBxDGXJ7G/VN1d8fa64= -github.com/facebookgo/stack v0.0.0-20160209184415-751773369052 h1:JWuenKqqX8nojtoVVWjGfOF9635RETekkoH6Cc9SX0A= github.com/facebookgo/stack v0.0.0-20160209184415-751773369052/go.mod h1:UbMTZqLaRiH3MsBH8va0n7s1pQYcu3uTb8G4tygF4Zg= -github.com/facebookgo/subset v0.0.0-20150612182917-8dac2c3c4870 h1:E2s37DuLxFhQDg5gKsWoLBOB0n+ZW8s599zru8FJ2/Y= github.com/facebookgo/subset v0.0.0-20150612182917-8dac2c3c4870/go.mod h1:5tD+neXqOorC30/tWg0LCSkrqj/AR6gu8yY8/fpw1q0= github.com/fatih/color v1.7.0/go.mod h1:Zm6kSWBoL9eyXnKyktHP6abPY2pDugNf5KwzbycvMj4= github.com/fatih/color v1.9.0/go.mod h1:eQcE1qtQxscV5RaZvpXrrb8Drkc3/DdQ+uUYCNjL+zU= @@ -252,7 +249,6 @@ github.com/frankban/quicktest v1.11.3/go.mod h1:wRf/ReqHper53s+kmmSZizM8NamnL3IM github.com/frankban/quicktest v1.14.0/go.mod h1:NeW+ay9A/U67EYXNFA1nPE8e/tnQv/09mUdL/ijj8og= github.com/fsnotify/fsnotify v1.4.7/go.mod h1:jwhsz4b93w/PPRr/qN1Yymfu8t87LnFCMoQvtojpjFo= github.com/fsnotify/fsnotify v1.4.9/go.mod h1:znqG4EE+3YCdAaPaxE2ZRY/06pZUdp0tY4IgpuI1SZQ= -github.com/fsnotify/fsnotify v1.5.1 h1:mZcQUHVQUQWoPXXtuf9yuEXKudkV2sx1E06UadKWpgI= github.com/fsnotify/fsnotify v1.5.1/go.mod h1:T3375wBYaZdLLcVNkcVbzGHY7f1l/uK5T5Ai1i3InKU= github.com/fullstorydev/grpcurl v1.6.0/go.mod h1:ZQ+ayqbKMJNhzLmbpCiurTVlaK2M/3nqZCxaQ2Ze/sM= github.com/fzipp/gocyclo v0.4.0/go.mod h1:rXPyn8fnlpa0R2csP/31uerbiVBugk5whMdlyaLkLoA= @@ -317,6 +313,7 @@ github.com/golang/mock v1.4.1/go.mod h1:UOMv5ysSaYNkG+OFQykRIcU/QvvxJf3p21QfJ2Bt github.com/golang/mock v1.4.3/go.mod h1:UOMv5ysSaYNkG+OFQykRIcU/QvvxJf3p21QfJ2Bt3cw= github.com/golang/mock v1.4.4/go.mod h1:l3mdAwkq5BuhzHwde/uurv3sEJeZMXNpwsxVWU71h+4= github.com/golang/mock v1.5.0/go.mod h1:CWnOUgYIOo4TcNZ0wHX3YZCqsaM1I1Jvs6v3mP3KVu8= +github.com/golang/mock v1.6.0 h1:ErTB+efbowRARo13NNdxyJji2egdxLGQhRaY+DUumQc= github.com/golang/mock v1.6.0/go.mod h1:p6yTPP+5HYm5mzsMV8JkE6ZKdX+/wYM6Hr+LicevLPs= github.com/golang/protobuf v1.1.0/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U= github.com/golang/protobuf v1.2.0/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U= @@ -542,7 +539,6 @@ github.com/kr/pretty v0.2.1/go.mod h1:ipq/a2n7PKx3OHsz4KJII5eveXtPO4qwEXGdVfWzfn github.com/kr/pretty v0.3.0/go.mod h1:640gp4NfQd8pI5XOwp5fnNeVWj67G7CFk/SaSQn7NBk= github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ= github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI= -github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY= github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE= github.com/kulti/thelper v0.5.1/go.mod h1:vMu2Cizjy/grP+jmsvOFDx1kYP6+PD1lqg4Yu5exl2U= github.com/kunwardeep/paralleltest v1.0.3/go.mod h1:vLydzomDFpk7yu5UX02RmP0H8QfRPOV/oFhWN85Mjb4= @@ -633,13 +629,11 @@ github.com/nats-io/nkeys v0.2.0/go.mod h1:XdZpAbhgyyODYqjTawOnIOI7VlbKSarI9Gfy1t github.com/nats-io/nkeys v0.3.0/go.mod h1:gvUNGjVcM2IPr5rCsRsC6Wb3Hr2CQAm08dsxtV6A5y4= github.com/nats-io/nuid v1.0.1/go.mod h1:19wcPz3Ph3q0Jbyiqsd0kePYG7A95tJPxeL+1OSON2c= github.com/nbutton23/zxcvbn-go v0.0.0-20210217022336-fa2cb2858354/go.mod h1:KSVJerMDfblTH7p5MZaTt+8zaT2iEk3AkVb9PQdZuE8= -github.com/niemeyer/pretty v0.0.0-20200227124842-a10e7caefd8e h1:fD57ERR4JtEqsWbfPhv4DMiApHyliiK5xCTNVSPiaAs= github.com/niemeyer/pretty v0.0.0-20200227124842-a10e7caefd8e/go.mod h1:zD1mROLANZcx1PVRCS0qkT7pwLkGfwJo4zjcN/Tysno= github.com/nishanths/exhaustive v0.7.11/go.mod h1:gX+MP7DWMKJmNa1HfMozK+u04hQd3na9i0hyqf3/dOI= github.com/nishanths/predeclared v0.0.0-20190419143655-18a43bb90ffc/go.mod h1:62PewwiQTlm/7Rj+cxVYqZvDIUc+JjZq6GHAC1fsObQ= github.com/nishanths/predeclared v0.2.1/go.mod h1:HvkGJcA3naj4lOwnFXFDkFxVtSqQMB9sbB1usJ+xjQE= github.com/nxadm/tail v1.4.4/go.mod h1:kenIhsEOeOJmVchQTgglprH7qJGnHDVpk1VPCcaMI8A= -github.com/nxadm/tail v1.4.8 h1:nPr65rt6Y5JFSKQO7qToXr7pePgD6Gwiw05lkbyAQTE= github.com/nxadm/tail v1.4.8/go.mod h1:+ncqLTQzXmGhMZNUePPaPqPvBxHAIsmXswZKocGu+AU= github.com/oasisprotocol/curve25519-voi v0.0.0-20210609091139-0a56a4bca00b/go.mod h1:TLJifjWF6eotcfzDjKZsDqWJ+73Uvj/N85MvVyrvynM= github.com/oklog/ulid v1.3.1/go.mod h1:CirwcVhetQ6Lv90oh/F+FBtV6XMibvdAFo93nm5qn4U= @@ -653,7 +647,6 @@ github.com/onsi/ginkgo v1.10.3/go.mod h1:lLunBs/Ym6LB5Z9jYTR76FiuTmxDTDusOGeTQH+ github.com/onsi/ginkgo v1.12.1/go.mod h1:zj2OWP4+oCPe1qIXoGWkgMRwljMUYCdkwsT2108oapk= github.com/onsi/ginkgo v1.14.0/go.mod h1:iSB4RoI2tjJc9BBv4NKIKWKya62Rps+oPG/Lv9klQyY= github.com/onsi/ginkgo v1.16.2/go.mod h1:CObGmKUOKaSC0RjmoAK7tKyn4Azo5P2IWuoMnvwxz1E= -github.com/onsi/ginkgo v1.16.4 h1:29JGrr5oVBm5ulCWet69zQkzWipVXIol6ygQUe/EzNc= github.com/onsi/ginkgo v1.16.4/go.mod h1:dX+/inL/fNMqNlz0e9LfyB9TswhZpCVdJM/Z6Vvnwo0= github.com/onsi/ginkgo/v2 v2.0.0/go.mod h1:vw5CSIxN1JObi/U8gcbwft7ZxR2dgaR70JSE3/PpL4c= github.com/onsi/gomega v1.4.1/go.mod h1:C1qb7wdrVGGVU+Z6iS04AVkA3Q65CEZX59MT0QO5uiA= @@ -661,7 +654,6 @@ github.com/onsi/gomega v1.4.3/go.mod h1:ex+gbHU/CVuBBDIJjb2X0qEXbFg53c61hWP/1Cpa github.com/onsi/gomega v1.7.1/go.mod h1:XdKZgCCFLUoM/7CFJVPcG8C1xQ1AJ0vpAezJrB7JYyY= github.com/onsi/gomega v1.10.1/go.mod h1:iN09h71vgCQne3DLsj+A5owkum+a2tYe+TOCB1ybHNo= github.com/onsi/gomega v1.13.0/go.mod h1:lRk9szgn8TxENtWd0Tp4c3wjlRfMTMH27I+3Je41yGY= -github.com/onsi/gomega v1.17.0 h1:9Luw4uT5HTjHTN8+aNcSThgH1vdXnmdJ8xIfZ4wyTRE= github.com/onsi/gomega v1.17.0/go.mod h1:HnhC7FXeEQY45zxNK3PPoIUhzk/80Xly9PcubAlGdZY= github.com/op/go-logging v0.0.0-20160315200505-970db520ece7/go.mod h1:HzydrMdWErDVzsI23lYNej1Htcns9BCg93Dk0bBINWk= github.com/opencontainers/go-digest v1.0.0/go.mod h1:0JzlMkj0TRzQZfJkVvzbP0HBR3IKzErnv2BNG4W4MAM= @@ -788,7 +780,6 @@ github.com/sonatard/noctx v0.0.1/go.mod h1:9D2D/EoULe8Yy2joDHJj7bv3sZoq9AaSb8B4l github.com/sony/gobreaker v0.4.1/go.mod h1:ZKptC7FHNvhBz7dN2LGjPVBz2sZJmc0/PkyDJOjmxWY= github.com/sourcegraph/go-diff v0.6.1/go.mod h1:iBszgVvyxdc8SFZ7gm69go2KDdt3ag071iBaWPF6cjs= github.com/spaolacci/murmur3 v0.0.0-20180118202830-f09979ecbc72/go.mod h1:JwIasOWyU6f++ZhiEuf87xNszmSA2myDM2Kzu9HwQUA= -github.com/spaolacci/murmur3 v1.1.0 h1:7c1g84S4BPRrfL5Xrdp6fOJ206sU9y293DDHaoy0bLI= github.com/spaolacci/murmur3 v1.1.0/go.mod h1:JwIasOWyU6f++ZhiEuf87xNszmSA2myDM2Kzu9HwQUA= github.com/spf13/afero v1.1.2/go.mod h1:j4pytiNVoe2o6bmDsKpLACNPDBIoEAkihy7loJ1B0CQ= github.com/spf13/afero v1.3.3/go.mod h1:5KUK8ByomD5Ti5Artl0RtHeI5pTF7MIDuXL3yY520V4= @@ -1310,7 +1301,6 @@ golang.org/x/tools v0.1.9/go.mod h1:nABZi5QlRsZVlzPpHl034qft6wpY4eDcsTt5AaioBiU= golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= -golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1 h1:go1bK/D/BFZV2I8cIQd1NKEZ+0owSTG1fDTci4IqFcE= golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= gonum.org/v1/gonum v0.0.0-20180816165407-929014505bf4/go.mod h1:Y+Yx5eoAFn32cQvJDxZx5Dpnq+c3wtXuadVZAcxbbBo= gonum.org/v1/gonum v0.8.2/go.mod h1:oe/vMfY3deqTw+1EZJhuvEW2iwGF1bW9wwu7XCu0+v0= @@ -1487,7 +1477,6 @@ gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8 gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/check.v1 v1.0.0-20190902080502-41f04d3bba15/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/check.v1 v1.0.0-20200227125254-8fa46927fb4f/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= -gopkg.in/check.v1 v1.0.0-20200902074654-038fdea0a05b h1:QRR6H1YWRnHb4Y/HeNFCTJLFVxaq6wH4YuVdsUOr75U= gopkg.in/check.v1 v1.0.0-20200902074654-038fdea0a05b/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/cheggaaa/pb.v1 v1.0.25/go.mod h1:V/YB90LKu/1FcN3WVnfiiE5oMCibMjukxqG/qStrOgw= gopkg.in/cheggaaa/pb.v1 v1.0.28/go.mod h1:V/YB90LKu/1FcN3WVnfiiE5oMCibMjukxqG/qStrOgw= @@ -1498,7 +1487,6 @@ gopkg.in/ini.v1 v1.63.2/go.mod h1:pNLf8WUiyNEtQjuu5G5vTm06TEv9tsIgeAvK8hOrP4k= gopkg.in/ini.v1 v1.66.2/go.mod h1:pNLf8WUiyNEtQjuu5G5vTm06TEv9tsIgeAvK8hOrP4k= gopkg.in/ini.v1 v1.66.3/go.mod h1:pNLf8WUiyNEtQjuu5G5vTm06TEv9tsIgeAvK8hOrP4k= gopkg.in/resty.v1 v1.12.0/go.mod h1:mDo4pnntr5jdWRML875a/NmxYqAlA73dVijT2AXvQQo= -gopkg.in/tomb.v1 v1.0.0-20141024135613-dd632973f1e7 h1:uRGJdciOHaEIrze2W8Q3AKkepLTh2hOroT7a+7czfdQ= gopkg.in/tomb.v1 v1.0.0-20141024135613-dd632973f1e7/go.mod h1:dt/ZhP58zS4L8KSrWDmTeBkI65Dw0HsyUHuEVlX15mw= gopkg.in/warnings.v0 v0.1.2/go.mod h1:jksf8JmL6Qr/oQM2OXTHunEvvTAsrWBLb6OOjuVWRNI= gopkg.in/yaml.v2 v2.0.0-20170812160011-eb3733d160e7/go.mod h1:JAlM8MvJe8wmxCU4Bli9HhUf9+ttbYbLASfIpnQbh74= @@ -1510,7 +1498,6 @@ gopkg.in/yaml.v2 v2.2.5/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= gopkg.in/yaml.v2 v2.2.6/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= gopkg.in/yaml.v2 v2.2.8/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= gopkg.in/yaml.v2 v2.3.0/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= -gopkg.in/yaml.v2 v2.4.0 h1:D8xgwECY7CYvx+Y2n4sBz93Jn9JRvxdiyyo8CTfuKaY= gopkg.in/yaml.v2 v2.4.0/go.mod h1:RDklbk79AGWmwhnvt/jBztapEOGDOx6ZbXqjP6csGnQ= gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= gopkg.in/yaml.v3 v3.0.0-20210107192922-496545a6307b h1:h8qDotaEPuJATrMmW04NCwg7v22aHH28wwpauUhK9Oo= diff --git a/immutable_tree.go b/immutable_tree.go index 14d85dfee..bbce557d7 100644 --- a/immutable_tree.go +++ b/immutable_tree.go @@ -137,36 +137,90 @@ func (t *ImmutableTree) Export() *Exporter { return newExporter(t) } -// Get returns the index and value of the specified key if it exists, or nil and the next index +// GetWithIndex returns the index and value of the specified key if it exists, or nil and the next index // otherwise. The returned value must not be modified, since it may point to data stored within // IAVL. -func (t *ImmutableTree) Get(key []byte) (index int64, value []byte) { +// +// The index is the index in the list of leaf nodes sorted lexicographically by key. The leftmost leaf has index 0. +// It's neighbor has index 1 and so on. +func (t *ImmutableTree) GetWithIndex(key []byte) (int64, []byte) { if t.root == nil { return 0, nil } return t.root.get(t, key) } +// Get returns the value of the specified key if it exists, or nil. +// The returned value must not be modified, since it may point to data stored within IAVL. +// Get potentially employs a more performant strategy than GetWithIndex for retrieving the value. +func (t *ImmutableTree) Get(key []byte) []byte { + if t.root == nil { + return nil + } + + // attempt to get a FastNode directly from db/cache. + // if call fails, fall back to the original IAVL logic in place. + fastNode, err := t.ndb.GetFastNode(key) + if err != nil { + _, result := t.root.get(t, key) + return result + } + + if fastNode == nil { + // If the tree is of the latest version and fast node is not in the tree + // then the regular node is not in the tree either because fast node + // represents live state. + if t.version == t.ndb.latestVersion { + return nil + } + + _, result := t.root.get(t, key) + return result + } + + if fastNode.versionLastUpdatedAt <= t.version { + return fastNode.value + } + + // Otherwise the cached node was updated later than the current tree. In this case, + // we need to use the regular stategy for reading from the current tree to avoid staleness. + _, result := t.root.get(t, key) + return result +} + // GetByIndex gets the key and value at the specified index. func (t *ImmutableTree) GetByIndex(index int64) (key []byte, value []byte) { if t.root == nil { return nil, nil } + return t.root.getByIndex(t, index) } -// Iterate iterates over all keys of the tree, in order. The keys and values must not be modified, -// since they may point to data stored within IAVL. -func (t *ImmutableTree) Iterate(fn func(key []byte, value []byte) bool) (stopped bool) { +// Iterate iterates over all keys of the tree. The keys and values must not be modified, +// since they may point to data stored within IAVL. Returns true if stopped by callback, false otherwise +func (t *ImmutableTree) Iterate(fn func(key []byte, value []byte) bool) bool { if t.root == nil { return false } - return t.root.traverse(t, true, func(node *Node) bool { - if node.height == 0 { - return fn(node.key, node.value) + + itr := t.Iterator(nil, nil, true) + defer itr.Close() + for ; itr.Valid(); itr.Next() { + if fn(itr.Key(), itr.Value()) { + return true } - return false - }) + + } + return false +} + +// Iterator returns an iterator over the immutable tree. +func (t *ImmutableTree) Iterator(start, end []byte, ascending bool) dbm.Iterator { + if t.IsFastCacheEnabled() { + return NewFastIterator(start, end, ascending, t.ndb) + } + return NewIterator(start, end, ascending, t) } // IterateRange makes a callback for all nodes with key between start and end non-inclusive. @@ -199,6 +253,18 @@ func (t *ImmutableTree) IterateRangeInclusive(start, end []byte, ascending bool, }) } +// IsFastCacheEnabled returns true if fast cache is enabled, false otherwise. +// For fast cache to be enabled, the following 2 conditions must be met: +// 1. The tree is of the latest version. +// 2. The underlying storage has been upgraded to fast cache +func (t *ImmutableTree) IsFastCacheEnabled() bool { + return t.isLatestTreeVersion() && t.ndb.hasUpgradedToFastStorage() +} + +func (t *ImmutableTree) isLatestTreeVersion() bool { + return t.version == t.ndb.getLatestVersion() +} + // Clone creates a clone of the tree. // Used internally by MutableTree. func (t *ImmutableTree) clone() *ImmutableTree { diff --git a/iterator.go b/iterator.go index bd69fcd42..cb22e50c9 100644 --- a/iterator.go +++ b/iterator.go @@ -5,6 +5,7 @@ package iavl import ( "bytes" + "errors" dbm "github.com/tendermint/tm-db" ) @@ -18,6 +19,8 @@ type traversal struct { delayedNodes *delayedNodes // delayed nodes to be traversed } +var errIteratorNilTreeGiven = errors.New("iterator must be created with an immutable tree but the tree was nil") + func (node *Node) newTraversal(tree *ImmutableTree, start, end []byte, ascending bool, inclusive bool, post bool) *traversal { return &traversal{ tree: tree, @@ -157,23 +160,31 @@ type Iterator struct { valid bool + err error + t *traversal } -func (t *ImmutableTree) Iterator(start, end []byte, ascending bool) *Iterator { +var _ dbm.Iterator = (*Iterator)(nil) + +// Returns a new iterator over the immutable tree. If the tree is nil, the iterator will be invalid. +func NewIterator(start, end []byte, ascending bool, tree *ImmutableTree) dbm.Iterator { iter := &Iterator{ start: start, end: end, - valid: true, - t: t.root.newTraversal(t, start, end, ascending, false, false), } - iter.Next() + if tree == nil { + iter.err = errIteratorNilTreeGiven + } else { + iter.valid = true + iter.t = tree.root.newTraversal(tree, start, end, ascending, false, false) + // Move iterator before the first element + iter.Next() + } return iter } -var _ dbm.Iterator = &Iterator{} - // Domain implements dbm.Iterator. func (iter *Iterator) Domain() ([]byte, []byte) { return iter.start, iter.end @@ -219,10 +230,15 @@ func (iter *Iterator) Next() { func (iter *Iterator) Close() error { iter.t = nil iter.valid = false - return nil + return iter.err } // Error implements dbm.Iterator func (iter *Iterator) Error() error { - return nil + return iter.err +} + +// IsFast returnts true if iterator uses fast strategy +func (iter *Iterator) IsFast() bool { + return false } diff --git a/iterator_test.go b/iterator_test.go new file mode 100644 index 000000000..c5db3b0e2 --- /dev/null +++ b/iterator_test.go @@ -0,0 +1,326 @@ +package iavl + +import ( + "math/rand" + "sort" + "testing" + + "github.com/stretchr/testify/require" + dbm "github.com/tendermint/tm-db" +) + +func TestIterator_NewIterator_NilTree_Failure(t *testing.T) { + var start, end = []byte{'a'}, []byte{'c'} + ascending := true + + performTest := func(t *testing.T, itr dbm.Iterator) { + require.NotNil(t, itr) + require.False(t, itr.Valid()) + actualsStart, actualEnd := itr.Domain() + require.Equal(t, start, actualsStart) + require.Equal(t, end, actualEnd) + require.Error(t, itr.Error()) + } + + t.Run("Iterator", func(t *testing.T) { + itr := NewIterator(start, end, ascending, nil) + performTest(t, itr) + require.ErrorIs(t, errIteratorNilTreeGiven, itr.Error()) + }) + + t.Run("Fast Iterator", func(t *testing.T) { + itr := NewFastIterator(start, end, ascending, nil) + performTest(t, itr) + require.ErrorIs(t, errFastIteratorNilNdbGiven, itr.Error()) + }) + + t.Run("Unsaved Fast Iterator", func(t *testing.T) { + itr := NewUnsavedFastIterator(start, end, ascending, nil, map[string]*FastNode{}, map[string]interface{}{}) + performTest(t, itr) + require.ErrorIs(t, errFastIteratorNilNdbGiven, itr.Error()) + }) +} + +func TestUnsavedFastIterator_NewIterator_NilAdditions_Failure(t *testing.T) { + var start, end = []byte{'a'}, []byte{'c'} + ascending := true + + performTest := func(t *testing.T, itr dbm.Iterator) { + require.NotNil(t, itr) + require.False(t, itr.Valid()) + actualsStart, actualEnd := itr.Domain() + require.Equal(t, start, actualsStart) + require.Equal(t, end, actualEnd) + require.Error(t, itr.Error()) + } + + t.Run("Nil additions given", func(t *testing.T) { + tree, err := NewMutableTree(dbm.NewMemDB(), 0) + require.NoError(t, err) + itr := NewUnsavedFastIterator(start, end, ascending, tree.ndb, nil, tree.unsavedFastNodeRemovals) + performTest(t, itr) + require.ErrorIs(t, errUnsavedFastIteratorNilAdditionsGiven, itr.Error()) + }) + + t.Run("Nil removals given", func(t *testing.T) { + tree, err := NewMutableTree(dbm.NewMemDB(), 0) + require.NoError(t, err) + itr := NewUnsavedFastIterator(start, end, ascending, tree.ndb, tree.unsavedFastNodeAdditions, nil) + performTest(t, itr) + require.ErrorIs(t, errUnsavedFastIteratorNilRemovalsGiven, itr.Error()) + }) + + t.Run("All nil", func(t *testing.T) { + itr := NewUnsavedFastIterator(start, end, ascending, nil, nil, nil) + performTest(t, itr) + require.ErrorIs(t, errFastIteratorNilNdbGiven, itr.Error()) + }) + + t.Run("Additions and removals are nil", func(t *testing.T) { + tree, err := NewMutableTree(dbm.NewMemDB(), 0) + require.NoError(t, err) + itr := NewUnsavedFastIterator(start, end, ascending, tree.ndb, nil, nil) + performTest(t, itr) + require.ErrorIs(t, errUnsavedFastIteratorNilAdditionsGiven, itr.Error()) + }) +} + +func TestIterator_Empty_Invalid(t *testing.T) { + config := &iteratorTestConfig{ + startByteToSet: 'a', + endByteToSet: 'z', + startIterate: []byte("a"), + endIterate: []byte("a"), + ascending: true, + } + + performTest := func(t *testing.T, itr dbm.Iterator, mirror [][]string) { + require.Equal(t, 0, len(mirror)) + require.False(t, itr.Valid()) + } + + t.Run("Iterator", func(t *testing.T) { + itr, mirror := setupIteratorAndMirror(t, config) + performTest(t, itr, mirror) + }) + + t.Run("Fast Iterator", func(t *testing.T) { + itr, mirror := setupFastIteratorAndMirror(t, config) + performTest(t, itr, mirror) + }) + + t.Run("Unsaved Fast Iterator", func(t *testing.T) { + itr, mirror := setupUnsavedFastIterator(t, config) + performTest(t, itr, mirror) + }) +} + +func TestIterator_Basic_Ranged_Ascending_Success(t *testing.T) { + config := &iteratorTestConfig{ + startByteToSet: 'a', + endByteToSet: 'z', + startIterate: []byte("e"), + endIterate: []byte("w"), + ascending: true, + } + iteratorSuccessTest(t, config) +} + +func TestIterator_Basic_Ranged_Descending_Success(t *testing.T) { + config := &iteratorTestConfig{ + startByteToSet: 'a', + endByteToSet: 'z', + startIterate: []byte("e"), + endIterate: []byte("w"), + ascending: false, + } + iteratorSuccessTest(t, config) +} + +func TestIterator_Basic_Full_Ascending_Success(t *testing.T) { + config := &iteratorTestConfig{ + startByteToSet: 'a', + endByteToSet: 'z', + startIterate: nil, + endIterate: nil, + ascending: true, + } + + iteratorSuccessTest(t, config) +} + +func TestIterator_Basic_Full_Descending_Success(t *testing.T) { + config := &iteratorTestConfig{ + startByteToSet: 'a', + endByteToSet: 'z', + startIterate: nil, + endIterate: nil, + ascending: false, + } + iteratorSuccessTest(t, config) +} + +func TestIterator_WithDelete_Full_Ascending_Success(t *testing.T) { + config := &iteratorTestConfig{ + startByteToSet: 'a', + endByteToSet: 'z', + startIterate: nil, + endIterate: nil, + ascending: false, + } + + tree, mirror := getRandomizedTreeAndMirror(t) + + _, _, err := tree.SaveVersion() + require.NoError(t, err) + + randomizeTreeAndMirror(t, tree, mirror) + + _, _, err = tree.SaveVersion() + require.NoError(t, err) + + err = tree.DeleteVersion(1) + require.NoError(t, err) + + immutableTree, err := tree.GetImmutable(tree.ndb.getLatestVersion()) + require.NoError(t, err) + + // sort mirror for assertion + sortedMirror := make([][]string, 0, len(mirror)) + for k, v := range mirror { + sortedMirror = append(sortedMirror, []string{k, v}) + } + + sort.Slice(sortedMirror, func(i, j int) bool { + return sortedMirror[i][0] > sortedMirror[j][0] + }) + + t.Run("Iterator", func(t *testing.T) { + itr := NewIterator(config.startIterate, config.endIterate, config.ascending, immutableTree) + require.True(t, itr.Valid()) + assertIterator(t, itr, sortedMirror, config.ascending) + }) + + t.Run("Fast Iterator", func(t *testing.T) { + itr := NewFastIterator(config.startIterate, config.endIterate, config.ascending, immutableTree.ndb) + require.True(t, itr.Valid()) + assertIterator(t, itr, sortedMirror, config.ascending) + }) + + t.Run("Unsaved Fast Iterator", func(t *testing.T) { + itr := NewUnsavedFastIterator(config.startIterate, config.endIterate, config.ascending, immutableTree.ndb, tree.unsavedFastNodeAdditions, tree.unsavedFastNodeRemovals) + require.True(t, itr.Valid()) + assertIterator(t, itr, sortedMirror, config.ascending) + }) +} + +func iteratorSuccessTest(t *testing.T, config *iteratorTestConfig) { + performTest := func(t *testing.T, itr dbm.Iterator, mirror [][]string) { + actualStart, actualEnd := itr.Domain() + require.Equal(t, config.startIterate, actualStart) + require.Equal(t, config.endIterate, actualEnd) + + require.NoError(t, itr.Error()) + + assertIterator(t, itr, mirror, config.ascending) + } + + t.Run("Iterator", func(t *testing.T) { + itr, mirror := setupIteratorAndMirror(t, config) + require.True(t, itr.Valid()) + performTest(t, itr, mirror) + }) + + t.Run("Fast Iterator", func(t *testing.T) { + itr, mirror := setupFastIteratorAndMirror(t, config) + require.True(t, itr.Valid()) + performTest(t, itr, mirror) + }) + + t.Run("Unsaved Fast Iterator", func(t *testing.T) { + itr, mirror := setupUnsavedFastIterator(t, config) + require.True(t, itr.Valid()) + performTest(t, itr, mirror) + }) +} + +func setupIteratorAndMirror(t *testing.T, config *iteratorTestConfig) (dbm.Iterator, [][]string) { + tree, err := NewMutableTree(dbm.NewMemDB(), 0) + require.NoError(t, err) + + mirror := setupMirrorForIterator(t, config, tree) + _, _, err = tree.SaveVersion() + require.NoError(t, err) + + immutableTree, err := tree.GetImmutable(tree.ndb.getLatestVersion()) + require.NoError(t, err) + + itr := NewIterator(config.startIterate, config.endIterate, config.ascending, immutableTree) + return itr, mirror +} + +func setupFastIteratorAndMirror(t *testing.T, config *iteratorTestConfig) (dbm.Iterator, [][]string) { + tree, err := NewMutableTree(dbm.NewMemDB(), 0) + require.NoError(t, err) + + mirror := setupMirrorForIterator(t, config, tree) + _, _, err = tree.SaveVersion() + require.NoError(t, err) + + itr := NewFastIterator(config.startIterate, config.endIterate, config.ascending, tree.ndb) + return itr, mirror +} + +func setupUnsavedFastIterator(t *testing.T, config *iteratorTestConfig) (dbm.Iterator, [][]string) { + tree, err := NewMutableTree(dbm.NewMemDB(), 0) + require.NoError(t, err) + + // For unsaved fast iterator, we would like to test the state where + // there are saved fast nodes as well as some unsaved additions and removals. + // So, we split the byte range in half where the first half is saved and the second half is unsaved. + breakpointByte := (config.endByteToSet + config.startByteToSet) / 2 + + firstHalfConfig := *config + firstHalfConfig.endByteToSet = breakpointByte // exclusive + + secondHalfConfig := *config + secondHalfConfig.startByteToSet = breakpointByte + + // First half of the mirror + mirror := setupMirrorForIterator(t, &firstHalfConfig, tree) + _, _, err = tree.SaveVersion() + require.NoError(t, err) + + // No unsaved additions or removals should be present after saving + require.Equal(t, 0, len(tree.unsavedFastNodeAdditions)) + require.Equal(t, 0, len(tree.unsavedFastNodeRemovals)) + + // Ensure that there are unsaved additions and removals present + secondHalfMirror := setupMirrorForIterator(t, &secondHalfConfig, tree) + + require.True(t, len(tree.unsavedFastNodeAdditions) >= len(secondHalfMirror)) + require.Equal(t, 0, len(tree.unsavedFastNodeRemovals)) + + // Merge the two halves + if config.ascending { + mirror = append(mirror, secondHalfMirror...) + } else { + mirror = append(secondHalfMirror, mirror...) + } + + if len(mirror) > 0 { + // Remove random keys + for i := 0; i < len(mirror)/4; i++ { + randIndex := rand.Intn(len(mirror)) + keyToRemove := mirror[randIndex][0] + + _, removed := tree.Remove([]byte(keyToRemove)) + require.True(t, removed) + + mirror = append(mirror[:randIndex], mirror[randIndex+1:]...) + } + } + + itr := NewUnsavedFastIterator(config.startIterate, config.endIterate, config.ascending, tree.ndb, tree.unsavedFastNodeAdditions, tree.unsavedFastNodeRemovals) + return itr, mirror +} diff --git a/key_format.go b/key_format.go index 988a97cdf..31ff82a26 100644 --- a/key_format.go +++ b/key_format.go @@ -7,9 +7,10 @@ import ( // Provides a fixed-width lexicographically sortable []byte key format type KeyFormat struct { - prefix byte - layout []int - length int + layout []int + length int + prefix byte + unbounded bool } // Create a []byte key format based on a single byte prefix and fixed width key segments each of whose length is @@ -27,16 +28,21 @@ type KeyFormat struct { // hasher.Sum(nil) // return keyFormat.Key(version, hasher.Sum(nil)) // } +// if the last term of the layout ends in 0 func NewKeyFormat(prefix byte, layout ...int) *KeyFormat { // For prefix byte length := 1 - for _, l := range layout { + for i, l := range layout { length += l + if l == 0 && i != len(layout)-1 { + panic("Only the last item in a key format can be 0") + } } return &KeyFormat{ - prefix: prefix, - layout: layout, - length: length, + prefix: prefix, + layout: layout, + length: length, + unbounded: len(layout) > 0 && layout[len(layout)-1] == 0, } } @@ -53,16 +59,30 @@ func (kf *KeyFormat) KeyBytes(segments ...[]byte) []byte { } } + if kf.unbounded { + if len(segments) > 0 { + keyLen += len(segments[len(segments)-1]) + } + } key := make([]byte, keyLen) key[0] = kf.prefix n := 1 for i, s := range segments { l := kf.layout[i] - if len(s) > l { - panic(fmt.Errorf("length of segment %X provided to KeyFormat.KeyBytes() is longer than the %d bytes "+ - "required by layout for segment %d", s, l, i)) + + switch l { + case 0: + // If the expected segment length is unbounded, increase it by `string length` + n += len(s) + default: + if len(s) > l { + panic(fmt.Errorf("length of segment %X provided to KeyFormat.KeyBytes() is longer than the %d bytes "+ + "required by layout for segment %d", s, l, i)) + } + // Otherwise increase n by the segment length + n += l + } - n += l // Big endian so pad on left if not given the full width for this segment copy(key[n-len(s):n], s) } @@ -90,10 +110,17 @@ func (kf *KeyFormat) ScanBytes(key []byte) [][]byte { n := 1 for i, l := range kf.layout { n += l + // if current section is longer than key, then there are no more subsequent segments. if n > len(key) { return segments[:i] } - segments[i] = key[n-l : n] + // if unbounded, segment is rest of key + if l == 0 { + segments[i] = key[n:] + break + } else { + segments[i] = key[n-l : n] + } } return segments } diff --git a/key_format_test.go b/key_format_test.go index 7bb55c44d..cc3a2d4f7 100644 --- a/key_format_test.go +++ b/key_format_test.go @@ -7,12 +7,59 @@ import ( ) func TestKeyFormatBytes(t *testing.T) { - kf := NewKeyFormat(byte('e'), 8, 8, 8) - assert.Equal(t, []byte{'e', 0, 0, 0, 0, 0, 1, 2, 3}, kf.KeyBytes([]byte{1, 2, 3})) - assert.Equal(t, []byte{'e', 1, 2, 3, 4, 5, 6, 7, 8}, kf.KeyBytes([]byte{1, 2, 3, 4, 5, 6, 7, 8})) - assert.Equal(t, []byte{'e', 1, 2, 3, 4, 5, 6, 7, 8, 1, 2, 3, 4, 5, 6, 7, 8, 0, 0, 1, 1, 2, 2, 3, 3}, - kf.KeyBytes([]byte{1, 2, 3, 4, 5, 6, 7, 8}, []byte{1, 2, 3, 4, 5, 6, 7, 8}, []byte{1, 1, 2, 2, 3, 3})) - assert.Equal(t, []byte{'e'}, kf.KeyBytes()) + type keyPairs struct { + key [][]byte + expected []byte + } + emptyTestVector := keyPairs{key: [][]byte{}, expected: []byte{'e'}} + threeByteTestVector := keyPairs{ + key: [][]byte{{1, 2, 3}}, + expected: []byte{'e', 0, 0, 0, 0, 0, 1, 2, 3}, + } + eightByteTestVector := keyPairs{ + key: [][]byte{{1, 2, 3, 4, 5, 6, 7, 8}}, + expected: []byte{'e', 1, 2, 3, 4, 5, 6, 7, 8}, + } + + tests := []struct { + name string + kf *KeyFormat + testVectors []keyPairs + }{{ + name: "simple 3 int key format", + kf: NewKeyFormat(byte('e'), 8, 8, 8), + testVectors: []keyPairs{ + emptyTestVector, + threeByteTestVector, + eightByteTestVector, + { + key: [][]byte{{1, 2, 3, 4, 5, 6, 7, 8}, {1, 2, 3, 4, 5, 6, 7, 8}, {1, 1, 2, 2, 3, 3}}, + expected: []byte{'e', 1, 2, 3, 4, 5, 6, 7, 8, 1, 2, 3, 4, 5, 6, 7, 8, 0, 0, 1, 1, 2, 2, 3, 3}, + }, + }, + }, { + name: "zero suffix key format", + kf: NewKeyFormat(byte('e'), 8, 0), + testVectors: []keyPairs{ + emptyTestVector, + threeByteTestVector, + eightByteTestVector, + { + key: [][]byte{{1, 2, 3, 4, 5, 6, 7, 8}, {1, 2, 3, 4, 5, 6, 7, 8, 9}}, + expected: []byte{'e', 1, 2, 3, 4, 5, 6, 7, 8, 1, 2, 3, 4, 5, 6, 7, 8, 9}, + }, + { + key: [][]byte{{1, 2, 3, 4, 5, 6, 7, 8}, []byte("hellohello")}, + expected: []byte{'e', 1, 2, 3, 4, 5, 6, 7, 8, 0x68, 0x65, 0x6c, 0x6c, 0x6f, 0x68, 0x65, 0x6c, 0x6c, 0x6f}, + }, + }, + }} + for _, tc := range tests { + kf := tc.kf + for i, v := range tc.testVectors { + assert.Equal(t, v.expected, kf.KeyBytes(v.key...), "key format %s, test case %d", tc.name, i) + } + } } func TestKeyFormat(t *testing.T) { diff --git a/mock/db_mock.go b/mock/db_mock.go new file mode 100644 index 000000000..8120cc64f --- /dev/null +++ b/mock/db_mock.go @@ -0,0 +1,420 @@ +// Code generated by MockGen. DO NOT EDIT. +// Source: /root/go/pkg/mod/github.com/tendermint/tm-db@v0.6.4/types.go + +// Package mock is a generated GoMock package. +package mock + +import ( + reflect "reflect" + + gomock "github.com/golang/mock/gomock" + db "github.com/tendermint/tm-db" +) + +// MockDB is a mock of DB interface. +type MockDB struct { + ctrl *gomock.Controller + recorder *MockDBMockRecorder +} + +// MockDBMockRecorder is the mock recorder for MockDB. +type MockDBMockRecorder struct { + mock *MockDB +} + +// NewMockDB creates a new mock instance. +func NewMockDB(ctrl *gomock.Controller) *MockDB { + mock := &MockDB{ctrl: ctrl} + mock.recorder = &MockDBMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MockDB) EXPECT() *MockDBMockRecorder { + return m.recorder +} + +// Close mocks base method. +func (m *MockDB) Close() error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Close") + ret0, _ := ret[0].(error) + return ret0 +} + +// Close indicates an expected call of Close. +func (mr *MockDBMockRecorder) Close() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Close", reflect.TypeOf((*MockDB)(nil).Close)) +} + +// Delete mocks base method. +func (m *MockDB) Delete(arg0 []byte) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Delete", arg0) + ret0, _ := ret[0].(error) + return ret0 +} + +// Delete indicates an expected call of Delete. +func (mr *MockDBMockRecorder) Delete(arg0 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Delete", reflect.TypeOf((*MockDB)(nil).Delete), arg0) +} + +// DeleteSync mocks base method. +func (m *MockDB) DeleteSync(arg0 []byte) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "DeleteSync", arg0) + ret0, _ := ret[0].(error) + return ret0 +} + +// DeleteSync indicates an expected call of DeleteSync. +func (mr *MockDBMockRecorder) DeleteSync(arg0 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteSync", reflect.TypeOf((*MockDB)(nil).DeleteSync), arg0) +} + +// Get mocks base method. +func (m *MockDB) Get(arg0 []byte) ([]byte, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Get", arg0) + ret0, _ := ret[0].([]byte) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// Get indicates an expected call of Get. +func (mr *MockDBMockRecorder) Get(arg0 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Get", reflect.TypeOf((*MockDB)(nil).Get), arg0) +} + +// Has mocks base method. +func (m *MockDB) Has(key []byte) (bool, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Has", key) + ret0, _ := ret[0].(bool) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// Has indicates an expected call of Has. +func (mr *MockDBMockRecorder) Has(key interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Has", reflect.TypeOf((*MockDB)(nil).Has), key) +} + +// Iterator mocks base method. +func (m *MockDB) Iterator(start, end []byte) (db.Iterator, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Iterator", start, end) + ret0, _ := ret[0].(db.Iterator) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// Iterator indicates an expected call of Iterator. +func (mr *MockDBMockRecorder) Iterator(start, end interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Iterator", reflect.TypeOf((*MockDB)(nil).Iterator), start, end) +} + +// NewBatch mocks base method. +func (m *MockDB) NewBatch() db.Batch { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "NewBatch") + ret0, _ := ret[0].(db.Batch) + return ret0 +} + +// NewBatch indicates an expected call of NewBatch. +func (mr *MockDBMockRecorder) NewBatch() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "NewBatch", reflect.TypeOf((*MockDB)(nil).NewBatch)) +} + +// Print mocks base method. +func (m *MockDB) Print() error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Print") + ret0, _ := ret[0].(error) + return ret0 +} + +// Print indicates an expected call of Print. +func (mr *MockDBMockRecorder) Print() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Print", reflect.TypeOf((*MockDB)(nil).Print)) +} + +// ReverseIterator mocks base method. +func (m *MockDB) ReverseIterator(start, end []byte) (db.Iterator, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "ReverseIterator", start, end) + ret0, _ := ret[0].(db.Iterator) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// ReverseIterator indicates an expected call of ReverseIterator. +func (mr *MockDBMockRecorder) ReverseIterator(start, end interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ReverseIterator", reflect.TypeOf((*MockDB)(nil).ReverseIterator), start, end) +} + +// Set mocks base method. +func (m *MockDB) Set(arg0, arg1 []byte) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Set", arg0, arg1) + ret0, _ := ret[0].(error) + return ret0 +} + +// Set indicates an expected call of Set. +func (mr *MockDBMockRecorder) Set(arg0, arg1 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Set", reflect.TypeOf((*MockDB)(nil).Set), arg0, arg1) +} + +// SetSync mocks base method. +func (m *MockDB) SetSync(arg0, arg1 []byte) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "SetSync", arg0, arg1) + ret0, _ := ret[0].(error) + return ret0 +} + +// SetSync indicates an expected call of SetSync. +func (mr *MockDBMockRecorder) SetSync(arg0, arg1 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetSync", reflect.TypeOf((*MockDB)(nil).SetSync), arg0, arg1) +} + +// Stats mocks base method. +func (m *MockDB) Stats() map[string]string { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Stats") + ret0, _ := ret[0].(map[string]string) + return ret0 +} + +// Stats indicates an expected call of Stats. +func (mr *MockDBMockRecorder) Stats() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Stats", reflect.TypeOf((*MockDB)(nil).Stats)) +} + +// MockBatch is a mock of Batch interface. +type MockBatch struct { + ctrl *gomock.Controller + recorder *MockBatchMockRecorder +} + +// MockBatchMockRecorder is the mock recorder for MockBatch. +type MockBatchMockRecorder struct { + mock *MockBatch +} + +// NewMockBatch creates a new mock instance. +func NewMockBatch(ctrl *gomock.Controller) *MockBatch { + mock := &MockBatch{ctrl: ctrl} + mock.recorder = &MockBatchMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MockBatch) EXPECT() *MockBatchMockRecorder { + return m.recorder +} + +// Close mocks base method. +func (m *MockBatch) Close() error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Close") + ret0, _ := ret[0].(error) + return ret0 +} + +// Close indicates an expected call of Close. +func (mr *MockBatchMockRecorder) Close() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Close", reflect.TypeOf((*MockBatch)(nil).Close)) +} + +// Delete mocks base method. +func (m *MockBatch) Delete(key []byte) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Delete", key) + ret0, _ := ret[0].(error) + return ret0 +} + +// Delete indicates an expected call of Delete. +func (mr *MockBatchMockRecorder) Delete(key interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Delete", reflect.TypeOf((*MockBatch)(nil).Delete), key) +} + +// Set mocks base method. +func (m *MockBatch) Set(key, value []byte) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Set", key, value) + ret0, _ := ret[0].(error) + return ret0 +} + +// Set indicates an expected call of Set. +func (mr *MockBatchMockRecorder) Set(key, value interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Set", reflect.TypeOf((*MockBatch)(nil).Set), key, value) +} + +// Write mocks base method. +func (m *MockBatch) Write() error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Write") + ret0, _ := ret[0].(error) + return ret0 +} + +// Write indicates an expected call of Write. +func (mr *MockBatchMockRecorder) Write() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Write", reflect.TypeOf((*MockBatch)(nil).Write)) +} + +// WriteSync mocks base method. +func (m *MockBatch) WriteSync() error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "WriteSync") + ret0, _ := ret[0].(error) + return ret0 +} + +// WriteSync indicates an expected call of WriteSync. +func (mr *MockBatchMockRecorder) WriteSync() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "WriteSync", reflect.TypeOf((*MockBatch)(nil).WriteSync)) +} + +// MockIterator is a mock of Iterator interface. +type MockIterator struct { + ctrl *gomock.Controller + recorder *MockIteratorMockRecorder +} + +// MockIteratorMockRecorder is the mock recorder for MockIterator. +type MockIteratorMockRecorder struct { + mock *MockIterator +} + +// NewMockIterator creates a new mock instance. +func NewMockIterator(ctrl *gomock.Controller) *MockIterator { + mock := &MockIterator{ctrl: ctrl} + mock.recorder = &MockIteratorMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MockIterator) EXPECT() *MockIteratorMockRecorder { + return m.recorder +} + +// Close mocks base method. +func (m *MockIterator) Close() error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Close") + ret0, _ := ret[0].(error) + return ret0 +} + +// Close indicates an expected call of Close. +func (mr *MockIteratorMockRecorder) Close() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Close", reflect.TypeOf((*MockIterator)(nil).Close)) +} + +// Domain mocks base method. +func (m *MockIterator) Domain() ([]byte, []byte) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Domain") + ret0, _ := ret[0].([]byte) + ret1, _ := ret[1].([]byte) + return ret0, ret1 +} + +// Domain indicates an expected call of Domain. +func (mr *MockIteratorMockRecorder) Domain() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Domain", reflect.TypeOf((*MockIterator)(nil).Domain)) +} + +// Error mocks base method. +func (m *MockIterator) Error() error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Error") + ret0, _ := ret[0].(error) + return ret0 +} + +// Error indicates an expected call of Error. +func (mr *MockIteratorMockRecorder) Error() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Error", reflect.TypeOf((*MockIterator)(nil).Error)) +} + +// Key mocks base method. +func (m *MockIterator) Key() []byte { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Key") + ret0, _ := ret[0].([]byte) + return ret0 +} + +// Key indicates an expected call of Key. +func (mr *MockIteratorMockRecorder) Key() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Key", reflect.TypeOf((*MockIterator)(nil).Key)) +} + +// Next mocks base method. +func (m *MockIterator) Next() { + m.ctrl.T.Helper() + m.ctrl.Call(m, "Next") +} + +// Next indicates an expected call of Next. +func (mr *MockIteratorMockRecorder) Next() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Next", reflect.TypeOf((*MockIterator)(nil).Next)) +} + +// Valid mocks base method. +func (m *MockIterator) Valid() bool { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Valid") + ret0, _ := ret[0].(bool) + return ret0 +} + +// Valid indicates an expected call of Valid. +func (mr *MockIteratorMockRecorder) Valid() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Valid", reflect.TypeOf((*MockIterator)(nil).Valid)) +} + +// Value mocks base method. +func (m *MockIterator) Value() []byte { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Value") + ret0, _ := ret[0].([]byte) + return ret0 +} + +// Value indicates an expected call of Value. +func (mr *MockIteratorMockRecorder) Value() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Value", reflect.TypeOf((*MockIterator)(nil).Value)) +} diff --git a/mutable_tree.go b/mutable_tree.go index 496aa4293..05e54225c 100644 --- a/mutable_tree.go +++ b/mutable_tree.go @@ -4,8 +4,10 @@ import ( "bytes" "crypto/sha256" "fmt" + "runtime" "sort" "sync" + "time" "github.com/pkg/errors" @@ -26,14 +28,16 @@ var ErrVersionDoesNotExist = errors.New("version does not exist") // // The inner ImmutableTree should not be used directly by callers. type MutableTree struct { - *ImmutableTree // The current, working tree. - lastSaved *ImmutableTree // The most recently saved tree. - orphans map[string]int64 // Nodes removed by changes to working tree. - versions map[int64]bool // The previous, saved versions of the tree. - allRootLoaded bool // Whether all roots are loaded or not(by LazyLoadVersion) - ndb *nodeDB + *ImmutableTree // The current, working tree. + lastSaved *ImmutableTree // The most recently saved tree. + orphans map[string]int64 // Nodes removed by changes to working tree. + versions map[int64]bool // The previous, saved versions of the tree. + allRootLoaded bool // Whether all roots are loaded or not(by LazyLoadVersion) + unsavedFastNodeAdditions map[string]*FastNode // FastNodes that have not yet been saved to disk + unsavedFastNodeRemovals map[string]interface{} // FastNodes that have not yet been removed from disk + ndb *nodeDB - mtx sync.RWMutex // versions Read/write lock. + mtx sync.Mutex } // NewMutableTree returns a new tree with the specified cache size and datastore. @@ -47,12 +51,14 @@ func NewMutableTreeWithOpts(db dbm.DB, cacheSize int, opts *Options) (*MutableTr head := &ImmutableTree{ndb: ndb} return &MutableTree{ - ImmutableTree: head, - lastSaved: head.clone(), - orphans: map[string]int64{}, - versions: map[int64]bool{}, - allRootLoaded: false, - ndb: ndb, + ImmutableTree: head, + lastSaved: head.clone(), + orphans: map[string]int64{}, + versions: map[int64]bool{}, + allRootLoaded: false, + unsavedFastNodeAdditions: make(map[string]*FastNode), + unsavedFastNodeRemovals: make(map[string]interface{}), + ndb: ndb, }, nil } @@ -82,8 +88,8 @@ func (tree *MutableTree) VersionExists(version int64) bool { // AvailableVersions returns all available versions in ascending order func (tree *MutableTree) AvailableVersions() []int { - tree.mtx.RLock() - defer tree.mtx.RUnlock() + tree.mtx.Lock() + defer tree.mtx.Unlock() res := make([]int, 0, len(tree.versions)) for i, v := range tree.versions { @@ -107,7 +113,7 @@ func (tree *MutableTree) WorkingHash() []byte { } // String returns a string representation of the tree. -func (tree *MutableTree) String() string { +func (tree *MutableTree) String() (string, error) { return tree.ndb.String() } @@ -128,6 +134,20 @@ func (tree *MutableTree) Set(key, value []byte) (updated bool) { return updated } +// Get returns the value of the specified key if it exists, or nil otherwise. +// The returned value must not be modified, since it may point to data stored within IAVL. +func (tree *MutableTree) Get(key []byte) []byte { + if tree.root == nil { + return nil + } + + if fastNode, ok := tree.unsavedFastNodeAdditions[string(key)]; ok { + return fastNode.value + } + + return tree.ImmutableTree.Get(key) +} + // Import returns an importer for tree nodes previously exported by ImmutableTree.Export(), // producing an identical IAVL tree. The caller must call Close() on the importer when done. // @@ -140,12 +160,43 @@ func (tree *MutableTree) Import(version int64) (*Importer, error) { return newImporter(tree, version) } +// Iterate iterates over all keys of the tree. The keys and values must not be modified, +// since they may point to data stored within IAVL. Returns true if stopped by callnack, false otherwise +func (tree *MutableTree) Iterate(fn func(key []byte, value []byte) bool) (stopped bool) { + if tree.root == nil { + return false + } + + if !tree.IsFastCacheEnabled() { + return tree.ImmutableTree.Iterate(fn) + } + + itr := NewUnsavedFastIterator(nil, nil, true, tree.ndb, tree.unsavedFastNodeAdditions, tree.unsavedFastNodeRemovals) + defer itr.Close() + for ; itr.Valid(); itr.Next() { + if fn(itr.Key(), itr.Value()) { + return true + } + } + return false +} + +// Iterator returns an iterator over the mutable tree. +// CONTRACT: no updates are made to the tree while an iterator is active. +func (tree *MutableTree) Iterator(start, end []byte, ascending bool) dbm.Iterator { + if tree.IsFastCacheEnabled() { + return NewUnsavedFastIterator(start, end, ascending, tree.ndb, tree.unsavedFastNodeAdditions, tree.unsavedFastNodeRemovals) + } + return tree.ImmutableTree.Iterator(start, end, ascending) +} + func (tree *MutableTree) set(key []byte, value []byte) (orphans []*Node, updated bool) { if value == nil { panic(fmt.Sprintf("Attempt to store nil value at key '%s'", key)) } if tree.ImmutableTree.root == nil { + tree.addUnsavedAddition(key, NewFastNode(key, value, tree.version+1)) tree.ImmutableTree.root = NewNode(key, value, tree.version+1) return nil, updated } @@ -161,6 +212,8 @@ func (tree *MutableTree) recursiveSet(node *Node, key []byte, value []byte, orph version := tree.version + 1 if node.isLeaf() { + tree.addUnsavedAddition(key, NewFastNode(key, value, version)) + switch bytes.Compare(key, node.key) { case -1: return &Node{ @@ -225,6 +278,8 @@ func (tree *MutableTree) remove(key []byte) (value []byte, orphaned []*Node, rem return nil, nil, false } + tree.addUnsavedRemoval(key) + if newRoot == nil && newRootHash != nil { tree.root = tree.ndb.GetNode(newRootHash) } else { @@ -311,7 +366,10 @@ func (tree *MutableTree) LazyLoadVersion(targetVersion int64) (int64, error) { // no versions have been saved if the latest version is non-positive if latestVersion <= 0 { if targetVersion <= 0 { - return 0, nil + tree.mtx.Lock() + defer tree.mtx.Unlock() + _, err := tree.enableFastStorageAndCommitIfNotEnabled() + return 0, err } return 0, fmt.Errorf("no versions found while trying to load %v", targetVersion) } @@ -331,6 +389,7 @@ func (tree *MutableTree) LazyLoadVersion(targetVersion int64) (int64, error) { tree.mtx.Lock() defer tree.mtx.Unlock() + tree.versions[targetVersion] = true iTree := &ImmutableTree{ @@ -347,6 +406,11 @@ func (tree *MutableTree) LazyLoadVersion(targetVersion int64) (int64, error) { tree.ImmutableTree = iTree tree.lastSaved = iTree.clone() + // Attempt to upgrade + if _, err := tree.enableFastStorageAndCommitIfNotEnabled(); err != nil { + return 0, err + } + return targetVersion, nil } @@ -359,7 +423,10 @@ func (tree *MutableTree) LoadVersion(targetVersion int64) (int64, error) { if len(roots) == 0 { if targetVersion <= 0 { - return 0, nil + tree.mtx.Lock() + defer tree.mtx.Unlock() + _, err := tree.enableFastStorageAndCommitIfNotEnabled() + return 0, err } return 0, fmt.Errorf("no versions found while trying to load %v", targetVersion) } @@ -406,6 +473,11 @@ func (tree *MutableTree) LoadVersion(targetVersion int64) (int64, error) { tree.lastSaved = t.clone() tree.allRootLoaded = true + // Attempt to upgrade + if _, err := tree.enableFastStorageAndCommitIfNotEnabled(); err != nil { + return 0, err + } + return latestVersion, nil } @@ -421,7 +493,7 @@ func (tree *MutableTree) LoadVersionForOverwriting(targetVersion int64) (int64, return latestVersion, err } - if err = tree.ndb.Commit(); err != nil { + if err := tree.enableFastStorageAndCommitLocked(); err != nil { return latestVersion, err } @@ -439,6 +511,112 @@ func (tree *MutableTree) LoadVersionForOverwriting(targetVersion int64) (int64, return latestVersion, nil } +// Returns true if the tree may be auto-upgraded, false otherwise +// An example of when an upgrade may be performed is when we are enaling fast storage for the first time or +// need to overwrite fast nodes due to mismatch with live state. +func (tree *MutableTree) IsUpgradeable() bool { + return !tree.ndb.hasUpgradedToFastStorage() || tree.ndb.shouldForceFastStorageUpgrade() +} + +// enableFastStorageAndCommitIfNotEnabled if nodeDB doesn't mark fast storage as enabled, enable it, and commit the update. +// Checks whether the fast cache on disk matches latest live state. If not, deletes all existing fast nodes and repopulates them +// from latest tree. +// nolint: unparam +func (tree *MutableTree) enableFastStorageAndCommitIfNotEnabled() (bool, error) { + shouldForceUpdate := tree.ndb.shouldForceFastStorageUpgrade() + isFastStorageEnabled := tree.ndb.hasUpgradedToFastStorage() + + if !tree.IsUpgradeable() { + return false, nil + } + + if isFastStorageEnabled && shouldForceUpdate { + // If there is a mismatch between which fast nodes are on disk and the live state due to temporary + // downgrade and subsequent re-upgrade, we cannot know for sure which fast nodes have been removed while downgraded, + // Therefore, there might exist stale fast nodes on disk. As a result, to avoid persisting the stale state, it might + // be worth to delete the fast nodes from disk. + fastItr := NewFastIterator(nil, nil, true, tree.ndb) + defer fastItr.Close() + for ; fastItr.Valid(); fastItr.Next() { + if err := tree.ndb.DeleteFastNode(fastItr.Key()); err != nil { + return false, err + } + } + } + + // Force garbage collection before we proceed to enabling fast storage. + runtime.GC() + + if err := tree.enableFastStorageAndCommit(); err != nil { + tree.ndb.storageVersion = defaultStorageVersionValue + return false, err + } + return true, nil +} + +func (tree *MutableTree) enableFastStorageAndCommitLocked() error { + tree.mtx.Lock() + defer tree.mtx.Unlock() + return tree.enableFastStorageAndCommit() +} + +func (tree *MutableTree) enableFastStorageAndCommit() error { + var err error + + // We start a new thread to keep on checking if we are above 4GB, and if so garbage collect. + // This thread only lasts during the fast node migration. + // This is done to keep RAM usage down. + done := make(chan struct{}) + defer func() { + done <- struct{}{} + close(done) + }() + + go func() { + timer := time.NewTimer(time.Second) + var m runtime.MemStats + + for { + // Sample the current memory usage + runtime.ReadMemStats(&m) + + if m.Alloc > 4*1024*1024*1024 { + // If we are using more than 4GB of memory, we should trigger garbage collection + // to free up some memory. + runtime.GC() + } + + select { + case <-timer.C: + timer.Reset(time.Second) + case <-done: + if !timer.Stop() { + <-timer.C + } + return + } + } + }() + + itr := NewIterator(nil, nil, true, tree.ImmutableTree) + defer itr.Close() + for ; itr.Valid(); itr.Next() { + if err = tree.ndb.SaveFastNodeNoCache(NewFastNode(itr.Key(), itr.Value(), tree.version)); err != nil { + return err + } + } + + if err = itr.Error(); err != nil { + return err + } + + if err = tree.ndb.setFastStorageVersionToBatch(); err != nil { + return err + } + + return tree.ndb.Commit() +} + // GetImmutable loads an ImmutableTree at a given version for querying. The returned tree is // safe for concurrent access, provided the version is not deleted, e.g. via `DeleteVersion()`. func (tree *MutableTree) GetImmutable(version int64) (*ImmutableTree, error) { @@ -476,21 +654,32 @@ func (tree *MutableTree) Rollback() { tree.ImmutableTree = &ImmutableTree{ndb: tree.ndb, version: 0} } tree.orphans = map[string]int64{} + tree.unsavedFastNodeAdditions = map[string]*FastNode{} + tree.unsavedFastNodeRemovals = map[string]interface{}{} } // GetVersioned gets the value at the specified key and version. The returned value must not be // modified, since it may point to data stored within IAVL. -func (tree *MutableTree) GetVersioned(key []byte, version int64) ( - index int64, value []byte, -) { +func (tree *MutableTree) GetVersioned(key []byte, version int64) []byte { if tree.VersionExists(version) { + if tree.IsFastCacheEnabled() { + fastNode, _ := tree.ndb.GetFastNode(key) + if fastNode == nil && version == tree.ndb.latestVersion { + return nil + } + + if fastNode != nil && fastNode.versionLastUpdatedAt <= version { + return fastNode.value + } + } t, err := tree.GetImmutable(version) if err != nil { - return -1, nil + return nil } - return t.Get(key) + value := t.Get(key) + return value } - return -1, nil + return nil } // SaveVersion saves a new tree version to disk, based on the current state of @@ -538,13 +727,19 @@ func (tree *MutableTree) SaveVersion() ([]byte, int64, error) { } } else { logger.Debug("SAVE TREE %v\n", version) - tree.ndb.SaveBranch(tree.root) + if _, err := tree.ndb.SaveBranch(tree.root); err != nil { + return nil, 0, err + } tree.ndb.SaveOrphans(version, tree.orphans) if err := tree.ndb.SaveRoot(tree.root, version); err != nil { return nil, 0, err } } + if err := tree.saveFastNodeVersion(); err != nil { + return nil, version, err + } + if err := tree.ndb.Commit(); err != nil { return nil, version, err } @@ -558,10 +753,73 @@ func (tree *MutableTree) SaveVersion() ([]byte, int64, error) { tree.ImmutableTree = tree.ImmutableTree.clone() tree.lastSaved = tree.ImmutableTree.clone() tree.orphans = map[string]int64{} + tree.unsavedFastNodeAdditions = make(map[string]*FastNode) + tree.unsavedFastNodeRemovals = make(map[string]interface{}) return tree.Hash(), version, nil } +func (tree *MutableTree) saveFastNodeVersion() error { + if err := tree.saveFastNodeAdditions(); err != nil { + return err + } + if err := tree.saveFastNodeRemovals(); err != nil { + return err + } + return tree.ndb.setFastStorageVersionToBatch() +} + +// nolint: unused +func (tree *MutableTree) getUnsavedFastNodeAdditions() map[string]*FastNode { + return tree.unsavedFastNodeAdditions +} + +// getUnsavedFastNodeRemovals returns unsaved FastNodes to remove +// nolint: unused +func (tree *MutableTree) getUnsavedFastNodeRemovals() map[string]interface{} { + return tree.unsavedFastNodeRemovals +} + +func (tree *MutableTree) addUnsavedAddition(key []byte, node *FastNode) { + delete(tree.unsavedFastNodeRemovals, string(key)) + tree.unsavedFastNodeAdditions[string(key)] = node +} + +func (tree *MutableTree) saveFastNodeAdditions() error { + keysToSort := make([]string, 0, len(tree.unsavedFastNodeAdditions)) + for key := range tree.unsavedFastNodeAdditions { + keysToSort = append(keysToSort, key) + } + sort.Strings(keysToSort) + + for _, key := range keysToSort { + if err := tree.ndb.SaveFastNode(tree.unsavedFastNodeAdditions[key]); err != nil { + return err + } + } + return nil +} + +func (tree *MutableTree) addUnsavedRemoval(key []byte) { + delete(tree.unsavedFastNodeAdditions, string(key)) + tree.unsavedFastNodeRemovals[string(key)] = true +} + +func (tree *MutableTree) saveFastNodeRemovals() error { + keysToSort := make([]string, 0, len(tree.unsavedFastNodeRemovals)) + for key := range tree.unsavedFastNodeRemovals { + keysToSort = append(keysToSort, key) + } + sort.Strings(keysToSort) + + for _, key := range keysToSort { + if err := tree.ndb.DeleteFastNode([]byte(key)); err != nil { + return err + } + } + return nil +} + func (tree *MutableTree) deleteVersion(version int64) error { if version <= 0 { return errors.New("version must be greater than 0") diff --git a/mutable_tree_test.go b/mutable_tree_test.go index f7a75602c..d24cb0237 100644 --- a/mutable_tree_test.go +++ b/mutable_tree_test.go @@ -2,11 +2,16 @@ package iavl import ( "bytes" + "errors" "fmt" "runtime" + "sort" "strconv" "testing" + "github.com/cosmos/iavl/internal/encoding" + "github.com/cosmos/iavl/mock" + "github.com/golang/mock/gomock" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" @@ -101,7 +106,7 @@ func TestMutableTree_DeleteVersions(t *testing.T) { require.NoError(t, err) for _, e := range versionEntries[v] { - _, val := tree.Get(e.key) + val := tree.Get(e.key) require.Equal(t, e.value, val) } } @@ -178,12 +183,12 @@ func TestMutableTree_DeleteVersionsRange(t *testing.T) { require.NoError(err, version) require.Equal(v, version) - _, value := tree.Get([]byte("aaa")) + value := tree.Get([]byte("aaa")) require.Equal(string(value), "bbb") for _, count := range versions[:version] { countStr := strconv.Itoa(int(count)) - _, value := tree.Get([]byte("key" + countStr)) + value := tree.Get([]byte("key" + countStr)) require.Equal(string(value), "value"+countStr) } } @@ -202,17 +207,17 @@ func TestMutableTree_DeleteVersionsRange(t *testing.T) { require.NoError(err) require.Equal(v, version) - _, value := tree.Get([]byte("aaa")) + value := tree.Get([]byte("aaa")) require.Equal(string(value), "bbb") for _, count := range versions[:fromLength] { countStr := strconv.Itoa(int(count)) - _, value := tree.Get([]byte("key" + countStr)) + value := tree.Get([]byte("key" + countStr)) require.Equal(string(value), "value"+countStr) } for _, count := range versions[int64(maxLength/2)-1 : version] { countStr := strconv.Itoa(int(count)) - _, value := tree.Get([]byte("key" + countStr)) + value := tree.Get([]byte("key" + countStr)) require.Equal(string(value), "value"+countStr) } } @@ -318,9 +323,8 @@ func TestMutableTree_VersionExists(t *testing.T) { require.False(t, tree.VersionExists(3)) } -func checkGetVersioned(t *testing.T, tree *MutableTree, version, index int64, key, value []byte) { - idx, val := tree.GetVersioned(key, version) - require.True(t, idx == index) +func checkGetVersioned(t *testing.T, tree *MutableTree, version int64, key, value []byte) { + val := tree.GetVersioned(key, version) require.True(t, bytes.Equal(val, value)) } @@ -330,17 +334,17 @@ func TestMutableTree_GetVersioned(t *testing.T) { require.True(t, ver == 1) require.NoError(t, err) // check key of unloaded version - checkGetVersioned(t, tree, 1, 1, []byte{1}, []byte("a")) - checkGetVersioned(t, tree, 2, 1, []byte{1}, []byte("b")) - checkGetVersioned(t, tree, 3, -1, []byte{1}, nil) + checkGetVersioned(t, tree, 1, []byte{1}, []byte("a")) + checkGetVersioned(t, tree, 2, []byte{1}, []byte("b")) + checkGetVersioned(t, tree, 3, []byte{1}, nil) tree = prepareTree(t) ver, err = tree.LazyLoadVersion(2) require.True(t, ver == 2) require.NoError(t, err) - checkGetVersioned(t, tree, 1, 1, []byte{1}, []byte("a")) - checkGetVersioned(t, tree, 2, 1, []byte{1}, []byte("b")) - checkGetVersioned(t, tree, 3, -1, []byte{1}, nil) + checkGetVersioned(t, tree, 1, []byte{1}, []byte("a")) + checkGetVersioned(t, tree, 2, []byte{1}, []byte("b")) + checkGetVersioned(t, tree, 3, []byte{1}, nil) } func TestMutableTree_DeleteVersion(t *testing.T) { @@ -380,3 +384,648 @@ func TestMutableTree_LazyLoadVersionWithEmptyTree(t *testing.T) { require.True(t, newTree1.root == newTree2.root) } + +func TestMutableTree_SetSimple(t *testing.T) { + mdb := db.NewMemDB() + tree, err := NewMutableTree(mdb, 0) + require.NoError(t, err) + + const testKey1 = "a" + const testVal1 = "test" + + isUpdated := tree.Set([]byte(testKey1), []byte(testVal1)) + require.False(t, isUpdated) + + fastValue := tree.Get([]byte(testKey1)) + _, regularValue := tree.GetWithIndex([]byte(testKey1)) + + require.Equal(t, []byte(testVal1), fastValue) + require.Equal(t, []byte(testVal1), regularValue) + + fastNodeAdditions := tree.getUnsavedFastNodeAdditions() + require.Equal(t, 1, len(fastNodeAdditions)) + + fastNodeAddition := fastNodeAdditions[testKey1] + require.Equal(t, []byte(testKey1), fastNodeAddition.key) + require.Equal(t, []byte(testVal1), fastNodeAddition.value) + require.Equal(t, int64(1), fastNodeAddition.versionLastUpdatedAt) +} + +func TestMutableTree_SetTwoKeys(t *testing.T) { + mdb := db.NewMemDB() + tree, err := NewMutableTree(mdb, 0) + require.NoError(t, err) + + const testKey1 = "a" + const testVal1 = "test" + + const testKey2 = "b" + const testVal2 = "test2" + + isUpdated := tree.Set([]byte(testKey1), []byte(testVal1)) + require.False(t, isUpdated) + + isUpdated = tree.Set([]byte(testKey2), []byte(testVal2)) + require.False(t, isUpdated) + + fastValue := tree.Get([]byte(testKey1)) + _, regularValue := tree.GetWithIndex([]byte(testKey1)) + require.Equal(t, []byte(testVal1), fastValue) + require.Equal(t, []byte(testVal1), regularValue) + + fastValue2 := tree.Get([]byte(testKey2)) + _, regularValue2 := tree.GetWithIndex([]byte(testKey2)) + require.Equal(t, []byte(testVal2), fastValue2) + require.Equal(t, []byte(testVal2), regularValue2) + + fastNodeAdditions := tree.getUnsavedFastNodeAdditions() + require.Equal(t, 2, len(fastNodeAdditions)) + + fastNodeAddition := fastNodeAdditions[testKey1] + require.Equal(t, []byte(testKey1), fastNodeAddition.key) + require.Equal(t, []byte(testVal1), fastNodeAddition.value) + require.Equal(t, int64(1), fastNodeAddition.versionLastUpdatedAt) + + fastNodeAddition = fastNodeAdditions[testKey2] + require.Equal(t, []byte(testKey2), fastNodeAddition.key) + require.Equal(t, []byte(testVal2), fastNodeAddition.value) + require.Equal(t, int64(1), fastNodeAddition.versionLastUpdatedAt) +} + +func TestMutableTree_SetOverwrite(t *testing.T) { + mdb := db.NewMemDB() + tree, err := NewMutableTree(mdb, 0) + require.NoError(t, err) + + const testKey1 = "a" + const testVal1 = "test" + const testVal2 = "test2" + + isUpdated := tree.Set([]byte(testKey1), []byte(testVal1)) + require.False(t, isUpdated) + + isUpdated = tree.Set([]byte(testKey1), []byte(testVal2)) + require.True(t, isUpdated) + + fastValue := tree.Get([]byte(testKey1)) + _, regularValue := tree.GetWithIndex([]byte(testKey1)) + require.Equal(t, []byte(testVal2), fastValue) + require.Equal(t, []byte(testVal2), regularValue) + + fastNodeAdditions := tree.getUnsavedFastNodeAdditions() + require.Equal(t, 1, len(fastNodeAdditions)) + + fastNodeAddition := fastNodeAdditions[testKey1] + require.Equal(t, []byte(testKey1), fastNodeAddition.key) + require.Equal(t, []byte(testVal2), fastNodeAddition.value) + require.Equal(t, int64(1), fastNodeAddition.versionLastUpdatedAt) +} + +func TestMutableTree_SetRemoveSet(t *testing.T) { + mdb := db.NewMemDB() + tree, err := NewMutableTree(mdb, 0) + require.NoError(t, err) + + const testKey1 = "a" + const testVal1 = "test" + + // Set 1 + isUpdated := tree.Set([]byte(testKey1), []byte(testVal1)) + require.False(t, isUpdated) + + fastValue := tree.Get([]byte(testKey1)) + _, regularValue := tree.GetWithIndex([]byte(testKey1)) + require.Equal(t, []byte(testVal1), fastValue) + require.Equal(t, []byte(testVal1), regularValue) + + fastNodeAdditions := tree.getUnsavedFastNodeAdditions() + require.Equal(t, 1, len(fastNodeAdditions)) + + fastNodeAddition := fastNodeAdditions[testKey1] + require.Equal(t, []byte(testKey1), fastNodeAddition.key) + require.Equal(t, []byte(testVal1), fastNodeAddition.value) + require.Equal(t, int64(1), fastNodeAddition.versionLastUpdatedAt) + + // Remove + removedVal, isRemoved := tree.Remove([]byte(testKey1)) + require.NotNil(t, removedVal) + require.True(t, isRemoved) + + fastNodeAdditions = tree.getUnsavedFastNodeAdditions() + require.Equal(t, 0, len(fastNodeAdditions)) + + fastNodeRemovals := tree.getUnsavedFastNodeRemovals() + require.Equal(t, 1, len(fastNodeRemovals)) + + fastValue = tree.Get([]byte(testKey1)) + _, regularValue = tree.GetWithIndex([]byte(testKey1)) + require.Nil(t, fastValue) + require.Nil(t, regularValue) + + // Set 2 + isUpdated = tree.Set([]byte(testKey1), []byte(testVal1)) + require.False(t, isUpdated) + + fastValue = tree.Get([]byte(testKey1)) + _, regularValue = tree.GetWithIndex([]byte(testKey1)) + require.Equal(t, []byte(testVal1), fastValue) + require.Equal(t, []byte(testVal1), regularValue) + + fastNodeAdditions = tree.getUnsavedFastNodeAdditions() + require.Equal(t, 1, len(fastNodeAdditions)) + + fastNodeAddition = fastNodeAdditions[testKey1] + require.Equal(t, []byte(testKey1), fastNodeAddition.key) + require.Equal(t, []byte(testVal1), fastNodeAddition.value) + require.Equal(t, int64(1), fastNodeAddition.versionLastUpdatedAt) + + fastNodeRemovals = tree.getUnsavedFastNodeRemovals() + require.Equal(t, 0, len(fastNodeRemovals)) +} + +func TestMutableTree_FastNodeIntegration(t *testing.T) { + mdb := db.NewMemDB() + tree, err := NewMutableTree(mdb, 1000) + require.NoError(t, err) + + const key1 = "a" + const key2 = "b" + const key3 = "c" + + const testVal1 = "test" + const testVal2 = "test2" + + // Set key1 + res := tree.Set([]byte(key1), []byte(testVal1)) + require.False(t, res) + + unsavedNodeAdditions := tree.getUnsavedFastNodeAdditions() + require.Equal(t, len(unsavedNodeAdditions), 1) + + // Set key2 + res = tree.Set([]byte(key2), []byte(testVal1)) + require.False(t, res) + + unsavedNodeAdditions = tree.getUnsavedFastNodeAdditions() + require.Equal(t, len(unsavedNodeAdditions), 2) + + // Set key3 + res = tree.Set([]byte(key3), []byte(testVal1)) + require.False(t, res) + + unsavedNodeAdditions = tree.getUnsavedFastNodeAdditions() + require.Equal(t, len(unsavedNodeAdditions), 3) + + // Set key3 with new value + res = tree.Set([]byte(key3), []byte(testVal2)) + require.True(t, res) + + unsavedNodeAdditions = tree.getUnsavedFastNodeAdditions() + require.Equal(t, len(unsavedNodeAdditions), 3) + + // Remove key2 + removedVal, isRemoved := tree.Remove([]byte(key2)) + require.True(t, isRemoved) + require.Equal(t, []byte(testVal1), removedVal) + + unsavedNodeAdditions = tree.getUnsavedFastNodeAdditions() + require.Equal(t, len(unsavedNodeAdditions), 2) + + unsavedNodeRemovals := tree.getUnsavedFastNodeRemovals() + require.Equal(t, len(unsavedNodeRemovals), 1) + + // Save + _, _, err = tree.SaveVersion() + require.NoError(t, err) + + unsavedNodeAdditions = tree.getUnsavedFastNodeAdditions() + require.Equal(t, len(unsavedNodeAdditions), 0) + + unsavedNodeRemovals = tree.getUnsavedFastNodeRemovals() + require.Equal(t, len(unsavedNodeRemovals), 0) + + // Load + t2, err := NewMutableTree(mdb, 0) + require.NoError(t, err) + + _, err = t2.Load() + require.NoError(t, err) + + // Get and GetFast + fastValue := t2.Get([]byte(key1)) + _, regularValue := tree.GetWithIndex([]byte(key1)) + require.Equal(t, []byte(testVal1), fastValue) + require.Equal(t, []byte(testVal1), regularValue) + + fastValue = t2.Get([]byte(key2)) + _, regularValue = t2.GetWithIndex([]byte(key2)) + require.Nil(t, fastValue) + require.Nil(t, regularValue) + + fastValue = t2.Get([]byte(key3)) + _, regularValue = tree.GetWithIndex([]byte(key3)) + require.Equal(t, []byte(testVal2), fastValue) + require.Equal(t, []byte(testVal2), regularValue) +} + +func TestIterate_MutableTree_Unsaved(t *testing.T) { + tree, mirror := getRandomizedTreeAndMirror(t) + assertMutableMirrorIterate(t, tree, mirror) +} + +func TestIterate_MutableTree_Saved(t *testing.T) { + tree, mirror := getRandomizedTreeAndMirror(t) + + _, _, err := tree.SaveVersion() + require.NoError(t, err) + + assertMutableMirrorIterate(t, tree, mirror) +} + +func TestIterate_MutableTree_Unsaved_NextVersion(t *testing.T) { + tree, mirror := getRandomizedTreeAndMirror(t) + + _, _, err := tree.SaveVersion() + require.NoError(t, err) + + assertMutableMirrorIterate(t, tree, mirror) + + randomizeTreeAndMirror(t, tree, mirror) + + assertMutableMirrorIterate(t, tree, mirror) +} + +func TestIterator_MutableTree_Invalid(t *testing.T) { + tree, err := getTestTree(0) + require.NoError(t, err) + + itr := tree.Iterator([]byte("a"), []byte("b"), true) + + require.NotNil(t, itr) + require.False(t, itr.Valid()) +} + +func TestUpgradeStorageToFast_LatestVersion_Success(t *testing.T) { + // Setup + db := db.NewMemDB() + tree, err := NewMutableTree(db, 1000) + + // Default version when storage key does not exist in the db + require.NoError(t, err) + require.False(t, tree.IsFastCacheEnabled()) + + mirror := make(map[string]string) + // Fill with some data + randomizeTreeAndMirror(t, tree, mirror) + + // Enable fast storage + require.True(t, tree.IsUpgradeable()) + enabled, err := tree.enableFastStorageAndCommitIfNotEnabled() + require.NoError(t, err) + require.True(t, enabled) + require.False(t, tree.IsUpgradeable()) + + require.True(t, tree.IsFastCacheEnabled()) +} + +func TestUpgradeStorageToFast_AlreadyUpgraded_Success(t *testing.T) { + // Setup + db := db.NewMemDB() + tree, err := NewMutableTree(db, 1000) + + // Default version when storage key does not exist in the db + require.NoError(t, err) + require.False(t, tree.IsFastCacheEnabled()) + + mirror := make(map[string]string) + // Fill with some data + randomizeTreeAndMirror(t, tree, mirror) + + // Enable fast storage + require.True(t, tree.IsUpgradeable()) + enabled, err := tree.enableFastStorageAndCommitIfNotEnabled() + require.NoError(t, err) + require.True(t, enabled) + require.True(t, tree.IsFastCacheEnabled()) + require.False(t, tree.IsUpgradeable()) + + // Test enabling fast storage when already enabled + enabled, err = tree.enableFastStorageAndCommitIfNotEnabled() + require.NoError(t, err) + require.False(t, enabled) + require.True(t, tree.IsFastCacheEnabled()) + +} + +func TestUpgradeStorageToFast_DbErrorConstructor_Failure(t *testing.T) { + ctrl := gomock.NewController(t) + dbMock := mock.NewMockDB(ctrl) + rIterMock := mock.NewMockIterator(ctrl) + + // rIterMock is used to get the latest version from disk. We are mocking that rIterMock returns latestTreeVersion from disk + rIterMock.EXPECT().Valid().Return(true).Times(1) + rIterMock.EXPECT().Key().Return(rootKeyFormat.Key([]byte(defaultStorageVersionValue))) + rIterMock.EXPECT().Close().Return(nil).Times(1) + + expectedError := errors.New("some db error") + + dbMock.EXPECT().Get(gomock.Any()).Return(nil, expectedError).Times(1) + dbMock.EXPECT().NewBatch().Return(nil).Times(1) + dbMock.EXPECT().ReverseIterator(gomock.Any(), gomock.Any()).Return(rIterMock, nil).Times(1) + + tree, err := NewMutableTree(dbMock, 0) + require.Nil(t, err) + require.NotNil(t, tree) + require.False(t, tree.IsFastCacheEnabled()) +} + +func TestUpgradeStorageToFast_DbErrorEnableFastStorage_Failure(t *testing.T) { + ctrl := gomock.NewController(t) + dbMock := mock.NewMockDB(ctrl) + rIterMock := mock.NewMockIterator(ctrl) + + // rIterMock is used to get the latest version from disk. We are mocking that rIterMock returns latestTreeVersion from disk + rIterMock.EXPECT().Valid().Return(true).Times(1) + rIterMock.EXPECT().Key().Return(rootKeyFormat.Key([]byte(defaultStorageVersionValue))) + rIterMock.EXPECT().Close().Return(nil).Times(1) + + expectedError := errors.New("some db error") + + batchMock := mock.NewMockBatch(ctrl) + + dbMock.EXPECT().Get(gomock.Any()).Return(nil, nil).Times(1) + dbMock.EXPECT().NewBatch().Return(batchMock).Times(1) + dbMock.EXPECT().ReverseIterator(gomock.Any(), gomock.Any()).Return(rIterMock, nil).Times(1) + + batchMock.EXPECT().Set(gomock.Any(), gomock.Any()).Return(expectedError).Times(1) + + tree, err := NewMutableTree(dbMock, 0) + require.Nil(t, err) + require.NotNil(t, tree) + require.False(t, tree.IsFastCacheEnabled()) + + enabled, err := tree.enableFastStorageAndCommitIfNotEnabled() + require.ErrorIs(t, err, expectedError) + require.False(t, enabled) + require.False(t, tree.IsFastCacheEnabled()) +} + +func TestFastStorageReUpgradeProtection_NoForceUpgrade_Success(t *testing.T) { + ctrl := gomock.NewController(t) + dbMock := mock.NewMockDB(ctrl) + rIterMock := mock.NewMockIterator(ctrl) + + // We are trying to test downgrade and re-upgrade protection + // We need to set up a state where latest fast storage version is equal to latest tree version + const latestFastStorageVersionOnDisk = 1 + const latestTreeVersion = latestFastStorageVersionOnDisk + + // Setup fake reverse iterator db to traverse root versions, called by ndb's getLatestVersion + expectedStorageVersion := []byte(fastStorageVersionValue + fastStorageVersionDelimiter + strconv.Itoa(latestFastStorageVersionOnDisk)) + + // rIterMock is used to get the latest version from disk. We are mocking that rIterMock returns latestTreeVersion from disk + rIterMock.EXPECT().Valid().Return(true).Times(1) + rIterMock.EXPECT().Key().Return(rootKeyFormat.Key(latestTreeVersion)) + rIterMock.EXPECT().Close().Return(nil).Times(1) + + batchMock := mock.NewMockBatch(ctrl) + + dbMock.EXPECT().Get(gomock.Any()).Return(expectedStorageVersion, nil).Times(1) + dbMock.EXPECT().NewBatch().Return(batchMock).Times(1) + dbMock.EXPECT().ReverseIterator(gomock.Any(), gomock.Any()).Return(rIterMock, nil).Times(1) // called to get latest version + + tree, err := NewMutableTree(dbMock, 0) + require.Nil(t, err) + require.NotNil(t, tree) + + // Pretend that we called Load and have the latest state in the tree + tree.version = latestTreeVersion + require.Equal(t, tree.ndb.getLatestVersion(), int64(latestTreeVersion)) + + // Ensure that the right branch of enableFastStorageAndCommitIfNotEnabled will be triggered + require.True(t, tree.IsFastCacheEnabled()) + require.False(t, tree.ndb.shouldForceFastStorageUpgrade()) + + enabled, err := tree.enableFastStorageAndCommitIfNotEnabled() + require.NoError(t, err) + require.False(t, enabled) +} + +func TestFastStorageReUpgradeProtection_ForceUpgradeFirstTime_NoForceSecondTime_Success(t *testing.T) { + ctrl := gomock.NewController(t) + dbMock := mock.NewMockDB(ctrl) + batchMock := mock.NewMockBatch(ctrl) + iterMock := mock.NewMockIterator(ctrl) + rIterMock := mock.NewMockIterator(ctrl) + + // We are trying to test downgrade and re-upgrade protection + // We need to set up a state where latest fast storage version is of a lower version + // than tree version + const latestFastStorageVersionOnDisk = 1 + const latestTreeVersion = latestFastStorageVersionOnDisk + 1 + + // Setup db for iterator and reverse iterator mocks + expectedStorageVersion := []byte(fastStorageVersionValue + fastStorageVersionDelimiter + strconv.Itoa(latestFastStorageVersionOnDisk)) + + // Setup fake reverse iterator db to traverse root versions, called by ndb's getLatestVersion + // rItr, err := db.ReverseIterator(rootKeyFormat.Key(1), rootKeyFormat.Key(latestTreeVersion + 1)) + // require.NoError(t, err) + + // dbMock represents the underlying database under the hood of nodeDB + dbMock.EXPECT().Get(gomock.Any()).Return(expectedStorageVersion, nil).Times(1) + dbMock.EXPECT().NewBatch().Return(batchMock).Times(2) + dbMock.EXPECT().ReverseIterator(gomock.Any(), gomock.Any()).Return(rIterMock, nil).Times(1) // called to get latest version + startFormat := fastKeyFormat.Key() + endFormat := fastKeyFormat.Key() + endFormat[0]++ + dbMock.EXPECT().Iterator(startFormat, endFormat).Return(iterMock, nil).Times(1) + + // rIterMock is used to get the latest version from disk. We are mocking that rIterMock returns latestTreeVersion from disk + rIterMock.EXPECT().Valid().Return(true).Times(1) + rIterMock.EXPECT().Key().Return(rootKeyFormat.Key(latestTreeVersion)) + rIterMock.EXPECT().Close().Return(nil).Times(1) + + fastNodeKeyToDelete := []byte("some_key") + + // batchMock represents a structure that receives all the updates related to + // upgrade and then commits them all in the end. + updatedExpectedStorageVersion := make([]byte, len(expectedStorageVersion)) + copy(updatedExpectedStorageVersion, expectedStorageVersion) + updatedExpectedStorageVersion[len(updatedExpectedStorageVersion)-1]++ + batchMock.EXPECT().Delete(fastKeyFormat.Key(fastNodeKeyToDelete)).Return(nil).Times(1) + batchMock.EXPECT().Set(metadataKeyFormat.Key([]byte(storageVersionKey)), updatedExpectedStorageVersion).Return(nil).Times(1) + batchMock.EXPECT().Write().Return(nil).Times(1) + batchMock.EXPECT().Close().Return(nil).Times(1) + + // iterMock is used to mock the underlying db iterator behing fast iterator + // Here, we want to mock the behavior of deleting fast nodes from disk when + // force upgrade is detected. + iterMock.EXPECT().Valid().Return(true).Times(1) + iterMock.EXPECT().Error().Return(nil).Times(1) + iterMock.EXPECT().Key().Return(fastKeyFormat.Key(fastNodeKeyToDelete)).Times(1) + // encode value + var buf bytes.Buffer + testValue := "test_value" + buf.Grow(encoding.EncodeVarintSize(int64(latestFastStorageVersionOnDisk)) + encoding.EncodeBytesSize([]byte(testValue))) + err := encoding.EncodeVarint(&buf, int64(latestFastStorageVersionOnDisk)) + require.NoError(t, err) + err = encoding.EncodeBytes(&buf, []byte(testValue)) + require.NoError(t, err) + iterMock.EXPECT().Value().Return(buf.Bytes()).Times(1) // this is encoded as version 1 with value "2" + iterMock.EXPECT().Valid().Return(true).Times(1) + // Call Next at the end of loop iteration + iterMock.EXPECT().Next().Return().Times(1) + iterMock.EXPECT().Error().Return(nil).Times(1) + iterMock.EXPECT().Valid().Return(false).Times(1) + // Call Valid after first iteraton + iterMock.EXPECT().Valid().Return(false).Times(1) + iterMock.EXPECT().Close().Return(nil).Times(1) + + tree, err := NewMutableTree(dbMock, 0) + require.Nil(t, err) + require.NotNil(t, tree) + + // Pretend that we called Load and have the latest state in the tree + tree.version = latestTreeVersion + require.Equal(t, tree.ndb.getLatestVersion(), int64(latestTreeVersion)) + + // Ensure that the right branch of enableFastStorageAndCommitIfNotEnabled will be triggered + require.True(t, tree.IsFastCacheEnabled()) + require.True(t, tree.ndb.shouldForceFastStorageUpgrade()) + + // Actual method under test + enabled, err := tree.enableFastStorageAndCommitIfNotEnabled() + require.NoError(t, err) + require.True(t, enabled) + + // Test that second time we call this, force upgrade does not happen + enabled, err = tree.enableFastStorageAndCommitIfNotEnabled() + require.NoError(t, err) + require.False(t, enabled) +} + +func TestUpgradeStorageToFast_Integration_Upgraded_FastIterator_Success(t *testing.T) { + // Setup + tree, mirror := setupTreeAndMirrorForUpgrade(t) + + require.False(t, tree.IsFastCacheEnabled()) + require.True(t, tree.IsUpgradeable()) + + // Should auto enable in save version + _, _, err := tree.SaveVersion() + require.NoError(t, err) + + require.True(t, tree.IsFastCacheEnabled()) + require.False(t, tree.IsUpgradeable()) + + sut, _ := NewMutableTree(tree.ndb.db, 1000) + + require.False(t, sut.IsFastCacheEnabled()) + require.False(t, sut.IsUpgradeable()) // upgraded in save version + + // Load version - should auto enable fast storage + version, err := sut.Load() + require.NoError(t, err) + + require.True(t, sut.IsFastCacheEnabled()) + + require.Equal(t, int64(1), version) + + // Test that upgraded mutable tree iterates as expected + t.Run("Mutable tree", func(t *testing.T) { + i := 0 + sut.Iterate(func(k, v []byte) bool { + require.Equal(t, []byte(mirror[i][0]), k) + require.Equal(t, []byte(mirror[i][1]), v) + i++ + return false + }) + }) + + // Test that upgraded immutable tree iterates as expected + t.Run("Immutable tree", func(t *testing.T) { + immutableTree, err := sut.GetImmutable(sut.version) + require.NoError(t, err) + + i := 0 + immutableTree.Iterate(func(k, v []byte) bool { + require.Equal(t, []byte(mirror[i][0]), k) + require.Equal(t, []byte(mirror[i][1]), v) + i++ + return false + }) + }) +} + +func TestUpgradeStorageToFast_Integration_Upgraded_GetFast_Success(t *testing.T) { + // Setup + tree, mirror := setupTreeAndMirrorForUpgrade(t) + + require.False(t, tree.IsFastCacheEnabled()) + require.True(t, tree.IsUpgradeable()) + + // Should auto enable in save version + _, _, err := tree.SaveVersion() + require.NoError(t, err) + + require.True(t, tree.IsFastCacheEnabled()) + require.False(t, tree.IsUpgradeable()) + + sut, _ := NewMutableTree(tree.ndb.db, 1000) + + require.False(t, sut.IsFastCacheEnabled()) + require.False(t, sut.IsUpgradeable()) // upgraded in save version + + // LazyLoadVersion - should auto enable fast storage + version, err := sut.LazyLoadVersion(1) + require.NoError(t, err) + + require.True(t, sut.IsFastCacheEnabled()) + + require.Equal(t, int64(1), version) + + t.Run("Mutable tree", func(t *testing.T) { + for _, kv := range mirror { + v := sut.Get([]byte(kv[0])) + require.Equal(t, []byte(kv[1]), v) + } + }) + + t.Run("Immutable tree", func(t *testing.T) { + immutableTree, err := sut.GetImmutable(sut.version) + require.NoError(t, err) + + for _, kv := range mirror { + v := immutableTree.Get([]byte(kv[0])) + require.Equal(t, []byte(kv[1]), v) + } + }) +} + +func setupTreeAndMirrorForUpgrade(t *testing.T) (*MutableTree, [][]string) { + db := db.NewMemDB() + + tree, _ := NewMutableTree(db, 0) + + const numEntries = 100 + var keyPrefix, valPrefix = "key", "val" + + mirror := make([][]string, 0, numEntries) + for i := 0; i < numEntries; i++ { + key := fmt.Sprintf("%s_%d", keyPrefix, i) + val := fmt.Sprintf("%s_%d", valPrefix, i) + mirror = append(mirror, []string{key, val}) + require.False(t, tree.Set([]byte(key), []byte(val))) + } + + // Delete fast nodes from database to mimic a version with no upgrade + for i := 0; i < numEntries; i++ { + key := fmt.Sprintf("%s_%d", keyPrefix, i) + require.NoError(t, db.Delete(fastKeyFormat.Key([]byte(key)))) + } + + sort.Slice(mirror, func(i, j int) bool { + return mirror[i][0] < mirror[j][0] + }) + return tree, mirror +} diff --git a/node.go b/node.go index f2befb934..b7e369727 100644 --- a/node.go +++ b/node.go @@ -159,6 +159,9 @@ func (node *Node) has(t *ImmutableTree, key []byte) (has bool) { } // Get a key under the node. +// +// The index is the index in the list of leaf nodes sorted lexicographically by key. The leftmost leaf has index 0. +// It's neighbor has index 1 and so on. func (node *Node) get(t *ImmutableTree, key []byte) (index int64, value []byte) { if node.isLeaf() { switch bytes.Compare(node.key, key) { @@ -441,6 +444,7 @@ func (node *Node) calcBalance(t *ImmutableTree) int { } // traverse is a wrapper over traverseInRange when we want the whole tree +// nolint: unparam func (node *Node) traverse(t *ImmutableTree, ascending bool, cb func(*Node) bool) bool { return node.traverseInRange(t, nil, nil, ascending, false, false, func(node *Node) bool { return cb(node) diff --git a/nodedb.go b/nodedb.go index d1313ab5d..06e887002 100644 --- a/nodedb.go +++ b/nodedb.go @@ -7,6 +7,8 @@ import ( "fmt" "math" "sort" + "strconv" + "strings" "sync" "github.com/cosmos/iavl/internal/logger" @@ -15,9 +17,19 @@ import ( ) const ( - int64Size = 8 - hashSize = sha256.Size - genesisVersion = 1 + int64Size = 8 + hashSize = sha256.Size + genesisVersion = 1 + storageVersionKey = "storage_version" + // We store latest saved version together with storage version delimited by the constant below. + // This delimiter is valid only if fast storage is enabled (i.e. storageVersion >= fastStorageVersionValue). + // The latest saved version is needed for protection against downgrade and re-upgrade. In such a case, it would + // be possible to observe mismatch between the latest version state and the fast nodes on disk. + // Therefore, we would like to detect that and overwrite fast nodes on disk with the latest version state. + fastStorageVersionDelimiter = "-" + // Using semantic versioning: https://semver.org/ + defaultStorageVersionValue = "1.0.0" + fastStorageVersionValue = "1.1.0" ) var ( @@ -30,23 +42,47 @@ var ( // to exist, while the second number represents the *earliest* version at // which it is expected to exist - which starts out by being the version // of the node being orphaned. + // To clarify: + // When I write to key {X} with value V and old value O, we orphan O with =time of write + // and = version O was created at. orphanKeyFormat = NewKeyFormat('o', int64Size, int64Size, hashSize) // o + // Key Format for making reads and iterates go through a data-locality preserving db. + // The value at an entry will list what version it was written to. + // Then to query values, you first query state via this fast method. + // If its present, then check the tree version. If tree version >= result_version, + // return result_version. Else, go through old (slow) IAVL get method that walks through tree. + fastKeyFormat = NewKeyFormat('f', 0) // f + + // Key Format for storing metadata about the chain such as the vesion number. + // The value at an entry will be in a variable format and up to the caller to + // decide how to parse. + metadataKeyFormat = NewKeyFormat('m', 0) // v + // Root nodes are indexed separately by their version rootKeyFormat = NewKeyFormat('r', int64Size) // r ) +var ( + errInvalidFastStorageVersion = fmt.Sprintf("Fast storage version must be in the format %s", fastStorageVersionDelimiter) +) + type nodeDB struct { mtx sync.Mutex // Read/write lock. db dbm.DB // Persistent node storage. batch dbm.Batch // Batched writing buffer. opts Options // Options to customize for pruning/writing versionReaders map[int64]uint32 // Number of active version readers + storageVersion string // Storage version latestVersion int64 nodeCache map[string]*list.Element // Node cache. nodeCacheSize int // Node cache size limit in elements. nodeCacheQueue *list.List // LRU queue of cache elements. Used for deletion. + + fastNodeCache map[string]*list.Element // FastNode cache. + fastNodeCacheSize int // FastNode cache size limit in elements. + fastNodeCacheQueue *list.List // LRU queue of cache elements. Used for deletion. } func newNodeDB(db dbm.DB, cacheSize int, opts *Options) *nodeDB { @@ -54,15 +90,26 @@ func newNodeDB(db dbm.DB, cacheSize int, opts *Options) *nodeDB { o := DefaultOptions() opts = &o } + + storeVersion, err := db.Get(metadataKeyFormat.Key([]byte(storageVersionKey))) + + if err != nil || storeVersion == nil { + storeVersion = []byte(defaultStorageVersionValue) + } + return &nodeDB{ - db: db, - batch: db.NewBatch(), - opts: *opts, - latestVersion: 0, // initially invalid - nodeCache: make(map[string]*list.Element), - nodeCacheSize: cacheSize, - nodeCacheQueue: list.New(), - versionReaders: make(map[int64]uint32, 8), + db: db, + batch: db.NewBatch(), + opts: *opts, + latestVersion: 0, // initially invalid + nodeCache: make(map[string]*list.Element), + nodeCacheSize: cacheSize, + nodeCacheQueue: list.New(), + fastNodeCache: make(map[string]*list.Element), + fastNodeCacheSize: 100000, + fastNodeCacheQueue: list.New(), + versionReaders: make(map[int64]uint32, 8), + storageVersion: string(storeVersion), } } @@ -104,6 +151,43 @@ func (ndb *nodeDB) GetNode(hash []byte) *Node { return node } +func (ndb *nodeDB) GetFastNode(key []byte) (*FastNode, error) { + if !ndb.hasUpgradedToFastStorage() { + return nil, errors.New("storage version is not fast") + } + + ndb.mtx.Lock() + defer ndb.mtx.Unlock() + + if len(key) == 0 { + return nil, fmt.Errorf("nodeDB.GetFastNode() requires key, len(key) equals 0") + } + + // Check the cache. + if elem, ok := ndb.fastNodeCache[string(key)]; ok { + // Already exists. Move to back of fastNodeCacheQueue. + ndb.fastNodeCacheQueue.MoveToBack(elem) + return elem.Value.(*FastNode), nil + } + + // Doesn't exist, load. + buf, err := ndb.db.Get(ndb.fastNodeKey(key)) + if err != nil { + return nil, fmt.Errorf("can't get FastNode %X: %w", key, err) + } + if buf == nil { + return nil, nil + } + + fastNode, err := DeserializeFastNode(key, buf) + if err != nil { + return nil, fmt.Errorf("error reading FastNode. bytes: %x, error: %w", buf, err) + } + + ndb.cacheFastNode(fastNode) + return fastNode, nil +} + // SaveNode saves a node to disk. func (ndb *nodeDB) SaveNode(node *Node) { ndb.mtx.Lock() @@ -132,6 +216,94 @@ func (ndb *nodeDB) SaveNode(node *Node) { ndb.cacheNode(node) } +// SaveNode saves a FastNode to disk and add to cache. +func (ndb *nodeDB) SaveFastNode(node *FastNode) error { + ndb.mtx.Lock() + defer ndb.mtx.Unlock() + return ndb.saveFastNodeUnlocked(node, true) +} + +// SaveNode saves a FastNode to disk without adding to cache. +func (ndb *nodeDB) SaveFastNodeNoCache(node *FastNode) error { + ndb.mtx.Lock() + defer ndb.mtx.Unlock() + return ndb.saveFastNodeUnlocked(node, false) +} + +// setFastStorageVersionToBatch sets storage version to fast where the version is +// 1.1.0-. Returns error if storage version is incorrect or on +// db error, nil otherwise. Requires changes to be committed after to be persisted. +func (ndb *nodeDB) setFastStorageVersionToBatch() error { + var newVersion string + if ndb.storageVersion >= fastStorageVersionValue { + // Storage version should be at index 0 and latest fast cache version at index 1 + versions := strings.Split(ndb.storageVersion, fastStorageVersionDelimiter) + + if len(versions) > 2 { + return errors.New(errInvalidFastStorageVersion) + } + + newVersion = versions[0] + } else { + newVersion = fastStorageVersionValue + } + + newVersion += fastStorageVersionDelimiter + strconv.Itoa(int(ndb.getLatestVersion())) + + if err := ndb.batch.Set(metadataKeyFormat.Key([]byte(storageVersionKey)), []byte(newVersion)); err != nil { + return err + } + ndb.storageVersion = newVersion + return nil +} + +func (ndb *nodeDB) getStorageVersion() string { + return ndb.storageVersion +} + +// Returns true if the upgrade to latest storage version has been performed, false otherwise. +func (ndb *nodeDB) hasUpgradedToFastStorage() bool { + return ndb.getStorageVersion() >= fastStorageVersionValue +} + +// Returns true if the upgrade to fast storage has occurred but it does not match the live state, false otherwise. +// When the live state is not matched, we must force reupgrade. +// We determine this by checking the version of the live state and the version of the live state when +// latest storage was updated on disk the last time. +func (ndb *nodeDB) shouldForceFastStorageUpgrade() bool { + versions := strings.Split(ndb.storageVersion, fastStorageVersionDelimiter) + + if len(versions) == 2 { + if versions[1] != strconv.Itoa(int(ndb.getLatestVersion())) { + return true + } + } + return false +} + +// SaveNode saves a FastNode to disk. +func (ndb *nodeDB) saveFastNodeUnlocked(node *FastNode, shouldAddToCache bool) error { + if node.key == nil { + return fmt.Errorf("cannot have FastNode with a nil value for key") + } + + // Save node bytes to db. + var buf bytes.Buffer + buf.Grow(node.encodedSize()) + + if err := node.writeBytes(&buf); err != nil { + return fmt.Errorf("error while writing fastnode bytes. Err: %w", err) + } + + if err := ndb.batch.Set(ndb.fastNodeKey(node.key), buf.Bytes()); err != nil { + return fmt.Errorf("error while writing key/val to nodedb batch. Err: %w", err) + } + if shouldAddToCache { + ndb.cacheFastNode(node) + } + return nil +} + // Has checks if a hash exists in the database. func (ndb *nodeDB) Has(hash []byte) (bool, error) { key := ndb.nodeKey(hash) @@ -155,16 +327,26 @@ func (ndb *nodeDB) Has(hash []byte) (bool, error) { // NOTE: This function clears leftNode/rigthNode recursively and // calls _hash() on the given node. // TODO refactor, maybe use hashWithCount() but provide a callback. -func (ndb *nodeDB) SaveBranch(node *Node) []byte { +func (ndb *nodeDB) SaveBranch(node *Node) ([]byte, error) { if node.persisted { - return node.hash + return node.hash, nil } + var err error if node.leftNode != nil { - node.leftHash = ndb.SaveBranch(node.leftNode) + node.leftHash, err = ndb.SaveBranch(node.leftNode) } + + if err != nil { + return nil, err + } + if node.rightNode != nil { - node.rightHash = ndb.SaveBranch(node.rightNode) + node.rightHash, err = ndb.SaveBranch(node.rightNode) + } + + if err != nil { + return nil, err } node._hash() @@ -172,16 +354,18 @@ func (ndb *nodeDB) SaveBranch(node *Node) []byte { // resetBatch only working on generate a genesis block if node.version <= genesisVersion { - ndb.resetBatch() + if err = ndb.resetBatch(); err != nil { + return nil, err + } } node.leftNode = nil node.rightNode = nil - return node.hash + return node.hash, nil } // resetBatch reset the db batch, keep low memory used -func (ndb *nodeDB) resetBatch() { +func (ndb *nodeDB) resetBatch() error { var err error if ndb.opts.Sync { err = ndb.batch.WriteSync() @@ -189,13 +373,20 @@ func (ndb *nodeDB) resetBatch() { err = ndb.batch.Write() } if err != nil { - panic(err) + return err } - ndb.batch.Close() + err = ndb.batch.Close() + if err != nil { + return err + } + ndb.batch = ndb.db.NewBatch() + + return nil } // DeleteVersion deletes a tree version from disk. +// calls deleteOrphans(version), deleteRoot(version, checkLatestVersion) func (ndb *nodeDB) DeleteVersion(version int64, checkLatestVersion bool) error { ndb.mtx.Lock() defer ndb.mtx.Unlock() @@ -204,9 +395,16 @@ func (ndb *nodeDB) DeleteVersion(version int64, checkLatestVersion bool) error { return errors.Errorf("unable to delete version %v, it has %v active readers", version, ndb.versionReaders[version]) } - ndb.deleteOrphans(version) - ndb.deleteRoot(version, checkLatestVersion) - return nil + err := ndb.deleteOrphans(version) + if err != nil { + return err + } + + err = ndb.deleteRoot(version, checkLatestVersion) + if err != nil { + return err + } + return err } // DeleteVersionsFrom permanently deletes all tree versions from the given version upwards. @@ -239,32 +437,64 @@ func (ndb *nodeDB) DeleteVersionsFrom(version int64) error { // Next, delete orphans: // - Delete orphan entries *and referred nodes* with fromVersion >= version // - Delete orphan entries with toVersion >= version-1 (since orphans at latest are not orphans) - ndb.traverseOrphans(func(key, hash []byte) { + err = ndb.traverseOrphans(func(key, hash []byte) error { var fromVersion, toVersion int64 orphanKeyFormat.Scan(key, &toVersion, &fromVersion) if fromVersion >= version { if err = ndb.batch.Delete(key); err != nil { - panic(err) + return err } if err = ndb.batch.Delete(ndb.nodeKey(hash)); err != nil { - panic(err) + return err } ndb.uncacheNode(hash) } else if toVersion >= version-1 { - if err := ndb.batch.Delete(key); err != nil { - panic(err) + if err = ndb.batch.Delete(key); err != nil { + return err } } + return nil }) - // Finally, delete the version root entries - ndb.traverseRange(rootKeyFormat.Key(version), rootKeyFormat.Key(int64(math.MaxInt64)), func(k, v []byte) { - if err := ndb.batch.Delete(k); err != nil { - panic(err) + if err != nil { + return err + } + + // Delete the version root entries + err = ndb.traverseRange(rootKeyFormat.Key(version), rootKeyFormat.Key(int64(math.MaxInt64)), func(k, v []byte) error { + if err = ndb.batch.Delete(k); err != nil { + return err + } + return nil + }) + + if err != nil { + return err + } + + // Delete fast node entries + err = ndb.traverseFastNodes(func(keyWithPrefix, v []byte) error { + key := keyWithPrefix[1:] + fastNode, err := DeserializeFastNode(key, v) + + if err != nil { + return err + } + + if version <= fastNode.versionLastUpdatedAt { + if err = ndb.batch.Delete(keyWithPrefix); err != nil { + return err + } + ndb.uncacheFastNode(key) } + return nil }) + if err != nil { + return err + } + return nil } @@ -296,11 +526,11 @@ func (ndb *nodeDB) DeleteVersionsRange(fromVersion, toVersion int64) error { // If the predecessor is earlier than the beginning of the lifetime, we can delete the orphan. // Otherwise, we shorten its lifetime, by moving its endpoint to the predecessor version. for version := fromVersion; version < toVersion; version++ { - ndb.traverseOrphansVersion(version, func(key, hash []byte) { + err := ndb.traverseOrphansVersion(version, func(key, hash []byte) error { var from, to int64 orphanKeyFormat.Scan(key, &to, &from) if err := ndb.batch.Delete(key); err != nil { - panic(err) + return err } if from > predecessor { if err := ndb.batch.Delete(ndb.nodeKey(hash)); err != nil { @@ -310,16 +540,34 @@ func (ndb *nodeDB) DeleteVersionsRange(fromVersion, toVersion int64) error { } else { ndb.saveOrphan(hash, from, predecessor) } + return nil }) + if err != nil { + return err + } } // Delete the version root entries - ndb.traverseRange(rootKeyFormat.Key(fromVersion), rootKeyFormat.Key(toVersion), func(k, v []byte) { + err := ndb.traverseRange(rootKeyFormat.Key(fromVersion), rootKeyFormat.Key(toVersion), func(k, v []byte) error { if err := ndb.batch.Delete(k); err != nil { - panic(err) + return err } + return nil }) + if err != nil { + return err + } + return nil +} + +func (ndb *nodeDB) DeleteFastNode(key []byte) error { + ndb.mtx.Lock() + defer ndb.mtx.Unlock() + if err := ndb.batch.Delete(ndb.fastNodeKey(key)); err != nil { + return err + } + ndb.uncacheFastNode(key) return nil } @@ -380,13 +628,13 @@ func (ndb *nodeDB) saveOrphan(hash []byte, fromVersion, toVersion int64) { // deleteOrphans deletes orphaned nodes from disk, and the associated orphan // entries. -func (ndb *nodeDB) deleteOrphans(version int64) { +func (ndb *nodeDB) deleteOrphans(version int64) error { // Will be zero if there is no previous version. predecessor := ndb.getPreviousVersion(version) // Traverse orphans with a lifetime ending at the version specified. // TODO optimize. - ndb.traverseOrphansVersion(version, func(key, hash []byte) { + return ndb.traverseOrphansVersion(version, func(key, hash []byte) error { var fromVersion, toVersion int64 // See comment on `orphanKeyFmt`. Note that here, `version` and @@ -395,7 +643,7 @@ func (ndb *nodeDB) deleteOrphans(version int64) { // Delete orphan key and reverse-lookup key. if err := ndb.batch.Delete(key); err != nil { - panic(err) + return err } // If there is no predecessor, or the predecessor is earlier than the @@ -406,13 +654,14 @@ func (ndb *nodeDB) deleteOrphans(version int64) { if predecessor < fromVersion || fromVersion == toVersion { logger.Debug("DELETE predecessor:%v fromVersion:%v toVersion:%v %X\n", predecessor, fromVersion, toVersion, hash) if err := ndb.batch.Delete(ndb.nodeKey(hash)); err != nil { - panic(err) + return err } ndb.uncacheNode(hash) } else { logger.Debug("MOVE predecessor:%v fromVersion:%v toVersion:%v %X\n", predecessor, fromVersion, toVersion, hash) ndb.saveOrphan(hash, fromVersion, predecessor) } + return nil }) } @@ -420,6 +669,10 @@ func (ndb *nodeDB) nodeKey(hash []byte) []byte { return nodeKeyFormat.KeyBytes(hash) } +func (ndb *nodeDB) fastNodeKey(key []byte) []byte { + return fastKeyFormat.KeyBytes(key) +} + func (ndb *nodeDB) orphanKey(fromVersion, toVersion int64, hash []byte) []byte { return orphanKeyFormat.Key(toVersion, fromVersion, hash) } @@ -470,58 +723,97 @@ func (ndb *nodeDB) getPreviousVersion(version int64) int64 { } // deleteRoot deletes the root entry from disk, but not the node it points to. -func (ndb *nodeDB) deleteRoot(version int64, checkLatestVersion bool) { +func (ndb *nodeDB) deleteRoot(version int64, checkLatestVersion bool) error { if checkLatestVersion && version == ndb.getLatestVersion() { - panic("Tried to delete latest version") + return errors.New("tried to delete latest version") } if err := ndb.batch.Delete(ndb.rootKey(version)); err != nil { - panic(err) + return err } + return nil } -func (ndb *nodeDB) traverseOrphans(fn func(k, v []byte)) { - ndb.traversePrefix(orphanKeyFormat.Key(), fn) +// Traverse orphans and return error if any, nil otherwise +func (ndb *nodeDB) traverseOrphans(fn func(keyWithPrefix, v []byte) error) error { + return ndb.traversePrefix(orphanKeyFormat.Key(), fn) } -// Traverse orphans ending at a certain version. -func (ndb *nodeDB) traverseOrphansVersion(version int64, fn func(k, v []byte)) { - ndb.traversePrefix(orphanKeyFormat.Key(version), fn) +// Traverse fast nodes and return error if any, nil otherwise +func (ndb *nodeDB) traverseFastNodes(fn func(k, v []byte) error) error { + return ndb.traversePrefix(fastKeyFormat.Key(), fn) } -// Traverse all keys. -//nolint:unused -func (ndb *nodeDB) traverse(fn func(key, value []byte)) { - ndb.traverseRange(nil, nil, fn) +// Traverse orphans ending at a certain version. return error if any, nil otherwise +func (ndb *nodeDB) traverseOrphansVersion(version int64, fn func(k, v []byte) error) error { + return ndb.traversePrefix(orphanKeyFormat.Key(version), fn) } -// Traverse all keys between a given range (excluding end). -func (ndb *nodeDB) traverseRange(start []byte, end []byte, fn func(k, v []byte)) { +// Traverse all keys and return error if any, nil otherwise +// nolint: unused +func (ndb *nodeDB) traverse(fn func(key, value []byte) error) error { + return ndb.traverseRange(nil, nil, fn) +} + +// Traverse all keys between a given range (excluding end) and return error if any, nil otherwise +func (ndb *nodeDB) traverseRange(start []byte, end []byte, fn func(k, v []byte) error) error { itr, err := ndb.db.Iterator(start, end) if err != nil { - panic(err) + return err } defer itr.Close() for ; itr.Valid(); itr.Next() { - fn(itr.Key(), itr.Value()) + if err := fn(itr.Key(), itr.Value()); err != nil { + return err + } } if err := itr.Error(); err != nil { - panic(err) + return err } + + return nil } -// Traverse all keys with a certain prefix. -func (ndb *nodeDB) traversePrefix(prefix []byte, fn func(k, v []byte)) { +// Traverse all keys with a certain prefix. Return error if any, nil otherwise +func (ndb *nodeDB) traversePrefix(prefix []byte, fn func(k, v []byte) error) error { itr, err := dbm.IteratePrefix(ndb.db, prefix) if err != nil { - panic(err) + return err } defer itr.Close() for ; itr.Valid(); itr.Next() { - fn(itr.Key(), itr.Value()) + if err := fn(itr.Key(), itr.Value()); err != nil { + return err + } } + + return nil +} + +// Get iterator for fast prefix and error, if any +func (ndb *nodeDB) getFastIterator(start, end []byte, ascending bool) (dbm.Iterator, error) { + var startFormatted, endFormatted []byte + + if start != nil { + startFormatted = fastKeyFormat.KeyBytes(start) + } else { + startFormatted = fastKeyFormat.Key() + } + + if end != nil { + endFormatted = fastKeyFormat.KeyBytes(end) + } else { + endFormatted = fastKeyFormat.Key() + endFormatted[0]++ + } + + if ascending { + return ndb.db.Iterator(startFormatted, endFormatted) + } + + return ndb.db.ReverseIterator(startFormatted, endFormatted) } func (ndb *nodeDB) uncacheNode(hash []byte) { @@ -544,6 +836,26 @@ func (ndb *nodeDB) cacheNode(node *Node) { } } +func (ndb *nodeDB) uncacheFastNode(key []byte) { + if elem, ok := ndb.fastNodeCache[string(key)]; ok { + ndb.fastNodeCacheQueue.Remove(elem) + delete(ndb.fastNodeCache, string(key)) + } +} + +// Add a node to the cache and pop the least recently used node if we've +// reached the cache size limit. +func (ndb *nodeDB) cacheFastNode(node *FastNode) { + elem := ndb.fastNodeCacheQueue.PushBack(node) + ndb.fastNodeCache[string(node.key)] = elem + + if ndb.fastNodeCacheQueue.Len() > ndb.fastNodeCacheSize { + oldest := ndb.fastNodeCacheQueue.Front() + key := ndb.fastNodeCacheQueue.Remove(oldest).(*FastNode).key + delete(ndb.fastNodeCache, string(key)) + } +} + // Write to disk. func (ndb *nodeDB) Commit() error { ndb.mtx.Lock() @@ -573,16 +885,15 @@ func (ndb *nodeDB) getRoot(version int64) ([]byte, error) { return ndb.db.Get(ndb.rootKey(version)) } -//nolint:unparam -func (ndb *nodeDB) getRoots() (map[int64][]byte, error) { - roots := map[int64][]byte{} - - ndb.traversePrefix(rootKeyFormat.Key(), func(k, v []byte) { +func (ndb *nodeDB) getRoots() (roots map[int64][]byte, err error) { + roots = make(map[int64][]byte) + err = ndb.traversePrefix(rootKeyFormat.Key(), func(k, v []byte) error { var version int64 rootKeyFormat.Scan(k, &version) roots[version] = v + return nil }) - return roots, nil + return roots, err } // SaveRoot creates an entry on disk for the given root, so that it can be @@ -634,36 +945,54 @@ func (ndb *nodeDB) decrVersionReaders(version int64) { // Utility and test functions -//nolint:unused -func (ndb *nodeDB) leafNodes() []*Node { +// nolint: unused +func (ndb *nodeDB) leafNodes() ([]*Node, error) { leaves := []*Node{} - ndb.traverseNodes(func(hash []byte, node *Node) { + err := ndb.traverseNodes(func(hash []byte, node *Node) error { if node.isLeaf() { leaves = append(leaves, node) } + return nil }) - return leaves + + if err != nil { + return nil, err + } + + return leaves, nil } -//nolint:unused -func (ndb *nodeDB) nodes() []*Node { +// nolint: unused +func (ndb *nodeDB) nodes() ([]*Node, error) { nodes := []*Node{} - ndb.traverseNodes(func(hash []byte, node *Node) { + err := ndb.traverseNodes(func(hash []byte, node *Node) error { nodes = append(nodes, node) + return nil }) - return nodes + + if err != nil { + return nil, err + } + + return nodes, nil } -//nolint:unused -func (ndb *nodeDB) orphans() [][]byte { +// nolint: unused +func (ndb *nodeDB) orphans() ([][]byte, error) { orphans := [][]byte{} - ndb.traverseOrphans(func(k, v []byte) { + err := ndb.traverseOrphans(func(k, v []byte) error { orphans = append(orphans, v) + return nil }) - return orphans + + if err != nil { + return nil, err + } + + return orphans, nil } // Not efficient. @@ -672,51 +1001,76 @@ func (ndb *nodeDB) orphans() [][]byte { //nolint:unused func (ndb *nodeDB) size() int { size := 0 - ndb.traverse(func(k, v []byte) { + err := ndb.traverse(func(k, v []byte) error { size++ + return nil }) + + if err != nil { + return -1 + } return size } -func (ndb *nodeDB) traverseNodes(fn func(hash []byte, node *Node)) { +func (ndb *nodeDB) traverseNodes(fn func(hash []byte, node *Node) error) error { nodes := []*Node{} - ndb.traversePrefix(nodeKeyFormat.Key(), func(key, value []byte) { + err := ndb.traversePrefix(nodeKeyFormat.Key(), func(key, value []byte) error { node, err := MakeNode(value) if err != nil { - panic(fmt.Sprintf("Couldn't decode node from database: %v", err)) + return err } nodeKeyFormat.Scan(key, &node.hash) nodes = append(nodes, node) + return nil }) + if err != nil { + return err + } + sort.Slice(nodes, func(i, j int) bool { return bytes.Compare(nodes[i].key, nodes[j].key) < 0 }) for _, n := range nodes { - fn(n.hash, n) + if err := fn(n.hash, n); err != nil { + return err + } } + return nil } -func (ndb *nodeDB) String() string { +func (ndb *nodeDB) String() (string, error) { buf := bufPool.Get().(*bytes.Buffer) defer bufPool.Put(buf) buf.Reset() index := 0 - ndb.traversePrefix(rootKeyFormat.Key(), func(key, value []byte) { + err := ndb.traversePrefix(rootKeyFormat.Key(), func(key, value []byte) error { fmt.Fprintf(buf, "%s: %x\n", key, value) + return nil }) + + if err != nil { + return "", err + } + buf.WriteByte('\n') - ndb.traverseOrphans(func(key, value []byte) { + err = ndb.traverseOrphans(func(key, value []byte) error { fmt.Fprintf(buf, "%s: %x\n", key, value) + return nil }) + + if err != nil { + return "", err + } + buf.WriteByte('\n') - ndb.traverseNodes(func(hash []byte, node *Node) { + err = ndb.traverseNodes(func(hash []byte, node *Node) error { switch { case len(hash) == 0: buf.WriteByte('\n') @@ -730,6 +1084,12 @@ func (ndb *nodeDB) String() string { nodeKeyFormat.Prefix(), hash, node.key, node.value, node.height, node.version) } index++ + return nil }) - return "-" + "\n" + buf.String() + "-" + + if err != nil { + return "", err + } + + return "-" + "\n" + buf.String() + "-", nil } diff --git a/nodedb_test.go b/nodedb_test.go index e5dadf446..3db11769e 100644 --- a/nodedb_test.go +++ b/nodedb_test.go @@ -2,12 +2,16 @@ package iavl import ( "encoding/binary" + "errors" "math/rand" + "strconv" "testing" + "github.com/golang/mock/gomock" "github.com/stretchr/testify/require" + db "github.com/tendermint/tm-db" - dbm "github.com/tendermint/tm-db" + "github.com/cosmos/iavl/mock" ) func BenchmarkNodeKey(b *testing.B) { @@ -31,7 +35,8 @@ func BenchmarkTreeString(b *testing.B) { b.ReportAllocs() b.ResetTimer() for i := 0; i < b.N; i++ { - sink = tree.String() + sink, _ = tree.String() + require.NotNil(b, sink) } if sink == nil { @@ -40,21 +45,207 @@ func BenchmarkTreeString(b *testing.B) { sink = (interface{})(nil) } -func makeAndPopulateMutableTree(tb testing.TB) *MutableTree { - memDB := dbm.NewMemDB() - tree, err := NewMutableTreeWithOpts(memDB, 0, &Options{InitialVersion: 9}) - require.NoError(tb, err) +func TestNewNoDbStorage_StorageVersionInDb_Success(t *testing.T) { + const expectedVersion = defaultStorageVersionValue - for i := 0; i < 1e4; i++ { - buf := make([]byte, 0, (i/255)+1) - for j := 0; 1<>j)&0xff)) - } - tree.Set(buf, buf) - } - _, _, err = tree.SaveVersion() - require.Nil(tb, err, "Expected .SaveVersion to succeed") - return tree + ctrl := gomock.NewController(t) + dbMock := mock.NewMockDB(ctrl) + + dbMock.EXPECT().Get(gomock.Any()).Return([]byte(expectedVersion), nil).Times(1) + dbMock.EXPECT().NewBatch().Return(nil).Times(1) + + ndb := newNodeDB(dbMock, 0, nil) + require.Equal(t, expectedVersion, ndb.storageVersion) +} + +func TestNewNoDbStorage_ErrorInConstructor_DefaultSet(t *testing.T) { + const expectedVersion = defaultStorageVersionValue + + ctrl := gomock.NewController(t) + dbMock := mock.NewMockDB(ctrl) + + dbMock.EXPECT().Get(gomock.Any()).Return(nil, errors.New("some db error")).Times(1) + dbMock.EXPECT().NewBatch().Return(nil).Times(1) + + ndb := newNodeDB(dbMock, 0, nil) + require.Equal(t, expectedVersion, ndb.getStorageVersion()) +} + +func TestNewNoDbStorage_DoesNotExist_DefaultSet(t *testing.T) { + const expectedVersion = defaultStorageVersionValue + + ctrl := gomock.NewController(t) + dbMock := mock.NewMockDB(ctrl) + + dbMock.EXPECT().Get(gomock.Any()).Return(nil, nil).Times(1) + dbMock.EXPECT().NewBatch().Return(nil).Times(1) + + ndb := newNodeDB(dbMock, 0, nil) + require.Equal(t, expectedVersion, ndb.getStorageVersion()) +} + +func TestSetStorageVersion_Success(t *testing.T) { + const expectedVersion = fastStorageVersionValue + + db := db.NewMemDB() + + ndb := newNodeDB(db, 0, nil) + require.Equal(t, defaultStorageVersionValue, ndb.getStorageVersion()) + + err := ndb.setFastStorageVersionToBatch() + require.NoError(t, err) + require.Equal(t, expectedVersion+fastStorageVersionDelimiter+strconv.Itoa(int(ndb.getLatestVersion())), ndb.getStorageVersion()) + require.NoError(t, ndb.batch.Write()) +} + +func TestSetStorageVersion_DBFailure_OldKept(t *testing.T) { + ctrl := gomock.NewController(t) + dbMock := mock.NewMockDB(ctrl) + batchMock := mock.NewMockBatch(ctrl) + rIterMock := mock.NewMockIterator(ctrl) + + expectedErrorMsg := "some db error" + + expectedFastCacheVersion := 2 + + dbMock.EXPECT().Get(gomock.Any()).Return([]byte(defaultStorageVersionValue), nil).Times(1) + dbMock.EXPECT().NewBatch().Return(batchMock).Times(1) + + // rIterMock is used to get the latest version from disk. We are mocking that rIterMock returns latestTreeVersion from disk + rIterMock.EXPECT().Valid().Return(true).Times(1) + rIterMock.EXPECT().Key().Return(rootKeyFormat.Key(expectedFastCacheVersion)).Times(1) + rIterMock.EXPECT().Close().Return(nil).Times(1) + + dbMock.EXPECT().ReverseIterator(gomock.Any(), gomock.Any()).Return(rIterMock, nil).Times(1) + batchMock.EXPECT().Set(metadataKeyFormat.Key([]byte(storageVersionKey)), []byte(fastStorageVersionValue+fastStorageVersionDelimiter+strconv.Itoa(expectedFastCacheVersion))).Return(errors.New(expectedErrorMsg)).Times(1) + + ndb := newNodeDB(dbMock, 0, nil) + require.Equal(t, defaultStorageVersionValue, ndb.getStorageVersion()) + + err := ndb.setFastStorageVersionToBatch() + require.Error(t, err) + require.Equal(t, expectedErrorMsg, err.Error()) + require.Equal(t, defaultStorageVersionValue, ndb.getStorageVersion()) +} + +func TestSetStorageVersion_InvalidVersionFailure_OldKept(t *testing.T) { + ctrl := gomock.NewController(t) + dbMock := mock.NewMockDB(ctrl) + batchMock := mock.NewMockBatch(ctrl) + + expectedErrorMsg := errInvalidFastStorageVersion + + invalidStorageVersion := fastStorageVersionValue + fastStorageVersionDelimiter + "1" + fastStorageVersionDelimiter + "2" + + dbMock.EXPECT().Get(gomock.Any()).Return([]byte(invalidStorageVersion), nil).Times(1) + dbMock.EXPECT().NewBatch().Return(batchMock).Times(1) + + ndb := newNodeDB(dbMock, 0, nil) + require.Equal(t, invalidStorageVersion, ndb.getStorageVersion()) + + err := ndb.setFastStorageVersionToBatch() + require.Error(t, err) + require.Equal(t, expectedErrorMsg, err.Error()) + require.Equal(t, invalidStorageVersion, ndb.getStorageVersion()) +} + +func TestSetStorageVersion_FastVersionFirst_VersionAppended(t *testing.T) { + db := db.NewMemDB() + ndb := newNodeDB(db, 0, nil) + ndb.storageVersion = fastStorageVersionValue + ndb.latestVersion = 100 + + err := ndb.setFastStorageVersionToBatch() + require.NoError(t, err) + require.Equal(t, fastStorageVersionValue+fastStorageVersionDelimiter+strconv.Itoa(int(ndb.latestVersion)), ndb.storageVersion) +} + +func TestSetStorageVersion_FastVersionSecond_VersionAppended(t *testing.T) { + db := db.NewMemDB() + ndb := newNodeDB(db, 0, nil) + ndb.latestVersion = 100 + + storageVersionBytes := []byte(fastStorageVersionValue) + storageVersionBytes[len(fastStorageVersionValue)-1]++ // increment last byte + ndb.storageVersion = string(storageVersionBytes) + + err := ndb.setFastStorageVersionToBatch() + require.NoError(t, err) + require.Equal(t, string(storageVersionBytes)+fastStorageVersionDelimiter+strconv.Itoa(int(ndb.latestVersion)), ndb.storageVersion) +} + +func TestSetStorageVersion_SameVersionTwice(t *testing.T) { + db := db.NewMemDB() + ndb := newNodeDB(db, 0, nil) + ndb.latestVersion = 100 + + storageVersionBytes := []byte(fastStorageVersionValue) + storageVersionBytes[len(fastStorageVersionValue)-1]++ // increment last byte + ndb.storageVersion = string(storageVersionBytes) + + err := ndb.setFastStorageVersionToBatch() + require.NoError(t, err) + newStorageVersion := string(storageVersionBytes) + fastStorageVersionDelimiter + strconv.Itoa(int(ndb.latestVersion)) + require.Equal(t, newStorageVersion, ndb.storageVersion) + + err = ndb.setFastStorageVersionToBatch() + require.NoError(t, err) + require.Equal(t, newStorageVersion, ndb.storageVersion) +} + +// Test case where version is incorrect and has some extra garbage at the end +func TestShouldForceFastStorageUpdate_DefaultVersion_True(t *testing.T) { + db := db.NewMemDB() + ndb := newNodeDB(db, 0, nil) + ndb.storageVersion = defaultStorageVersionValue + ndb.latestVersion = 100 + + require.False(t, ndb.shouldForceFastStorageUpgrade()) +} + +func TestShouldForceFastStorageUpdate_FastVersion_Greater_True(t *testing.T) { + db := db.NewMemDB() + ndb := newNodeDB(db, 0, nil) + ndb.latestVersion = 100 + ndb.storageVersion = fastStorageVersionValue + fastStorageVersionDelimiter + strconv.Itoa(int(ndb.latestVersion+1)) + + require.True(t, ndb.shouldForceFastStorageUpgrade()) +} + +func TestShouldForceFastStorageUpdate_FastVersion_Smaller_True(t *testing.T) { + db := db.NewMemDB() + ndb := newNodeDB(db, 0, nil) + ndb.latestVersion = 100 + ndb.storageVersion = fastStorageVersionValue + fastStorageVersionDelimiter + strconv.Itoa(int(ndb.latestVersion-1)) + + require.True(t, ndb.shouldForceFastStorageUpgrade()) +} + +func TestShouldForceFastStorageUpdate_FastVersion_Match_False(t *testing.T) { + db := db.NewMemDB() + ndb := newNodeDB(db, 0, nil) + ndb.latestVersion = 100 + ndb.storageVersion = fastStorageVersionValue + fastStorageVersionDelimiter + strconv.Itoa(int(ndb.latestVersion)) + + require.False(t, ndb.shouldForceFastStorageUpgrade()) +} + +func TestIsFastStorageEnabled_True(t *testing.T) { + db := db.NewMemDB() + ndb := newNodeDB(db, 0, nil) + ndb.latestVersion = 100 + ndb.storageVersion = fastStorageVersionValue + fastStorageVersionDelimiter + strconv.Itoa(int(ndb.latestVersion)) + + require.True(t, ndb.hasUpgradedToFastStorage()) +} + +func TestIsFastStorageEnabled_False(t *testing.T) { + db := db.NewMemDB() + ndb := newNodeDB(db, 0, nil) + ndb.latestVersion = 100 + ndb.storageVersion = defaultStorageVersionValue + + require.False(t, ndb.shouldForceFastStorageUpgrade()) } func makeHashes(b *testing.B, seed int64) [][]byte { @@ -72,3 +263,20 @@ func makeHashes(b *testing.B, seed int64) [][]byte { b.StartTimer() return hashes } + +func makeAndPopulateMutableTree(tb testing.TB) *MutableTree { + memDB := db.NewMemDB() + tree, err := NewMutableTreeWithOpts(memDB, 0, &Options{InitialVersion: 9}) + require.NoError(tb, err) + + for i := 0; i < 1e4; i++ { + buf := make([]byte, 0, (i/255)+1) + for j := 0; 1<>j)&0xff)) + } + tree.Set(buf, buf) + } + _, _, err = tree.SaveVersion() + require.Nil(tb, err, "Expected .SaveVersion to succeed") + return tree +} diff --git a/proof_ics23.go b/proof_ics23.go index c598b2c33..04c29ba6d 100644 --- a/proof_ics23.go +++ b/proof_ics23.go @@ -30,7 +30,7 @@ If the key exists in the tree, this will return an error. */ func (t *ImmutableTree) GetNonMembershipProof(key []byte) (*ics23.CommitmentProof, error) { // idx is one node right of what we want.... - idx, val := t.Get(key) + idx, val := t.GetWithIndex(key) if val != nil { return nil, fmt.Errorf("cannot create NonExistanceProof when Key in State") } diff --git a/proof_ics23_test.go b/proof_ics23_test.go index dcb912d7b..43b2f656c 100644 --- a/proof_ics23_test.go +++ b/proof_ics23_test.go @@ -42,11 +42,11 @@ func TestGetMembership(t *testing.T) { for name, tc := range cases { tc := tc t.Run(name, func(t *testing.T) { - tree, allkeys, err := BuildTree(tc.size) + tree, allkeys, err := BuildTree(tc.size, 0) require.NoError(t, err, "Creating tree: %+v", err) key := GetKey(allkeys, tc.loc) - _, val := tree.Get(key) + val := tree.Get(key) proof, err := tree.GetMembershipProof(key) require.NoError(t, err, "Creating Proof: %+v", err) @@ -72,26 +72,105 @@ func TestGetNonMembership(t *testing.T) { "big right": {size: 5431, loc: Right}, } + performTest := func(tree *MutableTree, allKeys [][]byte, loc Where) { + key := GetNonKey(allKeys, loc) + + proof, err := tree.GetNonMembershipProof(key) + require.NoError(t, err, "Creating Proof: %+v", err) + + root := tree.Hash() + valid := ics23.VerifyNonMembership(ics23.IavlSpec, root, proof, key) + if !valid { + require.NoError(t, err, "Non Membership Proof Invalid") + } + } + for name, tc := range cases { tc := tc - t.Run(name, func(t *testing.T) { - tree, allkeys, err := BuildTree(tc.size) + t.Run("fast-"+name, func(t *testing.T) { + tree, allkeys, err := BuildTree(tc.size, 0) require.NoError(t, err, "Creating tree: %+v", err) + // Save version to enable fast cache + _, _, err = tree.SaveVersion() + require.NoError(t, err) - key := GetNonKey(allkeys, tc.loc) + require.True(t, tree.IsFastCacheEnabled()) - proof, err := tree.GetNonMembershipProof(key) - require.NoError(t, err, "Creating Proof: %+v", err) + performTest(tree, allkeys, tc.loc) + }) - root := tree.Hash() - valid := ics23.VerifyNonMembership(ics23.IavlSpec, root, proof, key) - if !valid { - require.NoError(t, err, "Non Membership Proof Invalid") - } + t.Run("regular-"+name, func(t *testing.T) { + tree, allkeys, err := BuildTree(tc.size, 0) + require.NoError(t, err, "Creating tree: %+v", err) + require.False(t, tree.IsFastCacheEnabled()) + + performTest(tree, allkeys, tc.loc) }) } } +func BenchmarkGetNonMembership(b *testing.B) { + cases := []struct { + size int + loc Where + }{ + {size: 100, loc: Left}, + {size: 100, loc: Middle}, + {size: 100, loc: Right}, + {size: 5431, loc: Left}, + {size: 5431, loc: Middle}, + {size: 5431, loc: Right}, + } + + performTest := func(tree *MutableTree, allKeys [][]byte, loc Where) { + key := GetNonKey(allKeys, loc) + + proof, err := tree.GetNonMembershipProof(key) + require.NoError(b, err, "Creating Proof: %+v", err) + + b.StopTimer() + root := tree.Hash() + valid := ics23.VerifyNonMembership(ics23.IavlSpec, root, proof, key) + if !valid { + require.NoError(b, err, "Non Membership Proof Invalid") + } + b.StartTimer() + } + + b.Run("fast", func(b *testing.B) { + for i := 0; i < b.N; i++ { + b.StopTimer() + caseIdx := rand.Intn(len(cases)) + tc := cases[caseIdx] + + tree, allkeys, err := BuildTree(tc.size, 100000) + require.NoError(b, err, "Creating tree: %+v", err) + // Save version to enable fast cache + _, _, err = tree.SaveVersion() + require.NoError(b, err) + + require.True(b, tree.IsFastCacheEnabled()) + b.StartTimer() + performTest(tree, allkeys, tc.loc) + } + }) + + b.Run("regular", func(b *testing.B) { + for i := 0; i < b.N; i++ { + b.StopTimer() + caseIdx := rand.Intn(len(cases)) + tc := cases[caseIdx] + + tree, allkeys, err := BuildTree(tc.size, 100000) + require.NoError(b, err, "Creating tree: %+v", err) + require.False(b, tree.IsFastCacheEnabled()) + + b.StartTimer() + performTest(tree, allkeys, tc.loc) + } + }) +} + // Test Helpers // Result is the result of one match @@ -106,7 +185,11 @@ type Result struct { // // returns a range proof and the root hash of the tree func GenerateResult(size int, loc Where) (*Result, error) { - tree, allkeys, err := BuildTree(size) + tree, allkeys, err := BuildTree(size, 0) + if err != nil { + return nil, err + } + _, _, err = tree.SaveVersion() if err != nil { return nil, err } @@ -172,8 +255,8 @@ func GetNonKey(allkeys [][]byte, loc Where) []byte { // BuildTree creates random key/values and stores in tree // returns a list of all keys in sorted order -func BuildTree(size int) (itree *ImmutableTree, keys [][]byte, err error) { - tree, _ := NewMutableTree(db.NewMemDB(), 0) +func BuildTree(size int, cacheSize int) (itree *MutableTree, keys [][]byte, err error) { + tree, _ := NewMutableTree(db.NewMemDB(), cacheSize) // insert lots of info and store the bytes keys = make([][]byte, size) @@ -189,7 +272,7 @@ func BuildTree(size int) (itree *ImmutableTree, keys [][]byte, err error) { return bytes.Compare(keys[i], keys[j]) < 0 }) - return tree.ImmutableTree, keys, nil + return tree, keys, nil } // sink is kept as a global to ensure that value checks and assignments to it can't be diff --git a/repair.go b/repair.go index e0c7a052c..e688b9cda 100644 --- a/repair.go +++ b/repair.go @@ -41,20 +41,21 @@ func Repair013Orphans(db dbm.DB) (uint64, error) { ) batch := db.NewBatch() defer batch.Close() - ndb.traverseRange(orphanKeyFormat.Key(version), orphanKeyFormat.Key(int64(math.MaxInt64)), func(k, v []byte) { + err = ndb.traverseRange(orphanKeyFormat.Key(version), orphanKeyFormat.Key(int64(math.MaxInt64)), func(k, v []byte) error { // Sanity check so we don't remove stuff we shouldn't var toVersion int64 orphanKeyFormat.Scan(k, &toVersion) if toVersion < version { err = errors.Errorf("Found unexpected orphan with toVersion=%v, lesser than latest version %v", toVersion, version) - return + return err } repaired++ err = batch.Delete(k) if err != nil { - return + return err } + return nil }) if err != nil { return 0, err diff --git a/repair_test.go b/repair_test.go index ec6b598b3..560ea51c3 100644 --- a/repair_test.go +++ b/repair_test.go @@ -59,7 +59,7 @@ func TestRepair013Orphans(t *testing.T) { require.NoError(t, err) // Reading "rm7" (which should not have been deleted now) would panic with a broken database. - _, value := tree.Get([]byte("rm7")) + value := tree.Get([]byte("rm7")) require.Equal(t, []byte{1}, value) // Check all persisted versions. @@ -91,7 +91,7 @@ func assertVersion(t *testing.T, tree *MutableTree, version int64) { version = itree.version // The "current" value should have the current version for <= 6, then 6 afterwards - _, value := itree.Get([]byte("current")) + value := itree.Get([]byte("current")) if version >= 6 { require.EqualValues(t, []byte{6}, value) } else { @@ -101,14 +101,14 @@ func assertVersion(t *testing.T, tree *MutableTree, version int64) { // The "addX" entries should exist for 1-6 in the respective versions, and the // "rmX" entries should have been removed for 1-6 in the respective versions. for i := byte(1); i < 8; i++ { - _, value = itree.Get([]byte(fmt.Sprintf("add%v", i))) + value = itree.Get([]byte(fmt.Sprintf("add%v", i))) if i <= 6 && int64(i) <= version { require.Equal(t, []byte{i}, value) } else { require.Nil(t, value) } - _, value = itree.Get([]byte(fmt.Sprintf("rm%v", i))) + value = itree.Get([]byte(fmt.Sprintf("rm%v", i))) if i <= 6 && version >= int64(i) { require.Nil(t, value) } else { @@ -177,7 +177,9 @@ func copyDB(src, dest string) error { defer out.Close() in, err := os.Open(filepath.Join(src, entry.Name())) - defer in.Close() // nolint + defer func() { + in.Close() + }() if err != nil { return err } diff --git a/server/server.go b/server/server.go index aa7d5e8d2..2c0d2afaf 100644 --- a/server/server.go +++ b/server/server.go @@ -78,7 +78,7 @@ func (s *IAVLServer) Get(_ context.Context, req *pb.GetRequest) (*pb.GetResponse s.rwLock.RLock() defer s.rwLock.RUnlock() - idx, value := s.tree.Get(req.Key) + idx, value := s.tree.GetWithIndex(req.Key) return &pb.GetResponse{Index: idx, Value: value, NotFound: value == nil}, nil } @@ -139,7 +139,7 @@ func (s *IAVLServer) GetVersioned(_ context.Context, req *pb.GetVersionedRequest return nil, err } - idx, value := iTree.Get(req.Key) + idx, value := iTree.GetWithIndex(req.Key) return &pb.GetResponse{Index: idx, Value: value}, nil } diff --git a/server/server_test.go b/server/server_test.go index e8e26fcd0..4d2775599 100644 --- a/server/server_test.go +++ b/server/server_test.go @@ -836,19 +836,19 @@ func (suite *ServerTestSuite) TestList() { func (suite *ServerTestSuite) TestAvailableVersions() { res1, err := suite.server.GetAvailableVersions(context.Background(), nil) suite.NoError(err) - oldVersions := res1.Versions + versions := res1.Versions _, err = suite.server.SaveVersion(context.Background(), nil) suite.NoError(err) versionRes, err := suite.server.Version(context.Background(), nil) suite.NoError(err) - newVersions := append(oldVersions, versionRes.Version) + versions = append(versions, versionRes.Version) res2, err := suite.server.GetAvailableVersions(context.Background(), nil) suite.NoError(err) - suite.Equal(res2.Versions, newVersions) + suite.Equal(res2.Versions, versions) } diff --git a/testutils_test.go b/testutils_test.go index de9165059..540561ac7 100644 --- a/testutils_test.go +++ b/testutils_test.go @@ -5,9 +5,10 @@ import ( "bytes" "fmt" "runtime" + "sort" "testing" - mrand "math/rand" + "math/rand" "github.com/stretchr/testify/require" db "github.com/tendermint/tm-db" @@ -16,6 +17,12 @@ import ( iavlrand "github.com/cosmos/iavl/internal/rand" ) +type iteratorTestConfig struct { + startIterate, endIterate []byte + startByteToSet, endByteToSet byte + ascending bool +} + func randstr(length int) string { return iavlrand.RandStr(length) } @@ -81,11 +88,7 @@ func P(n *Node) string { } func randBytes(length int) []byte { - key := make([]byte, length) - // math.rand.Read always returns err=nil - // we do not need cryptographic randomness for this test: - mrand.Read(key) - return key + return iavlrand.RandBytes(length) } type traverser struct { @@ -115,6 +118,197 @@ func expectTraverse(t *testing.T, trav traverser, start, end string, count int) } } +func assertMutableMirrorIterate(t *testing.T, tree *MutableTree, mirror map[string]string) { + sortedMirrorKeys := make([]string, 0, len(mirror)) + for k := range mirror { + sortedMirrorKeys = append(sortedMirrorKeys, k) + } + sort.Strings(sortedMirrorKeys) + + curKeyIdx := 0 + tree.Iterate(func(k, v []byte) bool { + nextMirrorKey := sortedMirrorKeys[curKeyIdx] + nextMirrorValue := mirror[nextMirrorKey] + + require.Equal(t, []byte(nextMirrorKey), k) + require.Equal(t, []byte(nextMirrorValue), v) + + curKeyIdx++ + return false + }) +} + +func assertImmutableMirrorIterate(t *testing.T, tree *ImmutableTree, mirror map[string]string) { + sortedMirrorKeys := getSortedMirrorKeys(mirror) + + curKeyIdx := 0 + tree.Iterate(func(k, v []byte) bool { + nextMirrorKey := sortedMirrorKeys[curKeyIdx] + nextMirrorValue := mirror[nextMirrorKey] + + require.Equal(t, []byte(nextMirrorKey), k) + require.Equal(t, []byte(nextMirrorValue), v) + + curKeyIdx++ + return false + }) +} + +func getSortedMirrorKeys(mirror map[string]string) []string { + sortedMirrorKeys := make([]string, 0, len(mirror)) + for k := range mirror { + sortedMirrorKeys = append(sortedMirrorKeys, k) + } + sort.Strings(sortedMirrorKeys) + return sortedMirrorKeys +} + +func getRandomizedTreeAndMirror(t *testing.T) (*MutableTree, map[string]string) { + const cacheSize = 100 + + tree, err := getTestTree(cacheSize) + require.NoError(t, err) + + mirror := make(map[string]string) + + randomizeTreeAndMirror(t, tree, mirror) + return tree, mirror +} + +func randomizeTreeAndMirror(t *testing.T, tree *MutableTree, mirror map[string]string) { + if mirror == nil { + mirror = make(map[string]string) + } + const keyValLength = 5 + + numberOfSets := 1000 + numberOfUpdates := numberOfSets / 4 + numberOfRemovals := numberOfSets / 4 + + for numberOfSets > numberOfRemovals*3 { + key := randBytes(keyValLength) + value := randBytes(keyValLength) + + isUpdated := tree.Set(key, value) + require.False(t, isUpdated) + mirror[string(key)] = string(value) + + numberOfSets-- + } + + for numberOfSets+numberOfRemovals+numberOfUpdates > 0 { + randOp := rand.Intn(3) + + switch randOp { + case 0: + if numberOfSets == 0 { + continue + } + + numberOfSets-- + + key := randBytes(keyValLength) + value := randBytes(keyValLength) + + isUpdated := tree.Set(key, value) + require.False(t, isUpdated) + mirror[string(key)] = string(value) + case 1: + + if numberOfUpdates == 0 { + continue + } + numberOfUpdates-- + + key := getRandomKeyFrom(mirror) + value := randBytes(keyValLength) + + isUpdated := tree.Set([]byte(key), value) + require.True(t, isUpdated) + mirror[key] = string(value) + case 2: + if numberOfRemovals == 0 { + continue + } + numberOfRemovals-- + + key := getRandomKeyFrom(mirror) + + val, isRemoved := tree.Remove([]byte(key)) + require.True(t, isRemoved) + require.NotNil(t, val) + delete(mirror, key) + default: + t.Error("Invalid randOp", randOp) + } + } +} + +func getRandomKeyFrom(mirror map[string]string) string { + for k := range mirror { + return k + } + panic("no keys in mirror") +} + +func setupMirrorForIterator(t *testing.T, config *iteratorTestConfig, tree *MutableTree) [][]string { + var mirror [][]string + + startByteToSet := config.startByteToSet + endByteToSet := config.endByteToSet + + if !config.ascending { + startByteToSet, endByteToSet = endByteToSet, startByteToSet + } + + curByte := startByteToSet + for curByte != endByteToSet { + value := randBytes(5) + + if (config.startIterate == nil || curByte >= config.startIterate[0]) && (config.endIterate == nil || curByte < config.endIterate[0]) { + mirror = append(mirror, []string{string(curByte), string(value)}) + } + + isUpdated := tree.Set([]byte{curByte}, value) + require.False(t, isUpdated) + + if config.ascending { + curByte++ + } else { + curByte-- + } + } + return mirror +} + +// assertIterator confirms that the iterator returns the expected values desribed by mirror in the same order. +// mirror is a slice containing slices of the form [key, value]. In other words, key at index 0 and value at index 1. +func assertIterator(t *testing.T, itr db.Iterator, mirror [][]string, ascending bool) { + startIdx, endIdx := 0, len(mirror)-1 + increment := 1 + mirrorIdx := startIdx + + // flip the iteration order over mirror if descending + if !ascending { + startIdx = endIdx - 1 + endIdx = -1 + increment *= -1 + } + + for startIdx != endIdx { + nextExpectedPair := mirror[mirrorIdx] + + require.True(t, itr.Valid()) + require.Equal(t, []byte(nextExpectedPair[0]), itr.Key()) + require.Equal(t, []byte(nextExpectedPair[1]), itr.Value()) + itr.Next() + require.NoError(t, itr.Error()) + + startIdx += increment + mirrorIdx++ + } +} + func BenchmarkImmutableAvlTreeMemDB(b *testing.B) { db, err := db.NewDB("test", db.MemDBBackend, "") require.NoError(b, err) diff --git a/tree_fuzz_test.go b/tree_fuzz_test.go index f1ae680cf..6f760290b 100644 --- a/tree_fuzz_test.go +++ b/tree_fuzz_test.go @@ -118,7 +118,9 @@ func TestMutableTreeFuzz(t *testing.T) { program := genRandomProgram(size) err = program.Execute(tree) if err != nil { - t.Fatalf("Error after %d iterations (size %d): %s\n%s", iterations, size, err.Error(), tree.String()) + str, err := tree.String() + require.Nil(t, err) + t.Fatalf("Error after %d iterations (size %d): %s\n%s", iterations, size, err.Error(), str) } iterations++ } diff --git a/tree_random_test.go b/tree_random_test.go index 29e2fd425..422a3e3ed 100644 --- a/tree_random_test.go +++ b/tree_random_test.go @@ -7,6 +7,8 @@ import ( "math/rand" "os" "sort" + "strconv" + "strings" "testing" "github.com/stretchr/testify/require" @@ -333,29 +335,39 @@ func assertEmptyDatabase(t *testing.T, tree *MutableTree) { require.NoError(t, err) var ( - firstKey []byte - count int + foundKeys []string ) for ; iter.Valid(); iter.Next() { - count++ - if firstKey == nil { - firstKey = iter.Key() - } + foundKeys = append(foundKeys, string(iter.Key())) } require.NoError(t, iter.Error()) - require.EqualValues(t, 1, count, "Found %v database entries, expected 1", count) + require.EqualValues(t, 2, len(foundKeys), "Found %v database entries, expected 1", len(foundKeys)) // 1 for storage version and 1 for root + + firstKey := foundKeys[0] + secondKey := foundKeys[1] + + require.True(t, strings.HasPrefix(firstKey, metadataKeyFormat.Prefix())) + require.True(t, strings.HasPrefix(secondKey, rootKeyFormat.Prefix())) + + require.Equal(t, string(metadataKeyFormat.KeyBytes([]byte(storageVersionKey))), firstKey, "Unexpected storage version key") + + storageVersionValue, err := tree.ndb.db.Get([]byte(firstKey)) + require.NoError(t, err) + require.Equal(t, fastStorageVersionValue+fastStorageVersionDelimiter+strconv.Itoa(int(tree.ndb.getLatestVersion())), string(storageVersionValue)) var foundVersion int64 - rootKeyFormat.Scan(firstKey, &foundVersion) + rootKeyFormat.Scan([]byte(secondKey), &foundVersion) require.Equal(t, version, foundVersion, "Unexpected root version") } // Checks that the tree has the given number of orphan nodes. func assertOrphans(t *testing.T, tree *MutableTree, expected int) { count := 0 - tree.ndb.traverseOrphans(func(k, v []byte) { + err := tree.ndb.traverseOrphans(func(k, v []byte) error { count++ + return nil }) + require.Nil(t, err) require.EqualValues(t, expected, count, "Expected %v orphans, got %v", expected, count) } @@ -389,9 +401,53 @@ func assertMirror(t *testing.T, tree *MutableTree, mirror map[string]string, ver require.EqualValues(t, len(mirror), itree.Size()) require.EqualValues(t, len(mirror), iterated) for key, value := range mirror { - _, actual := itree.Get([]byte(key)) + actualFast := itree.Get([]byte(key)) + require.Equal(t, value, string(actualFast)) + _, actual := itree.GetWithIndex([]byte(key)) require.Equal(t, value, string(actual)) } + + assertFastNodeCacheIsLive(t, tree, mirror, version) + assertFastNodeDiskIsLive(t, tree, mirror, version) +} + +// Checks that fast node cache matches live state. +func assertFastNodeCacheIsLive(t *testing.T, tree *MutableTree, mirror map[string]string, version int64) { + if tree.ndb.getLatestVersion() != version { + // The fast node cache check should only be done to the latest version + return + } + + for key, cacheElem := range tree.ndb.fastNodeCache { + liveFastNode, ok := mirror[key] + + require.True(t, ok, "cached fast node must be in the live tree") + require.Equal(t, liveFastNode, string(cacheElem.Value.(*FastNode).value), "cached fast node's value must be equal to live state value") + } +} + +// Checks that fast nodes on disk match live state. +func assertFastNodeDiskIsLive(t *testing.T, tree *MutableTree, mirror map[string]string, version int64) { + if tree.ndb.getLatestVersion() != version { + // The fast node disk check should only be done to the latest version + return + } + + count := 0 + err := tree.ndb.traverseFastNodes(func(keyWithPrefix, v []byte) error { + key := keyWithPrefix[1:] + count++ + fastNode, err := DeserializeFastNode(key, v) + require.Nil(t, err) + + mirrorVal := mirror[string(fastNode.key)] + + require.NotNil(t, mirrorVal) + require.Equal(t, []byte(mirrorVal), fastNode.value) + return nil + }) + require.NoError(t, err) + require.Equal(t, len(mirror), count) } // Checks that all versions in the tree are present in the mirrors, and vice-versa. diff --git a/tree_test.go b/tree_test.go index fe9f2f3cb..f2f9b6c59 100644 --- a/tree_test.go +++ b/tree_test.go @@ -67,13 +67,17 @@ func TestVersionedRandomTree(t *testing.T) { } roots, err := tree.ndb.getRoots() require.NoError(err) - require.Equal(versions, len(roots), "wrong number of roots") - require.Equal(versions*keysPerVersion, len(tree.ndb.leafNodes()), "wrong number of nodes") + + leafNodes, err := tree.ndb.leafNodes() + require.Nil(err) + require.Equal(versions*keysPerVersion, len(leafNodes), "wrong number of nodes") // Before deleting old versions, we should have equal or more nodes in the // db than in the current tree version. - require.True(len(tree.ndb.nodes()) >= tree.nodeSize()) + nodes, err := tree.ndb.nodes() + require.Nil(err) + require.True(len(nodes) >= tree.nodeSize()) // Ensure it returns all versions in sorted order available := tree.AvailableVersions() @@ -97,9 +101,13 @@ func TestVersionedRandomTree(t *testing.T) { // After cleaning up all previous versions, we should have as many nodes // in the db as in the current tree version. - require.Len(tree.ndb.leafNodes(), int(tree.Size())) + leafNodes, err = tree.ndb.leafNodes() + require.Nil(err) + require.Len(leafNodes, int(tree.Size())) - require.Equal(tree.nodeSize(), len(tree.ndb.nodes())) + nodes, err = tree.ndb.nodes() + require.Nil(err) + require.Equal(tree.nodeSize(), len(nodes)) } // nolint: dupl @@ -200,13 +208,19 @@ func TestVersionedRandomTreeSmallKeys(t *testing.T) { // After cleaning up all previous versions, we should have as many nodes // in the db as in the current tree version. The simple tree must be equal // too. - require.Len(tree.ndb.leafNodes(), int(tree.Size())) - require.Len(tree.ndb.nodes(), tree.nodeSize()) - require.Len(tree.ndb.nodes(), singleVersionTree.nodeSize()) + leafNodes, err := tree.ndb.leafNodes() + require.Nil(err) + + nodes, err := tree.ndb.nodes() + require.Nil(err) + + require.Len(leafNodes, int(tree.Size())) + require.Len(nodes, tree.nodeSize()) + require.Len(nodes, singleVersionTree.nodeSize()) // Try getting random keys. for i := 0; i < keysPerVersion; i++ { - _, val := tree.Get([]byte(iavlrand.RandStr(1))) + val := tree.Get([]byte(iavlrand.RandStr(1))) require.NotNil(val) require.NotEmpty(val) } @@ -243,13 +257,19 @@ func TestVersionedRandomTreeSmallKeysRandomDeletes(t *testing.T) { // After cleaning up all previous versions, we should have as many nodes // in the db as in the current tree version. The simple tree must be equal // too. - require.Len(tree.ndb.leafNodes(), int(tree.Size())) - require.Len(tree.ndb.nodes(), tree.nodeSize()) - require.Len(tree.ndb.nodes(), singleVersionTree.nodeSize()) + leafNodes, err := tree.ndb.leafNodes() + require.Nil(err) + + nodes, err := tree.ndb.nodes() + require.Nil(err) + + require.Len(leafNodes, int(tree.Size())) + require.Len(nodes, tree.nodeSize()) + require.Len(nodes, singleVersionTree.nodeSize()) // Try getting random keys. for i := 0; i < keysPerVersion; i++ { - _, val := tree.Get([]byte(iavlrand.RandStr(1))) + val := tree.Get([]byte(iavlrand.RandStr(1))) require.NotNil(val) require.NotEmpty(val) } @@ -275,7 +295,9 @@ func TestVersionedTreeSpecial1(t *testing.T) { tree.DeleteVersion(2) tree.DeleteVersion(3) - require.Equal(t, tree.nodeSize(), len(tree.ndb.nodes())) + nodes, err := tree.ndb.nodes() + require.Nil(t, err) + require.Equal(t, tree.nodeSize(), len(nodes)) } func TestVersionedRandomTreeSpecial2(t *testing.T) { @@ -292,7 +314,10 @@ func TestVersionedRandomTreeSpecial2(t *testing.T) { tree.SaveVersion() tree.DeleteVersion(1) - require.Len(tree.ndb.nodes(), tree.nodeSize()) + + nodes, err := tree.ndb.nodes() + require.NoError(err) + require.Len(nodes, tree.nodeSize()) } func TestVersionedEmptyTree(t *testing.T) { @@ -361,9 +386,10 @@ func TestVersionedTree(t *testing.T) { tree, err := NewMutableTree(d, 0) require.NoError(err) - // We start with zero keys in the databse. + // We start with empty database. require.Equal(0, tree.ndb.size()) require.True(tree.IsEmpty()) + require.False(tree.IsFastCacheEnabled()) // version 0 @@ -371,7 +397,9 @@ func TestVersionedTree(t *testing.T) { tree.Set([]byte("key2"), []byte("val0")) // Still zero keys, since we haven't written them. - require.Len(tree.ndb.leafNodes(), 0) + nodes, err := tree.ndb.leafNodes() + require.NoError(err) + require.Len(nodes, 0) require.False(tree.IsEmpty()) // Now let's write the keys to storage. @@ -386,7 +414,8 @@ func TestVersionedTree(t *testing.T) { // key2 (root) version=1 // ----------- - nodes1 := tree.ndb.leafNodes() + nodes1, err := tree.ndb.leafNodes() + require.NoError(err) require.Len(nodes1, 2, "db should have a size of 2") // version 1 @@ -394,7 +423,9 @@ func TestVersionedTree(t *testing.T) { tree.Set([]byte("key1"), []byte("val1")) tree.Set([]byte("key2"), []byte("val1")) tree.Set([]byte("key3"), []byte("val1")) - require.Len(tree.ndb.leafNodes(), len(nodes1)) + nodes, err = tree.ndb.leafNodes() + require.NoError(err) + require.Len(nodes, len(nodes1)) hash2, v2, err := tree.SaveVersion() require.NoError(err) @@ -420,9 +451,12 @@ func TestVersionedTree(t *testing.T) { // key3 = val1 // ----------- - nodes2 := tree.ndb.leafNodes() + nodes2, err := tree.ndb.leafNodes() + require.NoError(err) require.Len(nodes2, 5, "db should have grown in size") - require.Len(tree.ndb.orphans(), 3, "db should have three orphans") + orphans, err := tree.ndb.orphans() + require.NoError(err) + require.Len(orphans, 3, "db should have three orphans") // Create three more orphans. tree.Remove([]byte("key1")) // orphans both leaf node and inner node containing "key1" and "key2" @@ -442,9 +476,13 @@ func TestVersionedTree(t *testing.T) { // key2 = val2 // ----------- - nodes3 := tree.ndb.leafNodes() + nodes3, err := tree.ndb.leafNodes() + require.NoError(err) require.Len(nodes3, 6, "wrong number of nodes") - require.Len(tree.ndb.orphans(), 7, "wrong number of orphans") + + orphans, err = tree.ndb.orphans() + require.NoError(err) + require.Len(orphans, 7, "wrong number of orphans") hash4, _, _ := tree.SaveVersion() require.EqualValues(hash3, hash4) @@ -459,48 +497,49 @@ func TestVersionedTree(t *testing.T) { // DB UNCHANGED // ------------ - nodes4 := tree.ndb.leafNodes() + nodes4, err := tree.ndb.leafNodes() + require.NoError(err) require.Len(nodes4, len(nodes3), "db should not have changed in size") tree.Set([]byte("key1"), []byte("val0")) // "key2" - _, val := tree.GetVersioned([]byte("key2"), 0) + val := tree.GetVersioned([]byte("key2"), 0) require.Nil(val) - _, val = tree.GetVersioned([]byte("key2"), 1) + val = tree.GetVersioned([]byte("key2"), 1) require.Equal("val0", string(val)) - _, val = tree.GetVersioned([]byte("key2"), 2) + val = tree.GetVersioned([]byte("key2"), 2) require.Equal("val1", string(val)) - _, val = tree.Get([]byte("key2")) + val = tree.Get([]byte("key2")) require.Equal("val2", string(val)) // "key1" - _, val = tree.GetVersioned([]byte("key1"), 1) + val = tree.GetVersioned([]byte("key1"), 1) require.Equal("val0", string(val)) - _, val = tree.GetVersioned([]byte("key1"), 2) + val = tree.GetVersioned([]byte("key1"), 2) require.Equal("val1", string(val)) - _, val = tree.GetVersioned([]byte("key1"), 3) + val = tree.GetVersioned([]byte("key1"), 3) require.Nil(val) - _, val = tree.GetVersioned([]byte("key1"), 4) + val = tree.GetVersioned([]byte("key1"), 4) require.Nil(val) - _, val = tree.Get([]byte("key1")) + val = tree.Get([]byte("key1")) require.Equal("val0", string(val)) // "key3" - _, val = tree.GetVersioned([]byte("key3"), 0) + val = tree.GetVersioned([]byte("key3"), 0) require.Nil(val) - _, val = tree.GetVersioned([]byte("key3"), 2) + val = tree.GetVersioned([]byte("key3"), 2) require.Equal("val1", string(val)) - _, val = tree.GetVersioned([]byte("key3"), 3) + val = tree.GetVersioned([]byte("key3"), 3) require.Equal("val1", string(val)) // Delete a version. After this the keys in that version should not be found. @@ -516,29 +555,31 @@ func TestVersionedTree(t *testing.T) { // key2 = val2 // ----------- - nodes5 := tree.ndb.leafNodes() + nodes5, err := tree.ndb.leafNodes() + require.NoError(err) + require.True(len(nodes5) < len(nodes4), "db should have shrunk after delete %d !< %d", len(nodes5), len(nodes4)) - _, val = tree.GetVersioned([]byte("key2"), 2) + val = tree.GetVersioned([]byte("key2"), 2) require.Nil(val) - _, val = tree.GetVersioned([]byte("key3"), 2) + val = tree.GetVersioned([]byte("key3"), 2) require.Nil(val) // But they should still exist in the latest version. - _, val = tree.Get([]byte("key2")) + val = tree.Get([]byte("key2")) require.Equal("val2", string(val)) - _, val = tree.Get([]byte("key3")) + val = tree.Get([]byte("key3")) require.Equal("val1", string(val)) // Version 1 should still be available. - _, val = tree.GetVersioned([]byte("key1"), 1) + val = tree.GetVersioned([]byte("key1"), 1) require.Equal("val0", string(val)) - _, val = tree.GetVersioned([]byte("key2"), 1) + val = tree.GetVersioned([]byte("key2"), 1) require.Equal("val0", string(val)) } @@ -554,29 +595,39 @@ func TestVersionedTreeVersionDeletingEfficiency(t *testing.T) { tree.Set([]byte("key2"), []byte("val0")) tree.SaveVersion() - require.Len(t, tree.ndb.leafNodes(), 3) + leafNodes, err := tree.ndb.leafNodes() + require.Nil(t, err) + require.Len(t, leafNodes, 3) tree.Set([]byte("key1"), []byte("val1")) tree.Set([]byte("key2"), []byte("val1")) tree.Set([]byte("key3"), []byte("val1")) tree.SaveVersion() - require.Len(t, tree.ndb.leafNodes(), 6) + leafNodes, err = tree.ndb.leafNodes() + require.Nil(t, err) + require.Len(t, leafNodes, 6) tree.Set([]byte("key0"), []byte("val2")) tree.Remove([]byte("key1")) tree.Set([]byte("key2"), []byte("val2")) tree.SaveVersion() - require.Len(t, tree.ndb.leafNodes(), 8) + leafNodes, err = tree.ndb.leafNodes() + require.Nil(t, err) + require.Len(t, leafNodes, 8) tree.DeleteVersion(2) - require.Len(t, tree.ndb.leafNodes(), 6) + leafNodes, err = tree.ndb.leafNodes() + require.Nil(t, err) + require.Len(t, leafNodes, 6) tree.DeleteVersion(1) - require.Len(t, tree.ndb.leafNodes(), 3) + leafNodes, err = tree.ndb.leafNodes() + require.Nil(t, err) + require.Len(t, leafNodes, 3) tree2, err := getTestTree(0) require.NoError(t, err) @@ -609,21 +660,23 @@ func TestVersionedTreeOrphanDeleting(t *testing.T) { tree.DeleteVersion(2) - _, val := tree.Get([]byte("key0")) + val := tree.Get([]byte("key0")) require.Equal(t, val, []byte("val2")) - _, val = tree.Get([]byte("key1")) + val = tree.Get([]byte("key1")) require.Nil(t, val) - _, val = tree.Get([]byte("key2")) + val = tree.Get([]byte("key2")) require.Equal(t, val, []byte("val2")) - _, val = tree.Get([]byte("key3")) + val = tree.Get([]byte("key3")) require.Equal(t, val, []byte("val1")) tree.DeleteVersion(1) - require.Len(t, tree.ndb.leafNodes(), 3) + leafNodes, err := tree.ndb.leafNodes() + require.Nil(t, err) + require.Len(t, leafNodes, 3) } func TestVersionedTreeSpecialCase(t *testing.T) { @@ -647,7 +700,7 @@ func TestVersionedTreeSpecialCase(t *testing.T) { tree.DeleteVersion(2) - _, val := tree.GetVersioned([]byte("key2"), 1) + val := tree.GetVersioned([]byte("key2"), 1) require.Equal("val0", string(val)) } @@ -676,7 +729,7 @@ func TestVersionedTreeSpecialCase2(t *testing.T) { require.NoError(tree.DeleteVersion(2)) - _, val := tree.GetVersioned([]byte("key2"), 1) + val := tree.GetVersioned([]byte("key2"), 1) require.Equal("val0", string(val)) } @@ -706,7 +759,9 @@ func TestVersionedTreeSpecialCase3(t *testing.T) { tree.DeleteVersion(3) tree.DeleteVersion(4) - require.Equal(tree.nodeSize(), len(tree.ndb.nodes())) + nodes, err := tree.ndb.nodes() + require.NoError(err) + require.Equal(tree.nodeSize(), len(nodes)) } func TestVersionedTreeSaveAndLoad(t *testing.T) { @@ -759,7 +814,9 @@ func TestVersionedTreeSaveAndLoad(t *testing.T) { require.False(ntree.IsEmpty()) require.Equal(int64(4), ntree.Size()) - require.Len(ntree.ndb.nodes(), ntree.nodeSize()) + nodes, err := tree.ndb.nodes() + require.NoError(err) + require.Len(nodes, ntree.nodeSize()) } func TestVersionedTreeErrors(t *testing.T) { @@ -781,7 +838,7 @@ func TestVersionedTreeErrors(t *testing.T) { require.Error(tree.DeleteVersion(1)) // Trying to get a key from a version which doesn't exist. - _, val := tree.GetVersioned([]byte("key"), 404) + val := tree.GetVersioned([]byte("key"), 404) require.Nil(val) // Same thing with proof. We get an error because a proof couldn't be @@ -811,19 +868,21 @@ func TestVersionedCheckpoints(t *testing.T) { keys[int64(i)] = append(keys[int64(i)], k) tree.Set(k, v) } - tree.SaveVersion() + _, _, err = tree.SaveVersion() + require.NoError(err, "failed to save version") } for i := 1; i <= versions; i++ { if i%versionsPerCheckpoint != 0 { - tree.DeleteVersion(int64(i)) + err = tree.DeleteVersion(int64(i)) + require.NoError(err, "failed to delete") } } // Make sure all keys exist at least once. for _, ks := range keys { for _, k := range ks { - _, val := tree.Get(k) + val := tree.Get(k) require.NotEmpty(val) } } @@ -832,7 +891,7 @@ func TestVersionedCheckpoints(t *testing.T) { for i := 1; i <= versions; i++ { if i%versionsPerCheckpoint != 0 { for _, k := range keys[int64(i)] { - _, val := tree.GetVersioned(k, int64(i)) + val := tree.GetVersioned(k, int64(i)) require.Nil(val) } } @@ -842,7 +901,7 @@ func TestVersionedCheckpoints(t *testing.T) { for i := 1; i <= versions; i++ { for _, k := range keys[int64(i)] { if i%versionsPerCheckpoint == 0 { - _, val := tree.GetVersioned(k, int64(i)) + val := tree.GetVersioned(k, int64(i)) require.NotEmpty(val) } } @@ -871,7 +930,7 @@ func TestVersionedCheckpointsSpecialCase(t *testing.T) { // checkpoint, which is version 10. tree.DeleteVersion(1) - _, val := tree.GetVersioned(key, 2) + val := tree.GetVersioned(key, 2) require.NotEmpty(val) require.Equal([]byte("val1"), val) } @@ -935,19 +994,19 @@ func TestVersionedCheckpointsSpecialCase4(t *testing.T) { tree.Set([]byte("X"), []byte("New")) tree.SaveVersion() - _, val := tree.GetVersioned([]byte("A"), 2) + val := tree.GetVersioned([]byte("A"), 2) require.Nil(t, val) - _, val = tree.GetVersioned([]byte("A"), 1) + val = tree.GetVersioned([]byte("A"), 1) require.NotEmpty(t, val) tree.DeleteVersion(1) tree.DeleteVersion(2) - _, val = tree.GetVersioned([]byte("A"), 2) + val = tree.GetVersioned([]byte("A"), 2) require.Nil(t, val) - _, val = tree.GetVersioned([]byte("A"), 1) + val = tree.GetVersioned([]byte("A"), 1) require.Nil(t, val) } @@ -1050,9 +1109,15 @@ func TestVersionedTreeEfficiency(t *testing.T) { // Keys of size one are likely to be overwritten. tree.Set([]byte(iavlrand.RandStr(1)), []byte(iavlrand.RandStr(8))) } - sizeBefore := len(tree.ndb.nodes()) + nodes, err := tree.ndb.nodes() + require.NoError(err) + sizeBefore := len(nodes) tree.SaveVersion() - sizeAfter := len(tree.ndb.nodes()) + _, err = tree.ndb.nodes() + require.NoError(err) + nodes, err = tree.ndb.nodes() + require.NoError(err) + sizeAfter := len(nodes) change := sizeAfter - sizeBefore keysAddedPerVersion[i] = change keysAdded += change @@ -1061,9 +1126,13 @@ func TestVersionedTreeEfficiency(t *testing.T) { keysDeleted := 0 for i := 1; i < versions; i++ { if tree.VersionExists(int64(i)) { - sizeBefore := len(tree.ndb.nodes()) + nodes, err := tree.ndb.nodes() + require.NoError(err) + sizeBefore := len(nodes) tree.DeleteVersion(int64(i)) - sizeAfter := len(tree.ndb.nodes()) + nodes, err = tree.ndb.nodes() + require.NoError(err) + sizeAfter := len(nodes) change := sizeBefore - sizeAfter keysDeleted += change @@ -1163,22 +1232,24 @@ func TestOrphans(t *testing.T) { for j := 1; j < NUMUPDATES; j++ { tree.Set(randBytes(2), randBytes(2)) } - _, _, err := tree.SaveVersion() + _, _, err = tree.SaveVersion() require.NoError(err, "SaveVersion should not error") } idx := iavlrand.RandPerm(NUMVERSIONS - 2) for _, v := range idx { - err := tree.DeleteVersion(int64(v + 1)) + err = tree.DeleteVersion(int64(v + 1)) require.NoError(err, "DeleteVersion should not error") } - tree.ndb.traverseOrphans(func(k, v []byte) { + err = tree.ndb.traverseOrphans(func(k, v []byte) error { var fromVersion, toVersion int64 orphanKeyFormat.Scan(k, &toVersion, &fromVersion) require.True(fromVersion == int64(1) || toVersion == int64(99), fmt.Sprintf(`Unexpected orphan key exists: %v with fromVersion = %d and toVersion = %d.\n Any orphan remaining in db should have either fromVersion == 1 or toVersion == 99. Since Version 1 and 99 are only versions in db`, k, fromVersion, toVersion)) + return nil }) + require.Nil(err) } func TestVersionedTreeHash(t *testing.T) { @@ -1225,12 +1296,12 @@ func TestCopyValueSemantics(t *testing.T) { val := []byte("v1") tree.Set([]byte("k"), val) - _, v := tree.Get([]byte("k")) + v := tree.Get([]byte("k")) require.Equal([]byte("v1"), v) val[1] = '2' - _, val = tree.Get([]byte("k")) + val = tree.Get([]byte("k")) require.Equal([]byte("v2"), val) } @@ -1254,13 +1325,13 @@ func TestRollback(t *testing.T) { require.Equal(int64(2), tree.Size()) - _, val := tree.Get([]byte("r")) + val := tree.Get([]byte("r")) require.Nil(val) - _, val = tree.Get([]byte("s")) + val = tree.Get([]byte("s")) require.Nil(val) - _, val = tree.Get([]byte("t")) + val = tree.Get([]byte("t")) require.Equal([]byte("v"), val) } @@ -1285,7 +1356,7 @@ func TestLazyLoadVersion(t *testing.T) { require.NoError(t, err, "unexpected error when lazy loading version") require.Equal(t, version, int64(maxVersions)) - _, value := tree.Get([]byte(fmt.Sprintf("key_%d", maxVersions))) + value := tree.Get([]byte(fmt.Sprintf("key_%d", maxVersions))) require.Equal(t, value, []byte(fmt.Sprintf("value_%d", maxVersions)), "unexpected value") // require the ability to lazy load an older version @@ -1293,7 +1364,7 @@ func TestLazyLoadVersion(t *testing.T) { require.NoError(t, err, "unexpected error when lazy loading version") require.Equal(t, version, int64(maxVersions-1)) - _, value = tree.Get([]byte(fmt.Sprintf("key_%d", maxVersions-1))) + value = tree.Get([]byte(fmt.Sprintf("key_%d", maxVersions-1))) require.Equal(t, value, []byte(fmt.Sprintf("value_%d", maxVersions-1)), "unexpected value") // require the inability to lazy load a non-valid version @@ -1607,7 +1678,9 @@ func TestLoadVersionForOverwritingCase2(t *testing.T) { removedNodes := []*Node{} - for _, n := range tree.ndb.nodes() { + nodes, err := tree.ndb.nodes() + require.NoError(err) + for _, n := range nodes { if n.version > 1 { removedNodes = append(removedNodes, n) } @@ -1617,7 +1690,7 @@ func TestLoadVersionForOverwritingCase2(t *testing.T) { require.NoError(err, "LoadVersionForOverwriting should not fail") for i := byte(0); i < 20; i++ { - _, v := tree.Get([]byte{i}) + v := tree.Get([]byte{i}) require.Equal([]byte{i}, v) } @@ -1660,7 +1733,9 @@ func TestLoadVersionForOverwritingCase3(t *testing.T) { removedNodes := []*Node{} - for _, n := range tree.ndb.nodes() { + nodes, err := tree.ndb.nodes() + require.NoError(err) + for _, n := range nodes { if n.version > 1 { removedNodes = append(removedNodes, n) } @@ -1681,7 +1756,173 @@ func TestLoadVersionForOverwritingCase3(t *testing.T) { } for i := byte(0); i < 20; i++ { - _, v := tree.Get([]byte{i}) + v := tree.Get([]byte{i}) require.Equal([]byte{i}, v) } } + +func TestIterate_ImmutableTree_Version1(t *testing.T) { + tree, mirror := getRandomizedTreeAndMirror(t) + + _, _, err := tree.SaveVersion() + require.NoError(t, err) + + immutableTree, err := tree.GetImmutable(1) + require.NoError(t, err) + + assertImmutableMirrorIterate(t, immutableTree, mirror) +} + +func TestIterate_ImmutableTree_Version2(t *testing.T) { + tree, mirror := getRandomizedTreeAndMirror(t) + + _, _, err := tree.SaveVersion() + require.NoError(t, err) + + randomizeTreeAndMirror(t, tree, mirror) + + _, _, err = tree.SaveVersion() + require.NoError(t, err) + + immutableTree, err := tree.GetImmutable(2) + require.NoError(t, err) + + assertImmutableMirrorIterate(t, immutableTree, mirror) +} + +func TestGetByIndex_ImmutableTree(t *testing.T) { + tree, mirror := getRandomizedTreeAndMirror(t) + mirrorKeys := getSortedMirrorKeys(mirror) + + _, _, err := tree.SaveVersion() + require.NoError(t, err) + + immutableTree, err := tree.GetImmutable(1) + require.NoError(t, err) + + require.True(t, immutableTree.IsFastCacheEnabled()) + + for index, expectedKey := range mirrorKeys { + expectedValue := mirror[expectedKey] + + actualKey, actualValue := immutableTree.GetByIndex(int64(index)) + + require.Equal(t, expectedKey, string(actualKey)) + require.Equal(t, expectedValue, string(actualValue)) + } +} + +func TestGetWithIndex_ImmutableTree(t *testing.T) { + tree, mirror := getRandomizedTreeAndMirror(t) + mirrorKeys := getSortedMirrorKeys(mirror) + + _, _, err := tree.SaveVersion() + require.NoError(t, err) + + immutableTree, err := tree.GetImmutable(1) + require.NoError(t, err) + + require.True(t, immutableTree.IsFastCacheEnabled()) + + for expectedIndex, key := range mirrorKeys { + expectedValue := mirror[key] + + actualIndex, actualValue := immutableTree.GetWithIndex([]byte(key)) + + require.Equal(t, expectedValue, string(actualValue)) + require.Equal(t, int64(expectedIndex), actualIndex) + } +} + +func Benchmark_GetWithIndex(b *testing.B) { + db, err := db.NewDB("test", db.MemDBBackend, "") + require.NoError(b, err) + + const numKeyVals = 100000 + + t, err := NewMutableTree(db, numKeyVals) + require.NoError(b, err) + + keys := make([][]byte, 0, numKeyVals) + + for i := 0; i < numKeyVals; i++ { + key := randBytes(10) + keys = append(keys, key) + t.Set(key, randBytes(10)) + } + _, _, err = t.SaveVersion() + require.NoError(b, err) + + b.ReportAllocs() + runtime.GC() + + b.Run("fast", func(sub *testing.B) { + require.True(b, t.IsFastCacheEnabled()) + b.ResetTimer() + for i := 0; i < sub.N; i++ { + randKey := rand.Intn(numKeyVals) + t.GetWithIndex(keys[randKey]) + } + }) + + b.Run("regular", func(sub *testing.B) { + // get non-latest version to force regular storage + _, latestVersion, err := t.SaveVersion() + require.NoError(b, err) + + itree, err := t.GetImmutable(latestVersion - 1) + require.NoError(b, err) + + require.False(b, itree.IsFastCacheEnabled()) + b.ResetTimer() + for i := 0; i < sub.N; i++ { + randKey := rand.Intn(numKeyVals) + itree.GetWithIndex(keys[randKey]) + } + }) +} + +func Benchmark_GetByIndex(b *testing.B) { + db, err := db.NewDB("test", db.MemDBBackend, "") + require.NoError(b, err) + + const numKeyVals = 100000 + + t, err := NewMutableTree(db, numKeyVals) + require.NoError(b, err) + + for i := 0; i < numKeyVals; i++ { + key := randBytes(10) + t.Set(key, randBytes(10)) + } + _, _, err = t.SaveVersion() + require.NoError(b, err) + + b.ReportAllocs() + runtime.GC() + + b.Run("fast", func(sub *testing.B) { + require.True(b, t.IsFastCacheEnabled()) + b.ResetTimer() + for i := 0; i < sub.N; i++ { + randIdx := rand.Intn(numKeyVals) + t.GetByIndex(int64(randIdx)) + } + }) + + b.Run("regular", func(sub *testing.B) { + // get non-latest version to force regular storage + _, latestVersion, err := t.SaveVersion() + require.NoError(b, err) + + itree, err := t.GetImmutable(latestVersion - 1) + require.NoError(b, err) + + require.False(b, itree.IsFastCacheEnabled()) + b.ResetTimer() + for i := 0; i < sub.N; i++ { + randIdx := rand.Intn(numKeyVals) + itree.GetByIndex(int64(randIdx)) + } + }) +} diff --git a/unsaved_fast_iterator.go b/unsaved_fast_iterator.go new file mode 100644 index 000000000..69483156f --- /dev/null +++ b/unsaved_fast_iterator.go @@ -0,0 +1,228 @@ +package iavl + +import ( + "bytes" + "errors" + "sort" + + dbm "github.com/tendermint/tm-db" +) + +var ( + errUnsavedFastIteratorNilAdditionsGiven = errors.New("unsaved fast iterator must be created with unsaved additions but they were nil") + + errUnsavedFastIteratorNilRemovalsGiven = errors.New("unsaved fast iterator must be created with unsaved removals but they were nil") +) + +// UnsavedFastIterator is a dbm.Iterator for ImmutableTree +// it iterates over the latest state via fast nodes, +// taking advantage of keys being located in sequence in the underlying database. +type UnsavedFastIterator struct { + start, end []byte + + valid bool + + ascending bool + + err error + + ndb *nodeDB + + unsavedFastNodeAdditions map[string]*FastNode + + unsavedFastNodeRemovals map[string]interface{} + + unsavedFastNodesToSort []string + + nextKey []byte + + nextVal []byte + + nextUnsavedNodeIdx int + + fastIterator dbm.Iterator +} + +var _ dbm.Iterator = (*UnsavedFastIterator)(nil) + +func NewUnsavedFastIterator(start, end []byte, ascending bool, ndb *nodeDB, unsavedFastNodeAdditions map[string]*FastNode, unsavedFastNodeRemovals map[string]interface{}) *UnsavedFastIterator { + iter := &UnsavedFastIterator{ + start: start, + end: end, + ascending: ascending, + ndb: ndb, + unsavedFastNodeAdditions: unsavedFastNodeAdditions, + unsavedFastNodeRemovals: unsavedFastNodeRemovals, + nextKey: nil, + nextVal: nil, + nextUnsavedNodeIdx: 0, + fastIterator: NewFastIterator(start, end, ascending, ndb), + } + + // We need to ensure that we iterate over saved and unsaved state in order. + // The strategy is to sort unsaved nodes, the fast node on disk are already sorted. + // Then, we keep a pointer to both the unsaved and saved nodes, and iterate over them in order efficiently. + for _, fastNode := range unsavedFastNodeAdditions { + if start != nil && bytes.Compare(fastNode.key, start) < 0 { + continue + } + + if end != nil && bytes.Compare(fastNode.key, end) >= 0 { + continue + } + + iter.unsavedFastNodesToSort = append(iter.unsavedFastNodesToSort, string(fastNode.key)) + } + + sort.Slice(iter.unsavedFastNodesToSort, func(i, j int) bool { + if ascending { + return iter.unsavedFastNodesToSort[i] < iter.unsavedFastNodesToSort[j] + } + return iter.unsavedFastNodesToSort[i] > iter.unsavedFastNodesToSort[j] + }) + + if iter.ndb == nil { + iter.err = errFastIteratorNilNdbGiven + iter.valid = false + return iter + } + + if iter.unsavedFastNodeAdditions == nil { + iter.err = errUnsavedFastIteratorNilAdditionsGiven + iter.valid = false + return iter + } + + if iter.unsavedFastNodeRemovals == nil { + iter.err = errUnsavedFastIteratorNilRemovalsGiven + iter.valid = false + return iter + } + + // Move to the first elemenet + iter.Next() + + return iter +} + +// Domain implements dbm.Iterator. +// Maps the underlying nodedb iterator domain, to the 'logical' keys involved. +func (iter *UnsavedFastIterator) Domain() ([]byte, []byte) { + return iter.start, iter.end +} + +// Valid implements dbm.Iterator. +func (iter *UnsavedFastIterator) Valid() bool { + if iter.start != nil && iter.end != nil { + if bytes.Compare(iter.end, iter.start) != 1 { + return false + } + } + + return iter.fastIterator.Valid() || iter.nextUnsavedNodeIdx < len(iter.unsavedFastNodesToSort) || (iter.nextKey != nil && iter.nextVal != nil) +} + +// Key implements dbm.Iterator +func (iter *UnsavedFastIterator) Key() []byte { + return iter.nextKey +} + +// Value implements dbm.Iterator +func (iter *UnsavedFastIterator) Value() []byte { + return iter.nextVal +} + +// Next implements dbm.Iterator +// Its effectively running the constant space overhead algorithm for streaming through sorted lists: +// the sorted lists being underlying fast nodes & unsavedFastNodeChanges +func (iter *UnsavedFastIterator) Next() { + if iter.ndb == nil { + iter.err = errFastIteratorNilNdbGiven + iter.valid = false + return + } + + if iter.fastIterator.Valid() && iter.nextUnsavedNodeIdx < len(iter.unsavedFastNodesToSort) { + diskKeyStr := string(iter.fastIterator.Key()) + + if iter.unsavedFastNodeRemovals[diskKeyStr] != nil { + // If next fast node from disk is to be removed, skip it. + iter.fastIterator.Next() + iter.Next() + return + } + + nextUnsavedKey := iter.unsavedFastNodesToSort[iter.nextUnsavedNodeIdx] + nextUnsavedNode := iter.unsavedFastNodeAdditions[nextUnsavedKey] + + var isUnsavedNext bool + if iter.ascending { + isUnsavedNext = diskKeyStr >= nextUnsavedKey + } else { + isUnsavedNext = diskKeyStr <= nextUnsavedKey + } + + if isUnsavedNext { + // Unsaved node is next + + if diskKeyStr == nextUnsavedKey { + // Unsaved update prevails over saved copy so we skip the copy from disk + iter.fastIterator.Next() + } + + iter.nextKey = nextUnsavedNode.key + iter.nextVal = nextUnsavedNode.value + + iter.nextUnsavedNodeIdx++ + return + } + // Disk node is next + iter.nextKey = iter.fastIterator.Key() + iter.nextVal = iter.fastIterator.Value() + + iter.fastIterator.Next() + return + } + + // if only nodes on disk are left, we return them + if iter.fastIterator.Valid() { + if iter.unsavedFastNodeRemovals[string(iter.fastIterator.Key())] != nil { + // If next fast node from disk is to be removed, skip it. + iter.fastIterator.Next() + iter.Next() + return + } + + iter.nextKey = iter.fastIterator.Key() + iter.nextVal = iter.fastIterator.Value() + + iter.fastIterator.Next() + return + } + + // if only unsaved nodes are left, we can just iterate + if iter.nextUnsavedNodeIdx < len(iter.unsavedFastNodesToSort) { + nextUnsavedKey := iter.unsavedFastNodesToSort[iter.nextUnsavedNodeIdx] + nextUnsavedNode := iter.unsavedFastNodeAdditions[nextUnsavedKey] + + iter.nextKey = nextUnsavedNode.key + iter.nextVal = nextUnsavedNode.value + + iter.nextUnsavedNodeIdx++ + return + } + + iter.nextKey = nil + iter.nextVal = nil +} + +// Close implements dbm.Iterator +func (iter *UnsavedFastIterator) Close() error { + iter.valid = false + return iter.fastIterator.Close() +} + +// Error implements dbm.Iterator +func (iter *UnsavedFastIterator) Error() error { + return iter.err +}