Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix(state): node hashes vs merkle values #2915

Merged
merged 2 commits into from
Feb 4, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 5 additions & 5 deletions dot/state/offline_pruner.go
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,7 @@ func (p *OfflinePruner) SetBloomFilter() (err error) {
}

latestBlockNum := header.Number
merkleValues := make(map[string]struct{})
nodeHashes := make(map[common.Hash]struct{})

logger.Infof("Latest block number is %d", latestBlockNum)

Expand All @@ -132,7 +132,7 @@ func (p *OfflinePruner) SetBloomFilter() (err error) {
return err
}

trie.PopulateNodeHashes(tr.RootNode(), merkleValues)
trie.PopulateNodeHashes(tr.RootNode(), nodeHashes)

// get parent header of current block
header, err = p.blockState.GetHeader(header.ParentHash)
Expand All @@ -142,14 +142,14 @@ func (p *OfflinePruner) SetBloomFilter() (err error) {
blockNum = header.Number
}

for key := range merkleValues {
err = p.filterDatabase.Put([]byte(key), nil)
for key := range nodeHashes {
err = p.filterDatabase.Put(key.ToBytes(), nil)
if err != nil {
return err
}
}

logger.Infof("Total keys added in bloom filter: %d", len(merkleValues))
logger.Infof("Total keys added in filter database: %d", len(nodeHashes))
return nil
}

