diff --git a/x/merkledb/codec.go b/x/merkledb/codec.go index c9837abb509f..973dd5888ab6 100644 --- a/x/merkledb/codec.go +++ b/x/merkledb/codec.go @@ -160,7 +160,7 @@ func (c *codecImpl) decodeDBNode(b []byte, n *dbNode) error { return io.ErrUnexpectedEOF } - n.children = make(map[byte]child, numChildren) + n.children = make(map[byte]*child, numChildren) var previousChild uint64 for i := uint64(0); i < numChildren; i++ { index, err := c.decodeUint(src) @@ -184,7 +184,7 @@ func (c *codecImpl) decodeDBNode(b []byte, n *dbNode) error { if err != nil { return err } - n.children[byte(index)] = child{ + n.children[byte(index)] = &child{ compressedKey: compressedKey, id: childID, hasValue: hasValue, diff --git a/x/merkledb/codec_test.go b/x/merkledb/codec_test.go index 00e5790b3171..699db9a4bd81 100644 --- a/x/merkledb/codec_test.go +++ b/x/merkledb/codec_test.go @@ -146,7 +146,7 @@ func FuzzCodecDBNodeDeterministic(f *testing.F) { numChildren := r.Intn(int(bf)) // #nosec G404 - children := map[byte]child{} + children := map[byte]*child{} for i := 0; i < numChildren; i++ { var childID ids.ID _, _ = r.Read(childID[:]) // #nosec G404 @@ -154,7 +154,7 @@ func FuzzCodecDBNodeDeterministic(f *testing.F) { childKeyBytes := make([]byte, r.Intn(32)) // #nosec G404 _, _ = r.Read(childKeyBytes) // #nosec G404 - children[byte(i)] = child{ + children[byte(i)] = &child{ compressedKey: ToKey(childKeyBytes), id: childID, } @@ -202,14 +202,14 @@ func FuzzEncodeHashValues(f *testing.F) { for _, bf := range validBranchFactors { // Create a random node r := rand.New(rand.NewSource(int64(randSeed))) // #nosec G404 - children := map[byte]child{} + children := map[byte]*child{} numChildren := r.Intn(int(bf)) // #nosec G404 for i := 0; i < numChildren; i++ { compressedKeyLen := r.Intn(32) // #nosec G404 compressedKeyBytes := make([]byte, compressedKeyLen) _, _ = r.Read(compressedKeyBytes) // #nosec G404 - children[byte(i)] = child{ + children[byte(i)] = &child{ compressedKey: ToKey(compressedKeyBytes), id: ids.GenerateTestID(), hasValue: r.Intn(2) == 1, // #nosec G404 diff --git a/x/merkledb/db.go b/x/merkledb/db.go index b1ee699bab97..c813a9478b1f 100644 --- a/x/merkledb/db.go +++ b/x/merkledb/db.go @@ -205,6 +205,7 @@ type merkleDB struct { // It is the node with a nil key and is the ancestor of all nodes in the trie. // If it has a value or has multiple children, it is also the root of the trie. sentinelNode *node + rootID ids.ID // Valid children of this trie. childViews []*trieView @@ -260,14 +261,13 @@ func newDatabase( tokenSize: BranchFactorToTokenSize[config.BranchFactor], } - root, err := trieDB.initializeRootIfNeeded() - if err != nil { + if err := trieDB.initializeRoot(); err != nil { return nil, err } // add current root to history (has no changes) trieDB.history.record(&changeSummary{ - rootID: root, + rootID: trieDB.rootID, values: map[Key]*change[maybe.Maybe[[]byte]]{}, nodes: map[Key]*change[*node]{}, }) @@ -578,13 +578,7 @@ func (db *merkleDB) GetMerkleRoot(ctx context.Context) (ids.ID, error) { // Assumes [db.lock] is read locked. func (db *merkleDB) getMerkleRoot() ids.ID { - if !isSentinelNodeTheRoot(db.sentinelNode) { - // if the sentinel node should be skipped, the trie's root is the nil key node's only child - for _, childEntry := range db.sentinelNode.children { - return childEntry.id - } - } - return db.sentinelNode.id + return db.rootID } // isSentinelNodeTheRoot returns true if the passed in sentinel node has a value and or multiple child nodes @@ -982,6 +976,7 @@ func (db *merkleDB) commitChanges(ctx context.Context, trieToCommit *trieView) e // Only modify in-memory state after the commit succeeds // so that we don't need to clean up on error. db.sentinelNode = sentinelChange.after + db.rootID = changes.rootID db.history.record(changes) return nil } @@ -1161,34 +1156,38 @@ func (db *merkleDB) invalidateChildrenExcept(exception *trieView) { } } -func (db *merkleDB) initializeRootIfNeeded() (ids.ID, error) { - // not sure if the sentinel node exists or if it had a value - // check under both prefixes +func (db *merkleDB) initializeRoot() error { + // Not sure if the sentinel node exists or if it had a value, + // so check under both prefixes var err error db.sentinelNode, err = db.intermediateNodeDB.Get(Key{}) + if errors.Is(err, database.ErrNotFound) { + // Didn't find the sentinel in the intermediateNodeDB, check the valueNodeDB db.sentinelNode, err = db.valueNodeDB.Get(Key{}) } - if err == nil { - // sentinel node already exists, so calculate the root ID of the trie - db.sentinelNode.calculateID(db.metrics) - return db.getMerkleRoot(), nil - } - if !errors.Is(err, database.ErrNotFound) { - return ids.Empty, err - } - - // sentinel node doesn't exist; make a new one. - db.sentinelNode = newNode(Key{}) - // update its ID - db.sentinelNode.calculateID(db.metrics) + if err != nil { + if !errors.Is(err, database.ErrNotFound) { + return err + } - if err := db.intermediateNodeDB.Put(Key{}, db.sentinelNode); err != nil { - return ids.Empty, err + // Sentinel node doesn't exist in either database prefix. + // Make a new one and store it in the intermediateNodeDB + db.sentinelNode = newNode(Key{}) + if err := db.intermediateNodeDB.Put(Key{}, db.sentinelNode); err != nil { + return err + } } - return db.sentinelNode.id, nil + db.rootID = db.sentinelNode.calculateID(db.metrics) + if !isSentinelNodeTheRoot(db.sentinelNode) { + // If the sentinel node is not the root, the trie's root is the sentinel node's only child + for _, childEntry := range db.sentinelNode.children { + db.rootID = childEntry.id + } + } + return nil } // Returns a view of the trie as it was when it had root [rootID] for keys within range [start, end]. @@ -1289,7 +1288,7 @@ func (db *merkleDB) Clear() error { // Clear root db.sentinelNode = newNode(Key{}) - db.sentinelNode.calculateID(db.metrics) + db.rootID = db.sentinelNode.calculateID(db.metrics) // Clear history db.history = newTrieHistory(db.history.maxHistoryLen) diff --git a/x/merkledb/history_test.go b/x/merkledb/history_test.go index 2ee1e5f4b31b..d2945c9c5018 100644 --- a/x/merkledb/history_test.go +++ b/x/merkledb/history_test.go @@ -660,8 +660,8 @@ func TestHistoryGetChangesToRoot(t *testing.T) { rootID: ids.GenerateTestID(), nodes: map[Key]*change[*node]{ ToKey([]byte{byte(i)}): { - before: &node{id: ids.GenerateTestID()}, - after: &node{id: ids.GenerateTestID()}, + before: &node{}, + after: &node{}, }, }, values: map[Key]*change[maybe.Maybe[[]byte]]{ diff --git a/x/merkledb/node.go b/x/merkledb/node.go index 3fd38021a0c8..9a63ef82c4a7 100644 --- a/x/merkledb/node.go +++ b/x/merkledb/node.go @@ -4,7 +4,6 @@ package merkledb import ( - "golang.org/x/exp/maps" "golang.org/x/exp/slices" "github.com/ava-labs/avalanchego/ids" @@ -17,7 +16,7 @@ const HashLength = 32 // Representation of a node stored in the database. type dbNode struct { value maybe.Maybe[[]byte] - children map[byte]child + children map[byte]*child } type child struct { @@ -29,7 +28,6 @@ type child struct { // node holds additional information on top of the dbNode that makes calculations easier to do type node struct { dbNode - id ids.ID key Key nodeBytes []byte valueDigest maybe.Maybe[[]byte] @@ -39,7 +37,7 @@ type node struct { func newNode(key Key) *node { return &node{ dbNode: dbNode{ - children: make(map[byte]child, 2), + children: make(map[byte]*child, 2), }, key: key, } @@ -78,19 +76,14 @@ func (n *node) bytes() []byte { // clear the cached values that will need to be recalculated whenever the node changes // for example, node ID and byte representation func (n *node) onNodeChanged() { - n.id = ids.Empty n.nodeBytes = nil } // Returns and caches the ID of this node. -func (n *node) calculateID(metrics merkleMetrics) { - if n.id != ids.Empty { - return - } - +func (n *node) calculateID(metrics merkleMetrics) ids.ID { metrics.HashCalculated() bytes := codec.encodeHashValues(n) - n.id = hashing.ComputeHash256Array(bytes) + return hashing.ComputeHash256Array(bytes) } // Set [n]'s value to [val]. @@ -114,16 +107,15 @@ func (n *node) setValueDigest() { func (n *node) addChild(childNode *node, tokenSize int) { n.setChildEntry( childNode.key.Token(n.key.length, tokenSize), - child{ + &child{ compressedKey: childNode.key.Skip(n.key.length + tokenSize), - id: childNode.id, hasValue: childNode.hasValue(), }, ) } // Adds a child to [n] without a reference to the child node. -func (n *node) setChildEntry(index byte, childEntry child) { +func (n *node) setChildEntry(index byte, childEntry *child) { n.onNodeChanged() n.children[index] = childEntry } @@ -139,16 +131,23 @@ func (n *node) removeChild(child *node, tokenSize int) { // if this ever changes, value will need to be copied as well // it is safe to clone all fields because they are only written/read while one or both of the db locks are held func (n *node) clone() *node { - return &node{ - id: n.id, + result := &node{ key: n.key, dbNode: dbNode{ value: n.value, - children: maps.Clone(n.children), + children: make(map[byte]*child, len(n.children)), }, valueDigest: n.valueDigest, nodeBytes: n.nodeBytes, } + for key, existing := range n.children { + result.children[key] = &child{ + compressedKey: existing.compressedKey, + id: existing.id, + hasValue: existing.hasValue, + } + } + return result } // Returns the ProofNode representation of this node. diff --git a/x/merkledb/proof.go b/x/merkledb/proof.go index e348a83f0f13..39ceff3d3157 100644 --- a/x/merkledb/proof.go +++ b/x/merkledb/proof.go @@ -847,7 +847,7 @@ func addPathInfo( // We only need the IDs to be correct so that the calculated hash is correct. n.setChildEntry( index, - child{ + &child{ id: childID, compressedKey: compressedKey, }) diff --git a/x/merkledb/trieview.go b/x/merkledb/trieview.go index 622bfcb11207..730d4b0187d0 100644 --- a/x/merkledb/trieview.go +++ b/x/merkledb/trieview.go @@ -251,9 +251,15 @@ func (t *trieView) calculateNodeIDs(ctx context.Context) error { } _ = t.db.calculateNodeIDsSema.Acquire(context.Background(), 1) - t.calculateNodeIDsHelper(t.sentinelNode) + t.changes.rootID = t.calculateNodeIDsHelper(t.sentinelNode) t.db.calculateNodeIDsSema.Release(1) - t.changes.rootID = t.getMerkleRoot() + + // If the sentinel node is not the root, the trie's root is the sentinel node's only child + if !isSentinelNodeTheRoot(t.sentinelNode) { + for _, childEntry := range t.sentinelNode.children { + t.changes.rootID = childEntry.id + } + } // ensure no ancestor changes occurred during execution if t.isInvalid() { @@ -266,58 +272,40 @@ func (t *trieView) calculateNodeIDs(ctx context.Context) error { // Calculates the ID of all descendants of [n] which need to be recalculated, // and then calculates the ID of [n] itself. -func (t *trieView) calculateNodeIDsHelper(n *node) { - var ( - // We use [wg] to wait until all descendants of [n] have been updated. - wg sync.WaitGroup - updatedChildren = make(chan *node, len(n.children)) - ) +func (t *trieView) calculateNodeIDsHelper(n *node) ids.ID { + // We use [wg] to wait until all descendants of [n] have been updated. + var wg sync.WaitGroup - for childIndex, child := range n.children { - childKey := n.key.Extend(ToToken(childIndex, t.tokenSize), child.compressedKey) + for childIndex := range n.children { + childEntry := n.children[childIndex] + childKey := n.key.Extend(ToToken(childIndex, t.tokenSize), childEntry.compressedKey) childNodeChange, ok := t.changes.nodes[childKey] if !ok { // This child wasn't changed. continue } - - wg.Add(1) - calculateChildID := func() { - defer wg.Done() - - t.calculateNodeIDsHelper(childNodeChange.after) - - // Note that this will never block - updatedChildren <- childNodeChange.after - } + n.onNodeChanged() + childEntry.hasValue = childNodeChange.after.hasValue() // Try updating the child and its descendants in a goroutine. if ok := t.db.calculateNodeIDsSema.TryAcquire(1); ok { + wg.Add(1) go func() { - calculateChildID() + childEntry.id = t.calculateNodeIDsHelper(childNodeChange.after) t.db.calculateNodeIDsSema.Release(1) + wg.Done() }() } else { // We're at the goroutine limit; do the work in this goroutine. - calculateChildID() + childEntry.id = t.calculateNodeIDsHelper(childNodeChange.after) } } // Wait until all descendants of [n] have been updated. wg.Wait() - close(updatedChildren) - - for updatedChild := range updatedChildren { - index := updatedChild.key.Token(n.key.length, t.tokenSize) - n.setChildEntry(index, child{ - compressedKey: n.children[index].compressedKey, - id: updatedChild.id, - hasValue: updatedChild.hasValue(), - }) - } // The IDs [n]'s descendants are up to date so we can calculate [n]'s ID. - n.calculateID(t.db.metrics) + return n.calculateID(t.db.metrics) } // GetProof returns a proof that [bytesPath] is in or not in trie [t]. @@ -381,8 +369,7 @@ func (t *trieView) getProof(ctx context.Context, key []byte) (*Proof, error) { return proof, nil } - childNode, err := t.getNodeWithID( - child.id, + childNode, err := t.getNode( closestNode.key.Extend(ToToken(nextIndex, t.tokenSize), child.compressedKey), child.hasValue, ) @@ -568,17 +555,7 @@ func (t *trieView) GetMerkleRoot(ctx context.Context) (ids.ID, error) { if err := t.calculateNodeIDs(ctx); err != nil { return ids.Empty, err } - return t.getMerkleRoot(), nil -} - -func (t *trieView) getMerkleRoot() ids.ID { - if !isSentinelNodeTheRoot(t.sentinelNode) { - for _, childEntry := range t.sentinelNode.children { - return childEntry.id - } - } - - return t.sentinelNode.id + return t.changes.rootID, nil } func (t *trieView) GetValues(ctx context.Context, keys [][]byte) ([][]byte, []error) { @@ -650,7 +627,7 @@ func (t *trieView) remove(key Key) error { } // confirm a node exists with a value - keyNode, err := t.getNodeWithID(ids.Empty, key, true) + keyNode, err := t.getNode(key, true) if err != nil { if errors.Is(err, database.ErrNotFound) { // key didn't exist @@ -719,7 +696,7 @@ func (t *trieView) compressNodePath(parent, node *node) error { } var ( - childEntry child + childEntry *child childKey Key ) // There is only one child, but we don't know the index. @@ -733,7 +710,7 @@ func (t *trieView) compressNodePath(parent, node *node) error { // [node] is the first node with multiple children. // combine it with the [node] passed in. parent.setChildEntry(childKey.Token(parent.key.length, t.tokenSize), - child{ + &child{ compressedKey: childKey.Skip(parent.key.length + t.tokenSize), id: childEntry.id, hasValue: childEntry.hasValue, @@ -765,7 +742,7 @@ func (t *trieView) visitPathToKey(key Key, visitNode func(*node) error) error { return nil } // grab the next node along the path - currentNode, err = t.getNodeWithID(nextChildEntry.id, key.Take(currentNode.key.length+t.tokenSize+nextChildEntry.compressedKey.length), nextChildEntry.hasValue) + currentNode, err = t.getNode(key.Take(currentNode.key.length+t.tokenSize+nextChildEntry.compressedKey.length), nextChildEntry.hasValue) if err != nil { return err } @@ -784,7 +761,7 @@ func (t *trieView) getEditableNode(key Key, hadValue bool) (*node, error) { } // grab the node in question - n, err := t.getNodeWithID(ids.Empty, key, hadValue) + n, err := t.getNode(key, hadValue) if err != nil { return nil, err } @@ -880,7 +857,7 @@ func (t *trieView) insert( // add the existing child onto the branch node branchNode.setChildEntry( existingChildEntry.compressedKey.Token(commonPrefixLength, t.tokenSize), - child{ + &child{ compressedKey: existingChildEntry.compressedKey.Skip(commonPrefixLength + t.tokenSize), id: existingChildEntry.id, hasValue: existingChildEntry.hasValue, @@ -924,8 +901,7 @@ func (t *trieView) getRoot() (*node, error) { if !isSentinelNodeTheRoot(t.sentinelNode) { // sentinelNode has one child, which is the root for index, childEntry := range t.sentinelNode.children { - return t.getNodeWithID( - childEntry.id, + return t.getNode( t.sentinelNode.key.Extend(ToToken(index, t.tokenSize), childEntry.compressedKey), childEntry.hasValue) } @@ -1004,7 +980,7 @@ func (t *trieView) recordValueChange(key Key, value maybe.Maybe[[]byte]) error { // sets the node's ID to [id]. // If the node is loaded from the baseDB, [hasValue] determines which database the node is stored in. // Returns database.ErrNotFound if the node doesn't exist. -func (t *trieView) getNodeWithID(id ids.ID, key Key, hasValue bool) (*node, error) { +func (t *trieView) getNode(key Key, hasValue bool) (*node, error) { // check for the key within the changed nodes if nodeChange, isChanged := t.changes.nodes[key]; isChanged { t.db.metrics.ViewNodeCacheHit() @@ -1015,17 +991,7 @@ func (t *trieView) getNodeWithID(id ids.ID, key Key, hasValue bool) (*node, erro } // get the node from the parent trie and store a local copy - parentTrieNode, err := t.getParentTrie().getEditableNode(key, hasValue) - if err != nil { - return nil, err - } - - // only need to initialize the id if it's from the parent trie. - // nodes in the current view change list have already been initialized. - if id != ids.Empty { - parentTrieNode.id = id - } - return parentTrieNode, nil + return t.getParentTrie().getEditableNode(key, hasValue) } // Get the parent trie of the view