diff --git a/triton-vm/benches/proof_size.rs b/triton-vm/benches/proof_size.rs index 4e1711af..0a99f289 100644 --- a/triton-vm/benches/proof_size.rs +++ b/triton-vm/benches/proof_size.rs @@ -183,7 +183,7 @@ fn program_halt() -> ProgramAndInput { /// The base 2, integer logarithm of the FRI domain length. fn log_2_fri_domain_length(parameters: StarkParameters, proof: &Proof) -> u32 { let padded_height = proof.padded_height().unwrap(); - let fri = Stark::derive_fri(parameters, padded_height); + let fri = Stark::derive_fri(parameters, padded_height).unwrap(); fri.domain.length.ilog2() } diff --git a/triton-vm/benches/prove_fib.rs b/triton-vm/benches/prove_fib.rs index dd3e230e..5e37ec5e 100644 --- a/triton-vm/benches/prove_fib.rs +++ b/triton-vm/benches/prove_fib.rs @@ -47,7 +47,7 @@ fn prover_timing_report(claim: &Claim, aet: &AlgebraicExecutionTrace) -> Report profiler.finish(); let padded_height = proof.padded_height().unwrap(); - let fri = Stark::derive_fri(parameters, padded_height); + let fri = Stark::derive_fri(parameters, padded_height).unwrap(); profiler .report() .with_cycle_count(aet.processor_trace.nrows()) diff --git a/triton-vm/benches/prove_halt.rs b/triton-vm/benches/prove_halt.rs index 2892cdfa..cc156d34 100644 --- a/triton-vm/benches/prove_halt.rs +++ b/triton-vm/benches/prove_halt.rs @@ -38,7 +38,7 @@ fn prove_halt(criterion: &mut Criterion) { println!("Writing report ..."); let padded_height = proof.padded_height().unwrap(); - let fri = Stark::derive_fri(parameters, padded_height); + let fri = Stark::derive_fri(parameters, padded_height).unwrap(); let report = profiler .report() .with_cycle_count(aet.processor_trace.nrows()) diff --git a/triton-vm/benches/verify_halt.rs b/triton-vm/benches/verify_halt.rs index e6edb082..9a377ade 100644 --- a/triton-vm/benches/verify_halt.rs +++ b/triton-vm/benches/verify_halt.rs @@ -30,7 +30,7 @@ fn verify_halt(criterion: &mut Criterion) { let mut profiler = profiler.unwrap(); profiler.finish(); let padded_height = proof.padded_height().unwrap(); - let fri = Stark::derive_fri(parameters, padded_height); + let fri = Stark::derive_fri(parameters, padded_height).unwrap(); let report = profiler .report() .with_cycle_count(aet.processor_trace.nrows()) diff --git a/triton-vm/src/arithmetic_domain.rs b/triton-vm/src/arithmetic_domain.rs index 3537facb..afc33ac9 100644 --- a/triton-vm/src/arithmetic_domain.rs +++ b/triton-vm/src/arithmetic_domain.rs @@ -2,7 +2,6 @@ use std::ops::MulAssign; use num_traits::One; use twenty_first::shared_math::b_field_element::BFieldElement; -use twenty_first::shared_math::other::is_power_of_two; use twenty_first::shared_math::polynomial::Polynomial; use twenty_first::shared_math::traits::FiniteField; use twenty_first::shared_math::traits::ModPowU32; @@ -36,7 +35,7 @@ impl ArithmeticDomain { /// The domain length must be a power of 2. pub fn generator_for_length(domain_length: u64) -> BFieldElement { assert!( - 0 == domain_length || is_power_of_two(domain_length), + 0 == domain_length || domain_length.is_power_of_two(), "The domain length must be a power of 2 but was {domain_length}.", ); BFieldElement::primitive_root_of_unity(domain_length).unwrap() diff --git a/triton-vm/src/error.rs b/triton-vm/src/error.rs index cadc13be..9685d989 100644 --- a/triton-vm/src/error.rs +++ b/triton-vm/src/error.rs @@ -117,6 +117,19 @@ pub enum ProofStreamError { DecodingError(#[from] as BFieldCodec>::Error), } +#[non_exhaustive] +#[derive(Debug, Clone, Copy, PartialEq, Eq, Error)] +pub enum FriSetupError { + #[error("the expansion factor must be greater than 1")] + ExpansionFactorTooSmall, + + #[error("the expansion factor must be a power of 2")] + ExpansionFactorUnsupported, + + #[error("the expansion factor must be smaller than the domain length")] + ExpansionFactorMismatch, +} + #[non_exhaustive] #[derive(Debug, Error)] pub enum FriProvingError { @@ -239,6 +252,9 @@ pub enum VerificationError { #[error(transparent)] ProofStreamError(#[from] ProofStreamError), + #[error(transparent)] + FriSetupError(#[from] FriSetupError), + #[error(transparent)] FriValidationError(#[from] FriValidationError), } diff --git a/triton-vm/src/fri.rs b/triton-vm/src/fri.rs index 2fbe7b04..bac640cd 100644 --- a/triton-vm/src/fri.rs +++ b/triton-vm/src/fri.rs @@ -17,6 +17,7 @@ use twenty_first::util_types::merkle_tree_maker::MerkleTreeMaker; use crate::arithmetic_domain::ArithmeticDomain; use crate::error::FriProvingError; +use crate::error::FriSetupError; use crate::error::FriValidationError; use crate::error::FriValidationError::*; use crate::profiler::prof_start; @@ -27,8 +28,10 @@ use crate::proof_item::ProofItem; use crate::proof_stream::ProofStream; use crate::stark::MTMaker; -type VerifierResult = Result; -type ProverResult = Result; +pub(crate) type SetupResult = Result; +pub(crate) type ProverResult = Result; +pub(crate) type VerifierResult = Result; + pub type AuthenticationStructure = Vec; #[derive(Debug, Clone, Copy)] @@ -550,17 +553,21 @@ impl Fri { domain: ArithmeticDomain, expansion_factor: usize, num_collinearity_checks: usize, - ) -> Self { - assert!(expansion_factor > 1); - assert!(expansion_factor.is_power_of_two()); - assert!(domain.length >= expansion_factor); + ) -> SetupResult { + match expansion_factor { + ef if ef <= 1 => return Err(FriSetupError::ExpansionFactorTooSmall), + ef if !ef.is_power_of_two() => return Err(FriSetupError::ExpansionFactorUnsupported), + ef if ef > domain.length => return Err(FriSetupError::ExpansionFactorMismatch), + _ => (), + }; - Self { + let fri = Self { domain, expansion_factor, num_collinearity_checks, _hasher: PhantomData, - } + }; + Ok(fri) } /// Create a FRI proof and return indices of revealed elements of round 0. @@ -708,7 +715,7 @@ mod tests { let domain_length = max(sampled_domain_length, min_expanded_domain_length); let fri_domain = ArithmeticDomain::of_length(domain_length).with_offset(offset); - Fri::new(fri_domain, expansion_factor, num_collinearity_checks) + Fri::new(fri_domain, expansion_factor, num_collinearity_checks).unwrap() } } @@ -807,42 +814,42 @@ mod tests { let domain = ArithmeticDomain::of_length(2); let expansion_factor = 2; let num_collinearity_checks = 1; - Fri::new(domain, expansion_factor, num_collinearity_checks) + Fri::new(domain, expansion_factor, num_collinearity_checks).unwrap() } #[test] - #[should_panic] fn too_small_expansion_factor_is_rejected() { let domain = ArithmeticDomain::of_length(2); let expansion_factor = 1; let num_collinearity_checks = 1; - Fri::::new(domain, expansion_factor, num_collinearity_checks); + let err = Fri::::new(domain, expansion_factor, num_collinearity_checks).unwrap_err(); + assert_eq!(FriSetupError::ExpansionFactorTooSmall, err); } #[proptest] - #[should_panic] fn expansion_factor_not_a_power_of_two_is_rejected( - #[strategy(2_usize..)] expansion_factor: usize, - #[strategy(arb())] offset: BFieldElement, + #[strategy(2_usize..(1 << 32))] + #[filter(!#expansion_factor.is_power_of_two())] + expansion_factor: usize, ) { - if expansion_factor.is_power_of_two() { - return Ok(()); - } - let domain = ArithmeticDomain::of_length(2 * expansion_factor).with_offset(offset); + let largest_supported_domain_size = 1 << 32; + let domain = ArithmeticDomain::of_length(largest_supported_domain_size); let num_collinearity_checks = 1; - Fri::::new(domain, expansion_factor, num_collinearity_checks); + let err = Fri::::new(domain, expansion_factor, num_collinearity_checks).unwrap_err(); + prop_assert_eq!(FriSetupError::ExpansionFactorUnsupported, err); } #[proptest] - #[should_panic] fn domain_size_smaller_than_expansion_factor_is_rejected( - #[strategy(1_usize..=8)] log_2_expansion_factor: usize, - #[strategy(arb())] offset: BFieldElement, + #[strategy(1_usize..32)] log_2_expansion_factor: usize, + #[strategy(..#log_2_expansion_factor)] log_2_domain_length: usize, ) { let expansion_factor = (1 << log_2_expansion_factor) as usize; - let domain = ArithmeticDomain::of_length(expansion_factor - 1).with_offset(offset); + let domain_length = (1 << log_2_domain_length) as usize; + let domain = ArithmeticDomain::of_length(domain_length); let num_collinearity_checks = 1; - Fri::::new(domain, expansion_factor, num_collinearity_checks); + let err = Fri::::new(domain, expansion_factor, num_collinearity_checks).unwrap_err(); + prop_assert_eq!(FriSetupError::ExpansionFactorMismatch, err); } // todo: add test fuzzing proof_stream diff --git a/triton-vm/src/profiler.rs b/triton-vm/src/profiler.rs index b69e7105..98e9fd35 100644 --- a/triton-vm/src/profiler.rs +++ b/triton-vm/src/profiler.rs @@ -14,7 +14,6 @@ use colored::Color; use colored::ColoredString; use colored::Colorize; use criterion::profiler::Profiler; -use twenty_first::shared_math::other::log_2_floor; use unicode_width::UnicodeWidthStr; const GET_PROFILE_OUTPUT_AS_YOU_GO_ENV_VAR_NAME: &str = "PROFILE_AS_YOU_GO"; @@ -594,7 +593,7 @@ impl Display for Report { if let Some(fri_domain_length) = self.fri_domain_len { if fri_domain_length != 0 { - let log_2_fri_domain_length = log_2_floor(fri_domain_length as u128); + let log_2_fri_domain_length = fri_domain_length.ilog2(); writeln!(f, "FRI domain length is 2^{log_2_fri_domain_length}")?; } } diff --git a/triton-vm/src/proof_stream.rs b/triton-vm/src/proof_stream.rs index 2eb2baf5..2f71bdc7 100644 --- a/triton-vm/src/proof_stream.rs +++ b/triton-vm/src/proof_stream.rs @@ -3,7 +3,6 @@ use twenty_first::shared_math::b_field_element::BFieldElement; use twenty_first::shared_math::b_field_element::BFIELD_ONE; use twenty_first::shared_math::b_field_element::BFIELD_ZERO; use twenty_first::shared_math::bfield_codec::BFieldCodec; -use twenty_first::shared_math::other::is_power_of_two; use twenty_first::shared_math::x_field_element::XFieldElement; use twenty_first::util_types::algebraic_hasher::AlgebraicHasher; @@ -108,7 +107,7 @@ where /// - `upper_bound`: The (non-inclusive) upper bound. Must be a power of two. /// - `num_indices`: The number of indices to sample pub fn sample_indices(&mut self, upper_bound: usize, num_indices: usize) -> Vec { - assert!(is_power_of_two(upper_bound)); + assert!(upper_bound.is_power_of_two()); assert!(upper_bound <= BFieldElement::MAX as usize); H::sample_indices(&mut self.sponge_state, upper_bound as u32, num_indices) .into_iter() diff --git a/triton-vm/src/shared_tests.rs b/triton-vm/src/shared_tests.rs index 7732ca45..ab0f3f96 100644 --- a/triton-vm/src/shared_tests.rs +++ b/triton-vm/src/shared_tests.rs @@ -159,7 +159,7 @@ pub(crate) fn construct_master_base_table( aet: &AlgebraicExecutionTrace, ) -> MasterBaseTable { let padded_height = aet.padded_height(); - let fri = Stark::derive_fri(parameters, padded_height); + let fri = Stark::derive_fri(parameters, padded_height).unwrap(); let max_degree = Stark::derive_max_degree(padded_height, parameters.num_trace_randomizers); let quotient_domain = Stark::quotient_domain(fri.domain, max_degree); MasterBaseTable::new( diff --git a/triton-vm/src/stark.rs b/triton-vm/src/stark.rs index 10a37801..5ad237bc 100644 --- a/triton-vm/src/stark.rs +++ b/triton-vm/src/stark.rs @@ -31,6 +31,7 @@ use crate::aet::AlgebraicExecutionTrace; use crate::arithmetic_domain::ArithmeticDomain; use crate::error::VerificationError; use crate::error::VerificationError::*; +use crate::fri; use crate::fri::Fri; use crate::profiler::prof_itr0; use crate::profiler::prof_start; @@ -151,7 +152,7 @@ impl Stark { prof_start!(maybe_profiler, "derive additional parameters"); let padded_height = aet.padded_height(); let max_degree = Self::derive_max_degree(padded_height, parameters.num_trace_randomizers); - let fri = Self::derive_fri(parameters, padded_height); + let fri = Self::derive_fri(parameters, padded_height).unwrap(); let quotient_domain = Self::quotient_domain(fri.domain, max_degree); proof_stream.enqueue(ProofItem::Log2PaddedHeight(padded_height.ilog2())); prof_stop!(maybe_profiler, "derive additional parameters"); @@ -626,13 +627,17 @@ impl Stark { /// In principle, the FRI domain is also influenced by the AIR's degree /// (see [`AIR_TARGET_DEGREE`]). However, by segmenting the quotient polynomial into /// [`AIR_TARGET_DEGREE`]-many parts, that influence is mitigated. - pub fn derive_fri(parameters: StarkParameters, padded_height: usize) -> Fri { + pub fn derive_fri( + parameters: StarkParameters, + padded_height: usize, + ) -> fri::SetupResult> { let interpolant_degree = interpolant_degree(padded_height, parameters.num_trace_randomizers); let interpolant_codeword_length = interpolant_degree as usize + 1; let fri_domain_length = parameters.fri_expansion_factor * interpolant_codeword_length; let coset_offset = BFieldElement::generator(); let domain = ArithmeticDomain::of_length(fri_domain_length).with_offset(coset_offset); + Fri::new( domain, parameters.fri_expansion_factor, @@ -733,7 +738,7 @@ impl Stark { prof_start!(maybe_profiler, "derive additional parameters"); let log_2_padded_height = proof_stream.dequeue()?.as_log2_padded_height()?; let padded_height = 1 << log_2_padded_height; - let fri = Self::derive_fri(parameters, padded_height); + let fri = Self::derive_fri(parameters, padded_height)?; let merkle_tree_height = fri.domain.length.ilog2() as usize; prof_stop!(maybe_profiler, "derive additional parameters"); @@ -2220,7 +2225,7 @@ pub(crate) mod tests { assert!(let Ok(()) = Stark::verify(parameters, &claim, &proof, &mut None)); let_assert!(Ok(padded_height) = proof.padded_height()); - let fri = Stark::derive_fri(parameters, padded_height); + let fri = Stark::derive_fri(parameters, padded_height).unwrap(); let report = profiler .report() .with_padded_height(padded_height) @@ -2244,7 +2249,7 @@ pub(crate) mod tests { assert!(let Ok(()) = Stark::verify(parameters, &claim, &proof, &mut None)); let_assert!(Ok(padded_height) = proof.padded_height()); - let fri = Stark::derive_fri(parameters, padded_height); + let fri = Stark::derive_fri(parameters, padded_height).unwrap(); let report = profiler .report() .with_padded_height(padded_height) @@ -2286,7 +2291,7 @@ pub(crate) mod tests { assert!(let Ok(()) = Stark::verify(parameters, &claim, &proof, &mut None)); let_assert!(Ok(padded_height) = proof.padded_height()); - let fri = Stark::derive_fri(parameters, padded_height); + let fri = Stark::derive_fri(parameters, padded_height).unwrap(); let report = profiler .report() .with_padded_height(padded_height)