Skip to content

Commit

Permalink
verify proof passes all test cases
Browse files Browse the repository at this point in the history
  • Loading branch information
weiihann committed Nov 1, 2024
1 parent 6c5fdcf commit 0c14bb7
Show file tree
Hide file tree
Showing 6 changed files with 392 additions and 232 deletions.
59 changes: 45 additions & 14 deletions core/trie/key.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,41 @@ func NewKey(length uint8, keyBytes []byte) Key {
return k
}

// MostSignificantBits returns a new Key containing the first n most significant bits of the original key
func (k *Key) MostSignificantBits(n uint8) (*Key, error) {
if n > k.len {
return nil, fmt.Errorf("cannot take %d bits from key of length %d", n, k.len)
}

Check warning on line 31 in core/trie/key.go

View check run for this annotation

Codecov / codecov/patch

core/trie/key.go#L30-L31

Added lines #L30 - L31 were not covered by tests

if n == k.len {
return k.Copy(), nil
}

Check warning on line 35 in core/trie/key.go

View check run for this annotation

Codecov / codecov/patch

core/trie/key.go#L34-L35

Added lines #L34 - L35 were not covered by tests

newKey := &Key{len: n}

// Calculate how many bytes we need to copy
bytesToCopy := (n + 7) / 8
if bytesToCopy > 0 {
// Copy the required bytes from the original key
startPos := len(k.bitset) - int((k.len+7)/8)

Check failure on line 43 in core/trie/key.go

View workflow job for this annotation

GitHub Actions / lint

Magic number: 8, in <argument> detected (mnd)
copy(newKey.bitset[len(newKey.bitset)-int(bytesToCopy):], k.bitset[startPos:])
}

// Clear any extra bits in the last byte if necessary
if n%8 != 0 && bytesToCopy > 0 {
lastBytePos := len(newKey.bitset) - int(bytesToCopy)
mask := byte(0xFF >> (8 - (n % 8)))

Check failure on line 50 in core/trie/key.go

View workflow job for this annotation

GitHub Actions / lint

Magic number: 0xFF, in <argument> detected (mnd)
newKey.bitset[lastBytePos] &= mask
}

Check warning on line 52 in core/trie/key.go

View check run for this annotation

Codecov / codecov/patch

core/trie/key.go#L49-L52

Added lines #L49 - L52 were not covered by tests

// Clear any remaining bytes
for i := 0; i < len(newKey.bitset)-int(bytesToCopy); i++ {
newKey.bitset[i] = 0
}

return newKey, nil
}

