diff --git a/core/types/hashing_test.go b/core/types/hashing_test.go
index 6d1ebf897c..de71ee41a4 100644
--- a/core/types/hashing_test.go
+++ b/core/types/hashing_test.go
@@ -26,6 +26,7 @@ import (
"github.com/ethereum/go-ethereum/common"
"github.com/ethereum/go-ethereum/common/hexutil"
+ "github.com/ethereum/go-ethereum/core/rawdb"
"github.com/ethereum/go-ethereum/core/types"
"github.com/ethereum/go-ethereum/crypto"
"github.com/ethereum/go-ethereum/rlp"
@@ -38,7 +39,8 @@ func TestDeriveSha(t *testing.T) {
t.Fatal(err)
}
for len(txs) < 1000 {
- exp := types.DeriveSha(txs, new(trie.Trie))
+ tr, _ := trie.New(common.Hash{}, trie.NewDatabase(rawdb.NewMemoryDatabase()))
+ exp := types.DeriveSha(txs, tr)
got := types.DeriveSha(txs, trie.NewStackTrie(nil))
if !bytes.Equal(got[:], exp[:]) {
t.Fatalf("%d txs: got %x exp %x", len(txs), got, exp)
@@ -85,7 +87,8 @@ func BenchmarkDeriveSha200(b *testing.B) {
b.ResetTimer()
b.ReportAllocs()
for i := 0; i < b.N; i++ {
- exp = types.DeriveSha(txs, new(trie.Trie))
+ tr, _ := trie.New(common.Hash{}, trie.NewDatabase(rawdb.NewMemoryDatabase()))
+ exp = types.DeriveSha(txs, tr)
}
})
@@ -106,7 +109,8 @@ func TestFuzzDeriveSha(t *testing.T) {
rndSeed := mrand.Int()
for i := 0; i < 10; i++ {
seed := rndSeed + i
- exp := types.DeriveSha(newDummy(i), new(trie.Trie))
+ tr, _ := trie.New(common.Hash{}, trie.NewDatabase(rawdb.NewMemoryDatabase()))
+ exp := types.DeriveSha(newDummy(i), tr)
got := types.DeriveSha(newDummy(i), trie.NewStackTrie(nil))
if !bytes.Equal(got[:], exp[:]) {
printList(newDummy(seed))
@@ -134,7 +138,8 @@ func TestDerivableList(t *testing.T) {
},
}
for i, tc := range tcs[1:] {
- exp := types.DeriveSha(flatList(tc), new(trie.Trie))
+ tr, _ := trie.New(common.Hash{}, trie.NewDatabase(rawdb.NewMemoryDatabase()))
+ exp := types.DeriveSha(flatList(tc), tr)
got := types.DeriveSha(flatList(tc), trie.NewStackTrie(nil))
if !bytes.Equal(got[:], exp[:]) {
t.Fatalf("case %d: got %x exp %x", i, got, exp)
diff --git a/tests/fuzzers/rangeproof/rangeproof-fuzzer.go b/tests/fuzzers/rangeproof/rangeproof-fuzzer.go
index 09ee6bb9c7..5d7097b137 100644
--- a/tests/fuzzers/rangeproof/rangeproof-fuzzer.go
+++ b/tests/fuzzers/rangeproof/rangeproof-fuzzer.go
@@ -24,6 +24,7 @@ import (
"sort"
"github.com/ethereum/go-ethereum/common"
+ "github.com/ethereum/go-ethereum/core/rawdb"
"github.com/ethereum/go-ethereum/ethdb/memorydb"
"github.com/ethereum/go-ethereum/trie"
)
@@ -62,7 +63,7 @@ func (f *fuzzer) readInt() uint64 {
func (f *fuzzer) randomTrie(n int) (*trie.Trie, map[string]*kv) {
- trie := new(trie.Trie)
+ trie, _ := trie.New(common.Hash{}, trie.NewDatabase(rawdb.NewMemoryDatabase()))
vals := make(map[string]*kv)
size := f.readInt()
// Fill it with some fluff
@@ -182,8 +183,10 @@ func (f *fuzzer) fuzz() int {
// The function must return
// 1 if the fuzzer should increase priority of the
-// given input during subsequent fuzzing (for example, the input is lexically
-// correct and was parsed successfully);
+//
+// given input during subsequent fuzzing (for example, the input is lexically
+// correct and was parsed successfully);
+//
// -1 if the input must not be added to corpus even if gives new coverage; and
// 0 otherwise; other values are reserved for future use.
func Fuzz(input []byte) int {
diff --git a/trie/committer.go b/trie/committer.go
index 0721990a21..b74572ee27 100644
--- a/trie/committer.go
+++ b/trie/committer.go
@@ -91,7 +91,7 @@ func (c *committer) commit(n node, db *Database) (node, int, error) {
if hash != nil && !dirty {
return hash, 0, nil
}
- // Commit children, then parent, and remove remove the dirty flag.
+ // Commit children, then parent, and remove the dirty flag.
switch cn := n.(type) {
case *shortNode:
// Commit child
diff --git a/trie/iterator_test.go b/trie/iterator_test.go
index 1f984c0f4b..679ae2cdcc 100644
--- a/trie/iterator_test.go
+++ b/trie/iterator_test.go
@@ -24,6 +24,7 @@ import (
"testing"
"github.com/ethereum/go-ethereum/common"
+ "github.com/ethereum/go-ethereum/core/rawdb"
"github.com/ethereum/go-ethereum/crypto"
"github.com/ethereum/go-ethereum/ethdb"
"github.com/ethereum/go-ethereum/ethdb/memorydb"
@@ -296,7 +297,7 @@ func TestUnionIterator(t *testing.T) {
}
func TestIteratorNoDups(t *testing.T) {
- var tr Trie
+ tr, _ := New(common.Hash{}, NewDatabase(rawdb.NewMemoryDatabase()))
for _, val := range testdata1 {
tr.Update([]byte(val.k), []byte(val.v))
}
diff --git a/trie/proof.go b/trie/proof.go
index 51ecea0c39..2c2da9cb82 100644
--- a/trie/proof.go
+++ b/trie/proof.go
@@ -23,7 +23,6 @@ import (
"github.com/ethereum/go-ethereum/common"
"github.com/ethereum/go-ethereum/ethdb"
- "github.com/ethereum/go-ethereum/ethdb/memorydb"
"github.com/ethereum/go-ethereum/log"
"github.com/ethereum/go-ethereum/rlp"
)
@@ -335,9 +334,9 @@ findFork:
// unset removes all internal node references either the left most or right most.
// It can meet these scenarios:
//
-// - The given path is existent in the trie, unset the associated nodes with the
-// specific direction
-// - The given path is non-existent in the trie
+// - The given path is existent in the trie, unset the associated nodes with the
+// specific direction
+// - The given path is non-existent in the trie
// - the fork point is a fullnode, the corresponding child pointed by path
// is nil, return
// - the fork point is a shortnode, the shortnode is included in the range,
@@ -452,15 +451,15 @@ func hasRightElement(node node, key []byte) bool {
// Expect the normal case, this function can also be used to verify the following
// range proofs:
//
-// - All elements proof. In this case the proof can be nil, but the range should
-// be all the leaves in the trie.
+// - All elements proof. In this case the proof can be nil, but the range should
+// be all the leaves in the trie.
//
-// - One element proof. In this case no matter the edge proof is a non-existent
-// proof or not, we can always verify the correctness of the proof.
+// - One element proof. In this case no matter the edge proof is a non-existent
+// proof or not, we can always verify the correctness of the proof.
//
-// - Zero element proof. In this case a single non-existent proof is enough to prove.
-// Besides, if there are still some other leaves available on the right side, then
-// an error will be returned.
+// - Zero element proof. In this case a single non-existent proof is enough to prove.
+// Besides, if there are still some other leaves available on the right side, then
+// an error will be returned.
//
// Except returning the error to indicate the proof is valid or not, the function will
// also return a flag to indicate whether there exists more accounts/slots in the trie.
@@ -553,7 +552,7 @@ func VerifyRangeProof(rootHash common.Hash, firstKey []byte, lastKey []byte, key
}
// Rebuild the trie with the leaf stream, the shape of trie
// should be same with the original one.
- tr := &Trie{root: root, db: NewDatabase(memorydb.New())}
+ tr := newWithRootNode(root)
if empty {
tr.root = nil
}
diff --git a/trie/proof_test.go b/trie/proof_test.go
index 95ad6169c3..19ca51e259 100644
--- a/trie/proof_test.go
+++ b/trie/proof_test.go
@@ -26,6 +26,7 @@ import (
"time"
"github.com/ethereum/go-ethereum/common"
+ "github.com/ethereum/go-ethereum/core/rawdb"
"github.com/ethereum/go-ethereum/crypto"
"github.com/ethereum/go-ethereum/ethdb/memorydb"
)
@@ -79,7 +80,7 @@ func TestProof(t *testing.T) {
}
func TestOneElementProof(t *testing.T) {
- trie := new(Trie)
+ trie, _ := New(common.Hash{}, NewDatabase(rawdb.NewMemoryDatabase()))
updateString(trie, "k", "v")
for i, prover := range makeProvers(trie) {
proof := prover([]byte("k"))
@@ -130,7 +131,7 @@ func TestBadProof(t *testing.T) {
// Tests that missing keys can also be proven. The test explicitly uses a single
// entry trie and checks for missing keys both before and after the single entry.
func TestMissingKeyProof(t *testing.T) {
- trie := new(Trie)
+ trie, _ := New(common.Hash{}, NewDatabase(rawdb.NewMemoryDatabase()))
updateString(trie, "k", "v")
for i, key := range []string{"a", "j", "l", "z"} {
@@ -386,7 +387,7 @@ func TestOneElementRangeProof(t *testing.T) {
}
// Test the mini trie with only a single element.
- tinyTrie := new(Trie)
+ tinyTrie, _ := New(common.Hash{}, NewDatabase(rawdb.NewMemoryDatabase()))
entry := &kv{randBytes(32), randBytes(20), false}
tinyTrie.Update(entry.k, entry.v)
@@ -458,7 +459,7 @@ func TestAllElementsProof(t *testing.T) {
// TestSingleSideRangeProof tests the range starts from zero.
func TestSingleSideRangeProof(t *testing.T) {
for i := 0; i < 64; i++ {
- trie := new(Trie)
+ trie, _ := New(common.Hash{}, NewDatabase(rawdb.NewMemoryDatabase()))
var entries entrySlice
for i := 0; i < 4096; i++ {
value := &kv{randBytes(32), randBytes(20), false}
@@ -493,7 +494,7 @@ func TestSingleSideRangeProof(t *testing.T) {
// TestReverseSingleSideRangeProof tests the range ends with 0xffff...fff.
func TestReverseSingleSideRangeProof(t *testing.T) {
for i := 0; i < 64; i++ {
- trie := new(Trie)
+ trie, _ := New(common.Hash{}, NewDatabase(rawdb.NewMemoryDatabase()))
var entries entrySlice
for i := 0; i < 4096; i++ {
value := &kv{randBytes(32), randBytes(20), false}
@@ -600,7 +601,7 @@ func TestBadRangeProof(t *testing.T) {
// TestGappedRangeProof focuses on the small trie with embedded nodes.
// If the gapped node is embedded in the trie, it should be detected too.
func TestGappedRangeProof(t *testing.T) {
- trie := new(Trie)
+ trie, _ := New(common.Hash{}, NewDatabase(rawdb.NewMemoryDatabase()))
var entries []*kv // Sorted entries
for i := byte(0); i < 10; i++ {
value := &kv{common.LeftPadBytes([]byte{i}, 32), []byte{i}, false}
@@ -674,7 +675,7 @@ func TestSameSideProofs(t *testing.T) {
}
func TestHasRightElement(t *testing.T) {
- trie := new(Trie)
+ trie, _ := New(common.Hash{}, NewDatabase(rawdb.NewMemoryDatabase()))
var entries entrySlice
for i := 0; i < 4096; i++ {
value := &kv{randBytes(32), randBytes(20), false}
@@ -1027,7 +1028,7 @@ func benchmarkVerifyRangeNoProof(b *testing.B, size int) {
}
func randomTrie(n int) (*Trie, map[string]*kv) {
- trie := new(Trie)
+ trie, _ := New(common.Hash{}, NewDatabase(rawdb.NewMemoryDatabase()))
vals := make(map[string]*kv)
for i := byte(0); i < 100; i++ {
value := &kv{common.LeftPadBytes([]byte{i}, 32), []byte{i}, false}
@@ -1052,7 +1053,7 @@ func randBytes(n int) []byte {
}
func nonRandomTrie(n int) (*Trie, map[string]*kv) {
- trie := new(Trie)
+ trie, _ := New(common.Hash{}, NewDatabase(rawdb.NewMemoryDatabase()))
vals := make(map[string]*kv)
max := uint64(0xffffffffffffffff)
for i := uint64(0); i < uint64(n); i++ {
diff --git a/trie/secure_trie_test.go b/trie/secure_trie_test.go
index fb6c38ee22..a3ece84b57 100644
--- a/trie/secure_trie_test.go
+++ b/trie/secure_trie_test.go
@@ -112,8 +112,7 @@ func TestSecureTrieConcurrency(t *testing.T) {
threads := runtime.NumCPU()
tries := make([]*SecureTrie, threads)
for i := 0; i < threads; i++ {
- cpy := *trie
- tries[i] = &cpy
+ tries[i] = trie.Copy()
}
// Start a batch of goroutines interactng with the trie
pend := new(sync.WaitGroup)
diff --git a/trie/trie.go b/trie/trie.go
index eb53258700..79ed3176f0 100644
--- a/trie/trie.go
+++ b/trie/trie.go
@@ -24,6 +24,7 @@ import (
"sync"
"github.com/ethereum/go-ethereum/common"
+ "github.com/ethereum/go-ethereum/core/rawdb"
"github.com/ethereum/go-ethereum/core/types"
"github.com/ethereum/go-ethereum/crypto"
"github.com/ethereum/go-ethereum/log"
@@ -66,6 +67,8 @@ type Trie struct {
// hashing operation. This number will not directly map to the number of
// actually unhashed nodes
unhashed int
+
+ tracer *tracer
}
// newFlag returns the cache flag value for a newly created node.
@@ -73,6 +76,16 @@ func (t *Trie) newFlag() nodeFlag {
return nodeFlag{dirty: true}
}
+// newWithRootNode initializes the trie with the given root node.
+// It's only used by range prover.
+func newWithRootNode(root node) *Trie {
+ return &Trie{
+ root: root,
+ //tracer: newTracer(),
+ db: NewDatabase(rawdb.NewMemoryDatabase()),
+ }
+}
+
// New creates a trie with an existing root node from db.
//
// If root is the zero hash or the sha3 hash of an empty string, the
@@ -85,6 +98,7 @@ func New(root common.Hash, db *Database) (*Trie, error) {
}
trie := &Trie{
db: db,
+ //tracer: newTracer(),
}
if root != (common.Hash{}) && root != emptyRoot {
rootnode, err := trie.resolveHash(root[:], nil)
@@ -317,6 +331,11 @@ func (t *Trie) insert(n node, prefix, key []byte, value node) (bool, node, error
if matchlen == 0 {
return true, branch, nil
}
+
+ // New branch node is created as a child of the original short node.
+ // Track the newly inserted node in the tracer. The node identifier
+ // passed is the path from the root node.
+ t.tracer.onInsert(append(prefix, key[:matchlen]...))
// Otherwise, replace it with a short node leading up to the branch.
return true, &shortNode{key[:matchlen], branch, t.newFlag()}, nil
@@ -331,6 +350,10 @@ func (t *Trie) insert(n node, prefix, key []byte, value node) (bool, node, error
return true, n, nil
case nil:
+ // New short node is created and track it in the tracer. The node identifier
+ // passed is the path from the root node. Note the valueNode won't be tracked
+ // since it's always embedded in its parent.
+ t.tracer.onInsert(prefix)
return true, &shortNode{key, value, t.newFlag()}, nil
case hashNode:
@@ -372,6 +395,33 @@ func (t *Trie) TryDelete(key []byte) error {
return nil
}
+// traverse method mostly for learning and testing purposes.
+func (t *Trie) traverse(n node, prefix []byte) {
+ switch n := n.(type) {
+ case *shortNode:
+ // If it's a short node, print the prefix and key
+ newPrefix := append(prefix, n.Key...)
+ fmt.Printf("[Traverse] Short Node: %+v\n", n)
+ t.traverse(n.Val, newPrefix)
+ case *fullNode:
+ // If it's a full node, print the prefix and each child
+
+ for i, _ := range n.Children {
+ if n.Children[i] != nil {
+ fmt.Printf("[Traverse] Full Node: %+v\n", n.Children[i])
+ newPrefix := append(prefix, byte(i))
+ t.traverse(n.Children[i], newPrefix)
+ }
+ }
+ case valueNode:
+ // If it's a value node, print the prefix and value
+ fmt.Printf("Value Node: %s -> %s\n", string(prefix), string(n))
+ case hashNode:
+ // If it's a hash node, resolve it and traverse the result
+ fmt.Printf("Hash Node: %s -> %s \n", string(prefix), string(n))
+ }
+}
+
// delete returns the new root of the trie with key deleted.
// It reduces the trie to minimal form by simplifying
// nodes on the way up after deleting recursively.
@@ -383,6 +433,10 @@ func (t *Trie) delete(n node, prefix, key []byte) (bool, node, error) {
return false, n, nil // don't replace n on mismatch
}
if matchlen == len(key) {
+ // It means that matched short node is deleted entirely, and track
+ // it in the deletion set. The same the valueNode doesn't need
+ // to be tracked at all since it's always be embedded in its parent.
+ t.tracer.onDelete(prefix)
return true, nil, nil // remove n entirely for whole matches
}
// The key is longer than n.Key. Remove the remaining suffix
@@ -395,6 +449,10 @@ func (t *Trie) delete(n node, prefix, key []byte) (bool, node, error) {
}
switch child := child.(type) {
case *shortNode:
+ // The child shortNode is merged into its parent, track
+ // is deleted as well.
+ t.tracer.onDelete(append(prefix, n.Key...))
+
// Deleting from the subtrie reduced it to another
// short node. Merge the nodes to avoid creating a
// shortNode{..., shortNode{...}}. Use concat (which
@@ -456,6 +514,11 @@ func (t *Trie) delete(n node, prefix, key []byte) (bool, node, error) {
return false, nil, err
}
if cnode, ok := cnode.(*shortNode); ok {
+ // Replace the entire full node with the short node.
+ // Mark the original short nodes as delete since the value
+ // is embedded in its parent now.
+ t.tracer.onDelete(append(prefix, byte(pos)))
+
k := append([]byte{byte(pos)}, cnode.Key...)
return true, &shortNode{k, cnode.Val, t.newFlag()}, nil
}
@@ -528,6 +591,9 @@ func (t *Trie) Commit(onleaf LeafCallback) (common.Hash, int, error) {
if t.db == nil {
panic("commit called on trie with nil database")
}
+
+ defer t.tracer.reset()
+
if t.root == nil {
return emptyRoot, 0, nil
}
@@ -586,6 +652,7 @@ func (t *Trie) hashRoot() (node, node, error) {
func (t *Trie) Reset() {
t.root = nil
t.unhashed = 0
+ t.tracer.reset()
}
// Copy returns a copy of Trie.
@@ -594,5 +661,6 @@ func (t *Trie) Copy() *Trie {
db: t.db,
root: t.root,
unhashed: t.unhashed,
+ tracer: t.tracer.copy(),
}
}
diff --git a/trie/utils.go b/trie/utils.go
new file mode 100644
index 0000000000..be5e491bd8
--- /dev/null
+++ b/trie/utils.go
@@ -0,0 +1,134 @@
+// Copyright 2022 The go-ethereum Authors
+// This file is part of the go-ethereum library.
+//
+// The go-ethereum library is free software: you can redistribute it and/or modify
+// it under the terms of the GNU Lesser General Public License as published by
+// the Free Software Foundation, either version 3 of the License, or
+// (at your option) any later version.
+//
+// The go-ethereum library is distributed in the hope that it will be useful,
+// but WITHOUT ANY WARRANTY; without even the implied warranty of
+// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+// GNU Lesser General Public License for more details.
+//
+// You should have received a copy of the GNU Lesser General Public License
+// along with the go-ethereum library. If not, see .
+
+package trie
+
+// tracer tracks the changes of trie nodes. During the trie operations,
+// some nodes can be deleted from the trie, while these deleted nodes
+// won't be captured by trie.Hasher or trie.Commiter. Thus, these deleted
+// nodes won't be removed from the disk at all. Tracer is an auxiliary tool
+// used to track all insert and delete operations on trie and capture all
+// deleted nodes eventually.
+//
+// The changed nodes can be mainly divided into two categories: the leaf
+// nodes and intermediate nodes. The fromer is inserted/deleted by callers
+// white the latter is iserted/deleted in order to follow the rule of trie.
+// This tool can track all of them no matter is embedded in its
+// parent or nit, but the valueNode is never tracked.
+//
+// Note tracer is not thread-safe, callers should be responsible for handling
+// the concurrency issues by themselves.
+type tracer struct {
+ insert map[string]struct{}
+ delete map[string]struct{}
+}
+
+// newTracer initlializes tride node diff tracer.
+func newTracer() *tracer {
+ return &tracer{
+ insert: make(map[string]struct{}),
+ delete: make(map[string]struct{}),
+ }
+}
+
+// onInsert tracks the newly inserted trie node. If it's already
+// in the delete set(resurrected node), then just wipe it from
+// the deletion set as it's untouched.
+func (t *tracer) onInsert(key []byte) {
+ // Tracer isn't used right now, remove this check latter.
+ if t == nil {
+ return
+ }
+ // If the key is in the delete set, then it's a resurrected node, then wipe it.
+ if _, present := t.delete[string(key)]; present {
+ delete(t.delete, string(key))
+ return
+ }
+ t.insert[string(key)] = struct{}{}
+}
+
+// OnDelete tracks the newly deleted trie node. If it's already
+// in the addition set, then just wipe it from the addtion set
+// as it's untouched.
+func (t *tracer) onDelete(key []byte) {
+ // Tracer isn't used right now, remove this check latter.
+ if t == nil {
+ return
+ }
+ if _, present := t.insert[string(key)]; present {
+ delete(t.insert, string(key))
+ return
+ }
+ t.delete[string(key)] = struct{}{}
+}
+
+// insertList returns the tracked inserted trie nodes in list format.
+func (t *tracer) insertList() [][]byte {
+ // Tracer isn't used right now, remove this check later.
+ if t == nil {
+ return nil
+ }
+ var ret [][]byte
+ for key := range t.insert {
+ ret = append(ret, []byte(key))
+ }
+ return ret
+}
+
+// deleteList returns the tracked deleted trie nodes in list format.
+func (t *tracer) deleteList() [][]byte {
+ // Tracer isn't used right now, remove this check later.
+ if t == nil {
+ return nil
+ }
+ var ret [][]byte
+ for key := range t.delete {
+ ret = append(ret, []byte(key))
+ }
+ return ret
+}
+
+// reset clears the content tracked by tracer.
+func (t *tracer) reset() {
+ // Tracer isn't used right now, remove this check later.
+ if t == nil {
+ return
+ }
+ t.insert = make(map[string]struct{})
+ t.delete = make(map[string]struct{})
+}
+
+// copy returns a deep copied tracer instance.
+func (t *tracer) copy() *tracer {
+ // Tracer isn't used right now, remove this check later.
+ if t == nil {
+ return nil
+ }
+ var (
+ insert = make(map[string]struct{})
+ delete = make(map[string]struct{})
+ )
+ for key := range t.insert {
+ insert[key] = struct{}{}
+ }
+ for key := range t.delete {
+ delete[key] = struct{}{}
+ }
+ return &tracer{
+ insert: insert,
+ delete: delete,
+ }
+}
diff --git a/trie/utils_test.go b/trie/utils_test.go
new file mode 100644
index 0000000000..fadb0553b5
--- /dev/null
+++ b/trie/utils_test.go
@@ -0,0 +1,122 @@
+// Copyright 2022 The go-ethereum Authors
+// This file is part of the go-ethereum library.
+//
+// The go-ethereum library is free software: you can redistribute it and/or modify
+// it under the terms of the GNU Lesser General Public License as published by
+// the Free Software Foundation, either version 3 of the License, or
+// (at your option) any later version.
+//
+// The go-ethereum library is distributed in the hope that it will be useful,
+// but WITHOUT ANY WARRANTY; without even the implied warranty of
+// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+// GNU Lesser General Public License for more details.
+//
+// You should have received a copy of the GNU Lesser General Public License
+// along with the go-ethereum library. If not, see .
+
+package trie
+
+import (
+ "testing"
+
+ "github.com/ethereum/go-ethereum/common"
+ "github.com/ethereum/go-ethereum/core/rawdb"
+)
+
+// Tests if the trie diffs are tracked correctly.
+func TestTrieTracer(t *testing.T) {
+ db := NewDatabase(rawdb.NewMemoryDatabase())
+ trie, _ := New(common.Hash{}, db)
+ trie.tracer = newTracer()
+
+ // Insert a batch of entries, all the nodes should be marked as inserted
+ vals := []struct{ k, v string }{
+ {"do", "verb"},
+ {"ether", "wookiedoo"},
+ {"horse", "stallion"},
+ {"shaman", "horse"},
+ {"doge", "coin"},
+ {"dog", "puppy"},
+ {"somethingveryoddindeedthis is", "myothernodedata"},
+ }
+ for _, val := range vals {
+ trie.Update([]byte(val.k), []byte(val.v))
+ }
+ trie.Hash()
+
+ seen := make(map[string]struct{})
+ it := trie.NodeIterator(nil)
+ for it.Next(true) {
+ if it.Leaf() {
+ continue
+ }
+ seen[string(it.Path())] = struct{}{}
+ }
+ inserted := trie.tracer.insertList()
+ if len(inserted) != len(seen) {
+ t.Fatalf("Unexpected inserted node tracked want %d got %d", len(seen), len(inserted))
+ }
+ for _, k := range inserted {
+ _, ok := seen[string(k)]
+ if !ok {
+ t.Fatalf("Unexpected inserted node")
+ }
+ }
+ deleted := trie.tracer.deleteList()
+ if len(deleted) != 0 {
+ t.Fatalf("Unexpected deleted node tracked %d", len(deleted))
+ }
+
+ // Commit the changes
+ trie.Commit(nil)
+
+ // Delete all the elements, check deletion set
+ for _, val := range vals {
+ trie.Delete([]byte(val.k))
+ }
+ trie.Hash()
+
+ inserted = trie.tracer.insertList()
+ if len(inserted) != 0 {
+ t.Fatalf("Unexpected inserted node tracked %d", len(inserted))
+ }
+ deleted = trie.tracer.deleteList()
+ if len(deleted) != len(seen) {
+ t.Fatalf("Unexpected deleted node tracked want %d got %d", len(seen), len(deleted))
+ }
+ for _, k := range deleted {
+ _, ok := seen[string(k)]
+ if !ok {
+ t.Fatalf("Unexpected inserted node")
+ }
+ }
+}
+
+func TestTrieTracerNoop(t *testing.T) {
+ db := NewDatabase(rawdb.NewMemoryDatabase())
+ trie, _ := New(common.Hash{}, db)
+ trie.tracer = newTracer()
+
+ // Insert a batch of entries, all the nodes should be marked as inserted
+ vals := []struct{ k, v string }{
+ {"do", "verb"},
+ {"ether", "wookiedoo"},
+ {"horse", "stallion"},
+ {"shaman", "horse"},
+ {"doge", "coin"},
+ {"dog", "puppy"},
+ {"somethingveryoddindeedthis is", "myothernodedata"},
+ }
+ for _, val := range vals {
+ trie.Update([]byte(val.k), []byte(val.v))
+ }
+ for _, val := range vals {
+ trie.Delete([]byte(val.k))
+ }
+ if len(trie.tracer.insertList()) != 0 {
+ t.Fatalf("Unexpected inserted node tracked %d", len(trie.tracer.insertList()))
+ }
+ if len(trie.tracer.deleteList()) != 0 {
+ t.Fatalf("Unexpected deleted node tracked %d", len(trie.tracer.deleteList()))
+ }
+}