diff --git a/core/state/statedb.go b/core/state/statedb.go index 48adadf085d6..5c33e2d7e130 100644 --- a/core/state/statedb.go +++ b/core/state/statedb.go @@ -32,6 +32,8 @@ import ( "github.com/ethereum/go-ethereum/log" "github.com/ethereum/go-ethereum/metrics" "github.com/ethereum/go-ethereum/params" + "github.com/ethereum/go-ethereum/rlp" + "github.com/ethereum/go-ethereum/trie" "github.com/ethereum/go-ethereum/trie/trienode" "github.com/ethereum/go-ethereum/trie/triestate" ) @@ -720,6 +722,43 @@ func (s *StateDB) CreateAccount(addr common.Address) { } } +func (s *StateDB) ForEachStorage(addr common.Address, cb func(key, value common.Hash) bool) error { + so := s.getStateObject(addr) + if so == nil { + return nil + } + tr, err := so.getTrie() + if err != nil { + return err + } + trieIt, err := tr.NodeIterator(nil) + if err != nil { + return err + } + it := trie.NewIterator(trieIt) + + for it.Next() { + key := common.BytesToHash(s.trie.GetKey(it.Key)) + if value, dirty := so.dirtyStorage[key]; dirty { + if !cb(key, value) { + return nil + } + continue + } + + if len(it.Value) > 0 { + _, content, _, err := rlp.Split(it.Value) + if err != nil { + return err + } + if !cb(key, common.BytesToHash(content)) { + return nil + } + } + } + return nil +} + // Copy creates a deep, independent copy of the state. // Snapshots of the copied state cannot be applied to the copy. func (s *StateDB) Copy() *StateDB { diff --git a/core/state/statedb_test.go b/core/state/statedb_test.go index 869410ff431b..c08a0d959fa7 100644 --- a/core/state/statedb_test.go +++ b/core/state/statedb_test.go @@ -35,7 +35,6 @@ import ( "github.com/ethereum/go-ethereum/core/state/snapshot" "github.com/ethereum/go-ethereum/core/types" "github.com/ethereum/go-ethereum/crypto" - "github.com/ethereum/go-ethereum/rlp" "github.com/ethereum/go-ethereum/trie" "github.com/ethereum/go-ethereum/trie/triedb/hashdb" "github.com/ethereum/go-ethereum/trie/triedb/pathdb" @@ -448,43 +447,6 @@ func (test *snapshotTest) run() bool { return true } -func forEachStorage(s *StateDB, addr common.Address, cb func(key, value common.Hash) bool) error { - so := s.getStateObject(addr) - if so == nil { - return nil - } - tr, err := so.getTrie() - if err != nil { - return err - } - trieIt, err := tr.NodeIterator(nil) - if err != nil { - return err - } - it := trie.NewIterator(trieIt) - - for it.Next() { - key := common.BytesToHash(s.trie.GetKey(it.Key)) - if value, dirty := so.dirtyStorage[key]; dirty { - if !cb(key, value) { - return nil - } - continue - } - - if len(it.Value) > 0 { - _, content, _, err := rlp.Split(it.Value) - if err != nil { - return err - } - if !cb(key, common.BytesToHash(content)) { - return nil - } - } - } - return nil -} - // checkEqual checks that methods of state and checkstate return the same values. func (test *snapshotTest) checkEqual(state, checkstate *StateDB) error { for _, addr := range test.addrs { @@ -506,10 +468,10 @@ func (test *snapshotTest) checkEqual(state, checkstate *StateDB) error { checkeq("GetCodeSize", state.GetCodeSize(addr), checkstate.GetCodeSize(addr)) // Check storage. if obj := state.getStateObject(addr); obj != nil { - forEachStorage(state, addr, func(key, value common.Hash) bool { + state.ForEachStorage(addr, func(key, value common.Hash) bool { return checkeq("GetState("+key.Hex()+")", checkstate.GetState(addr, key), value) }) - forEachStorage(checkstate, addr, func(key, value common.Hash) bool { + checkstate.ForEachStorage(addr, func(key, value common.Hash) bool { return checkeq("GetState("+key.Hex()+")", checkstate.GetState(addr, key), value) }) }