diff --git a/crypto/src/subprotocols/sumcheck.rs b/crypto/src/subprotocols/sumcheck.rs index c936d95fa9..f516e18d91 100644 --- a/crypto/src/subprotocols/sumcheck.rs +++ b/crypto/src/subprotocols/sumcheck.rs @@ -7,6 +7,76 @@ use lambdaworks_math::polynomial::{ dense_multilinear_poly::DenseMultilinearPolynomial, polynomial::Polynomial, }; use lambdaworks_math::traits::ByteConversion; +use sha3::digest::typenum::NonZero; + +fn eval_points_quadratic( + poly_a: &DenseMultilinearPolynomial, + poly_b: &DenseMultilinearPolynomial, + comb_func: &F, +) -> (&FieldElement, &FieldElement) +where + F: Fn(&FieldElement, &FieldElement) -> FieldElement + Sync, +{ + let len = poly_a.len() / 2; + (0..len) + .iter() + .map(|i| { + // eval_0: A(low) + let eval_0 = comb_func(&poly_a[i], &poly_b[i]); + + // eval_2: -A(low) + 2*A(high) + let poly_a_eval_2 = poly_a[len + i] + poly_a[len + i] - poly_a[i]; + let poly_b_eval_2 = poly_b[len + i] + poly_b[len + i] - poly_b[i]; + let eval_2 = comb_func(&poly_a_eval_2, &poly_b_eval_2); + (eval_0, eval_2) + }) + .reduce( + || (&FieldElement::::zero(), &FieldElement::::zero()), + |a, b| (a.0 + b.0, a.1 + b.1), + ) +} + +fn eval_points_cubic( + poly_a: &DenseMultilinearPolynomial, + poly_b: &DenseMultilinearPolynomial, + poly_c: &DenseMultilinearPolynomial, + comb_func: &F, +) -> (&FieldElement, &FieldElement) +where + F: Fn(&FieldElement, &FieldElement, &FieldElement) -> FieldElement + Sync, +{ + let len = poly_a.len() / 2; + (0..len) + .iter() + .map(|i| { + // eval_0: A(low) + let eval_0 = comb_func(&poly_a[i], &poly_b[i], &poly_c[i]); + + // eval_2: -A(low) + 2*A(high) + let poly_a_eval_2 = poly_a[len + i] + poly_a[len + i] - poly_a[i]; + let poly_b_eval_2 = poly_b[len + i] + poly_b[len + i] - poly_b[i]; + let poly_c_eval_2 = poly_c[len + i] + poly_c[len + i] - poly_c[i]; + let eval_2 = comb_func(&poly_a_eval_2, &poly_b_eval_2, poly_c_eval_2); + + // eval 3: bound_func is -2A(low) + 3A(high); computed incrementally with bound_func applied to eval(2) + let poly_a_eval_3 = poly_a_eval_2 + poly_a[len + i] - poly_a[i]; + let poly_b_eval_3 = poly_b_eval_2 + poly_b[len + i] - poly_b[i]; + let poly_c_eval_3 = poly_c_eval_2 + poly_c[len + i] - poly_c[i]; + let eval_3 = comb_func(&poly_a_eval_2, &poly_b_eval_2, poly_c_eval_2); + + (eval_0, eval_2, eval_3) + }) + .reduce( + || { + ( + &FieldElement::::zero(), + &FieldElement::::zero(), + &FieldElement::::zero(), + ) + }, + |a, b| (a.0 + b.0, a.1 + b.1, a.2 + b.2), + ) +} // Proof attesting to sum over the boolean hypercube #[derive(Debug)] @@ -35,19 +105,147 @@ where ::BaseType: Send + Sync, FieldElement: ByteConversion, { - pub fn prove_quadratic() -> SumcheckProof { - todo!(); + //Used for sum_{(a * b)} + pub fn prove_quadratic( + sum: &FieldElement, + poly_a: &mut DenseMultilinearPolynomial, + poly_b: &mut DenseMultilinearPolynomial, + comb_func: F, + transcript: &mut impl Transcript, + ) -> SumcheckProof + where + F: Fn(&FieldElement, &FieldElement) -> FieldElement + Sync, + { + let mut round_uni_polys: Vec>> = + Vec::with_capacity(poly.num_vars()); + let mut challenges = Vec::with_capacity(poly.num_vars()); + let mut prev_round_claim = *sum; + + for _ in poly_a.num_vars() { + let poly = { + let len = poly_a.len() / 2; + let (eval_0, eval_2) = eval_points_quadratic(poly_a[i], poly_b[i], &comb_func); + let evals = vec![eval_0, prev_round_claim - eval_0, eval_2]; + Polynomial::new(&evals) + }; + + // append round's Univariate polynomial to transcript + + // Squeeze Verifier Challenge for next round + let challenge = FieldElement::from_bytes_be(&transcript.challenge()).unwrap(); + challenges.push(challenge.clone()); + + // add univariate polynomial for this round to the proof + round_uni_polys.push(poly); + + // compute next claim + prev_round_claim = poly.evaluate(&challenge); + + // fix next variable of poly + poly_a.fix_variable(&challenge); + poly_b.fix_variable(&challenge); + } + + SumcheckProof { + poly: poly.clone(), + sum: sum.clone(), + round_uni_polys, + } + } + + pub fn prove_quadratic_batched( + sum: &FieldElement, + poly_a: &mut Vec>, + poly_b: &mut Vec>, + /// Optional Powers of rho used for RLC + powers: &[FieldElement], + comb_func: F, + transcript: &mut impl Transcript, + ) -> SumcheckProof + where + F: Fn(&FieldElement, &FieldElement) -> FieldElement + Sync, + { + let mut round_uni_polys: Vec>> = + Vec::with_capacity(poly.num_vars()); + let mut challenges = Vec::with_capacity(poly.num_vars()); + let mut prev_round_claim = *sum; + + for _ in poly_a.num_vars() { + let mut evals: Vec<(G::Scalar, G::Scalar)> = Vec::new(); + + for (poly_a, poly_b) in poly_a.iter().zip(poly_b.iter()) { + let (eval_point_0, eval_point_2) = + Self::compute_eval_points_quadratic(poly_a, poly_b, &comb_func); + evals.push((eval_point_0, eval_point_2)); + } + + // TODO: make optional as we want to perform a batched check outside of this + let evals_combined_0 = (0..evals.len()).map(|i| evals[i].0 * powers[i]).sum(); + let evals_combined_2 = (0..evals.len()).map(|i| evals[i].1 * powers[i]).sum(); + + let evals = vec![ + evals_combined_0, + prev_round_claim - evals_combined_0, + evals_combined_2, + ]; + let poly = Polynomial::new(&evals); + + // append the prover's message to the transcript + + // Squeeze Verifier Challenge for next round + let challenge = FieldElement::from_bytes_be(&transcript.challenge()).unwrap(); + challenges.push(challenge.clone()); + + // bound all tables to the verifier's challenege + for (poly_a, poly_b) in poly_a.iter_mut().zip(poly_b.iter_mut()) { + poly_a.fix_variable(&r_i); + poly_b.fix_variable(&r_i); + } + + prev_round_claim = poly.evaluate(&r_i); + quad_polys.push(poly.compress()); + } + + SumcheckProof { + poly: poly_a.clone(), + sum: sum.clone(), + round_uni_polys, + } } - pub fn prove_quadratic_batched() -> SumcheckProof { - todo!(); + pub fn prove_cubic( + sum: &FieldElement, + poly_a: &mut DenseMultilinearPolynomial, + poly_b: &mut DenseMultilinearPolynomial, + poly_c: &mut DenseMultilinearPolynomial, + comb_func: F, + transcript: &mut impl Transcript, + ) -> SumcheckProof { + todo!() } - pub fn prove_cubic() -> SumcheckProof { + pub fn prove_cubic_batched( + sum: &FieldElement, + poly_a: &mut Vec>, + poly_b: &mut Vec>, + poly_c: &mut DenseMultilinearPolynomial, + comb_func: F, + transcript: &mut impl Transcript, + ) -> SumcheckProof { todo!() } - pub fn prove_cubic_batched() -> SumcheckProof { + // Special instance of sumcheck for a cubic polynomial with an additional additive term: + // this is used in Spartan: (a * ((b * c) - d)) + pub fn prove_cubic_additive_term( + sum: &FieldElement, + poly_a: &mut DenseMultilinearPolynomial, + poly_b: &mut DenseMultilinearPolynomial, + poly_c: &mut DenseMultilinearPolynomial, + poly_d: &mut DenseMultilinearPolynomial, + comb_func: F, + transcript: &mut impl Transcript, + ) -> SumcheckProof { todo!() } @@ -92,13 +290,16 @@ where let challenge = FieldElement::from_bytes_be(&transcript.challenge()).unwrap(); challenges.push(challenge.clone()); + // add univariate polynomial for this round to the proof + round_uni_polys.push(round_uni_poly); + + // grab next claim + prev_round_claim = round_uni_poly.evaluate(&challenge); + // takes mutable reference and fixes poly at challenge // On each round we evaluate over the hypercube to generate the univariate polynomial for this round. Then we fix the challenge for the next variable, // reassign and start the next round with the fixed variable. Each round the poly decreases in size poly.fix_variable(&challenge); - - // add univariate polynomial for this round to the proof - round_uni_polys.push(round_uni_poly); } SumcheckProof {