diff --git a/triton-vm/src/proof_item.rs b/triton-vm/src/proof_item.rs index 55d780b7..e3fe4374 100644 --- a/triton-vm/src/proof_item.rs +++ b/triton-vm/src/proof_item.rs @@ -19,10 +19,6 @@ type AuthenticationStructure = Vec>; #[derive(Debug, Clone, PartialEq, Eq, BFieldCodec)] pub struct FriResponse(pub Vec<(PartialAuthenticationPath, XFieldElement)>); -pub trait MayBeUncast { - fn uncast(&self) -> Vec; -} - #[derive(Debug, Clone, PartialEq, Eq)] pub enum ProofItem { CompressedAuthenticationPaths(AuthenticationStructure), @@ -35,44 +31,32 @@ pub enum ProofItem { RevealedCombinationElements(Vec), FriCodeword(Vec), FriResponse(FriResponse), - Uncast(Vec), } -impl MayBeUncast for ProofItem { - fn uncast(&self) -> Vec { - if let Self::Uncast(vector) = self { - let vector_len = BFieldElement::new(vector.len() as u64); - vec![vec![vector_len], vector.to_owned()].concat() - } else { - self.encode() - } +impl ProofItem { + /// The unique identifier for this item type. + pub const fn discriminant(&self) -> BFieldElement { + use ProofItem::*; + let discriminant: u64 = match self { + CompressedAuthenticationPaths(_) => 0, + MasterBaseTableRows(_) => 1, + MasterExtTableRows(_) => 2, + OutOfDomainBaseRow(_) => 3, + OutOfDomainExtRow(_) => 4, + MerkleRoot(_) => 5, + AuthenticationPath(_) => 6, + RevealedCombinationElements(_) => 7, + FriCodeword(_) => 8, + FriResponse(_) => 9, + }; + BFieldElement::new(discriminant) } -} -impl ProofItem -where - AuthenticationStructure: BFieldCodec, - Vec>: BFieldCodec, - Vec>: BFieldCodec, - Digest: BFieldCodec, - Vec: BFieldCodec, - Vec: BFieldCodec, - Vec: BFieldCodec, - BFieldElement: BFieldCodec, - XFieldElement: BFieldCodec, - FriResponse: BFieldCodec, -{ pub fn as_compressed_authentication_paths(&self) -> Result> { match self { Self::CompressedAuthenticationPaths(caps) => Ok(caps.to_owned()), - Self::Uncast(str) => match AuthenticationStructure::::decode(str) { - Ok(boxed_auth_struct) => Ok(*boxed_auth_struct), - Err(e) => bail!(ProofStreamError::new(&format!( - "cast to authentication structure failed: {e}" - ))), - }, other => bail!(ProofStreamError::new(&format!( - "expected compressed authentication paths, but got something else: {other:?}", + "expected compressed authentication paths, but got {other:?}", ))), } } @@ -80,183 +64,139 @@ where pub fn as_master_base_table_rows(&self) -> Result>> { match self { Self::MasterBaseTableRows(bss) => Ok(bss.to_owned()), - Self::Uncast(str) => match Vec::>::decode(str) { - Ok(base_element_vectors) => Ok(*base_element_vectors), - Err(_) => bail!(ProofStreamError::new("cast to base element vectors failed",)), - }, - _ => bail!(ProofStreamError::new( - "expected master base table rows, but got something else", - )), + other => bail!(ProofStreamError::new(&format!( + "expected master base table rows, but got something {other:?}", + ))), } } pub fn as_master_ext_table_rows(&self) -> Result>> { match self { Self::MasterExtTableRows(xss) => Ok(xss.to_owned()), - Self::Uncast(str) => match Vec::>::decode(str) { - Ok(ext_element_vectors) => Ok(*ext_element_vectors), - Err(_) => bail!(ProofStreamError::new( - "cast to extension field element vectors failed", - )), - }, - _ => bail!(ProofStreamError::new( - "expected master extension table rows, but got something else", - )), + other => bail!(ProofStreamError::new(&format!( + "expected master extension table rows, but got {other:?}", + ))), } } pub fn as_out_of_domain_base_row(&self) -> Result> { match self { Self::OutOfDomainBaseRow(xs) => Ok(xs.to_owned()), - Self::Uncast(str) => match Vec::::decode(str) { - Ok(xs) => { - if xs.len() != NUM_BASE_COLUMNS { - bail!(ProofStreamError::new( - "cast to out of domain base row failed" - )); - } - Ok(*xs) - } - Err(_) => bail!(ProofStreamError::new( - "cast to out of domain base row failed" - )), - }, - _ => bail!(ProofStreamError::new( - "expected out of domain base row, but got something else", - )), + other => bail!(ProofStreamError::new(&format!( + "expected out of domain base row, but got {other:?}", + ))), } } pub fn as_out_of_domain_ext_row(&self) -> Result> { match self { Self::OutOfDomainExtRow(xs) => Ok(xs.to_owned()), - Self::Uncast(str) => match Vec::::decode(str) { - Ok(xs) => { - if xs.len() != NUM_EXT_COLUMNS { - bail!(ProofStreamError::new( - "cast to out of domain extension row failed" - )); - } - Ok(*xs) - } - Err(_) => bail!(ProofStreamError::new( - "cast to out of domain extension row failed" - )), - }, - _ => bail!(ProofStreamError::new( - "expected out of domain extension row, but got something else", - )), + other => bail!(ProofStreamError::new(&format!( + "expected out of domain extension row, but got {other:?}", + ))), } } pub fn as_merkle_root(&self) -> Result { match self { Self::MerkleRoot(bs) => Ok(*bs), - Self::Uncast(str) => match Digest::decode(str) { - Ok(merkle_root) => Ok(*merkle_root), - Err(_) => bail!(ProofStreamError::new("cast to Merkle root failed",)), - }, - _ => bail!(ProofStreamError::new( - "expected merkle root, but got something else", - )), + other => bail!(ProofStreamError::new(&format!( + "expected merkle root, but got {other:?}", + ))), } } pub fn as_authentication_path(&self) -> Result> { match self { Self::AuthenticationPath(bss) => Ok(bss.to_owned()), - Self::Uncast(str) => match Vec::::decode(str) { - Ok(authentication_path) => Ok(*authentication_path), - Err(_) => bail!(ProofStreamError::new("cast to authentication path failed",)), - }, - _ => bail!(ProofStreamError::new( - "expected authentication path, but got something else", - )), + other => bail!(ProofStreamError::new(&format!( + "expected authentication path, but got {other:?}", + ))), } } pub fn as_revealed_combination_elements(&self) -> Result> { match self { Self::RevealedCombinationElements(xs) => Ok(xs.to_owned()), - Self::Uncast(str) => match Vec::::decode(str) { - Ok(revealed_combination_elements) => Ok(*revealed_combination_elements), - Err(_) => bail!(ProofStreamError::new( - "cast to revealed combination elements failed", - )), - }, - _ => bail!(ProofStreamError::new( - "expected revealed combination elements, but got something else", - )), + other => bail!(ProofStreamError::new(&format!( + "expected revealed combination elements, but got {other:?}", + ))), } } pub fn as_fri_codeword(&self) -> Result> { match self { Self::FriCodeword(xs) => Ok(xs.to_owned()), - Self::Uncast(str) => match Vec::::decode(str) { - Ok(fri_codeword) => Ok(*fri_codeword), - Err(_) => bail!(ProofStreamError::new("cast to FRI codeword failed",)), - }, - _ => bail!(ProofStreamError::new( - "expected FRI codeword, but got something else", - )), + other => bail!(ProofStreamError::new(&format!( + "expected FRI codeword, but got {other:?}", + ))), } } pub fn as_fri_response(&self) -> Result { match self { Self::FriResponse(fri_proof) => Ok(fri_proof.to_owned()), - Self::Uncast(str) => match FriResponse::decode(str) { - Ok(fri_proof) => Ok(*fri_proof), - Err(_) => bail!(ProofStreamError::new("cast to FRI proof failed",)), - }, - _ => bail!(ProofStreamError::new( - "expected FRI proof, but got something else", - )), + other => bail!(ProofStreamError::new(&format!( + "expected FRI proof, but got {other:?}" + ),)), } } } impl BFieldCodec for ProofItem { - /// Turn the given string of BFieldElements into a ProofItem. The first element denotes the - /// length of the encoding; make sure it is correct! + /// Turn the given string of BFieldElements into a ProofItem. The first element indicates the + /// field type, and the rest of the elements are the data for the item. fn decode(str: &[BFieldElement]) -> Result> { - let Some(len) = str.get(0) else { - bail!(ProofStreamError::new("empty buffer")) - }; - if len.value() + 1 != str.len() as u64 { - bail!(ProofStreamError::new("length mismatch")) + if str.is_empty() { + bail!(ProofStreamError::new("empty buffer")); } - let raw_item = Self::Uncast(str[1..].to_vec()); - Ok(Box::new(raw_item)) + let discriminant = str[0].value(); + let str = &str[1..]; + let item = match discriminant { + 0 => Self::CompressedAuthenticationPaths(*AuthenticationStructure::decode(str)?), + 1 => Self::MasterBaseTableRows(*Vec::>::decode(str)?), + 2 => Self::MasterExtTableRows(*Vec::>::decode(str)?), + 3 => Self::OutOfDomainBaseRow(*Vec::::decode(str)?), + 4 => Self::OutOfDomainExtRow(*Vec::::decode(str)?), + 5 => Self::MerkleRoot(*Digest::decode(str)?), + 6 => Self::AuthenticationPath(*Vec::::decode(str)?), + 7 => Self::RevealedCombinationElements(*Vec::::decode(str)?), + 8 => Self::FriCodeword(*Vec::::decode(str)?), + 9 => Self::FriResponse(*FriResponse::decode(str)?), + i => bail!(ProofStreamError::new(&format!( + "Unknown discriminant {i} for ProofItem." + ))), + }; + Ok(Box::new(item)) } /// Encode the ProofItem as a string of BFieldElements, with the first element denoting the /// length of the rest. fn encode(&self) -> Vec { - let mut tail = match self { - ProofItem::CompressedAuthenticationPaths(something) => something.encode(), - ProofItem::MasterBaseTableRows(something) => something.encode(), - ProofItem::MasterExtTableRows(something) => something.encode(), - ProofItem::OutOfDomainBaseRow(row) => { - debug_assert_eq!(NUM_BASE_COLUMNS, row.len()); - row.encode() - } - ProofItem::OutOfDomainExtRow(row) => { - debug_assert_eq!(NUM_EXT_COLUMNS, row.len()); - row.encode() - } - ProofItem::MerkleRoot(something) => something.encode(), - ProofItem::AuthenticationPath(something) => something.encode(), - ProofItem::RevealedCombinationElements(something) => something.encode(), - ProofItem::FriCodeword(something) => something.encode(), - ProofItem::FriResponse(something) => something.encode(), - ProofItem::Uncast(something) => something.encode(), + use ProofItem::*; + + #[cfg(debug_assertions)] + match self { + OutOfDomainBaseRow(row) => assert_eq!(NUM_BASE_COLUMNS, row.len()), + OutOfDomainExtRow(row) => assert_eq!(NUM_EXT_COLUMNS, row.len()), + _ => (), + } + + let discriminant = vec![self.discriminant()]; + let encoding = match self { + CompressedAuthenticationPaths(something) => something.encode(), + MasterBaseTableRows(something) => something.encode(), + MasterExtTableRows(something) => something.encode(), + OutOfDomainBaseRow(row) => row.encode(), + OutOfDomainExtRow(row) => row.encode(), + MerkleRoot(something) => something.encode(), + AuthenticationPath(something) => something.encode(), + RevealedCombinationElements(something) => something.encode(), + FriCodeword(something) => something.encode(), + FriResponse(something) => something.encode(), }; - let head = BFieldElement::new(tail.len() as u64); - tail.insert(0, head); - tail + [discriminant, encoding].concat() } fn static_length() -> Option { diff --git a/triton-vm/src/proof_stream.rs b/triton-vm/src/proof_stream.rs index 2eef0aa4..db39f30b 100644 --- a/triton-vm/src/proof_stream.rs +++ b/triton-vm/src/proof_stream.rs @@ -3,6 +3,7 @@ use std::fmt::Display; use anyhow::bail; use anyhow::Result; +use itertools::Itertools; use twenty_first::shared_math::b_field_element::BFieldElement; use twenty_first::shared_math::b_field_element::BFIELD_ONE; use twenty_first::shared_math::b_field_element::BFIELD_ZERO; @@ -12,12 +13,11 @@ use twenty_first::shared_math::x_field_element::XFieldElement; use twenty_first::util_types::algebraic_hasher::AlgebraicHasher; use crate::proof::Proof; -use crate::proof_item::MayBeUncast; #[derive(Debug, Clone, PartialEq, Eq)] pub struct ProofStream where - Item: Clone + BFieldCodec + MayBeUncast, + Item: Clone + BFieldCodec, H: AlgebraicHasher, { pub items: Vec, @@ -49,7 +49,7 @@ impl Error for ProofStreamError {} impl ProofStream where - Item: Clone + BFieldCodec + MayBeUncast, + Item: Clone + BFieldCodec, H: AlgebraicHasher, { pub fn new() -> Self { @@ -64,10 +64,12 @@ where self.items.is_empty() } + /// The number of items in the proof stream. pub fn len(&self) -> usize { self.items.len() } + /// The number of field elements required to encode the proof. pub fn transcript_length(&self) -> usize { let Proof(b_field_elements) = self.to_proof(); b_field_elements.len() @@ -75,11 +77,9 @@ where /// Convert the proof stream, _i.e._, the transcript, into a Proof. pub fn to_proof(&self) -> Proof { - let mut bfes = vec![]; - for item in self.items.iter() { - bfes.append(&mut item.encode()); - } - Proof(bfes) + let encoded_items = self.items.iter().map(|item| item.encode()).collect_vec(); + let complete_encoding = encoded_items.concat(); + Proof(complete_encoding) } /// Convert the proof into a proof stream for the verifier. @@ -96,7 +96,7 @@ where have {proof_len} but expected {next_index}" ))); } - let str = &proof.0[index..(next_index)]; + let str = &proof.0[index..next_index]; let item = Item::decode(str)?; items.push(*item); index = next_index; @@ -180,7 +180,7 @@ where impl Default for ProofStream where - Item: Clone + BFieldCodec + MayBeUncast, + Item: Clone + BFieldCodec, H: AlgebraicHasher, { fn default() -> Self { @@ -208,88 +208,44 @@ mod proof_stream_typed_tests { enum TestItem { ManyB(Vec), ManyX(Vec), - Uncast(Vec), - } - - impl MayBeUncast for TestItem { - fn uncast(&self) -> Vec { - if let Self::Uncast(vector) = self { - let mut str = vec![]; - str.push(BFieldElement::new(vector.len().try_into().unwrap())); - str.append(&mut vector.clone()); - str - } else { - self.encode() - } - } } impl TestItem { - pub fn as_bs(&self) -> Self { - match self { - Self::Uncast(bs) => Self::ManyB(bs.to_vec()), - _ => panic!("can only cast from Uncast"), - } - } - - pub fn as_xs(&self) -> Self { + /// The unique identifier for this item type. + pub fn discriminant(&self) -> BFieldElement { + use TestItem::*; match self { - Self::Uncast(bs) => Self::ManyX( - bs.chunks(3) - .collect_vec() - .into_iter() - .map(|bbb| { - XFieldElement::new( - bbb.try_into() - .expect("cannot unwrap chunk of 3 (?) BFieldElements"), - ) - }) - .collect_vec(), - ), - _ => panic!("can only cast from Uncast"), + ManyB(_) => BFieldElement::new(0), + ManyX(_) => BFieldElement::new(1), } } } impl BFieldCodec for TestItem { fn decode(str: &[BFieldElement]) -> Result> { - let maybe_element_zero = str.get(0); - match maybe_element_zero { - None => Err(ProofStreamError::new( - "trying to decode empty string into test item", - )), - Some(bfe) => { - if str.len() != 1 + (bfe.value() as usize) { - Err(ProofStreamError::new("length mismatch")) - } else { - Ok(Box::new(Self::Uncast(str[1..].to_vec()))) - } - } + if str.is_empty() { + bail!("trying to decode empty string into test item"); } + + let discriminant = str[0].value(); + let str = &str[1..]; + let item = match discriminant { + 0 => Self::ManyB(*Vec::::decode(str)?), + 1 => Self::ManyX(*Vec::::decode(str)?), + i => bail!("Unknown discriminant ID {i}."), + }; + Ok(Box::new(item)) } fn encode(&self) -> Vec { - let mut vect = vec![]; - match self { - Self::ManyB(bs) => { - for b in bs { - vect.append(&mut b.encode()); - } - } - Self::ManyX(xs) => { - for x in xs { - vect.append(&mut x.encode()); - } - } - Self::Uncast(bs) => { - for b in bs { - vect.append(&mut b.encode()); - } - } - } - vect.insert(0, BFieldElement::new(vect.len().try_into().unwrap())); - - vect + use TestItem::*; + + let discriminant = vec![self.discriminant()]; + let encoding = match self { + ManyB(bs) => bs.encode(), + ManyX(xs) => xs.encode(), + }; + [discriminant, encoding].concat() } fn static_length() -> Option { @@ -319,34 +275,19 @@ mod proof_stream_typed_tests { ProofStream::from_proof(&proof).expect("invalid parsing of proof"); let fs1_ = proof_stream.sponge_state.state; - match proof_stream - .dequeue(false) - .expect("can't dequeue item") - .as_bs() - { + match proof_stream.dequeue(false).expect("can't dequeue item") { TestItem::ManyB(manyb1_) => assert_eq!(manyb1, manyb1_), TestItem::ManyX(_) => panic!(), - TestItem::Uncast(_) => panic!(), }; let fs2_ = proof_stream.sponge_state.state; - match proof_stream - .dequeue(true) - .expect("can't dequeue item") - .as_xs() - { + match proof_stream.dequeue(true).expect("can't dequeue item") { TestItem::ManyB(_) => panic!(), TestItem::ManyX(manyx_) => assert_eq!(manyx, manyx_), - TestItem::Uncast(_) => panic!(), }; let fs3_ = proof_stream.sponge_state.state; - match proof_stream - .dequeue(true) - .expect("can't dequeue item") - .as_bs() - { + match proof_stream.dequeue(true).expect("can't dequeue item") { TestItem::ManyB(manyb2_) => assert_eq!(manyb2, manyb2_), TestItem::ManyX(_) => panic!(), - TestItem::Uncast(_) => panic!(), }; let fs4_ = proof_stream.sponge_state.state;