From b13f4573d70914325a82e590f8d4d21775e72901 Mon Sep 17 00:00:00 2001 From: Adrian Hamelink Date: Fri, 8 Dec 2023 19:00:00 +0100 Subject: [PATCH] Sumcheck update (#163) * Include number of instances in transcript * Clarify Sumcheck and verifier * Fix clippy --- src/spartan/batched.rs | 9 + src/spartan/batched_ppsnark.rs | 322 ++++++++++++++++++--------------- 2 files changed, 189 insertions(+), 142 deletions(-) diff --git a/src/spartan/batched.rs b/src/spartan/batched.rs index a161ea77..84814276 100644 --- a/src/spartan/batched.rs +++ b/src/spartan/batched.rs @@ -160,6 +160,10 @@ where let mut transcript = E::TE::new(b"BatchedRelaxedR1CSSNARK"); transcript.absorb(b"vk", &pk.vk_digest); + if num_instances > 1 { + let num_instances_field = E::Scalar::from(num_instances as u64); + transcript.absorb(b"n", &num_instances_field); + } U.iter().for_each(|u| { transcript.absorb(b"U", u); }); @@ -385,9 +389,14 @@ where } fn verify(&self, vk: &Self::VerifierKey, U: &[RelaxedR1CSInstance]) -> Result<(), NovaError> { + let num_instances = U.len(); let mut transcript = E::TE::new(b"BatchedRelaxedR1CSSNARK"); transcript.absorb(b"vk", &vk.digest()); + if num_instances > 1 { + let num_instances_field = E::Scalar::from(num_instances as u64); + transcript.absorb(b"n", &num_instances_field); + } U.iter().for_each(|u| { transcript.absorb(b"U", u); }); diff --git a/src/spartan/batched_ppsnark.rs b/src/spartan/batched_ppsnark.rs index 9b9ddb8d..9af2b67a 100644 --- a/src/spartan/batched_ppsnark.rs +++ b/src/spartan/batched_ppsnark.rs @@ -37,7 +37,7 @@ use crate::{ use abomonation::Abomonation; use abomonation_derive::Abomonation; use ff::{Field, PrimeField}; -use itertools::Itertools as _; +use itertools::{chain, Itertools as _}; use once_cell::sync::*; use rayon::prelude::*; use serde::{Deserialize, Serialize}; @@ -278,6 +278,8 @@ where let N = pk.S_repr.iter().map(|s| s.N).collect::>(); assert!(N.iter().all(|&Ni| Ni.is_power_of_two())); + let num_instances = U.len(); + // Pad [(Wᵢ,Eᵢ)] to the next power of 2 (not to Ni) let W = zip_with_par_iter!((W, S), |w, s| w.pad(s)).collect::>>(); @@ -285,9 +287,12 @@ where let num_rounds_sc = N.iter().max().unwrap().log_2(); // Initialize transcript with vk || [Uᵢ] - // NOTE: We should prepend with the number of instances let mut transcript = E::TE::new(b"BatchedRelaxedR1CSSNARK"); transcript.absorb(b"vk", &pk.vk_digest); + if num_instances > 1 { + let num_instances_field = E::Scalar::from(num_instances as u64); + transcript.absorb(b"n", &num_instances_field); + } U.iter().for_each(|u| { transcript.absorb(b"U", u); }); @@ -372,7 +377,7 @@ where transcript.absorb(b"e", &evals.as_slice()); }); - // Pad Z, E to Nᵢ + // Pad Zᵢ, E to Nᵢ let polys_Z = polys_Z .into_par_iter() .zip_eq(N.par_iter()) @@ -700,11 +705,7 @@ where evals_mem_preprocessed ), |Az_Bz_Cz_W_E, L_row_col, mem_oracles, mem_preprocessed| { - Az_Bz_Cz_W_E - .iter() - .chain(L_row_col) - .chain(mem_oracles) - .chain(mem_preprocessed) + chain![Az_Bz_Cz_W_E, L_row_col, mem_oracles, mem_preprocessed] .cloned() .collect::>() } @@ -720,12 +721,12 @@ where pk.S_comm ), |Az_Bz_Cz, comms_W_E, L_row_col, mem_oracles, S_comm| { - Az_Bz_Cz - .iter() - .chain(comms_W_E) - .chain(L_row_col) - .chain(mem_oracles) - .chain([ + chain![ + Az_Bz_Cz, + comms_W_E, + L_row_col, + mem_oracles, + [ &S_comm.comm_val_A, &S_comm.comm_val_B, &S_comm.comm_val_C, @@ -733,7 +734,8 @@ where &S_comm.comm_col, &S_comm.comm_ts_row, &S_comm.comm_ts_col, - ]) + ] + ] } ) .flatten() @@ -750,12 +752,12 @@ where pk.S_repr.iter() ), |Az_Bz_Cz, W, E, L_row_col, mem_oracles, S_repr| { - Az_Bz_Cz - .into_iter() - .chain([W, E]) - .chain(L_row_col) - .chain(mem_oracles) - .chain([ + chain![ + Az_Bz_Cz, + [W, E], + L_row_col, + mem_oracles, + [ S_repr.val_A.clone(), S_repr.val_B.clone(), S_repr.val_C.clone(), @@ -763,7 +765,8 @@ where S_repr.col.clone(), S_repr.ts_row.clone(), S_repr.ts_col.clone(), - ]) + ] + ] } ) .flatten() @@ -821,13 +824,25 @@ where } fn verify(&self, vk: &Self::VerifierKey, U: &[RelaxedR1CSInstance]) -> Result<(), NovaError> { + let num_instances = U.len(); + let num_claims_per_instance = 10; + + // number of rounds of sum-check + let num_rounds = vk.S_comm.iter().map(|s| s.N.log_2()).collect::>(); + let num_rounds_max = *num_rounds.iter().max().unwrap(); + let mut transcript = E::TE::new(b"BatchedRelaxedR1CSSNARK"); transcript.absorb(b"vk", &vk.digest()); + if num_instances > 1 { + let num_instances_field = E::Scalar::from(num_instances as u64); + transcript.absorb(b"n", &num_instances_field); + } U.iter().for_each(|u| { transcript.absorb(b"U", u); }); + // Decompress commitments let comms_Az_Bz_Cz = self .comms_Az_Bz_Cz .iter() @@ -861,28 +876,28 @@ where }) .collect::, _>>()?; + // Add commitments [Az, Bz, Cz] to the transcript comms_Az_Bz_Cz .iter() .for_each(|comms| transcript.absorb(b"c", &comms.as_slice())); - // number of rounds of sum-check - let num_rounds_sc = vk.S_comm.iter().map(|s| s.N.log_2()).max().unwrap(); let tau = transcript.squeeze(b"t")?; - let tau_coords = PowPolynomial::new(&tau, num_rounds_sc).coordinates(); + let tau_coords = PowPolynomial::new(&tau, num_rounds_max).coordinates(); // absorb the claimed evaluations into the transcript self.evals_Az_Bz_Cz_at_tau.iter().for_each(|evals| { transcript.absorb(b"e", &evals.as_slice()); }); + // absorb commitments to L_row and L_col in the transcript comms_L_row_col.iter().for_each(|comms| { - // absorb commitments to L_row and L_col in the transcript transcript.absorb(b"e", &comms.as_slice()); }); // Batch at tau for each instance let c = transcript.squeeze(b"c")?; + // Compute eval_Mz = eval_Az_at_tau + c * eval_Bz_at_tau + c^2 * eval_Cz_at_tau let evals_Mz: Vec<_> = zip_with_iter!( (comms_Az_Bz_Cz, self.evals_Az_Bz_Cz_at_tau), |comm_Az_Bz_Cz, evals_Az_Bz_Cz_at_tau| { @@ -904,31 +919,44 @@ where transcript.absorb(b"l", &comms.as_slice()); }); - let num_instances = U.len(); - let num_claims = num_instances * 10; - let rho = transcript.squeeze(b"r")?; + let s = transcript.squeeze(b"r")?; - let coeffs = powers::(&s, num_claims); - - // Scale initial claims by 2^{log(N)-log(Ni)} - let claim = zip_with!( - (coeffs.chunks_exact(10), evals_Mz.iter(), vk.S_comm.iter()), - |coeffs, eval_Mz, s_comm| { - let scaling = 1 << (num_rounds_sc - s_comm.N.log_2()) as u64; - E::Scalar::from(scaling) * (coeffs[7] + coeffs[8]) * eval_Mz - } - ) - .sum::(); + let s_powers = powers::(&s, num_instances * num_claims_per_instance); + + let (claim_sc_final, rand_sc) = { + // Gather all claims into a single vector + let claims = evals_Mz + .iter() + .flat_map(|&eval_Mz| { + let mut claims = vec![E::Scalar::ZERO; num_claims_per_instance]; + claims[7] = eval_Mz; + claims[8] = eval_Mz; + claims.into_iter() + }) + .collect::>(); + + // Number of rounds for each claim + let num_rounds_by_claim = num_rounds + .iter() + .flat_map(|num_rounds_i| vec![*num_rounds_i; num_claims_per_instance].into_iter()) + .collect::>(); + + self + .sc + .verify_batch(&claims, &num_rounds_by_claim, &s_powers, 3, &mut transcript)? + }; - // verify sc - let (claim_sc_final, rand_sc) = self.sc.verify(claim, num_rounds_sc, 3, &mut transcript)?; + // Truncated sumcheck randomness for each instance + let rand_sc_i = num_rounds + .iter() + .map(|num_rounds| rand_sc[(num_rounds_max - num_rounds)..].to_vec()) + .collect::>(); let claim_sc_final_expected = zip_with!( ( vk.num_vars.iter(), - vk.S_comm.iter(), - coeffs.chunks_exact(10), + rand_sc_i.iter(), U.iter(), self.evals_Az_Bz_Cz_W_E.iter().cloned(), self.evals_L_row_col.iter().cloned(), @@ -936,8 +964,7 @@ where self.evals_mem_preprocessed.iter().cloned() ), |num_vars, - s_comm, - coeffs, + rand_sc, U, evals_Az_Bz_Cz_W_E, evals_L_row_col, @@ -949,10 +976,8 @@ where eval_mem_oracle; let [val_A, val_B, val_C, row, col, ts_row, ts_col] = eval_mem_preprocessed; - let num_rounds_i = s_comm.N.log_2(); + let num_rounds_i = rand_sc.len(); let num_vars_log = num_vars.log_2(); - // Only consider the last log(Ni) rounds of Sumcheck - let (_, rand_sc) = rand_sc.split_at(num_rounds_sc - num_rounds_i); let eq_rho = { let rho_coords = PowPolynomial::new(&rho, num_rounds_i).coordinates(); @@ -1031,27 +1056,29 @@ where w + r }; - let claim_mem_final_expected: E::Scalar = coeffs[0] * (t_plus_r_inv_row - w_plus_r_inv_row) - + coeffs[1] * (t_plus_r_inv_col - w_plus_r_inv_col) - + coeffs[2] * (eq_rho * (t_plus_r_inv_row * t_plus_r_row - ts_row)) - + coeffs[3] * (eq_rho * (w_plus_r_inv_row * w_plus_r_row - E::Scalar::ONE)) - + coeffs[4] * (eq_rho * (t_plus_r_inv_col * t_plus_r_col - ts_col)) - + coeffs[5] * (eq_rho * (w_plus_r_inv_col * w_plus_r_col - E::Scalar::ONE)); + let claims_mem = [ + t_plus_r_inv_row - w_plus_r_inv_row, + t_plus_r_inv_col - w_plus_r_inv_col, + eq_rho * (t_plus_r_inv_row * t_plus_r_row - ts_row), + eq_rho * (w_plus_r_inv_row * w_plus_r_row - E::Scalar::ONE), + eq_rho * (t_plus_r_inv_col * t_plus_r_col - ts_col), + eq_rho * (w_plus_r_inv_col * w_plus_r_col - E::Scalar::ONE), + ]; - let claim_outer_final_expected = coeffs[6] * eq_tau * (Az * Bz - U.u * Cz - E) - + coeffs[7] * eq_tau * (Az + c * Bz + c * c * Cz); - let claim_inner_final_expected = - coeffs[8] * L_row * L_col * (val_A + c * val_B + c * c * val_C); + let claims_outer = [ + eq_tau * (Az * Bz - U.u * Cz - E), + eq_tau * (Az + c * Bz + c * c * Cz), + ]; + let claims_inner = [L_row * L_col * (val_A + c * val_B + c * c * val_C)]; - let claims_witness_final_expected = coeffs[9] * eq_masked_tau * W; + let claims_witness = [eq_masked_tau * W]; - claim_mem_final_expected - + claim_outer_final_expected - + claim_inner_final_expected - + claims_witness_final_expected + chain![claims_mem, claims_outer, claims_inner, claims_witness] } ) - .sum::(); + .flatten() + .zip_eq(s_powers) + .fold(E::Scalar::ZERO, |acc, (claim, s)| acc + s * claim); if claim_sc_final_expected != claim_sc_final { return Err(NovaError::InvalidSumcheckProof); @@ -1065,75 +1092,65 @@ where self.evals_mem_preprocessed ), |Az_Bz_Cz_W_E, L_row_col, mem_oracles, mem_preprocessed| { - Az_Bz_Cz_W_E - .iter() - .chain(L_row_col) - .chain(mem_oracles) - .chain(mem_preprocessed) + chain![Az_Bz_Cz_W_E, L_row_col, mem_oracles, mem_preprocessed] .cloned() .collect::>() } ) .collect::>(); - let comms_vec = zip_with_iter!( - ( - comms_Az_Bz_Cz, - U, - comms_L_row_col, - comms_mem_oracles, - vk.S_comm - ), - |Az_Bz_Cz, U, L_row_col, mem_oracles, S_comm| { - Az_Bz_Cz - .iter() - .chain([&U.comm_W, &U.comm_E]) - .chain(L_row_col) - .chain(mem_oracles) - .chain([ - &S_comm.comm_val_A, - &S_comm.comm_val_B, - &S_comm.comm_val_C, - &S_comm.comm_row, - &S_comm.comm_col, - &S_comm.comm_ts_row, - &S_comm.comm_ts_col, - ]) - } - ) - .flatten() - .cloned() - .collect::>(); - + // Add all Sumcheck evaluations to the transcript evals_vec.iter().for_each(|evals| { transcript.absorb(b"e", &evals.as_slice()); // comm_vec is already in the transcript }); - // Rescale all evaluations by L_0(rand_sc_lo) - let evals_vec = evals_vec - .into_iter() - .zip_eq(vk.S_comm.iter()) - .flat_map(|(evals, s_comm)| { - let Ni = s_comm.N; - let (rand_sc_lo, _) = rand_sc.split_at(num_rounds_sc - Ni.log_2()); - let scaling = rand_sc_lo - .iter() - .fold(E::Scalar::ONE, |acc, r| acc * (E::Scalar::ONE - r)); - evals.into_iter().map(move |eval| scaling * eval) - }) + let c = transcript.squeeze(b"c")?; + + // Compute batched polynomial evaluation instance at rand_sc + let u = { + let num_evals = evals_vec[0].len(); + + let evals_vec = evals_vec.into_iter().flatten().collect::>(); + + let num_vars = num_rounds + .iter() + .flat_map(|num_rounds| vec![*num_rounds; num_evals].into_iter()) + .collect::>(); + + let comms_vec = zip_with!( + ( + comms_Az_Bz_Cz.into_iter(), + U.iter(), + comms_L_row_col.into_iter(), + comms_mem_oracles.into_iter(), + vk.S_comm.iter() + ), + |Az_Bz_Cz, U, L_row_col, mem_oracles, S_comm| { + chain![ + Az_Bz_Cz, + [U.comm_W, U.comm_E], + L_row_col, + mem_oracles, + [ + S_comm.comm_val_A, + S_comm.comm_val_B, + S_comm.comm_val_C, + S_comm.comm_row, + S_comm.comm_col, + S_comm.comm_ts_row, + S_comm.comm_ts_col, + ] + ] + } + ) + .flatten() .collect::>(); - let c = transcript.squeeze(b"c")?; - let u: PolyEvalInstance = PolyEvalInstance::batch(&comms_vec, &rand_sc, &evals_vec, &c); + PolyEvalInstance::::batch_diff_size(&comms_vec, &evals_vec, &num_vars, rand_sc, c) + }; + // verify - EE::verify( - &vk.vk_ee, - &mut transcript, - &u.c, - &rand_sc, - &u.e, - &self.eval_arg, - )?; + EE::verify(&vk.vk_ee, &mut transcript, &u.c, &u.x, &u.e, &self.eval_arg)?; Ok(()) } @@ -1226,7 +1243,9 @@ where assert!(inner.iter().all(|inst| inst.degree() == degree)); assert!(witness.iter().all(|inst| inst.degree() == degree)); - // these claims are already added to the transcript, so we do not need to add + // Collect all claims from the instances. If the instances is defined over `m` variables, + // which is less that the total number of rounds `n`, + // the individual claims σ are scaled by 2^{n-m}. let claims = zip_with_iter!( (mem, outer, inner, witness), |mem, outer, inner, witness| { @@ -1240,19 +1259,28 @@ where .flatten() .collect::>(); + // Sample a challenge for the random linear combination of all scaled claims let s = transcript.squeeze(b"r")?; let coeffs = powers::(&s, claims.len()); - // compute the joint claim - let claim = zip_with_iter!((claims, coeffs), |c_1, c_2| *c_1 * c_2).sum(); + // At the start of each round, the running claim is equal to the random linear combination + // of the Sumcheck claims, evaluated over the bound polynomials. + // Initially, it is equal to the random linear combination of the scaled input claims. + let mut running_claim = zip_with_iter!((claims, coeffs), |c_1, c_2| *c_1 * c_2).sum(); - let mut e = claim; + // Keep track of the verifier challenges r, and the univariate polynomials sent by the prover + // in each round let mut r: Vec = Vec::new(); let mut cubic_polys: Vec> = Vec::new(); for i in 0..num_rounds { + // At the start of round i, there input polynomials are defined over at most n-i variables. let remaining_variables = num_rounds - i; + // For each claim j, compute the evaluations of its univariate polynomial S_j(X_i) + // at X = 0, 2, 3. The polynomial is such that S_{j-1}(r_{j-1}) = S_j(0) + S_j(1). + // If the number of variable m of the claim is m < n-i, then the polynomial is + // constants and equal to the initial claim σ_j scaled by 2^{n-m-i-1}. let evals = zip_with_par_iter!( (mem, outer, inner, witness), |mem, outer, inner, witness| { @@ -1282,16 +1310,18 @@ where assert_eq!(evals.len(), claims.len()); + // Random linear combination of the univariate evaluations at X_i = 0, 2, 3 let evals_combined_0 = (0..evals.len()).map(|i| evals[i][0] * coeffs[i]).sum(); let evals_combined_2 = (0..evals.len()).map(|i| evals[i][1] * coeffs[i]).sum(); let evals_combined_3 = (0..evals.len()).map(|i| evals[i][2] * coeffs[i]).sum(); let evals = vec![ evals_combined_0, - e - evals_combined_0, + running_claim - evals_combined_0, evals_combined_2, evals_combined_3, ]; + // Coefficient representation of S(X_i) let poly = UniPoly::from_evals(&evals); // append the prover's message to the transcript @@ -1301,6 +1331,9 @@ where let r_i = transcript.squeeze(b"c")?; r.push(r_i); + // Bind the variable X_i of polynomials across all claims to r_i. + // If the claim is defined over m variables and m < n-i, then + // binding has no effect on the polynomial. zip_with_par_iter_mut_for_each!( (mem, outer, inner, witness), |mem, outer, inner, witness| { @@ -1321,10 +1354,12 @@ where } ); - e = poly.evaluate(&r_i); + running_claim = poly.evaluate(&r_i); cubic_polys.push(poly.compress()); } + // Collect evaluations at (r_{n-m}, ..., r_{n-1}) of polynomials over all claims, + // where m is the initial number of variables the individual claims are defined over. let claims_outer = outer.into_iter().map(|inst| inst.final_claims()).collect(); let claims_inner = inner.into_iter().map(|inst| inst.final_claims()).collect(); let claims_mem = mem.into_iter().map(|inst| inst.final_claims()).collect(); @@ -1343,14 +1378,18 @@ where )) } - // When the size of the current round is larger than the instance's size, - // the evaluations are constant and equal to the initial claims, appropriately - // scaled to the current round number. + /// In round i, computes the evaluations at X_i = 0, 2, 3 of the univariate polynomials S(X_i) + /// for each claim in the instance. + /// Let `n` be the total number of Sumcheck rounds, and assume the instance is defined over `m` variables. + /// We define `remaining_variables` as n-i. + /// If m < n-i, then the polynomials in the instance are not defined over X_i, so the univariate + /// polynomial is constant and equal to 2^{n-m-i-1}*σ, where σ is the initial claim. fn get_evals>(inst: &T, remaining_variables: usize) -> Vec> { - let expected_current_size = 1 << remaining_variables; - if inst.size() != expected_current_size { + let num_instance_variables = inst.size().log_2(); // m + if num_instance_variables < remaining_variables { let deg = inst.degree(); + // The evaluations at X_i = 0, 2, 3 are all equal to the scaled claim Self::scaled_claims(inst, remaining_variables - 1) .into_iter() .map(|scaled_claim| vec![scaled_claim; deg]) @@ -1360,22 +1399,21 @@ where } } - // When the size of the current round size is larger than the instance's size, - // binding the polynomials to r has no effect on the polynomial that we imagine repeats. + /// In round i after receiving challenge r_i, we partially evaluate all polynomials in the instance + /// at X_i = r_i. If the instance is defined over m variables m which is less than n-i, then + /// the polynomials do not depend on X_i, so binding them to r_i has no effect. fn bind>(inst: &mut T, remaining_variables: usize, r: &E::Scalar) { - let expected_current_size = 1 << remaining_variables; - if inst.size() == expected_current_size { + let num_instance_variables = inst.size().log_2(); // m + if remaining_variables <= num_instance_variables { inst.bound(r) } } - // In the current round, if the polynomials in the instance are smaller than the expected size, - // the claims are equal to the initial ones, scaled by expected_size/round_size to - // account for the imagined repetitions of the input polynomials. + /// Given an instance defined over m variables, the sum over n = `remaining_variables` is equal + /// to the initial claim scaled by 2^{n-m}, when m ≤ n. fn scaled_claims>(inst: &T, remaining_variables: usize) -> Vec { - let expected_current_size = 1 << remaining_variables; - let inst_size = inst.size(); - let num_repetitions = expected_current_size / inst_size; + let num_instance_variables = inst.size().log_2(); // m + let num_repetitions = 1 << (remaining_variables - num_instance_variables); let scaling = E::Scalar::from(num_repetitions as u64); inst .initial_claims()