From 39daa423c31c269691eff36b4424d6701508e9fd Mon Sep 17 00:00:00 2001 From: Steven Gu Date: Thu, 25 Jan 2024 22:24:05 +0800 Subject: [PATCH] Update digest circuit recursive test. --- src/benches/digest_tree.rs | 34 ++++++++++++++++++++-------------- 1 file changed, 20 insertions(+), 14 deletions(-) diff --git a/src/benches/digest_tree.rs b/src/benches/digest_tree.rs index 74a34a150..06a6aa3a5 100644 --- a/src/benches/digest_tree.rs +++ b/src/benches/digest_tree.rs @@ -58,10 +58,11 @@ fn rand_leaf() -> DigestNode { fn prove_all_leaves(circuit: &CyclicCircuit, tree: &mut DigestTree) { tree.all_leaves().iter_mut().for_each(|leaf| { - if let DigestNode::Leaf(value, _, proof) = leaf { - let inputs = value.0.into_iter().map(F::from_canonical_u64).collect(); + if let DigestNode::Leaf(value, _, proof_result) = leaf { + let input = value.0.map(F::from_canonical_u64); + let init_proof = circuit.prove_init(U::new(vec![])).unwrap().0; - *proof = Some(prove_once(circuit, inputs)); + *proof_result = Some(prove_once(circuit, vec![(input, init_proof)])); } else { panic!("Must be a leaf of tree"); } @@ -76,12 +77,12 @@ fn prove_branches_recursive( (0..max_level).rev().into_iter().for_each(|level| { tree.branches_at_level(level).iter_mut().for_each(|branch| { if let DigestNode::Branch(children, .., proof) = branch { - let inputs = children + let inputs_proofs = children .iter() - .flat_map(|node| node.hash().elements) + .map(|node| (node.hash().elements, node.proof().clone().unwrap())) .collect(); - *proof = Some(prove_once(circuit, inputs)); + *proof = Some(prove_once(circuit, inputs_proofs)); } else { panic!("Must be a branch of tree"); } @@ -91,14 +92,12 @@ fn prove_branches_recursive( fn prove_once( circuit: &CyclicCircuit, - inputs: Vec, + inputs_proofs: Vec<([F; 4], ProofWithPublicInputs)>, ) -> ProofWithPublicInputs { - let init_proofs: Vec<_> = iter::repeat(circuit.prove_init(U::new(vec![])).unwrap().0) - .take(inputs.len()) - .collect(); + let (inputs, proofs): (Vec<_>, Vec<_>) = inputs_proofs.into_iter().unzip(); - let dummy_n = ARITY - init_proofs.len(); - let children: [Option>; ARITY] = init_proofs + let dummy_n = ARITY - proofs.len(); + let proofs: [Option>; ARITY] = proofs .into_iter() .map(Some) .chain(std::iter::repeat(None).take(dummy_n)) @@ -106,7 +105,7 @@ fn prove_once( .try_into() .unwrap(); - let proof = circuit.prove_step(U::new(inputs), &children).unwrap().0; + let proof = circuit.prove_step(U::new(inputs), &proofs).unwrap().0; circuit .verify_proof(proof.clone()) .expect("Failed to verify proof"); @@ -167,7 +166,7 @@ where impl DigestNode { pub fn new_branch(children: Vec) -> Self { - assert!(children.len() > 0 && children.len() < 5); + assert!(children.len() > 0 && children.len() <= ARITY); let inputs: Vec<_> = children .iter() @@ -229,4 +228,11 @@ impl DigestNode { Self::Leaf(_, hash, ..) => hash, } } + + pub fn proof(&self) -> &Option> { + match self { + Self::Branch(.., proof) => proof, + Self::Leaf(.., proof) => proof, + } + } }