Skip to content

Commit

Permalink
perf: minimize squeezes for combination weights
Browse files Browse the repository at this point in the history
Sample the weights for the random linear combination of the DEEP'd
codewords at the same time as the weights for other random linear
combinations. This is possible without introducing soundness problems
because all the sampling happens within the same stage of the STARK.
  • Loading branch information
jan-ferdinand committed Mar 20, 2024
1 parent 14d08ef commit 50b803c
Showing 1 changed file with 70 additions and 52 deletions.
122 changes: 70 additions & 52 deletions triton-vm/src/stark.rs
Original file line number Diff line number Diff line change
Expand Up @@ -273,13 +273,8 @@ impl Stark {
));
prof_stop!(maybe_profiler, "out-of-domain rows");

// Get weights for remainder of the combination codeword.
prof_start!(maybe_profiler, "Fiat-Shamir", "hash");
let (base_weights, ext_weights, quotient_segment_weights) =
Self::sample_linear_combination_weights(&mut proof_stream);
assert_eq!(NUM_BASE_COLUMNS, base_weights.len());
assert_eq!(NUM_EXT_COLUMNS, ext_weights.len());
assert_eq!(NUM_QUOTIENT_SEGMENTS, quotient_segment_weights.len());
let weights = LinearCombinationWeights::sample(&mut proof_stream);
prof_stop!(maybe_profiler, "Fiat-Shamir");