func (k *Key) SubKey(n uint8) (*Key, error) {
if n > k.len {
return nil, errors.New(fmt.Sprint("cannot subtract key of length %i from key of length %i", n, k.len))
Expand Down Expand Up @@ -96,11 +131,13 @@ func (k *Key) Equal(other *Key) bool {
return k.len == other.len && k.bitset == other.bitset
}

func (k *Key) Test(bit uint8) bool {
// IsBitSet returns whether the bit at the given position is 1.
// Position 0 represents the least significant (rightmost) bit.
func (k *Key) IsBitSet(position uint8) bool {
const LSB = uint8(0x1)
byteIdx := bit / 8
byteIdx := position / 8
byteAtIdx := k.bitset[len(k.bitset)-int(byteIdx)-1]
bitIdx := bit % 8
bitIdx := position % 8
return ((byteAtIdx >> bitIdx) & LSB) != 0
}

Expand Down Expand Up @@ -136,20 +173,14 @@ func (k *Key) Truncate(length uint8) {
}
}

func (k *Key) RemoveLastBit() {
func (k *Key) RemoveMostSignificantBit() {

Check warning on line 176 in core/trie/key.go

View check run for this annotation

Codecov / codecov/patch

core/trie/key.go#L176

Added line #L176 was not covered by tests
if k.len == 0 {
return
}

k.len--

unusedBytes := k.unusedBytes()
clear(unusedBytes)
k.Truncate(k.len - 1)

Check warning on line 181 in core/trie/key.go

View check run for this annotation

Codecov / codecov/patch

core/trie/key.go#L181

Added line #L181 was not covered by tests
}

// clear upper bits on the last used byte
inUseBytes := k.inUseBytes()
unusedBitsCount := 8 - (k.len % 8)
if unusedBitsCount != 8 && len(inUseBytes) > 0 {
inUseBytes[0] = (inUseBytes[0] << unusedBitsCount) >> unusedBitsCount
}
func (k *Key) Copy() *Key {
return &Key{len: k.len, bitset: k.bitset}

Check warning on line 185 in core/trie/key.go

View check run for this annotation

Codecov / codecov/patch

core/trie/key.go#L184-L185

Added lines #L184 - L185 were not covered by tests
}
2 changes: 1 addition & 1 deletion core/trie/key_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ func BenchmarkKeyEncoding(b *testing.B) {
func TestKeyTest(t *testing.T) {
key := trie.NewKey(44, []byte{0x10, 0x02})
for i := 0; i < int(key.Len()); i++ {
assert.Equal(t, i == 1 || i == 12, key.Test(uint8(i)), i)
assert.Equal(t, i == 1 || i == 12, key.IsBitSet(uint8(i)), i)
}
}

Expand Down
111 changes: 74 additions & 37 deletions core/trie/proof.go
Original file line number Diff line number Diff line change
Expand Up @@ -60,8 +60,10 @@ func (e *Edge) PrettyPrint() {
fmt.Printf(" Path: %v\n", e.Path)
}

func (t *Trie) Prove(key *Key, proofSet *ProofSet) error {
nodesFromRoot, err := t.nodesFromRoot(key)
func (t *Trie) Prove(key *felt.Felt, proofSet *ProofSet) error {
k := t.FeltToKey(key)

nodesFromRoot, err := t.nodesFromRoot(&k)
if err != nil {
return err
}

Check warning on line 69 in core/trie/proof.go

View check run for this annotation

Codecov / codecov/patch

core/trie/proof.go#L68-L69

Added lines #L68 - L69 were not covered by tests
Expand Down Expand Up @@ -344,46 +346,79 @@ func GetProof(key *Key, tri *Trie) ([]ProofNode, error) {
return proofNodes, nil
}

// verifyProof checks if `leafPath` leads from `root` to `leafHash` along the `proofNodes`
// https://github.com/eqlabs/pathfinder/blob/main/crates/merkle-tree/src/tree.rs#L2006
func VerifyProof(root *felt.Felt, key *Key, value *felt.Felt, proofs []ProofNode, hash hashFunc) bool {
// VerifyProof verifies that a proof path is valid for a given key in a binary trie.
// It walks through the proof nodes, verifying each step matches the expected path to reach the key.
//
// The verification process:
// 1. Starts at the root hash and retrieves the corresponding proof node
// 2. For each proof node:
// - Verifies the node's computed hash matches the expected hash
// - For Binary nodes:
// -- Uses the next unprocessed bit in the key to choose left/right path
// -- If key bit is 0, takes left path; if 1, takes right path
// - For Edge nodes:
// -- Verifies the compressed path matches the corresponding bits in the key
// -- Moves to the child node if paths match
//
// 3. Continues until all bits in the key are processed
//
// The proof is considered invalid if:
// - Any proof node is missing from the proofSet
// - Any node's computed hash doesn't match its expected hash
// - The path bits don't match the key bits
// - The proof ends before processing all key bits
func VerifyProof(root *felt.Felt, key *Key, proofSet *ProofSet, hash hashFunc) (*felt.Felt, bool) {
expectedHash := root
remainingPath := NewKey(key.len, key.bitset[:])
for i, proofNode := range proofs {
keyLen := key.Len()
var processedBits uint8

for {
proofNode, ok := proofSet.Get(*expectedHash)
if !ok {
return nil, false
}

Check warning on line 379 in core/trie/proof.go

View check run for this annotation

Codecov / codecov/patch

core/trie/proof.go#L378-L379

Added lines #L378 - L379 were not covered by tests

// Verify the hash matches
if !proofNode.Hash(hash).Equal(expectedHash) {
return false
return nil, false

Check warning on line 383 in core/trie/proof.go

View check run for this annotation

Codecov / codecov/patch

core/trie/proof.go#L383

Added line #L383 was not covered by tests
}

switch proofNode := proofNode.(type) {
case *Binary:
if remainingPath.Test(remainingPath.Len() - 1) {
expectedHash = proofNode.RightHash
} else {
expectedHash = proofNode.LeftHash
switch node := proofNode.(type) {
case *Binary: // Binary nodes represent left/right choices
if key.Len() <= processedBits {
return nil, false

Check warning on line 389 in core/trie/proof.go

View check run for this annotation

Codecov / codecov/patch

core/trie/proof.go#L389

Added line #L389 was not covered by tests
}
remainingPath.RemoveLastBit()
case *Edge:
subKey, err := remainingPath.SubKey(proofNode.Path.Len())
if err != nil {
return false
// Check the bit at parent's position
expectedHash = node.LeftHash
if key.IsBitSet(keyLen - processedBits - 1) {
expectedHash = node.RightHash
}
processedBits++
case *Edge: // Edge nodes represent paths between binary nodes
nodeLen := node.Path.Len()

// Todo:
// If we are verifying the key doesn't exist, then we should
// update subKey to point in the other direction
if value == nil && i == len(proofs)-1 {
return true
if key.Len() < processedBits+nodeLen {
return nil, false

Check warning on line 401 in core/trie/proof.go

View check run for this annotation

Codecov / codecov/patch

core/trie/proof.go#L401

Added line #L401 was not covered by tests
}

if !proofNode.Path.Equal(subKey) {
return false
// Ensure the bits between segment of the key and the node path match
start := keyLen - processedBits - nodeLen
end := keyLen - processedBits
for i := start; i < end; i++ { // check if the bits match
if key.IsBitSet(i) != node.Path.IsBitSet(i-start) {
return nil, false
}
}
expectedHash = proofNode.Child
remainingPath.Truncate(251 - proofNode.Path.Len()) //nolint:mnd

processedBits += nodeLen
expectedHash = node.Child
}
}

return expectedHash.Equal(value)
// We've consumed all bits in our path
if processedBits >= keyLen {
return expectedHash, true
}
}
}

// VerifyRangeProof verifies the range proof for the given range of keys.
Expand Down Expand Up @@ -417,9 +452,11 @@ func VerifyRangeProof(root *felt.Felt, keys, values []*felt.Felt, proofKeys [2]*
var err error
for i := 0; i < 2; i++ {
if proofs[i] != nil {

Check failure on line 454 in core/trie/proof.go

View workflow job for this annotation

GitHub Actions / lint

unnecessary leading newline (whitespace)
if !VerifyProof(root, proofKeys[i], proofValues[i], proofs[i], hash) {
return false, fmt.Errorf("invalid proof for key %x", proofKeys[i].String())
}

// if !VerifyProof(root, proofKeys[i], proofValues[i], proofs[i], hash) {
// return false, fmt.Errorf("invalid proof for key %x", proofKeys[i].String())
// }
// TODO(weiihann): Verify proof

proofPaths[i], err = ProofToPath(proofs[i], proofKeys[i], hash)
if err != nil {
Expand Down Expand Up @@ -511,7 +548,7 @@ func assignChild(i, compressedParent int, parentNode *Node,
if err != nil {
return nil, err
}
if leafKey.Test(leafKey.len - parentKey.len - 1) {
if leafKey.IsBitSet(leafKey.len - parentKey.len - 1) {
parentNode.Right = childKey
parentNode.Left = nilKey
} else {
Expand Down Expand Up @@ -630,9 +667,9 @@ func getParentKey(idx int, compressedParentOffset uint8, leafKey *Key,
}

if _, ok := pNode.(*Binary); ok {
crntKey, err = leafKey.SubKey(height)
crntKey, err = leafKey.MostSignificantBits(height)
} else {
crntKey, err = leafKey.SubKey(height + compressedParentOffset)
crntKey, err = leafKey.MostSignificantBits(height + compressedParentOffset)
}
return crntKey, err
}
Expand All @@ -651,7 +688,7 @@ func getChildKey(childIdx int, crntKey, leafKey, nilKey *Key, proofNodes []Proof
return nilKey, nil
}

return leafKey.SubKey(crntKey.len + uint8(compressChild) + compressChildOffset)
return leafKey.MostSignificantBits(crntKey.len + uint8(compressChild) + compressChildOffset)
}

// BuildTrie builds a trie using the proof paths (including inner nodes), and then sets all the keys-values (leaves)
Expand Down
Loading

0 comments on commit 0c14bb7

Please sign in to comment.