Skip to content

Commit

Permalink
refactor: setup for sumcheck memory usage reduction (#418)
Browse files Browse the repository at this point in the history
# Rationale for this change

The sumcheck implementation uses memory inefficiently. This is primarily
due to unneeded cloning during the `ProverState` creation. This PR does
some setup work for a followup PR that will solve this.

# What changes are included in this PR?

There are several small changes in this PR that are relatively
disconnected.
A main overarching goal is the creation of the
`make_sumcheck_prover_state` function.
In the next PR, this will be refactored to directly create the
`ProverState`.

# Are these changes tested?
Yes
  • Loading branch information
JayWhite2357 authored Dec 10, 2024
2 parents 02be8a6 + 97b8c03 commit 2ecce43
Show file tree
Hide file tree
Showing 21 changed files with 266 additions and 163 deletions.
2 changes: 1 addition & 1 deletion crates/proof-of-sql/src/proof_primitive/sumcheck/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ mod proof_test;
pub use proof::SumcheckProof;

mod prover_state;
use prover_state::ProverState;
pub(crate) use prover_state::ProverState;

mod prover_round;
use prover_round::prove_round;
Expand Down
16 changes: 6 additions & 10 deletions crates/proof-of-sql/src/proof_primitive/sumcheck/proof.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use crate::{
base::{
polynomial::{interpolate_evaluations_to_reverse_coefficients, CompositePolynomial},
polynomial::interpolate_evaluations_to_reverse_coefficients,
proof::{ProofError, Transcript},
scalar::Scalar,
},
Expand Down Expand Up @@ -28,19 +28,15 @@ impl<S: Scalar> SumcheckProof<S> {
pub fn create(
transcript: &mut impl Transcript,
evaluation_point: &mut [S],
polynomial: &CompositePolynomial<S>,
mut state: ProverState<S>,
) -> Self {
assert_eq!(evaluation_point.len(), polynomial.num_variables);
transcript.extend_as_be([
polynomial.max_multiplicands as u64,
polynomial.num_variables as u64,
]);
assert_eq!(evaluation_point.len(), state.num_vars);
transcript.extend_as_be([state.max_multiplicands as u64, state.num_vars as u64]);
// This challenge is in order to keep transcript messages grouped. (This simplifies the Solidity implementation.)
transcript.scalar_challenge_as_be::<S>();
let mut r = None;
let mut state = ProverState::create(polynomial);
let mut coefficients = Vec::with_capacity(polynomial.num_variables);
for scalar in evaluation_point.iter_mut().take(polynomial.num_variables) {
let mut coefficients = Vec::with_capacity(state.num_vars);
for scalar in evaluation_point.iter_mut().take(state.num_vars) {
let round_evaluations = prove_round(&mut state, &r);
let round_coefficients =
interpolate_evaluations_to_reverse_coefficients(&round_evaluations);
Expand Down
30 changes: 20 additions & 10 deletions crates/proof-of-sql/src/proof_primitive/sumcheck/proof_test.rs
Original file line number Diff line number Diff line change
@@ -1,15 +1,17 @@
use super::test_cases::sumcheck_test_cases;
use crate::base::{
polynomial::CompositePolynomial,
proof::Transcript as _,
scalar::{test_scalar::TestScalar, Curve25519Scalar, MontScalar, Scalar},
};
/*
* Adopted from arkworks
*
* See third_party/license/arkworks.LICENSE
*/
use crate::proof_primitive::sumcheck::proof::*;
use super::test_cases::sumcheck_test_cases;
use crate::{
base::{
polynomial::CompositePolynomial,
proof::Transcript as _,
scalar::{test_scalar::TestScalar, Curve25519Scalar, MontScalar, Scalar},
},
proof_primitive::sumcheck::{ProverState, SumcheckProof},
};
use alloc::rc::Rc;
use ark_std::UniformRand;
use merlin::Transcript;
Expand All @@ -29,7 +31,11 @@ fn test_create_verify_proof() {
let fa = Rc::new(a_vec.to_vec());
poly.add_product([fa], Curve25519Scalar::from(1u64));
let mut transcript = Transcript::new(b"sumchecktest");
let mut proof = SumcheckProof::create(&mut transcript, &mut evaluation_point, &poly);
let mut proof = SumcheckProof::create(
&mut transcript,
&mut evaluation_point,
ProverState::create(&poly),
);

// verify proof
let mut transcript = Transcript::new(b"sumchecktest");
Expand Down Expand Up @@ -133,7 +139,11 @@ fn test_polynomial(nv: usize, num_multiplicands_range: (usize, usize), num_produ
// create a proof
let mut transcript = Transcript::new(b"sumchecktest");
let mut evaluation_point = vec![Curve25519Scalar::zero(); poly.num_variables];
let proof = SumcheckProof::create(&mut transcript, &mut evaluation_point, &poly);
let proof = SumcheckProof::create(
&mut transcript,
&mut evaluation_point,
ProverState::create(&poly),
);

// verify proof
let mut transcript = Transcript::new(b"sumchecktest");
Expand Down Expand Up @@ -180,7 +190,7 @@ fn we_can_verify_many_random_test_cases() {
let proof = SumcheckProof::create(
&mut transcript,
&mut evaluation_point,
&test_case.polynomial,
ProverState::create(&test_case.polynomial),
);

let mut transcript = Transcript::new(b"sumchecktest");
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,20 +19,13 @@ pub fn prove_round<S: Scalar>(prover_state: &mut ProverState<S>, r_maybe: &Optio
"first round should be prover first."
);

prover_state.randomness.push(*r);

// fix argument
let r_as_field = prover_state.randomness[prover_state.round - 1];
if_rayon!(
prover_state.flattened_ml_extensions.par_iter_mut(),
prover_state.flattened_ml_extensions.iter_mut()
)
.for_each(|multiplicand| {
in_place_fix_variable(
multiplicand,
r_as_field,
prover_state.num_vars - prover_state.round,
);
in_place_fix_variable(multiplicand, *r, prover_state.num_vars - prover_state.round);
});
} else if prover_state.round > 0 {
panic!("verifier message is empty");
Expand Down
29 changes: 20 additions & 9 deletions crates/proof-of-sql/src/proof_primitive/sumcheck/prover_state.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,6 @@ use crate::base::scalar::Scalar;
use alloc::vec::Vec;

pub struct ProverState<S: Scalar> {
/// sampled randomness given by the verifier
pub randomness: Vec<S>,
/// Stores the list of products that is meant to be added together. Each multiplicand is represented by
/// the index in `flattened_ml_extensions`
pub list_of_products: Vec<(S, Vec<usize>)>,
Expand All @@ -21,6 +19,21 @@ pub struct ProverState<S: Scalar> {
}

impl<S: Scalar> ProverState<S> {
pub fn new(
list_of_products: Vec<(S, Vec<usize>)>,
flattened_ml_extensions: Vec<Vec<S>>,
num_vars: usize,
max_multiplicands: usize,
) -> Self {
ProverState {
list_of_products,
flattened_ml_extensions,
num_vars,
max_multiplicands,
round: 0,
}
}

#[tracing::instrument(name = "ProverState::create", level = "debug", skip_all)]
pub fn create(polynomial: &CompositePolynomial<S>) -> Self {
assert!(
Expand All @@ -35,13 +48,11 @@ impl<S: Scalar> ProverState<S> {
.map(|x| x.as_ref().clone())
.collect();

ProverState {
randomness: Vec::with_capacity(polynomial.num_variables),
list_of_products: polynomial.products.clone(),
ProverState::new(
polynomial.products.clone(),
flattened_ml_extensions,
num_vars: polynomial.num_variables,
max_multiplicands: polynomial.max_multiplicands,
round: 0,
}
polynomial.num_variables,
polynomial.max_multiplicands,
)
}
}
34 changes: 6 additions & 28 deletions crates/proof-of-sql/src/sql/proof/final_round_builder.rs
Original file line number Diff line number Diff line change
@@ -1,11 +1,8 @@
use super::{
CompositePolynomialBuilder, SumcheckRandomScalars, SumcheckSubpolynomial,
SumcheckSubpolynomialTerm, SumcheckSubpolynomialType,
};
use super::{SumcheckSubpolynomial, SumcheckSubpolynomialTerm, SumcheckSubpolynomialType};
use crate::base::{
bit::BitDistribution,
commitment::{Commitment, CommittableColumn, VecCommitmentExt},
polynomial::{CompositePolynomial, MultilinearExtension},
polynomial::MultilinearExtension,
scalar::Scalar,
};
use alloc::{boxed::Box, vec::Vec};
Expand Down Expand Up @@ -105,29 +102,10 @@ impl<'a, S: Scalar> FinalRoundBuilder<'a, S> {
)
}

/// Given random multipliers, construct an aggregatated sumcheck polynomial from all
/// the individual subpolynomials.
#[tracing::instrument(
name = "FinalRoundBuilder::make_sumcheck_polynomial",
level = "debug",
skip_all
)]
pub fn make_sumcheck_polynomial(
&self,
scalars: &SumcheckRandomScalars<S>,
) -> CompositePolynomial<S> {
let mut builder = CompositePolynomialBuilder::new(
self.num_sumcheck_variables,
&scalars.compute_entrywise_multipliers(),
);
for (multiplier, subpoly) in scalars
.subpolynomial_multipliers
.iter()
.zip(self.sumcheck_subpolynomials.iter())
{
subpoly.compose(&mut builder, *multiplier);
}
builder.make_composite_polynomial()
/// Produce a subpolynomial to be aggegated into sumcheck where the sum across binary
/// values of the variables is zero.
pub fn sumcheck_subpolynomials(&self) -> &[SumcheckSubpolynomial<'a, S>] {
&self.sumcheck_subpolynomials
}

/// Given the evaluation vector, compute evaluations of all the MLEs used in sumcheck except
Expand Down
73 changes: 5 additions & 68 deletions crates/proof-of-sql/src/sql/proof/final_round_builder_test.rs
Original file line number Diff line number Diff line change
@@ -1,12 +1,8 @@
use super::{FinalRoundBuilder, ProvableQueryResult, SumcheckRandomScalars};
use crate::{
base::{
commitment::{Commitment, CommittableColumn},
database::{Column, ColumnField, ColumnType},
polynomial::{compute_evaluation_vector, CompositePolynomial, MultilinearExtension},
scalar::Curve25519Scalar,
},
sql::proof::SumcheckSubpolynomialType,
use super::{FinalRoundBuilder, ProvableQueryResult};
use crate::base::{
commitment::{Commitment, CommittableColumn},
database::{Column, ColumnField, ColumnType},
scalar::Curve25519Scalar,
};
use alloc::sync::Arc;
#[cfg(feature = "arrow")]
Expand All @@ -16,7 +12,6 @@ use arrow::{
record_batch::RecordBatch,
};
use curve25519_dalek::RistrettoPoint;
use num_traits::{One, Zero};

#[test]
fn we_can_compute_commitments_for_intermediate_mles_using_a_zero_offset() {
Expand Down Expand Up @@ -75,64 +70,6 @@ fn we_can_evaluate_pcs_proof_mles() {
assert_eq!(evals, expected_evals);
}

#[test]
fn we_can_form_an_aggregated_sumcheck_polynomial() {
let mle1 = [1, 2, -1];
let mle2 = [10i64, 20, 100, 30];
let mle3 = [2000i64, 3000, 5000, 7000];
let mut builder = FinalRoundBuilder::new(2, Vec::new());
builder.produce_anchored_mle(&mle1);
builder.produce_intermediate_mle(&mle2[..]);
builder.produce_intermediate_mle(&mle3[..]);

builder.produce_sumcheck_subpolynomial(
SumcheckSubpolynomialType::Identity,
vec![(-Curve25519Scalar::one(), vec![Box::new(&mle1)])],
);
builder.produce_sumcheck_subpolynomial(
SumcheckSubpolynomialType::Identity,
vec![(-Curve25519Scalar::from(10u64), vec![Box::new(&mle2)])],
);
builder.produce_sumcheck_subpolynomial(
SumcheckSubpolynomialType::ZeroSum,
vec![(Curve25519Scalar::from(9876u64), vec![Box::new(&mle3)])],
);

let multipliers = [
Curve25519Scalar::from(5u64),
Curve25519Scalar::from(2u64),
Curve25519Scalar::from(50u64),
Curve25519Scalar::from(25u64),
Curve25519Scalar::from(11u64),
];

let mut evaluation_vector = vec![Zero::zero(); 4];
compute_evaluation_vector(&mut evaluation_vector, &multipliers[..2]);

let poly = builder.make_sumcheck_polynomial(&SumcheckRandomScalars::new(&multipliers, 4, 2));
let mut expected_poly = CompositePolynomial::new(2);
let fr = (&evaluation_vector).to_sumcheck_term(2);
expected_poly.add_product(
[fr.clone(), (&mle1).to_sumcheck_term(2)],
-Curve25519Scalar::from(1u64) * multipliers[2],
);
expected_poly.add_product(
[fr, (&mle2).to_sumcheck_term(2)],
-Curve25519Scalar::from(10u64) * multipliers[3],
);
expected_poly.add_product(
[(&mle3).to_sumcheck_term(2)],
Curve25519Scalar::from(9876u64) * multipliers[4],
);
let random_point = [
Curve25519Scalar::from(123u64),
Curve25519Scalar::from(101_112_u64),
];
let eval = poly.evaluate(&random_point);
let expected_eval = expected_poly.evaluate(&random_point);
assert_eq!(eval, expected_eval);
}

#[cfg(feature = "arrow")]
#[test]
fn we_can_form_the_provable_query_result() {
Expand Down
Loading

0 comments on commit 2ecce43

Please sign in to comment.