From 54afb081f260c6f731e6baefcd9128a7c280d73f Mon Sep 17 00:00:00 2001 From: Jan Ferdinand Sauer Date: Thu, 8 Jun 2023 20:51:25 +0200 Subject: [PATCH] remove `TestItem`, test using actual `ProofItem`s instead This makes `ProofStream` less generic. --- triton-vm/src/fri.rs | 13 +- triton-vm/src/proof_item.rs | 157 ++++++------------ triton-vm/src/proof_stream.rs | 301 +++++++++++++++++++++++----------- triton-vm/src/stark.rs | 2 +- 4 files changed, 255 insertions(+), 218 deletions(-) diff --git a/triton-vm/src/fri.rs b/triton-vm/src/fri.rs index f9e7a297..acb947ed 100644 --- a/triton-vm/src/fri.rs +++ b/triton-vm/src/fri.rs @@ -81,7 +81,7 @@ impl Fri { indices: &[usize], codeword: &[XFieldElement], merkle_tree: &MerkleTree, - proof_stream: &mut ProofStream, + proof_stream: &mut ProofStream, ) { let partial_authentication_paths = merkle_tree.get_authentication_structure(indices); let revealed_values = indices.iter().map(|&i| codeword[i]).collect_vec(); @@ -101,7 +101,7 @@ impl Fri { root: Digest, tree_height: usize, indices: &[usize], - proof_stream: &mut ProofStream, + proof_stream: &mut ProofStream, ) -> Result> { let fri_response = proof_stream.dequeue(false)?.as_fri_response()?; let FriResponse(dequeued_paths_and_leafs) = fri_response; @@ -125,7 +125,7 @@ impl Fri { pub fn prove( &self, codeword: &[XFieldElement], - proof_stream: &mut ProofStream, + proof_stream: &mut ProofStream, ) -> (Vec, Digest) { debug_assert_eq!( self.domain.length, @@ -178,7 +178,7 @@ impl Fri { fn commit( &self, codeword: &[XFieldElement], - proof_stream: &mut ProofStream, + proof_stream: &mut ProofStream, ) -> Vec<(Vec, MerkleTree)> { let one = XFieldElement::one(); let two_inv = one / (one + one); @@ -248,7 +248,7 @@ impl Fri { /// Returns the indices and revealed elements of the codeword at the top level of the FRI proof. pub fn verify( &self, - proof_stream: &mut ProofStream, + proof_stream: &mut ProofStream, maybe_profiler: &mut Option, ) -> Result> { prof_start!(maybe_profiler, "init"); @@ -698,8 +698,7 @@ mod triton_xfri_tests { let proof = prover_proof_stream.into(); - let mut verifier_proof_stream: ProofStream = - ProofStream::try_from(&proof).unwrap(); + let mut verifier_proof_stream: ProofStream = ProofStream::try_from(&proof).unwrap(); assert_eq!(prover_proof_stream.len(), verifier_proof_stream.len()); for (prover_item, verifier_item) in prover_proof_stream diff --git a/triton-vm/src/proof_item.rs b/triton-vm/src/proof_item.rs index c3ca871e..cca7fe0c 100644 --- a/triton-vm/src/proof_item.rs +++ b/triton-vm/src/proof_item.rs @@ -207,135 +207,72 @@ impl BFieldCodec for ProofItem { #[cfg(test)] mod proof_item_typed_tests { use itertools::Itertools; - use rand::thread_rng; + use rand::distributions::Standard; + use rand::prelude::StdRng; + use rand::random; use rand::Rng; + use rand_core::RngCore; + use rand_core::SeedableRng; use twenty_first::shared_math::tip5::Tip5; use twenty_first::shared_math::x_field_element::XFieldElement; + use twenty_first::util_types::merkle_tree::CpuParallel; + use twenty_first::util_types::merkle_tree::MerkleTree; + use twenty_first::util_types::merkle_tree_maker::MerkleTreeMaker; + use crate::proof::Proof; use crate::proof_stream::ProofStream; use super::*; - fn random_bool() -> bool { - thread_rng().gen() + fn random_merkle_tree(seed: u64, num_leaves: usize) -> (MerkleTree, Vec) { + let rng = StdRng::seed_from_u64(seed); + let leaves: Vec = rng.sample_iter(Standard).take(num_leaves).collect(); + let leaves_as_digests: Vec = leaves.iter().map(|&x| x.into()).collect(); + (CpuParallel::from_digests(&leaves_as_digests), leaves) } - fn random_x_field_element() -> XFieldElement { - thread_rng().gen() - } - - fn random_digest() -> Digest { - thread_rng().gen() - } - - fn random_fri_response() -> FriResponse { - FriResponse( - (0..18) - .map(|r| { - ( - PartialAuthenticationPath( - (0..(20 - r)) - .map(|_| { - if random_bool() { - Some(random_digest()) - } else { - None - } - }) - .collect_vec(), - ), - random_x_field_element(), - ) - }) - .collect_vec(), - ) + fn fri_response( + merkle_tree: &MerkleTree, + leaves: &[XFieldElement], + revealed_indices: &[usize], + ) -> FriResponse { + let revealed_elements = revealed_indices.iter().map(|&i| leaves[i]).collect_vec(); + let auth_structure = merkle_tree.get_authentication_structure(revealed_indices); + let fri_response = auth_structure + .into_iter() + .zip(revealed_elements) + .collect_vec(); + FriResponse(fri_response) } #[test] fn serialize_fri_response_test() { - let fri_response = random_fri_response(); - let str = fri_response.encode(); - let fri_response_ = *FriResponse::decode(&str).unwrap(); - assert_eq!(fri_response, fri_response_); - } - - #[test] - fn test_serialize_stark_proof_with_fiat_shamir() { type H = Tip5; - let mut proof_stream: ProofStream<_, H> = ProofStream::new(); - let map = (0..7).map(|_| random_digest()).collect_vec(); - let auth_struct = (0..8) - .map(|_| { - PartialAuthenticationPath( - (0..11) - .map(|_| { - if random_bool() { - Some(random_digest()) - } else { - None - } - }) - .collect_vec(), - ) - }) - .collect_vec(); - let root = random_digest(); - let fri_response = random_fri_response(); - let mut fs = vec![]; - fs.push(proof_stream.sponge_state.state); - proof_stream.enqueue(&ProofItem::AuthenticationPath(map.clone()), false); - fs.push(proof_stream.sponge_state.state); - proof_stream.enqueue( - &ProofItem::CompressedAuthenticationPaths(auth_struct.clone()), - false, - ); - fs.push(proof_stream.sponge_state.state); - proof_stream.enqueue(&ProofItem::MerkleRoot(root), true); - fs.push(proof_stream.sponge_state.state); - proof_stream.enqueue(&ProofItem::FriResponse(fri_response.clone()), false); - fs.push(proof_stream.sponge_state.state); - - let proof = proof_stream.into(); - - let mut proof_stream_ = - ProofStream::::try_from(&proof).expect("invalid parsing of proof"); - - let mut fs_ = vec![]; - fs_.push(proof_stream_.sponge_state.state); - - let map_ = proof_stream_ - .dequeue(false) - .expect("can't dequeue item") - .as_authentication_path() - .expect("cannot parse dequeued item"); - assert_eq!(map, map_); - fs_.push(proof_stream_.sponge_state.state); + let seed = random(); + let mut rng = StdRng::seed_from_u64(seed); + println!("seed: {seed}"); - let auth_struct_ = proof_stream_ - .dequeue(false) - .expect("can't dequeue item") - .as_compressed_authentication_paths() - .expect("cannot parse dequeued item"); - assert_eq!(auth_struct, auth_struct_); - fs_.push(proof_stream_.sponge_state.state); - - let root_ = proof_stream_ - .dequeue(true) - .expect("can't dequeue item") - .as_merkle_root() - .expect("cannot parse dequeued item"); - assert_eq!(root, root_); - fs_.push(proof_stream_.sponge_state.state); + let codeword_len = 64; + let (merkle_tree, leaves) = random_merkle_tree(rng.next_u64(), codeword_len); + let num_indices = rng.gen_range(1..=codeword_len); + let revealed_indices = (0..num_indices) + .map(|_| rng.gen_range(0..codeword_len)) + .collect_vec(); + let fri_response = fri_response(&merkle_tree, &leaves, &revealed_indices); - let fri_response_ = proof_stream_ - .dequeue(false) - .expect("can't dequeue item") - .as_fri_response() - .expect("cannot parse dequeued item"); + // test encoding and decoding in isolation + let encoding = fri_response.encode(); + let fri_response_ = *FriResponse::decode(&encoding).unwrap(); assert_eq!(fri_response, fri_response_); - fs_.push(proof_stream_.sponge_state.state); - assert_eq!(fs, fs_); + // test encoding and decoding in a stream + let mut proof_stream = ProofStream::::new(); + proof_stream.enqueue(&ProofItem::FriResponse(fri_response.clone()), false); + let proof: Proof = proof_stream.into(); + let mut proof_stream = ProofStream::::try_from(&proof).unwrap(); + let fri_response_ = proof_stream.dequeue(false).unwrap(); + let fri_response_ = fri_response_.as_fri_response().unwrap(); + assert_eq!(fri_response, fri_response_); } } diff --git a/triton-vm/src/proof_stream.rs b/triton-vm/src/proof_stream.rs index b1222675..03ee8d37 100644 --- a/triton-vm/src/proof_stream.rs +++ b/triton-vm/src/proof_stream.rs @@ -11,14 +11,14 @@ use twenty_first::shared_math::x_field_element::XFieldElement; use twenty_first::util_types::algebraic_hasher::AlgebraicHasher; use crate::proof::Proof; +use crate::proof_item::ProofItem; #[derive(Debug, Clone, PartialEq, Eq)] -pub struct ProofStream +pub struct ProofStream where - Item: Clone + BFieldCodec, H: AlgebraicHasher, { - pub items: Vec, + pub items: Vec, pub items_index: usize, pub sponge_state: H::SpongeState, } @@ -45,9 +45,8 @@ impl Display for ProofStreamError { impl Error for ProofStreamError {} -impl ProofStream +impl ProofStream where - Item: Clone + BFieldCodec, H: AlgebraicHasher, { pub fn new() -> Self { @@ -73,7 +72,7 @@ where b_field_elements.len() } - fn encode_and_pad_item(item: &Item) -> Vec { + fn encode_and_pad_item(item: &ProofItem) -> Vec { let encoding = item.encode(); let last_chunk_len = (encoding.len() + 1) % H::RATE; let num_padding_zeros = match last_chunk_len { @@ -96,7 +95,7 @@ where /// in question was included (hashed) previously. /// - If the proof stream is not used to sample any more randomness, _i.e._, after the last /// round of interaction, no further items need to be included. - pub fn enqueue(&mut self, item: &Item, include_in_fs_heuristic: bool) { + pub fn enqueue(&mut self, item: &ProofItem, include_in_fs_heuristic: bool) { if include_in_fs_heuristic { H::absorb_repeatedly( &mut self.sponge_state, @@ -108,7 +107,7 @@ where /// Receive a proof item from prover as verifier. /// See [`ProofStream::enqueue`] for more details. - pub fn dequeue(&mut self, include_in_fs_heuristic: bool) -> Result { + pub fn dequeue(&mut self, include_in_fs_heuristic: bool) -> Result { let item = self .items .get(self.items_index) @@ -143,9 +142,8 @@ where } } -impl Default for ProofStream +impl Default for ProofStream where - Item: Clone + BFieldCodec, H: AlgebraicHasher, { fn default() -> Self { @@ -153,13 +151,12 @@ where } } -impl BFieldCodec for ProofStream +impl BFieldCodec for ProofStream where - Item: Clone + BFieldCodec, H: AlgebraicHasher, { fn decode(sequence: &[BFieldElement]) -> Result> { - let items = *Vec::::decode(sequence)?; + let items = *Vec::::decode(sequence)?; let proof_stream = ProofStream { items, items_index: 0, @@ -177,9 +174,8 @@ where } } -impl TryFrom<&Proof> for ProofStream +impl TryFrom<&Proof> for ProofStream where - Item: Clone + BFieldCodec, H: AlgebraicHasher, { type Error = anyhow::Error; @@ -190,31 +186,35 @@ where } } -impl From<&ProofStream> for Proof +impl From<&ProofStream> for Proof where - Item: Clone + BFieldCodec, H: AlgebraicHasher, { - fn from(proof_stream: &ProofStream) -> Self { + fn from(proof_stream: &ProofStream) -> Self { Proof(proof_stream.encode()) } } -impl From> for Proof +impl From> for Proof where - Item: Clone + BFieldCodec, H: AlgebraicHasher, { - fn from(proof_stream: ProofStream) -> Self { + fn from(proof_stream: ProofStream) -> Self { (&proof_stream).into() } } #[cfg(test)] mod proof_stream_typed_tests { - use anyhow::bail; use itertools::Itertools; - use twenty_first::shared_math::b_field_element::BFieldElement; + use rand::distributions::Standard; + use rand::prelude::Distribution; + use rand::prelude::SeedableRng; + use rand::prelude::StdRng; + use rand::random; + use rand::Rng; + use rand_core::RngCore; + use std::collections::VecDeque; use twenty_first::shared_math::other::random_elements; use twenty_first::shared_math::tip5::Tip5; use twenty_first::shared_math::x_field_element::XFieldElement; @@ -224,100 +224,201 @@ mod proof_stream_typed_tests { use crate::proof_item::FriResponse; use crate::proof_item::ProofItem; + use crate::table::master_table::NUM_BASE_COLUMNS; + use crate::table::master_table::NUM_EXT_COLUMNS; use super::*; - #[derive(Clone, Debug, PartialEq)] - enum TestItem { - ManyB(Vec), - ManyX(Vec), - } + #[test] + fn test_serialize_proof_with_fiat_shamir() { + type H = Tip5; - impl TestItem { - /// The unique identifier for this item type. - pub fn discriminant(&self) -> BFieldElement { - use TestItem::*; - match self { - ManyB(_) => BFieldElement::new(0), - ManyX(_) => BFieldElement::new(1), - } + fn random_elements(seed: u64, n: usize) -> Vec + where + Standard: Distribution, + { + let rng = StdRng::seed_from_u64(seed); + rng.sample_iter(Standard).take(n).collect() } - } - impl BFieldCodec for TestItem { - fn decode(str: &[BFieldElement]) -> Result> { - if str.is_empty() { - bail!("trying to decode empty string into test item"); - } - - let discriminant = str[0].value(); - let str = &str[1..]; - let item = match discriminant { - 0 => Self::ManyB(*Vec::::decode(str)?), - 1 => Self::ManyX(*Vec::::decode(str)?), - i => bail!("Unknown discriminant ID {i}."), - }; - Ok(Box::new(item)) - } + let seed = random(); + let mut rng = StdRng::seed_from_u64(seed); + println!("seed: {seed}"); + + let base_rows = vec![ + random_elements(rng.next_u64(), NUM_BASE_COLUMNS), + random_elements(rng.next_u64(), NUM_BASE_COLUMNS), + ]; + let ext_rows = vec![ + random_elements(rng.next_u64(), NUM_EXT_COLUMNS), + random_elements(rng.next_u64(), NUM_EXT_COLUMNS), + ]; + + let codeword_len = 32; + let fri_codeword: Vec = random_elements(rng.next_u64(), codeword_len); + let fri_codeword_digests = fri_codeword.iter().map(|&x| x.into()).collect_vec(); + let merkle_tree: MerkleTree = CpuParallel::from_digests(&fri_codeword_digests); + let root = merkle_tree.get_root(); + + let revealed_index = rng.gen_range(0..codeword_len); + let auth_path = merkle_tree.get_authentication_path(revealed_index); + + let num_revealed_indices = rng.gen_range(1..=codeword_len); + let revealed_indices = random_elements(rng.next_u64(), num_revealed_indices) + .into_iter() + .map(|idx: usize| idx % codeword_len) + .collect_vec(); + let auth_structure = merkle_tree.get_authentication_structure(&revealed_indices); - fn encode(&self) -> Vec { - use TestItem::*; + let ood_base_row = random_elements(rng.next_u64(), NUM_BASE_COLUMNS); + let ood_ext_row = random_elements(rng.next_u64(), NUM_EXT_COLUMNS); + let combination_elements = random_elements(rng.next_u64(), 5); - let discriminant = vec![self.discriminant()]; - let encoding = match self { - ManyB(bs) => bs.encode(), - ManyX(xs) => xs.encode(), - }; - [discriminant, encoding].concat() - } + let revealed_elements = revealed_indices + .iter() + .map(|&idx| fri_codeword[idx]) + .collect_vec(); + let fri_response = auth_structure + .clone() + .into_iter() + .zip(revealed_elements) + .collect_vec(); + let fri_response = FriResponse(fri_response); - fn static_length() -> Option { - None - } - } + let mut sponge_states = VecDeque::new(); + let mut proof_stream = ProofStream::::new(); - #[test] - fn test_serialize_proof_with_fiat_shamir() { - type H = Tip5; - let mut proof_stream: ProofStream<_, H> = ProofStream::new(); - let manyb1: Vec = random_elements(10); - let manyx: Vec = random_elements(13); - let manyb2: Vec = random_elements(11); - - let fs1 = proof_stream.sponge_state.state; - proof_stream.enqueue(&TestItem::ManyB(manyb1.clone()), false); - let fs2 = proof_stream.sponge_state.state; - proof_stream.enqueue(&TestItem::ManyX(manyx.clone()), true); - let fs3 = proof_stream.sponge_state.state; - proof_stream.enqueue(&TestItem::ManyB(manyb2.clone()), true); - let fs4 = proof_stream.sponge_state.state; + sponge_states.push_back(proof_stream.sponge_state.state); + proof_stream.enqueue( + &ProofItem::CompressedAuthenticationPaths(auth_structure.clone()), + false, + ); + sponge_states.push_back(proof_stream.sponge_state.state); + proof_stream.enqueue(&ProofItem::MasterBaseTableRows(base_rows.clone()), false); + sponge_states.push_back(proof_stream.sponge_state.state); + proof_stream.enqueue(&ProofItem::MasterExtTableRows(ext_rows.clone()), true); + sponge_states.push_back(proof_stream.sponge_state.state); + proof_stream.enqueue(&ProofItem::OutOfDomainBaseRow(ood_base_row.clone()), true); + sponge_states.push_back(proof_stream.sponge_state.state); + proof_stream.enqueue(&ProofItem::OutOfDomainExtRow(ood_ext_row.clone()), true); + sponge_states.push_back(proof_stream.sponge_state.state); + proof_stream.enqueue(&ProofItem::MerkleRoot(root), true); + sponge_states.push_back(proof_stream.sponge_state.state); + proof_stream.enqueue(&ProofItem::AuthenticationPath(auth_path.clone()), true); + sponge_states.push_back(proof_stream.sponge_state.state); + proof_stream.enqueue( + &ProofItem::RevealedCombinationElements(combination_elements.clone()), + true, + ); + sponge_states.push_back(proof_stream.sponge_state.state); + proof_stream.enqueue(&ProofItem::FriCodeword(fri_codeword.clone()), true); + sponge_states.push_back(proof_stream.sponge_state.state); + proof_stream.enqueue(&ProofItem::FriResponse(fri_response.clone()), true); + sponge_states.push_back(proof_stream.sponge_state.state); let proof = proof_stream.into(); - - let mut proof_stream: ProofStream = + let mut proof_stream: ProofStream = ProofStream::try_from(&proof).expect("invalid parsing of proof"); - let fs1_ = proof_stream.sponge_state.state; - match proof_stream.dequeue(false).expect("can't dequeue item") { - TestItem::ManyB(manyb1_) => assert_eq!(manyb1, manyb1_), - TestItem::ManyX(_) => panic!(), + assert_eq!( + sponge_states.pop_front(), + Some(proof_stream.sponge_state.state) + ); + match proof_stream.dequeue(false).unwrap() { + ProofItem::CompressedAuthenticationPaths(auth_structure_) => { + assert_eq!(auth_structure, auth_structure_) + } + _ => panic!(), + }; + + assert_eq!( + sponge_states.pop_front(), + Some(proof_stream.sponge_state.state) + ); + match proof_stream.dequeue(false).unwrap() { + ProofItem::MasterBaseTableRows(base_rows_) => assert_eq!(base_rows, base_rows_), + _ => panic!(), + }; + + assert_eq!( + sponge_states.pop_front(), + Some(proof_stream.sponge_state.state) + ); + match proof_stream.dequeue(true).unwrap() { + ProofItem::MasterExtTableRows(ext_rows_) => assert_eq!(ext_rows, ext_rows_), + _ => panic!(), + }; + + assert_eq!( + sponge_states.pop_front(), + Some(proof_stream.sponge_state.state) + ); + match proof_stream.dequeue(true).unwrap() { + ProofItem::OutOfDomainBaseRow(ood_base_row_) => assert_eq!(ood_base_row, ood_base_row_), + _ => panic!(), + }; + + assert_eq!( + sponge_states.pop_front(), + Some(proof_stream.sponge_state.state) + ); + match proof_stream.dequeue(true).unwrap() { + ProofItem::OutOfDomainExtRow(ood_ext_row_) => assert_eq!(ood_ext_row, ood_ext_row_), + _ => panic!(), + }; + + assert_eq!( + sponge_states.pop_front(), + Some(proof_stream.sponge_state.state) + ); + match proof_stream.dequeue(true).unwrap() { + ProofItem::MerkleRoot(root_) => assert_eq!(root, root_), + _ => panic!(), }; - let fs2_ = proof_stream.sponge_state.state; - match proof_stream.dequeue(true).expect("can't dequeue item") { - TestItem::ManyB(_) => panic!(), - TestItem::ManyX(manyx_) => assert_eq!(manyx, manyx_), + + assert_eq!( + sponge_states.pop_front(), + Some(proof_stream.sponge_state.state) + ); + match proof_stream.dequeue(true).unwrap() { + ProofItem::AuthenticationPath(auth_path_) => assert_eq!(auth_path, auth_path_), + _ => panic!(), }; - let fs3_ = proof_stream.sponge_state.state; - match proof_stream.dequeue(true).expect("can't dequeue item") { - TestItem::ManyB(manyb2_) => assert_eq!(manyb2, manyb2_), - TestItem::ManyX(_) => panic!(), + + assert_eq!( + sponge_states.pop_front(), + Some(proof_stream.sponge_state.state) + ); + match proof_stream.dequeue(true).unwrap() { + ProofItem::RevealedCombinationElements(combination_elements_) => { + assert_eq!(combination_elements, combination_elements_) + } + _ => panic!(), }; - let fs4_ = proof_stream.sponge_state.state; - assert_eq!(fs1, fs1_); - assert_eq!(fs2, fs2_); - assert_eq!(fs3, fs3_); - assert_eq!(fs4, fs4_); + assert_eq!( + sponge_states.pop_front(), + Some(proof_stream.sponge_state.state) + ); + match proof_stream.dequeue(true).unwrap() { + ProofItem::FriCodeword(fri_codeword_) => assert_eq!(fri_codeword, fri_codeword_), + _ => panic!(), + }; + + assert_eq!( + sponge_states.pop_front(), + Some(proof_stream.sponge_state.state) + ); + match proof_stream.dequeue(true).unwrap() { + ProofItem::FriResponse(fri_response_) => assert_eq!(fri_response, fri_response_), + _ => panic!(), + }; + + assert_eq!( + sponge_states.pop_front(), + Some(proof_stream.sponge_state.state) + ); + assert_eq!(sponge_states.len(), 0); } #[test] @@ -338,7 +439,7 @@ mod proof_stream_typed_tests { .collect_vec(); let fri_response = FriResponse(fri_response_content); - let mut proof_stream = ProofStream::::new(); + let mut proof_stream = ProofStream::::new(); proof_stream.enqueue(&ProofItem::FriResponse(fri_response), false); // TODO: Also check that deserializing from Proof works here. diff --git a/triton-vm/src/stark.rs b/triton-vm/src/stark.rs index 95fa3359..858ee1ef 100644 --- a/triton-vm/src/stark.rs +++ b/triton-vm/src/stark.rs @@ -47,7 +47,7 @@ use crate::vm::AlgebraicExecutionTrace; pub type StarkHasher = Tip5; pub type MTMaker = CpuParallel; -pub type StarkProofStream = ProofStream; +pub type StarkProofStream = ProofStream; /// All the security-related parameters for the zk-STARK. #[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Hash)]