From d80760db47009ad0bb14e701880921b19500f97e Mon Sep 17 00:00:00 2001 From: David Boehm <91908103+dboehm-avalabs@users.noreply.github.com> Date: Mon, 2 Oct 2023 14:41:10 -0400 Subject: [PATCH] MerkleDB Path changes cleanup (#2120) --- x/merkledb/db.go | 1 - x/merkledb/path.go | 25 ++++++++++++------------- x/merkledb/trieview.go | 20 ++++++++------------ 3 files changed, 20 insertions(+), 26 deletions(-) diff --git a/x/merkledb/db.go b/x/merkledb/db.go index c685c7061259..5146b447b91e 100644 --- a/x/merkledb/db.go +++ b/x/merkledb/db.go @@ -958,7 +958,6 @@ func (db *merkleDB) VerifyChangeProof( return err } - // Note that if [start] is Nothing, smallestPath is the empty path. smallestPath := maybe.Bind(start, db.newPath) // Make sure the start proof, if given, is well-formed. diff --git a/x/merkledb/path.go b/x/merkledb/path.go index 36aee2195ca3..a696a4bd394f 100644 --- a/x/merkledb/path.go +++ b/x/merkledb/path.go @@ -106,17 +106,19 @@ func (p Path) HasPrefix(prefix Path) bool { // The number of tokens in the last byte of [prefix], or zero // if [prefix] fits into a whole number of bytes. - remainderTokens := prefix.tokensLength % p.tokensPerByte - if remainderTokens == 0 { + remainderTokensCount := prefix.tokensLength % p.tokensPerByte + if remainderTokensCount == 0 { return strings.HasPrefix(p.value, prefix.value) } // check that the tokens in the partially filled final byte of [prefix] are // equal to the tokens in the final byte of [p]. - for i := prefix.tokensLength - remainderTokens; i < prefix.tokensLength; i++ { - if p.Token(i) != prefix.Token(i) { - return false - } + remainderBitsMask := byte(0xFF << (8 - remainderTokensCount*int(p.tokenBitSize))) + prefixRemainderTokens := prefix.value[len(prefix.value)-1] & remainderBitsMask + remainderTokens := p.value[len(prefix.value)-1] & remainderBitsMask + + if prefixRemainderTokens != remainderTokens { + return false } // Note that this will never be an index OOB because len(prefix.value) > 0. @@ -203,11 +205,8 @@ func (p Path) bitsToShift(index int) byte { // bytesNeeded returns the number of bytes needed to store the passed number of tokens func (p Path) bytesNeeded(tokens int) int { - bytesNeeded := tokens / p.tokensPerByte - if tokens%p.tokensPerByte > 0 { - bytesNeeded++ - } - return bytesNeeded + // adding p.tokensPerByte - 1 causes the division to always round up + return (tokens + p.tokensPerByte - 1) / p.tokensPerByte } // Extend returns a new Path that equals the passed Path appended to the current Path @@ -262,7 +261,7 @@ func (p Path) Extend(path Path) Path { func shiftCopy(dst []byte, src string, shift byte) { i := 0 for ; i < len(src)-1; i++ { - dst[i] = src[i]<>(8-shift) + dst[i] = src[i]<>(8-shift) } if i < len(dst) { @@ -324,7 +323,7 @@ func (p Path) Take(tokensToTake int) Path { // We want to zero out everything to the right of the last token, which is at index [tokensToTake] - 1 // Mask will be (8-bitsToShift) number of 1's followed by (bitsToShift) number of 0's mask := byte(0xFF << p.bitsToShift(tokensToTake-1)) - buffer[len(buffer)-1] = buffer[len(buffer)-1] & mask + buffer[len(buffer)-1] &= mask result.value = byteSliceToString(buffer) return result diff --git a/x/merkledb/trieview.go b/x/merkledb/trieview.go index 47245711c1cf..e216755a7f71 100644 --- a/x/merkledb/trieview.go +++ b/x/merkledb/trieview.go @@ -161,31 +161,27 @@ func newTrieView( } for _, op := range changes.BatchOps { - newVal := maybe.Nothing[[]byte]() key := op.Key + if !changes.ConsumeBytes { + key = slices.Clone(op.Key) + } + + newVal := maybe.Nothing[[]byte]() if !op.Delete { - val := op.Value + newVal = maybe.Some(op.Value) if !changes.ConsumeBytes { - val = slices.Clone(op.Value) + newVal = maybe.Some(slices.Clone(op.Value)) } - newVal = maybe.Some(val) - } - if !changes.ConsumeBytes { - key = slices.Clone(op.Key) } if err := newView.recordValueChange(db.newPath(key), newVal); err != nil { return nil, err } } for key, val := range changes.MapOps { - bytesKey := stringToByteSlice(key) - if !changes.ConsumeBytes { - bytesKey = slices.Clone(bytesKey) - } if !changes.ConsumeBytes { val = maybe.Bind(val, slices.Clone[[]byte]) } - if err := newView.recordValueChange(db.newPath(bytesKey), val); err != nil { + if err := newView.recordValueChange(db.newPath(stringToByteSlice(key)), val); err != nil { return nil, err } }