From 5c3735e898552badaa80e46a765a7d062c87d453 Mon Sep 17 00:00:00 2001 From: M Berger <76954195+tessico@users.noreply.github.com> Date: Fri, 20 Jan 2023 08:36:30 +0100 Subject: [PATCH] Digest-generic MerkleTreeGadget (#167) * Add an associated type DigestGadget to MerkleTreeGadget * make struct fields private * replace usages of RescueGadget with generic DigestGadget * constrain M::NodeValue to be of type PrimeField * position should be of type `Variable`, not `usize` --- primitives/src/circuit/merkle_tree/mod.rs | 49 +++++++++++++++++-- .../circuit/merkle_tree/rescue_merkle_tree.rs | 17 +++---- .../circuit/merkle_tree/sparse_merkle_tree.rs | 22 ++++----- 3 files changed, 61 insertions(+), 27 deletions(-) diff --git a/primitives/src/circuit/merkle_tree/mod.rs b/primitives/src/circuit/merkle_tree/mod.rs index fc2ea49b7..ed0e191be 100644 --- a/primitives/src/circuit/merkle_tree/mod.rs +++ b/primitives/src/circuit/merkle_tree/mod.rs @@ -10,12 +10,15 @@ use crate::{ merkle_tree::{MerkleTreeScheme, UniversalMerkleTreeScheme}, rescue::RescueParameter, }; +use ark_ff::PrimeField; use jf_relation::{errors::CircuitError, BoolVar, Circuit, PlonkCircuit, Variable}; mod rescue_merkle_tree; mod sparse_merkle_tree; use ark_std::vec::Vec; +use super::rescue::RescueNativeGadget; + /// Gadget for a Merkle tree /// /// # Examples @@ -60,12 +63,16 @@ use ark_std::vec::Vec; pub trait MerkleTreeGadget where M: MerkleTreeScheme, + M::NodeValue: PrimeField, { /// Type to represent the merkle proof of the concrete MT instantiation. /// It is MT-specific, e.g arity will affect the exact definition of the /// underlying Merkle path. type MembershipProofVar; + /// Gadget for the digest algorithm. + type DigestGadget: DigestAlgorithmGadget; + /// Allocate a variable for the membership proof. fn create_membership_proof_variable( &mut self, @@ -149,6 +156,7 @@ where pub trait UniversalMerkleTreeGadget: MerkleTreeGadget where M: UniversalMerkleTreeScheme, + M::NodeValue: PrimeField, { /// Type to represent the merkle non-membership proof of the concrete MT /// instantiation. It is MT-specific, e.g arity will affect the exact @@ -212,13 +220,13 @@ fn constrain_sibling_order( /// Circuit variable for a node in the Merkle path. pub struct Merkle3AryNodeVar { /// First sibling of the node. - pub sibling1: Variable, + sibling1: Variable, /// Second sibling of the node. - pub sibling2: Variable, + sibling2: Variable, /// Boolean variable indicating whether the node is a left child. - pub is_left_child: BoolVar, + is_left_child: BoolVar, /// Boolean variable indicating whether the node is a right child. - pub is_right_child: BoolVar, + is_right_child: BoolVar, } /// Circuit variable for a Merkle non-membership proof of a 3-ary Merkle tree. @@ -240,3 +248,36 @@ pub struct Merkle3AryMembershipProofVar { node_vars: Vec, elem_var: Variable, } +/// Circuit counterpart to DigestAlgorithm +pub trait DigestAlgorithmGadget +where + F: PrimeField, +{ + /// Digest a list of variables + fn digest(circuit: &mut PlonkCircuit, data: &[Variable]) -> Result; + + /// Digest an indexed element + fn digest_leaf( + circuit: &mut PlonkCircuit, + pos: usize, + elem: Variable, + ) -> Result; +} + +/// Digest gadget using for the Rescue hash function. +pub struct RescueDigestGadget {} + +impl DigestAlgorithmGadget for RescueDigestGadget { + fn digest(circuit: &mut PlonkCircuit, data: &[Variable]) -> Result { + Ok(RescueNativeGadget::::rescue_sponge_no_padding(circuit, data, 1)?[0]) + } + + fn digest_leaf( + circuit: &mut PlonkCircuit, + pos: Variable, + elem: Variable, + ) -> Result { + let zero = circuit.zero(); + Ok(RescueNativeGadget::::rescue_sponge_no_padding(circuit, &[zero, pos, elem], 1)?[0]) + } +} diff --git a/primitives/src/circuit/merkle_tree/rescue_merkle_tree.rs b/primitives/src/circuit/merkle_tree/rescue_merkle_tree.rs index 55b0a7542..a410b3991 100644 --- a/primitives/src/circuit/merkle_tree/rescue_merkle_tree.rs +++ b/primitives/src/circuit/merkle_tree/rescue_merkle_tree.rs @@ -8,7 +8,7 @@ //! with a Rescue hash function. use crate::{ - circuit::rescue::RescueNativeGadget, + circuit::merkle_tree::DigestAlgorithmGadget, merkle_tree::{ internal::MerkleNode, prelude::RescueMerkleTree, MerkleTreeScheme, ToTraversalPath, }, @@ -16,13 +16,13 @@ use crate::{ }; use ark_std::{string::ToString, vec::Vec}; use jf_relation::{errors::CircuitError, BoolVar, Circuit, PlonkCircuit, Variable}; - type NodeVal = as MerkleTreeScheme>::NodeValue; type MembershipProof = as MerkleTreeScheme>::MembershipProof; use typenum::U3; use super::{ constrain_sibling_order, Merkle3AryMembershipProofVar, Merkle3AryNodeVar, MerkleTreeGadget, + RescueDigestGadget, }; impl MerkleTreeGadget> for PlonkCircuit @@ -31,6 +31,8 @@ where { type MembershipProofVar = Merkle3AryMembershipProofVar; + type DigestGadget = RescueDigestGadget; + fn create_membership_proof_variable( &mut self, merkle_proof: &MembershipProof, @@ -96,14 +98,10 @@ where ) -> Result { let computed_root_var = { let proof_var = &proof_var; - let zero_var = self.zero(); // elem label = H(0, uid, elem) - let mut cur_label = RescueNativeGadget::::rescue_sponge_no_padding( - self, - &[zero_var, elem_idx_var, proof_var.elem_var], - 1, - )?[0]; + let mut cur_label = + Self::DigestGadget::digest_leaf(self, elem_idx_var, proof_var.elem_var)?; for cur_node in proof_var.node_vars.iter() { let input_labels = constrain_sibling_order( self, @@ -115,8 +113,7 @@ where )?; // check that the left child's label is non-zero self.non_zero_gate(input_labels[0])?; - cur_label = - RescueNativeGadget::::rescue_sponge_no_padding(self, &input_labels, 1)?[0]; + cur_label = Self::DigestGadget::digest(self, &input_labels)?; } Ok(cur_label) }?; diff --git a/primitives/src/circuit/merkle_tree/sparse_merkle_tree.rs b/primitives/src/circuit/merkle_tree/sparse_merkle_tree.rs index c4c78bce0..51e733a50 100644 --- a/primitives/src/circuit/merkle_tree/sparse_merkle_tree.rs +++ b/primitives/src/circuit/merkle_tree/sparse_merkle_tree.rs @@ -8,7 +8,7 @@ //! with a Rescue hash function. use crate::{ - circuit::rescue::RescueNativeGadget, + circuit::merkle_tree::DigestAlgorithmGadget, merkle_tree::{ internal::MerkleNode, prelude::RescueSparseMerkleTree, MerkleTreeScheme, ToTraversalPath, }, @@ -18,7 +18,6 @@ use ark_std::{string::ToString, vec::Vec}; use jf_relation::{errors::CircuitError, BoolVar, Circuit, PlonkCircuit, Variable}; type SparseMerkleTree = RescueSparseMerkleTree; - type NodeVal = as MerkleTreeScheme>::NodeValue; type MembershipProof = as MerkleTreeScheme>::MembershipProof; use num_bigint::BigUint; @@ -26,7 +25,8 @@ use typenum::U3; use super::{ constrain_sibling_order, Merkle3AryMembershipProofVar, Merkle3AryNodeVar, - Merkle3AryNonMembershipProofVar, MerkleTreeGadget, UniversalMerkleTreeGadget, + Merkle3AryNonMembershipProofVar, MerkleTreeGadget, RescueDigestGadget, + UniversalMerkleTreeGadget, }; impl UniversalMerkleTreeGadget> for PlonkCircuit @@ -57,8 +57,7 @@ where )?; // check that the left child's label is non-zero self.non_zero_gate(input_labels[0])?; - cur_label = - RescueNativeGadget::::rescue_sponge_no_padding(self, &input_labels, 1)?[0]; + cur_label = Self::DigestGadget::digest(self, &input_labels)?; } Ok(cur_label) }?; @@ -126,6 +125,8 @@ where { type MembershipProofVar = Merkle3AryMembershipProofVar; + type DigestGadget = RescueDigestGadget; + fn create_membership_proof_variable( &mut self, merkle_proof: &MembershipProof, @@ -191,14 +192,10 @@ where ) -> Result { let computed_root_var = { let proof_var = &proof_var; - let zero_var = self.zero(); // elem label = H(0, uid, elem) - let mut cur_label = RescueNativeGadget::::rescue_sponge_no_padding( - self, - &[zero_var, elem_idx_var, proof_var.elem_var], - 1, - )?[0]; + let mut cur_label = + Self::DigestGadget::digest_leaf(self, elem_idx_var, proof_var.elem_var)?; for cur_node in proof_var.node_vars.iter() { let input_labels = constrain_sibling_order( self, @@ -210,8 +207,7 @@ where )?; // check that the left child's label is non-zero self.non_zero_gate(input_labels[0])?; - cur_label = - RescueNativeGadget::::rescue_sponge_no_padding(self, &input_labels, 1)?[0]; + cur_label = Self::DigestGadget::digest(self, &input_labels)?; } Ok(cur_label) }?;