Expand Down
4 changes: 2 additions & 2 deletions dot/state/pruner/pruner.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,15 +33,15 @@ type Config struct {

// Pruner is implemented by FullNode and ArchiveNode.
type Pruner interface {
StoreJournalRecord(deletedMerkleValues, insertedMerkleValues map[string]struct{},
StoreJournalRecord(deletedNodeHashes, insertedNodeHashes map[common.Hash]struct{},
blockHash common.Hash, blockNum int64) error
}

// ArchiveNode is a no-op since we don't prune nodes in archive mode.
type ArchiveNode struct{}

// StoreJournalRecord for archive node doesn't do anything.
func (*ArchiveNode) StoreJournalRecord(_, _ map[string]struct{},
func (*ArchiveNode) StoreJournalRecord(_, _ map[common.Hash]struct{},
_ common.Hash, _ int64) error {
return nil
}
8 changes: 4 additions & 4 deletions dot/state/storage.go
Original file line number Diff line number Diff line change
Expand Up @@ -64,14 +64,14 @@ func (s *StorageState) StoreTrie(ts *rtstorage.TrieState, header *types.Header)
s.tries.softSet(root, ts.Trie())

if header != nil {
insertedMerkleValues, deletedMerkleValues, err := ts.GetChangedNodeHashes()
insertedNodeHashes, deletedNodeHashes, err := ts.GetChangedNodeHashes()
if err != nil {
return fmt.Errorf("failed to get state trie inserted keys: block %s %w", header.Hash(), err)
return fmt.Errorf("getting trie changed node hashes for block hash %s: %w", header.Hash(), err)
}

err = s.pruner.StoreJournalRecord(deletedMerkleValues, insertedMerkleValues, header.Hash(), int64(header.Number))
err = s.pruner.StoreJournalRecord(deletedNodeHashes, insertedNodeHashes, header.Hash(), int64(header.Number))
qdm12 marked this conversation as resolved.
Show resolved Hide resolved
if err != nil {
return err
return fmt.Errorf("storing journal record: %w", err)
}
}

Expand Down
2 changes: 1 addition & 1 deletion lib/runtime/storage/trie.go
Original file line number Diff line number Diff line change
Expand Up @@ -292,7 +292,7 @@ func (s *TrieState) LoadCodeHash() (common.Hash, error) {

// GetChangedNodeHashes returns the two sets of hashes for all nodes
// inserted and deleted in the state trie since the last block produced (trie snapshot).
func (s *TrieState) GetChangedNodeHashes() (inserted, deleted map[string]struct{}, err error) {
func (s *TrieState) GetChangedNodeHashes() (inserted, deleted map[common.Hash]struct{}, err error) {
s.lock.RLock()
defer s.lock.RUnlock()
return s.t.GetChangedNodeHashes()
Expand Down
72 changes: 43 additions & 29 deletions lib/trie/database.go
Original file line number Diff line number Diff line change
Expand Up @@ -68,33 +68,34 @@ func (t *Trie) loadNode(db Getter, n *Node) error {

merkleValue := child.MerkleValue

if len(merkleValue) == 0 {
if len(merkleValue) < 32 {
// node has already been loaded inline
// just set encoding + hash digest
// just set its encoding
_, err := child.CalculateMerkleValue()
if err != nil {
return fmt.Errorf("merkle value: %w", err)
}
continue
}

encodedNode, err := db.Get(merkleValue)
nodeHash := merkleValue
encodedNode, err := db.Get(nodeHash)
if err != nil {
return fmt.Errorf("cannot find child node key 0x%x in database: %w", merkleValue, err)
return fmt.Errorf("cannot find child node key 0x%x in database: %w", nodeHash, err)
}

reader := bytes.NewReader(encodedNode)
decodedNode, err := node.Decode(reader)
if err != nil {
return fmt.Errorf("decoding node with Merkle value 0x%x: %w", merkleValue, err)
return fmt.Errorf("decoding node with hash 0x%x: %w", nodeHash, err)
}

decodedNode.MerkleValue = merkleValue
decodedNode.MerkleValue = nodeHash
branch.Children[i] = decodedNode

err = t.loadNode(db, decodedNode)
if err != nil {
return fmt.Errorf("loading child at index %d with Merkle value 0x%x: %w", i, merkleValue, err)
return fmt.Errorf("loading child at index %d with node hash 0x%x: %w", i, nodeHash, err)
}

if decodedNode.Kind() == node.Branch {
Expand Down Expand Up @@ -132,7 +133,7 @@ func (t *Trie) loadNode(db Getter, n *Node) error {
// all its descendant nodes as keys to the nodeHashes map.
// It is assumed the node and its descendant nodes have their Merkle value already
// computed.
func PopulateNodeHashes(n *Node, nodeHashes map[string]struct{}) {
func PopulateNodeHashes(n *Node, nodeHashes map[common.Hash]struct{}) {
if n == nil {
return
}
Expand All @@ -148,7 +149,8 @@ func PopulateNodeHashes(n *Node, nodeHashes map[string]struct{}) {
return
}

nodeHashes[string(n.MerkleValue)] = struct{}{}
nodeHash := common.NewHash(n.MerkleValue)
nodeHashes[nodeHash] = struct{}{}

if n.Kind() == node.Leaf {
return
Expand Down Expand Up @@ -260,15 +262,15 @@ func getFromDBAtNode(db Getter, n *Node, key []byte) (
encodedChild, err := db.Get(childMerkleValue)
if err != nil {
return nil, fmt.Errorf(
"finding child node with Merkle value 0x%x in database: %w",
"finding child node with hash 0x%x in database: %w",
childMerkleValue, err)
}

reader := bytes.NewReader(encodedChild)
decodedChild, err := node.Decode(reader)
if err != nil {
return nil, fmt.Errorf(
"decoding child node with Merkle value 0x%x: %w",
"decoding child node with hash 0x%x: %w",
childMerkleValue, err)
}

Expand Down Expand Up @@ -305,11 +307,21 @@ func (t *Trie) writeDirtyNode(db Putter, n *Node) (err error) {
n.MerkleValue, err)
}

err = db.Put(merkleValue, encoding)
if len(merkleValue) < 32 {
// Merkle value is the node encoding which is less than 32 bytes.
// That means this node encoding is inlined in its parent node encoding,
// and so it is not needed to write it in the database.
n.SetClean()
return nil
}

nodeHash := merkleValue

err = db.Put(nodeHash, encoding)
if err != nil {
return fmt.Errorf(
"putting encoding of node with Merkle value 0x%x in database: %w",
merkleValue, err)
"putting encoding of node with node hash 0x%x in database: %w",
nodeHash, err)
}

if n.Kind() != node.Branch {
Expand Down Expand Up @@ -342,25 +354,20 @@ func (t *Trie) writeDirtyNode(db Putter, n *Node) (err error) {

// GetChangedNodeHashes returns the two sets of hashes for all nodes
// inserted and deleted in the state trie since the last snapshot.
// Returned maps are safe for mutation.
func (t *Trie) GetChangedNodeHashes() (inserted, deleted map[string]struct{}, err error) {
inserted = make(map[string]struct{})
// Returned inserted map is safe for mutation, but deleted is not safe for mutation.
func (t *Trie) GetChangedNodeHashes() (inserted, deleted map[common.Hash]struct{}, err error) {
inserted = make(map[common.Hash]struct{})
err = t.getInsertedNodeHashesAtNode(t.root, inserted)
if err != nil {
return nil, nil, fmt.Errorf("getting inserted node hashes: %w", err)
}

deletedNodeHashes := t.deltas.Deleted()
// TODO return deletedNodeHashes directly after changing MerkleValue -> NodeHash
deleted = make(map[string]struct{}, len(deletedNodeHashes))
for nodeHash := range deletedNodeHashes {
deleted[string(nodeHash[:])] = struct{}{}
}
deleted = t.deltas.Deleted()

return inserted, deleted, nil
}

func (t *Trie) getInsertedNodeHashesAtNode(n *Node, merkleValues map[string]struct{}) (err error) {
func (t *Trie) getInsertedNodeHashesAtNode(n *Node, nodeHashes map[common.Hash]struct{}) (err error) {
if n == nil || !n.Dirty {
return nil
}
Expand All @@ -372,12 +379,19 @@ func (t *Trie) getInsertedNodeHashesAtNode(n *Node, merkleValues map[string]stru
merkleValue, err = n.CalculateMerkleValue()
}
if err != nil {
return fmt.Errorf(
"encoding and hashing node with Merkle value 0x%x: %w",
n.MerkleValue, err)
return fmt.Errorf("calculating Merkle value: %w", err)
}

if len(merkleValue) < 32 {
// this is an inlined node and is encoded as part of its parent node.
// Therefore it is not written to disk and the online pruner does not
// need to track it. If the node encodes to less than 32B, it cannot have
// non-inlined children so it's safe to stop here and not recurse further.
return nil
}

merkleValues[string(merkleValue)] = struct{}{}
nodeHash := common.NewHash(merkleValue)
nodeHashes[nodeHash] = struct{}{}

if n.Kind() != node.Branch {
return nil
Expand All @@ -388,7 +402,7 @@ func (t *Trie) getInsertedNodeHashesAtNode(n *Node, merkleValues map[string]stru
continue
}

err := t.getInsertedNodeHashesAtNode(child, merkleValues)
err := t.getInsertedNodeHashesAtNode(child, nodeHashes)
if err != nil {
// Note: do not wrap error since this is called recursively.
return err
Expand Down
47 changes: 27 additions & 20 deletions lib/trie/database_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import (
"testing"

"github.com/ChainSafe/chaindb"
"github.com/ChainSafe/gossamer/lib/common"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
Expand Down Expand Up @@ -161,28 +162,34 @@ func Test_Trie_WriteDirty_ClearPrefix(t *testing.T) {
func Test_PopulateNodeHashes(t *testing.T) {
t.Parallel()

const (
merkleValue32Zeroes = "00000000000000000000000000000000"
merkleValue32Ones = "11111111111111111111111111111111"
merkleValue32Twos = "22222222222222222222222222222222"
merkleValue32Threes = "33333333333333333333333333333333"
var (
merkleValue32Zeroes = common.Hash{}
merkleValue32Ones = common.Hash{
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1}
merkleValue32Twos = common.Hash{
2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2}
merkleValue32Threes = common.Hash{
3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3,
3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3}
)

testCases := map[string]struct {
node *Node
nodeHashes map[string]struct{}
nodeHashes map[common.Hash]struct{}
panicValue interface{}
}{
"nil_node": {
nodeHashes: map[string]struct{}{},
nodeHashes: map[common.Hash]struct{}{},
},
"inlined_leaf_node": {
node: &Node{MerkleValue: []byte("a")},
nodeHashes: map[string]struct{}{},
nodeHashes: map[common.Hash]struct{}{},
},
"leaf_node": {
node: &Node{MerkleValue: []byte(merkleValue32Zeroes)},
nodeHashes: map[string]struct{}{
node: &Node{MerkleValue: merkleValue32Zeroes.ToBytes()},
nodeHashes: map[common.Hash]struct{}{
merkleValue32Zeroes: {},
},
},
Expand All @@ -197,34 +204,34 @@ func Test_PopulateNodeHashes(t *testing.T) {
{MerkleValue: []byte("b")},
}),
},
nodeHashes: map[string]struct{}{},
nodeHashes: map[common.Hash]struct{}{},
},
"branch_node": {
node: &Node{
MerkleValue: []byte(merkleValue32Zeroes),
MerkleValue: merkleValue32Zeroes.ToBytes(),
Children: padRightChildren([]*Node{
{MerkleValue: []byte(merkleValue32Ones)},
{MerkleValue: merkleValue32Ones.ToBytes()},
}),
},
nodeHashes: map[string]struct{}{
nodeHashes: map[common.Hash]struct{}{
merkleValue32Zeroes: {},
merkleValue32Ones: {},
},
},
"nested_branch_node": {
node: &Node{
MerkleValue: []byte(merkleValue32Zeroes),
MerkleValue: merkleValue32Zeroes.ToBytes(),
Children: padRightChildren([]*Node{
{MerkleValue: []byte(merkleValue32Ones)},
{MerkleValue: merkleValue32Ones.ToBytes()},
{
MerkleValue: []byte(merkleValue32Twos),
MerkleValue: merkleValue32Twos.ToBytes(),
Children: padRightChildren([]*Node{
{MerkleValue: []byte(merkleValue32Threes)},
{MerkleValue: merkleValue32Threes.ToBytes()},
}),
},
}),
},
nodeHashes: map[string]struct{}{
nodeHashes: map[common.Hash]struct{}{
merkleValue32Zeroes: {},
merkleValue32Ones: {},
merkleValue32Twos: {},
Expand All @@ -238,7 +245,7 @@ func Test_PopulateNodeHashes(t *testing.T) {
t.Run(name, func(t *testing.T) {
t.Parallel()

nodeHashes := make(map[string]struct{})
nodeHashes := make(map[common.Hash]struct{})

if testCase.panicValue != nil {
assert.PanicsWithValue(t, testCase.panicValue, func() {
Expand Down
11 changes: 7 additions & 4 deletions lib/trie/proof/generate.go
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ func Generate(rootHash []byte, fullKeys [][]byte, database Database) (
buffer := pools.DigestBuffers.Get().(*bytes.Buffer)
defer pools.DigestBuffers.Put(buffer)

merkleValuesSeen := make(map[string]struct{})
nodeHashesSeen := make(map[common.Hash]struct{})
for _, fullKey := range fullKeys {
fullKeyNibbles := codec.KeyLEToNibbles(fullKey)
newEncodedProofNodes, err := walkRoot(rootNode, fullKeyNibbles)
Expand All @@ -56,13 +56,16 @@ func Generate(rootHash []byte, fullKeys [][]byte, database Database) (
if err != nil {
return nil, fmt.Errorf("blake2b hash: %w", err)
}
merkleValueString := buffer.String()
// Note: all encoded proof nodes are larger than 32B so their
// merkle value is the encoding hash digest (32B) and never the
// encoding itself.
nodeHash := common.NewHash(buffer.Bytes())

_, seen := merkleValuesSeen[merkleValueString]
_, seen := nodeHashesSeen[nodeHash]
if seen {
continue
}
merkleValuesSeen[merkleValueString] = struct{}{}
nodeHashesSeen[nodeHash] = struct{}{}

encodedProofNodes = append(encodedProofNodes, encodedProofNode)
}
Expand Down
2 changes: 1 addition & 1 deletion lib/trie/proof/verify.go
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ func buildTrie(encodedProofNodes [][]byte, rootHash []byte) (t *trie.Trie, err e
buffer.Reset()
err = node.MerkleValueRoot(encodedProofNode, buffer)
if err != nil {
return nil, fmt.Errorf("calculating Merkle value: %w", err)
return nil, fmt.Errorf("calculating node hash: %w", err)
}
digest := buffer.Bytes()

Expand Down
Loading