Skip to content

Commit

Permalink
Refactor ProofNode to use interface (#2176)
Browse files Browse the repository at this point in the history
  • Loading branch information
weiihann authored Sep 24, 2024
1 parent 67430de commit 38638ca
Show file tree
Hide file tree
Showing 2 changed files with 184 additions and 214 deletions.
198 changes: 106 additions & 92 deletions core/trie/proof.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,58 +7,59 @@ import (
"github.com/NethermindEth/juno/core/felt"
)

// https://github.com/starknet-io/starknet-p2p-specs/blob/main/p2p/proto/snapshot.proto#L6
type ProofNode struct {
Binary *Binary
Edge *Edge
var (
ErrUnknownProofNode = errors.New("unknown proof node")
ErrChildHashNotFound = errors.New("can't determine the child hash from the parent and child")
)

type ProofNode interface {
Hash(hash hashFunc) *felt.Felt
Len() uint8
PrettyPrint()
}

// Note: does not work for leaves
func (pn *ProofNode) Hash(hash hashFunc) *felt.Felt {
switch {
case pn.Binary != nil:
return hash(pn.Binary.LeftHash, pn.Binary.RightHash)
case pn.Edge != nil:
length := make([]byte, len(pn.Edge.Path.bitset))
length[len(pn.Edge.Path.bitset)-1] = pn.Edge.Path.len
pathFelt := pn.Edge.Path.Felt()
lengthFelt := new(felt.Felt).SetBytes(length)
return new(felt.Felt).Add(hash(pn.Edge.Child, &pathFelt), lengthFelt)
default:
return nil
}
type Binary struct {
LeftHash *felt.Felt
RightHash *felt.Felt
}

func (pn *ProofNode) Len() uint8 {
if pn.Binary != nil {
return 1
}
return pn.Edge.Path.len
func (b *Binary) Hash(hash hashFunc) *felt.Felt {
return hash(b.LeftHash, b.RightHash)
}

func (pn *ProofNode) PrettyPrint() {
if pn.Binary != nil {
fmt.Printf(" Binary:\n")
fmt.Printf(" LeftHash: %v\n", pn.Binary.LeftHash)
fmt.Printf(" RightHash: %v\n", pn.Binary.RightHash)
}
if pn.Edge != nil {
fmt.Printf(" Edge:\n")
fmt.Printf(" Child: %v\n", pn.Edge.Child)
fmt.Printf(" Path: %v\n", pn.Edge.Path)
}
func (b *Binary) Len() uint8 {
return 1
}

type Binary struct {
LeftHash *felt.Felt
RightHash *felt.Felt
func (b *Binary) PrettyPrint() {
fmt.Printf(" Binary:\n")
fmt.Printf(" LeftHash: %v\n", b.LeftHash)
fmt.Printf(" RightHash: %v\n", b.RightHash)
}

type Edge struct {
Child *felt.Felt // child hash
Path *Key // path from parent to child
}

func (e *Edge) Hash(hash hashFunc) *felt.Felt {
length := make([]byte, len(e.Path.bitset))
length[len(e.Path.bitset)-1] = e.Path.len
pathFelt := e.Path.Felt()
lengthFelt := new(felt.Felt).SetBytes(length)
return new(felt.Felt).Add(hash(e.Child, &pathFelt), lengthFelt)
}

func (e *Edge) Len() uint8 {
return e.Path.Len()
}

func (e *Edge) PrettyPrint() {
fmt.Printf(" Edge:\n")
fmt.Printf(" Child: %v\n", e.Child)
fmt.Printf(" Path: %v\n", e.Path)
}

func GetBoundaryProofs(leftBoundary, rightBoundary *Key, tri *Trie) ([2][]ProofNode, error) {
proofs := [2][]ProofNode{}
leftProof, err := GetProof(leftBoundary, tri)
Expand Down Expand Up @@ -110,19 +111,19 @@ func transformNode(tri *Trie, parentKey *Key, sNode StorageNode) (*Edge, *Binary
rightHash := rNode.Value
if isEdge(sNode.key, StorageNode{node: rNode, key: sNode.node.Right}) {
edgePath := path(sNode.node.Right, sNode.key)
rEdge := ProofNode{Edge: &Edge{
rEdge := &Edge{
Path: &edgePath,
Child: rNode.Value,
}}
}
rightHash = rEdge.Hash(tri.hash)
}
leftHash := lNode.Value
if isEdge(sNode.key, StorageNode{node: lNode, key: sNode.node.Left}) {
edgePath := path(sNode.node.Left, sNode.key)
lEdge := ProofNode{Edge: &Edge{
lEdge := &Edge{
Path: &edgePath,
Child: lNode.Value,
}}
}
leftHash = lEdge.Hash(tri.hash)
}
binary := &Binary{
Expand All @@ -139,19 +140,20 @@ func transformNode(tri *Trie, parentKey *Key, sNode StorageNode) (*Edge, *Binary
func pathSplitOccurredCheck(mergedPath []ProofNode, nodeHashes map[felt.Felt]ProofNode) error {
splitHappened := false
for _, node := range mergedPath {
if node.Edge != nil {
switch node := node.(type) {
case *Edge:
continue
}

_, leftExists := nodeHashes[*node.Binary.LeftHash]
_, rightExists := nodeHashes[*node.Binary.RightHash]

if leftExists && rightExists {
if splitHappened {
return errors.New("split happened more than once")
case *Binary:
_, leftExists := nodeHashes[*node.LeftHash]
_, rightExists := nodeHashes[*node.RightHash]
if leftExists && rightExists {
if splitHappened {
return errors.New("split happened more than once")
}
splitHappened = true
}

splitHappened = true
default:
return fmt.Errorf("%w: %T", ErrUnknownProofNode, node)
}
}
return nil
Expand All @@ -173,9 +175,10 @@ func rootNodeExistsCheck(rootHash *felt.Felt, nodeHashes map[felt.Felt]ProofNode
func traverseNodes(currNode ProofNode, path *[]ProofNode, nodeHashes map[felt.Felt]ProofNode) {
*path = append(*path, currNode)

if currNode.Binary != nil {
nodeLeft, leftExist := nodeHashes[*currNode.Binary.LeftHash]
nodeRight, rightExist := nodeHashes[*currNode.Binary.RightHash]
switch currNode := currNode.(type) {
case *Binary:
nodeLeft, leftExist := nodeHashes[*currNode.LeftHash]
nodeRight, rightExist := nodeHashes[*currNode.RightHash]

if leftExist && rightExist {
return
Expand All @@ -184,8 +187,8 @@ func traverseNodes(currNode ProofNode, path *[]ProofNode, nodeHashes map[felt.Fe
} else if rightExist {
traverseNodes(nodeRight, path, nodeHashes)
}
} else if currNode.Edge != nil {
edgeNode, exist := nodeHashes[*currNode.Edge.Child]
case *Edge:
edgeNode, exist := nodeHashes[*currNode.Child]
if exist {
traverseNodes(edgeNode, path, nodeHashes)
}
Expand Down Expand Up @@ -269,8 +272,8 @@ func SplitProofPath(mergedPath []ProofNode, rootHash *felt.Felt, hash hashFunc)

currNode = commonPath[len(commonPath)-1]

leftNode := nodeHashes[*currNode.Binary.LeftHash]
rightNode := nodeHashes[*currNode.Binary.RightHash]
leftNode := nodeHashes[*currNode.(*Binary).LeftHash]
rightNode := nodeHashes[*currNode.(*Binary).RightHash]

traverseNodes(leftNode, &leftPath, nodeHashes)
traverseNodes(rightNode, &rightPath, nodeHashes)
Expand Down Expand Up @@ -298,11 +301,11 @@ func GetProof(key *Key, tri *Trie) ([]ProofNode, error) {
isLeaf := sNode.key.len == tri.height

if sNodeEdge != nil && !isLeaf { // Internal Edge
proofNodes = append(proofNodes, []ProofNode{{Edge: sNodeEdge}, {Binary: sNodeBinary}}...)
proofNodes = append(proofNodes, sNodeEdge, sNodeBinary)
} else if sNodeEdge == nil && !isLeaf { // Internal Binary
proofNodes = append(proofNodes, []ProofNode{{Binary: sNodeBinary}}...)
proofNodes = append(proofNodes, sNodeBinary)
} else if sNodeEdge != nil && isLeaf { // Leaf Edge
proofNodes = append(proofNodes, []ProofNode{{Edge: sNodeEdge}}...)
proofNodes = append(proofNodes, sNodeEdge)
} else if sNodeEdge == nil && sNodeBinary == nil { // sNode is a binary leaf
break
}
Expand All @@ -321,16 +324,16 @@ func VerifyProof(root *felt.Felt, key *Key, value *felt.Felt, proofs []ProofNode
return false
}

switch {
case proofNode.Binary != nil:
switch proofNode := proofNode.(type) {
case *Binary:
if remainingPath.Test(remainingPath.Len() - 1) {
expectedHash = proofNode.Binary.RightHash
expectedHash = proofNode.RightHash
} else {
expectedHash = proofNode.Binary.LeftHash
expectedHash = proofNode.LeftHash
}
remainingPath.RemoveLastBit()
case proofNode.Edge != nil:
subKey, err := remainingPath.SubKey(proofNode.Edge.Path.Len())
case *Edge:
subKey, err := remainingPath.SubKey(proofNode.Path.Len())
if err != nil {
return false
}
Expand All @@ -342,11 +345,11 @@ func VerifyProof(root *felt.Felt, key *Key, value *felt.Felt, proofs []ProofNode
return true
}

if !proofNode.Edge.Path.Equal(subKey) {
if !proofNode.Path.Equal(subKey) {
return false
}
expectedHash = proofNode.Edge.Child
remainingPath.Truncate(251 - proofNode.Edge.Path.Len()) //nolint:mnd
expectedHash = proofNode.Child
remainingPath.Truncate(251 - proofNode.Path.Len()) //nolint:mnd
}
}

Expand Down Expand Up @@ -438,27 +441,33 @@ func ensureMonotonicIncreasing(proofKeys [2]*Key, keys []*felt.Felt) error {

// compressNode determines if the node needs compressed, and if so, the len needed to arrive at the next key
func compressNode(idx int, proofNodes []ProofNode, hashF hashFunc) (int, uint8, error) {
parent := &proofNodes[idx]
parent := proofNodes[idx]

if idx == len(proofNodes)-1 {
if parent.Edge != nil {
if _, ok := parent.(*Edge); ok {
return 1, parent.Len(), nil
}
return 0, parent.Len(), nil
}

child := &proofNodes[idx+1]

switch {
case parent.Edge != nil && child.Binary != nil:
return 1, parent.Edge.Path.len, nil
case parent.Binary != nil && child.Edge != nil:
child := proofNodes[idx+1]
_, isChildBinary := child.(*Binary)
isChildEdge := !isChildBinary
switch parent := parent.(type) {
case *Edge:
if isChildEdge {
break
}
return 1, parent.Len(), nil
case *Binary:
if isChildBinary {
break
}
childHash := child.Hash(hashF)
if parent.Binary.LeftHash.Equal(childHash) || parent.Binary.RightHash.Equal(childHash) {
return 1, child.Edge.Path.len, nil
} else {
return 0, 0, errors.New("can't determine the child hash from the parent and child")
if parent.LeftHash.Equal(childHash) || parent.RightHash.Equal(childHash) {
return 1, child.Len(), nil
}
return 0, 0, ErrChildHashNotFound
}

return 0, 1, nil
Expand Down Expand Up @@ -539,6 +548,7 @@ func ProofToPath(proofNodes []ProofNode, leafKey *Key, hashF hashFunc) ([]Storag
break
}
}

return pathNodes, nil
}

Expand All @@ -558,14 +568,20 @@ func skipNode(pNode ProofNode, pathNodes []StorageNode, hashF hashFunc) bool {
}

func getLeftRightHash(parentInd int, proofNodes []ProofNode) (*felt.Felt, *felt.Felt, error) {
parent := &proofNodes[parentInd]
if parent.Binary == nil {
parent := proofNodes[parentInd]

switch parent := parent.(type) {
case *Binary:
return parent.LeftHash, parent.RightHash, nil
case *Edge:
if parentInd+1 > len(proofNodes)-1 {
return nil, nil, errors.New("cant get hash of children from proof node, out of range")
}
parent = &proofNodes[parentInd+1]
parentBinary := proofNodes[parentInd+1].(*Binary)
return parentBinary.LeftHash, parentBinary.RightHash, nil
default:
return nil, nil, fmt.Errorf("%w: %T", ErrUnknownProofNode, parent)
}
return parent.Binary.LeftHash, parent.Binary.RightHash, nil
}

func getParentKey(idx int, compressedParentOffset uint8, leafKey *Key,
Expand All @@ -576,16 +592,14 @@ func getParentKey(idx int, compressedParentOffset uint8, leafKey *Key,

var height uint8
if len(pathNodes) > 0 {
if proofNodes[idx].Edge != nil {
height = pathNodes[len(pathNodes)-1].key.len + proofNodes[idx].Edge.Path.len
if p, ok := proofNodes[idx].(*Edge); ok {
height = pathNodes[len(pathNodes)-1].key.len + p.Path.len
} else {
height = pathNodes[len(pathNodes)-1].key.len + 1
}
} else {
height = 0
}

if pNode.Binary != nil {
if _, ok := pNode.(*Binary); ok {
crntKey, err = leafKey.SubKey(height)
} else {
crntKey, err = leafKey.SubKey(height + compressedParentOffset)
Expand Down
Loading

0 comments on commit 38638ca

Please sign in to comment.