Skip to content

Commit

Permalink
refactor!: communicate possible FRI setup failures with Result
Browse files Browse the repository at this point in the history
  • Loading branch information
jan-ferdinand committed Jan 24, 2024
1 parent 9db6cd2 commit 3fe35ad
Show file tree
Hide file tree
Showing 11 changed files with 67 additions and 42 deletions.
2 changes: 1 addition & 1 deletion triton-vm/benches/proof_size.rs
Original file line number Diff line number Diff line change
Expand Up @@ -183,7 +183,7 @@ fn program_halt() -> ProgramAndInput {
/// The base 2, integer logarithm of the FRI domain length.
fn log_2_fri_domain_length(parameters: StarkParameters, proof: &Proof) -> u32 {
let padded_height = proof.padded_height().unwrap();
let fri = Stark::derive_fri(parameters, padded_height);
let fri = Stark::derive_fri(parameters, padded_height).unwrap();
fri.domain.length.ilog2()
}

Expand Down
2 changes: 1 addition & 1 deletion triton-vm/benches/prove_fib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ fn prover_timing_report(claim: &Claim, aet: &AlgebraicExecutionTrace) -> Report
profiler.finish();

let padded_height = proof.padded_height().unwrap();
let fri = Stark::derive_fri(parameters, padded_height);
let fri = Stark::derive_fri(parameters, padded_height).unwrap();
profiler
.report()
.with_cycle_count(aet.processor_trace.nrows())
Expand Down
2 changes: 1 addition & 1 deletion triton-vm/benches/prove_halt.rs
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ fn prove_halt(criterion: &mut Criterion) {

println!("Writing report ...");
let padded_height = proof.padded_height().unwrap();
let fri = Stark::derive_fri(parameters, padded_height);
let fri = Stark::derive_fri(parameters, padded_height).unwrap();
let report = profiler
.report()
.with_cycle_count(aet.processor_trace.nrows())
Expand Down
2 changes: 1 addition & 1 deletion triton-vm/benches/verify_halt.rs
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ fn verify_halt(criterion: &mut Criterion) {
let mut profiler = profiler.unwrap();
profiler.finish();
let padded_height = proof.padded_height().unwrap();
let fri = Stark::derive_fri(parameters, padded_height);
let fri = Stark::derive_fri(parameters, padded_height).unwrap();
let report = profiler
.report()
.with_cycle_count(aet.processor_trace.nrows())
Expand Down
3 changes: 1 addition & 2 deletions triton-vm/src/arithmetic_domain.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@ use std::ops::MulAssign;

use num_traits::One;
use twenty_first::shared_math::b_field_element::BFieldElement;
use twenty_first::shared_math::other::is_power_of_two;
use twenty_first::shared_math::polynomial::Polynomial;
use twenty_first::shared_math::traits::FiniteField;
use twenty_first::shared_math::traits::ModPowU32;
Expand Down Expand Up @@ -36,7 +35,7 @@ impl ArithmeticDomain {
/// The domain length must be a power of 2.
pub fn generator_for_length(domain_length: u64) -> BFieldElement {
assert!(
0 == domain_length || is_power_of_two(domain_length),
0 == domain_length || domain_length.is_power_of_two(),
"The domain length must be a power of 2 but was {domain_length}.",
);
BFieldElement::primitive_root_of_unity(domain_length).unwrap()
Expand Down
16 changes: 16 additions & 0 deletions triton-vm/src/error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,19 @@ pub enum ProofStreamError {
DecodingError(#[from] <ProofStream<StarkHasher> as BFieldCodec>::Error),
}

#[non_exhaustive]
#[derive(Debug, Clone, Copy, PartialEq, Eq, Error)]
pub enum FriSetupError {
#[error("the expansion factor must be greater than 1")]
ExpansionFactorTooSmall,

#[error("the expansion factor must be a power of 2")]
ExpansionFactorUnsupported,

#[error("the expansion factor must be smaller than the domain length")]
ExpansionFactorMismatch,
}

#[non_exhaustive]
#[derive(Debug, Error)]
pub enum FriProvingError {
Expand Down Expand Up @@ -239,6 +252,9 @@ pub enum VerificationError {
#[error(transparent)]
ProofStreamError(#[from] ProofStreamError),

#[error(transparent)]
FriSetupError(#[from] FriSetupError),

#[error(transparent)]
FriValidationError(#[from] FriValidationError),
}
Expand Down
57 changes: 32 additions & 25 deletions triton-vm/src/fri.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ use twenty_first::util_types::merkle_tree_maker::MerkleTreeMaker;

use crate::arithmetic_domain::ArithmeticDomain;
use crate::error::FriProvingError;
use crate::error::FriSetupError;
use crate::error::FriValidationError;
use crate::error::FriValidationError::*;
use crate::profiler::prof_start;
Expand All @@ -27,8 +28,10 @@ use crate::proof_item::ProofItem;
use crate::proof_stream::ProofStream;
use crate::stark::MTMaker;

type VerifierResult<T> = Result<T, FriValidationError>;
type ProverResult<T> = Result<T, FriProvingError>;
pub(crate) type SetupResult<T> = Result<T, FriSetupError>;
pub(crate) type ProverResult<T> = Result<T, FriProvingError>;
pub(crate) type VerifierResult<T> = Result<T, FriValidationError>;

pub type AuthenticationStructure = Vec<Digest>;

#[derive(Debug, Clone, Copy)]
Expand Down Expand Up @@ -550,17 +553,21 @@ impl<H: AlgebraicHasher> Fri<H> {
domain: ArithmeticDomain,
expansion_factor: usize,
num_collinearity_checks: usize,
) -> Self {
assert!(expansion_factor > 1);
assert!(expansion_factor.is_power_of_two());
assert!(domain.length >= expansion_factor);
) -> SetupResult<Self> {
match expansion_factor {
ef if ef <= 1 => return Err(FriSetupError::ExpansionFactorTooSmall),
ef if !ef.is_power_of_two() => return Err(FriSetupError::ExpansionFactorUnsupported),
ef if ef > domain.length => return Err(FriSetupError::ExpansionFactorMismatch),
_ => (),
};

Self {
let fri = Self {
domain,
expansion_factor,
num_collinearity_checks,
_hasher: PhantomData,
}
};
Ok(fri)
}

/// Create a FRI proof and return indices of revealed elements of round 0.
Expand Down Expand Up @@ -708,7 +715,7 @@ mod tests {
let domain_length = max(sampled_domain_length, min_expanded_domain_length);

let fri_domain = ArithmeticDomain::of_length(domain_length).with_offset(offset);
Fri::new(fri_domain, expansion_factor, num_collinearity_checks)
Fri::new(fri_domain, expansion_factor, num_collinearity_checks).unwrap()
}
}

Expand Down Expand Up @@ -807,42 +814,42 @@ mod tests {
let domain = ArithmeticDomain::of_length(2);
let expansion_factor = 2;
let num_collinearity_checks = 1;
Fri::new(domain, expansion_factor, num_collinearity_checks)
Fri::new(domain, expansion_factor, num_collinearity_checks).unwrap()
}

#[test]
#[should_panic]
fn too_small_expansion_factor_is_rejected() {
let domain = ArithmeticDomain::of_length(2);
let expansion_factor = 1;
let num_collinearity_checks = 1;
Fri::<Tip5>::new(domain, expansion_factor, num_collinearity_checks);
let err = Fri::<Tip5>::new(domain, expansion_factor, num_collinearity_checks).unwrap_err();
assert_eq!(FriSetupError::ExpansionFactorTooSmall, err);
}

#[proptest]
#[should_panic]
fn expansion_factor_not_a_power_of_two_is_rejected(
#[strategy(2_usize..)] expansion_factor: usize,
#[strategy(arb())] offset: BFieldElement,
#[strategy(2_usize..(1 << 32))]
#[filter(!#expansion_factor.is_power_of_two())]
expansion_factor: usize,
) {
if expansion_factor.is_power_of_two() {
return Ok(());
}
let domain = ArithmeticDomain::of_length(2 * expansion_factor).with_offset(offset);
let largest_supported_domain_size = 1 << 32;
let domain = ArithmeticDomain::of_length(largest_supported_domain_size);
let num_collinearity_checks = 1;
Fri::<Tip5>::new(domain, expansion_factor, num_collinearity_checks);
let err = Fri::<Tip5>::new(domain, expansion_factor, num_collinearity_checks).unwrap_err();
prop_assert_eq!(FriSetupError::ExpansionFactorUnsupported, err);
}

#[proptest]
#[should_panic]
fn domain_size_smaller_than_expansion_factor_is_rejected(
#[strategy(1_usize..=8)] log_2_expansion_factor: usize,
#[strategy(arb())] offset: BFieldElement,
#[strategy(1_usize..32)] log_2_expansion_factor: usize,
#[strategy(..#log_2_expansion_factor)] log_2_domain_length: usize,
) {
let expansion_factor = (1 << log_2_expansion_factor) as usize;
let domain = ArithmeticDomain::of_length(expansion_factor - 1).with_offset(offset);
let domain_length = (1 << log_2_domain_length) as usize;
let domain = ArithmeticDomain::of_length(domain_length);
let num_collinearity_checks = 1;
Fri::<Tip5>::new(domain, expansion_factor, num_collinearity_checks);
let err = Fri::<Tip5>::new(domain, expansion_factor, num_collinearity_checks).unwrap_err();
prop_assert_eq!(FriSetupError::ExpansionFactorMismatch, err);
}

// todo: add test fuzzing proof_stream
Expand Down
3 changes: 1 addition & 2 deletions triton-vm/src/profiler.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@ use colored::Color;
use colored::ColoredString;
use colored::Colorize;
use criterion::profiler::Profiler;
use twenty_first::shared_math::other::log_2_floor;
use unicode_width::UnicodeWidthStr;

const GET_PROFILE_OUTPUT_AS_YOU_GO_ENV_VAR_NAME: &str = "PROFILE_AS_YOU_GO";
Expand Down Expand Up @@ -594,7 +593,7 @@ impl Display for Report {

if let Some(fri_domain_length) = self.fri_domain_len {
if fri_domain_length != 0 {
let log_2_fri_domain_length = log_2_floor(fri_domain_length as u128);
let log_2_fri_domain_length = fri_domain_length.ilog2();
writeln!(f, "FRI domain length is 2^{log_2_fri_domain_length}")?;
}
}
Expand Down
3 changes: 1 addition & 2 deletions triton-vm/src/proof_stream.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@ use twenty_first::shared_math::b_field_element::BFieldElement;
use twenty_first::shared_math::b_field_element::BFIELD_ONE;
use twenty_first::shared_math::b_field_element::BFIELD_ZERO;
use twenty_first::shared_math::bfield_codec::BFieldCodec;
use twenty_first::shared_math::other::is_power_of_two;
use twenty_first::shared_math::x_field_element::XFieldElement;
use twenty_first::util_types::algebraic_hasher::AlgebraicHasher;

Expand Down Expand Up @@ -108,7 +107,7 @@ where
/// - `upper_bound`: The (non-inclusive) upper bound. Must be a power of two.
/// - `num_indices`: The number of indices to sample
pub fn sample_indices(&mut self, upper_bound: usize, num_indices: usize) -> Vec<usize> {
assert!(is_power_of_two(upper_bound));
assert!(upper_bound.is_power_of_two());
assert!(upper_bound <= BFieldElement::MAX as usize);
H::sample_indices(&mut self.sponge_state, upper_bound as u32, num_indices)
.into_iter()
Expand Down
2 changes: 1 addition & 1 deletion triton-vm/src/shared_tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,7 @@ pub(crate) fn construct_master_base_table(
aet: &AlgebraicExecutionTrace,
) -> MasterBaseTable {
let padded_height = aet.padded_height();
let fri = Stark::derive_fri(parameters, padded_height);
let fri = Stark::derive_fri(parameters, padded_height).unwrap();
let max_degree = Stark::derive_max_degree(padded_height, parameters.num_trace_randomizers);
let quotient_domain = Stark::quotient_domain(fri.domain, max_degree);
MasterBaseTable::new(
Expand Down
17 changes: 11 additions & 6 deletions triton-vm/src/stark.rs
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ use crate::aet::AlgebraicExecutionTrace;
use crate::arithmetic_domain::ArithmeticDomain;
use crate::error::VerificationError;
use crate::error::VerificationError::*;
use crate::fri;
use crate::fri::Fri;
use crate::profiler::prof_itr0;
use crate::profiler::prof_start;
Expand Down Expand Up @@ -151,7 +152,7 @@ impl Stark {
prof_start!(maybe_profiler, "derive additional parameters");
let padded_height = aet.padded_height();
let max_degree = Self::derive_max_degree(padded_height, parameters.num_trace_randomizers);
let fri = Self::derive_fri(parameters, padded_height);
let fri = Self::derive_fri(parameters, padded_height).unwrap();
let quotient_domain = Self::quotient_domain(fri.domain, max_degree);
proof_stream.enqueue(ProofItem::Log2PaddedHeight(padded_height.ilog2()));
prof_stop!(maybe_profiler, "derive additional parameters");
Expand Down Expand Up @@ -626,13 +627,17 @@ impl Stark {
/// In principle, the FRI domain is also influenced by the AIR's degree
/// (see [`AIR_TARGET_DEGREE`]). However, by segmenting the quotient polynomial into
/// [`AIR_TARGET_DEGREE`]-many parts, that influence is mitigated.
pub fn derive_fri(parameters: StarkParameters, padded_height: usize) -> Fri<StarkHasher> {
pub fn derive_fri(
parameters: StarkParameters,
padded_height: usize,
) -> fri::SetupResult<Fri<StarkHasher>> {
let interpolant_degree =
interpolant_degree(padded_height, parameters.num_trace_randomizers);
let interpolant_codeword_length = interpolant_degree as usize + 1;
let fri_domain_length = parameters.fri_expansion_factor * interpolant_codeword_length;
let coset_offset = BFieldElement::generator();
let domain = ArithmeticDomain::of_length(fri_domain_length).with_offset(coset_offset);

Fri::new(
domain,
parameters.fri_expansion_factor,
Expand Down Expand Up @@ -733,7 +738,7 @@ impl Stark {
prof_start!(maybe_profiler, "derive additional parameters");
let log_2_padded_height = proof_stream.dequeue()?.as_log2_padded_height()?;
let padded_height = 1 << log_2_padded_height;
let fri = Self::derive_fri(parameters, padded_height);
let fri = Self::derive_fri(parameters, padded_height)?;
let merkle_tree_height = fri.domain.length.ilog2() as usize;
prof_stop!(maybe_profiler, "derive additional parameters");

Expand Down Expand Up @@ -2220,7 +2225,7 @@ pub(crate) mod tests {
assert!(let Ok(()) = Stark::verify(parameters, &claim, &proof, &mut None));

let_assert!(Ok(padded_height) = proof.padded_height());
let fri = Stark::derive_fri(parameters, padded_height);
let fri = Stark::derive_fri(parameters, padded_height).unwrap();
let report = profiler
.report()
.with_padded_height(padded_height)
Expand All @@ -2244,7 +2249,7 @@ pub(crate) mod tests {
assert!(let Ok(()) = Stark::verify(parameters, &claim, &proof, &mut None));

let_assert!(Ok(padded_height) = proof.padded_height());
let fri = Stark::derive_fri(parameters, padded_height);
let fri = Stark::derive_fri(parameters, padded_height).unwrap();
let report = profiler
.report()
.with_padded_height(padded_height)
Expand Down Expand Up @@ -2286,7 +2291,7 @@ pub(crate) mod tests {
assert!(let Ok(()) = Stark::verify(parameters, &claim, &proof, &mut None));

let_assert!(Ok(padded_height) = proof.padded_height());
let fri = Stark::derive_fri(parameters, padded_height);
let fri = Stark::derive_fri(parameters, padded_height).unwrap();
let report = profiler
.report()
.with_padded_height(padded_height)
Expand Down

0 comments on commit 3fe35ad

Please sign in to comment.