let fri_domain_is_short_domain = fri.domain.length <= quotient_domain.length;
Expand Down Expand Up @@ -308,17 +303,17 @@ impl Stark {
prof_start!(maybe_profiler, "linear combination");
prof_start!(maybe_profiler, "base", "CC");
let base_codeword =
Self::random_linear_sum_base_field(short_domain_base_codewords, base_weights);
Self::random_linear_sum_base_field(short_domain_base_codewords, weights.base);
prof_stop!(maybe_profiler, "base");
prof_start!(maybe_profiler, "ext", "CC");
let ext_codeword = Self::random_linear_sum(short_domain_ext_codewords, ext_weights);
let ext_codeword = Self::random_linear_sum(short_domain_ext_codewords, weights.ext);
prof_stop!(maybe_profiler, "ext");
let base_and_ext_codeword = base_codeword + ext_codeword;

prof_start!(maybe_profiler, "quotient", "CC");
let quotient_segments_codeword = Self::random_linear_sum(
short_domain_quot_segment_codewords.view(),
quotient_segment_weights,
weights.quot_segments,
);
prof_stop!(maybe_profiler, "quotient");

Expand Down Expand Up @@ -368,10 +363,6 @@ impl Stark {
prof_stop!(maybe_profiler, "DEEP");

prof_start!(maybe_profiler, "combined DEEP polynomial");
prof_start!(maybe_profiler, "Fiat-Shamir", "hash");
let deep_codeword_weights =
Array1::from(proof_stream.sample_scalars(NUM_DEEP_CODEWORD_COMPONENTS));
prof_stop!(maybe_profiler, "Fiat-Shamir");
prof_start!(maybe_profiler, "sum", "CC");
let deep_codeword_components = [
base_and_ext_curr_row_deep_codeword,
Expand All @@ -383,7 +374,7 @@ impl Stark {
deep_codeword_components.concat(),
)
.unwrap();
let weighted_deep_codeword_components = &deep_codeword_components * &deep_codeword_weights;
let weighted_deep_codeword_components = &deep_codeword_components * &weights.deep;
let deep_codeword = weighted_deep_codeword_components.sum_axis(Axis(1));
prof_stop!(maybe_profiler, "sum");
let fri_combination_codeword = if fri_domain_is_short_domain {
Expand Down Expand Up @@ -491,25 +482,6 @@ impl Stark {
random_linear_sum
}

fn sample_linear_combination_weights(
proof_stream: &mut ProofStream,
) -> (
Array1<XFieldElement>,
Array1<XFieldElement>,
Array1<XFieldElement>,
) {
const NUM_WEIGHTS: usize = NUM_BASE_COLUMNS + NUM_EXT_COLUMNS + NUM_QUOTIENT_SEGMENTS;
const BASE_END: usize = NUM_BASE_COLUMNS;
const EXT_END: usize = BASE_END + NUM_EXT_COLUMNS;

let weights = proof_stream.sample_scalars(NUM_WEIGHTS);
let base_weights = weights[..BASE_END].to_vec().into();
let ext_weights = weights[BASE_END..EXT_END].to_vec().into();
let quotient_segment_weights = weights[EXT_END..].to_vec().into();

(base_weights, ext_weights, quotient_segment_weights)
}

fn fri_domain_segment_polynomials(
quotient_segment_polynomials: ArrayView1<Polynomial<XFieldElement>>,
fri_domain: ArithmeticDomain,
Expand Down Expand Up @@ -805,16 +777,8 @@ impl Stark {
prof_stop!(maybe_profiler, "verify quotient's segments");

prof_start!(maybe_profiler, "Fiat-Shamir 2", "hash");
let num_base_and_ext_and_quotient_segment_codeword_weights =
NUM_BASE_COLUMNS + NUM_EXT_COLUMNS + NUM_QUOTIENT_SEGMENTS;
let base_and_ext_and_quotient_segment_codeword_weights =
proof_stream.sample_scalars(num_base_and_ext_and_quotient_segment_codeword_weights);
let (base_and_ext_codeword_weights, quotient_segment_codeword_weights) =
base_and_ext_and_quotient_segment_codeword_weights
.split_at(NUM_BASE_COLUMNS + NUM_EXT_COLUMNS);
let base_and_ext_codeword_weights = Array1::from(base_and_ext_codeword_weights.to_vec());
let quotient_segment_codeword_weights =
Array1::from(quotient_segment_codeword_weights.to_vec());
let weights = LinearCombinationWeights::sample(&mut proof_stream);
let base_and_ext_codeword_weights = weights.base_and_ext();
prof_stop!(maybe_profiler, "Fiat-Shamir 2");

prof_start!(maybe_profiler, "sum out-of-domain values", "CC");
Expand All @@ -830,15 +794,11 @@ impl Stark {
base_and_ext_codeword_weights.view(),
maybe_profiler,
);
let out_of_domain_curr_row_quotient_segment_value =
quotient_segment_codeword_weights.dot(&out_of_domain_curr_row_quot_segments);
let out_of_domain_curr_row_quotient_segment_value = weights
.quot_segments
.dot(&out_of_domain_curr_row_quot_segments);
prof_stop!(maybe_profiler, "sum out-of-domain values");

prof_start!(maybe_profiler, "Fiat-Shamir", "hash");
let deep_codeword_weights =
Array1::from(proof_stream.sample_scalars(NUM_DEEP_CODEWORD_COMPONENTS));
prof_stop!(maybe_profiler, "Fiat-Shamir");

// verify low degree of combination polynomial with FRI
prof_start!(maybe_profiler, "FRI");
let revealed_fri_indices_and_elements = fri.verify(&mut proof_stream, maybe_profiler)?;
Expand Down Expand Up @@ -960,7 +920,8 @@ impl Stark {
base_and_ext_codeword_weights.view(),
maybe_profiler,
);
let quotient_segments_curr_row_element = quotient_segment_codeword_weights
let quotient_segments_curr_row_element = weights
.quot_segments
.dot(&Array1::from(quotient_segments_elements.to_vec()));
prof_stop!(maybe_profiler, "base & ext elements");

Expand Down Expand Up @@ -991,7 +952,7 @@ impl Stark {
base_and_ext_next_row_deep_value,
quot_curr_row_deep_value,
]);
if fri_value != deep_codeword_weights.dot(&deep_value_components) {
if fri_value != weights.deep.dot(&deep_value_components) {
return Err(VerificationError::CombinationCodewordMismatch);
};
prof_stop!(maybe_profiler, "combination codeword equality");
Expand Down Expand Up @@ -1053,6 +1014,47 @@ impl<'a> Arbitrary<'a> for Stark {
}
}

/// Fiat-Shamir-sampled challenges to compress a row into a single
/// [extension field element][XFieldElement].
struct LinearCombinationWeights {
/// of length [`NUM_BASE_COLUMNS`]
base: Array1<XFieldElement>,

/// of length [`NUM_EXT_COLUMNS`]
ext: Array1<XFieldElement>,

/// of length [`NUM_QUOTIENT_SEGMENTS`]
quot_segments: Array1<XFieldElement>,

/// of length [`NUM_DEEP_CODEWORD_COMPONENTS`]
deep: Array1<XFieldElement>,
}

impl LinearCombinationWeights {
const NUM: usize =
NUM_BASE_COLUMNS + NUM_EXT_COLUMNS + NUM_QUOTIENT_SEGMENTS + NUM_DEEP_CODEWORD_COMPONENTS;

fn sample(proof_stream: &mut ProofStream) -> Self {
const BASE_END: usize = NUM_BASE_COLUMNS;
const EXT_END: usize = BASE_END + NUM_EXT_COLUMNS;
const QUOT_END: usize = EXT_END + NUM_QUOTIENT_SEGMENTS;

let weights = proof_stream.sample_scalars(Self::NUM);

Self {
base: weights[..BASE_END].to_vec().into(),
ext: weights[BASE_END..EXT_END].to_vec().into(),
quot_segments: weights[EXT_END..QUOT_END].to_vec().into(),
deep: weights[QUOT_END..].to_vec().into(),
}
}

fn base_and_ext(&self) -> Array1<XFieldElement> {
let base = self.base.clone().into_iter();
base.chain(self.ext.clone()).collect()
}
}

#[cfg(test)]
pub(crate) mod tests {
use std::collections::HashMap;
Expand Down Expand Up @@ -2660,4 +2662,20 @@ pub(crate) mod tests {
}
}
}

#[proptest]
fn linear_combination_weights_samples_correct_number_of_elements(
#[strategy(arb())] mut proof_stream: ProofStream,
) {
let weights = LinearCombinationWeights::sample(&mut proof_stream);

prop_assert_eq!(NUM_BASE_COLUMNS, weights.base.len());
prop_assert_eq!(NUM_EXT_COLUMNS, weights.ext.len());
prop_assert_eq!(NUM_QUOTIENT_SEGMENTS, weights.quot_segments.len());
prop_assert_eq!(NUM_DEEP_CODEWORD_COMPONENTS, weights.deep.len());
prop_assert_eq!(
NUM_BASE_COLUMNS + NUM_EXT_COLUMNS,
weights.base_and_ext().len()
);
}
}

0 comments on commit 50b803c

Please sign in to comment.