Skip to content

Commit

Permalink
refactor: simplify into VerificationBuilder::consume_mle_evaluation(s)
Browse files Browse the repository at this point in the history
`VerificationBuilder::consume_anchored_mle` and `VerificationBuilder::consume_intermediate_mle` are identical.
They are merged and remamed to `VerificationBuilder::consume_mle_evaluation`
Additionally, `VerificationBuilder::consume_mle_evaluations` is added for ergonomics.
  • Loading branch information
JayWhite2357 committed Dec 11, 2024
1 parent a01db55 commit 8dcc3cd
Show file tree
Hide file tree
Showing 14 changed files with 42 additions and 56 deletions.
2 changes: 1 addition & 1 deletion crates/proof-of-sql/src/sql/proof/query_proof.rs
Original file line number Diff line number Diff line change
Expand Up @@ -340,7 +340,7 @@ impl<CP: CommitmentEvaluationProof> QueryProof<CP> {
.collect();
let evaluation_accessor: IndexMap<_, _> = column_references
.into_iter()
.map(|col| (col, builder.consume_anchored_mle()))
.map(|col| (col, builder.consume_mle_evaluation()))
.collect();

let verifier_evaluations = expr.verifier_evaluate(
Expand Down
10 changes: 5 additions & 5 deletions crates/proof-of-sql/src/sql/proof/query_proof_test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ impl ProofPlan for TrivialTestProofPlan {
_result: Option<&OwnedTable<S>>,
_one_eval_map: &IndexMap<TableRef, S>,
) -> Result<TableEvaluation<S>, ProofError> {
assert_eq!(builder.consume_intermediate_mle(), S::ZERO);
assert_eq!(builder.consume_mle_evaluation(), S::ZERO);
builder.produce_sumcheck_subpolynomial_evaluation(
SumcheckSubpolynomialType::ZeroSum,
S::from(self.evaluation),
Expand Down Expand Up @@ -278,7 +278,7 @@ impl ProofPlan for SquareTestProofPlan {
ColumnType::BigInt,
))
.unwrap();
let res_eval = builder.consume_intermediate_mle();
let res_eval = builder.consume_mle_evaluation();
builder.produce_sumcheck_subpolynomial_evaluation(
SumcheckSubpolynomialType::Identity,
res_eval - x_eval * x_eval,
Expand Down Expand Up @@ -474,8 +474,8 @@ impl ProofPlan for DoubleSquareTestProofPlan {
ColumnType::BigInt,
))
.unwrap();
let z_eval = builder.consume_intermediate_mle();
let res_eval = builder.consume_intermediate_mle();
let z_eval = builder.consume_mle_evaluation();
let res_eval = builder.consume_mle_evaluation();

// poly1
builder.produce_sumcheck_subpolynomial_evaluation(
Expand Down Expand Up @@ -681,7 +681,7 @@ impl ProofPlan for ChallengeTestProofPlan {
ColumnType::BigInt,
))
.unwrap();
let res_eval = builder.consume_intermediate_mle();
let res_eval = builder.consume_mle_evaluation();
builder.produce_sumcheck_subpolynomial_evaluation(
SumcheckSubpolynomialType::Identity,
alpha * res_eval - alpha * x_eval * x_eval,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -70,11 +70,10 @@ impl ProofPlan for EmptyTestQueryExpr {
_result: Option<&OwnedTable<S>>,
_one_eval_map: &IndexMap<TableRef, S>,
) -> Result<TableEvaluation<S>, ProofError> {
let _ = std::iter::repeat_with(|| {
assert_eq!(builder.consume_intermediate_mle(), S::ZERO);
})
.take(self.columns)
.collect::<Vec<_>>();
assert_eq!(
builder.consume_mle_evaluations(self.columns),
vec![S::ZERO; self.columns]
);
Ok(TableEvaluation::new(
vec![S::ZERO; self.columns],
builder.consume_one_evaluation(),
Expand Down
17 changes: 9 additions & 8 deletions crates/proof-of-sql/src/sql/proof/verification_builder.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
use super::{SumcheckMleEvaluations, SumcheckSubpolynomialType};
use crate::base::{bit::BitDistribution, scalar::Scalar};
use alloc::vec::Vec;
use core::iter;

/// Track components used to verify a query's proof
pub struct VerificationBuilder<'a, S: Scalar> {
Expand Down Expand Up @@ -72,12 +73,19 @@ impl<'a, S: Scalar> VerificationBuilder<'a, S> {
/// Consume the evaluation of an anchored MLE used in sumcheck and provide the commitment of the MLE
///
/// An anchored MLE is an MLE where the verifier has access to the commitment
pub fn consume_anchored_mle(&mut self) -> S {
pub fn consume_mle_evaluation(&mut self) -> S {
let index = self.consumed_pcs_proof_mles;
self.consumed_pcs_proof_mles += 1;
self.mle_evaluations.pcs_proof_evaluations[index]
}

/// Consume multiple MLE evaluations
pub fn consume_mle_evaluations(&mut self, count: usize) -> Vec<S> {
iter::repeat_with(|| self.consume_mle_evaluation())
.take(count)
.collect()
}

/// Consume a bit distribution that describes which bits are constant
/// and which bits varying in a column of data
pub fn consume_bit_distribution(&mut self) -> BitDistribution {
Expand All @@ -86,13 +94,6 @@ impl<'a, S: Scalar> VerificationBuilder<'a, S> {
res
}

/// Consume the evaluation of an intermediate MLE used in sumcheck
///
/// An interemdiate MLE is one where the verifier doesn't have access to its commitment
pub fn consume_intermediate_mle(&mut self) -> S {
self.consume_anchored_mle()
}

/// Produce the evaluation of a subpolynomial used in sumcheck
pub fn produce_sumcheck_subpolynomial_evaluation(
&mut self,
Expand Down
2 changes: 1 addition & 1 deletion crates/proof-of-sql/src/sql/proof_exprs/and_expr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ impl ProofExpr for AndExpr {
let rhs = self.rhs.verifier_evaluate(builder, accessor, one_eval)?;

// lhs_and_rhs
let lhs_and_rhs = builder.consume_intermediate_mle();
let lhs_and_rhs = builder.consume_mle_evaluation();

// subpolynomial: lhs_and_rhs - lhs * rhs
builder.produce_sumcheck_subpolynomial_evaluation(
Expand Down
4 changes: 2 additions & 2 deletions crates/proof-of-sql/src/sql/proof_exprs/equals_expr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -154,8 +154,8 @@ pub fn verifier_evaluate_equals_zero<S: Scalar>(
one_eval: S,
) -> S {
// consume mle evaluations
let lhs_pseudo_inv_eval = builder.consume_intermediate_mle();
let selection_not_eval = builder.consume_intermediate_mle();
let lhs_pseudo_inv_eval = builder.consume_mle_evaluation();
let selection_not_eval = builder.consume_mle_evaluation();
let selection_eval = one_eval - selection_not_eval;

// subpolynomial: selection * lhs
Expand Down
2 changes: 1 addition & 1 deletion crates/proof-of-sql/src/sql/proof_exprs/multiply_expr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ impl ProofExpr for MultiplyExpr {
let rhs = self.rhs.verifier_evaluate(builder, accessor, one_eval)?;

// lhs_times_rhs
let lhs_times_rhs = builder.consume_intermediate_mle();
let lhs_times_rhs = builder.consume_mle_evaluation();

// subpolynomial: lhs_times_rhs - lhs * rhs
builder.produce_sumcheck_subpolynomial_evaluation(
Expand Down
2 changes: 1 addition & 1 deletion crates/proof-of-sql/src/sql/proof_exprs/or_expr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,7 @@ pub fn verifier_evaluate_or<S: Scalar>(
rhs: &S,
) -> S {
// lhs_and_rhs
let lhs_and_rhs = builder.consume_intermediate_mle();
let lhs_and_rhs = builder.consume_mle_evaluation();

// subpolynomial: lhs_and_rhs - lhs * rhs
builder.produce_sumcheck_subpolynomial_evaluation(
Expand Down
2 changes: 1 addition & 1 deletion crates/proof-of-sql/src/sql/proof_exprs/sign_expr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,7 @@ pub fn verifier_evaluate_sign<S: Scalar>(
// bits of the expression
let mut bit_evals = Vec::with_capacity(num_varying_bits);
for _ in 0..num_varying_bits {
let eval = builder.consume_intermediate_mle();
let eval = builder.consume_mle_evaluation();
bit_evals.push(eval);
}

Expand Down
10 changes: 4 additions & 6 deletions crates/proof-of-sql/src/sql/proof_plans/filter_exec.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ use crate::{
};
use alloc::{boxed::Box, vec, vec::Vec};
use bumpalo::Bump;
use core::{iter::repeat_with, marker::PhantomData};
use core::marker::PhantomData;
use num_traits::{One, Zero};
use serde::{Deserialize, Serialize};

Expand Down Expand Up @@ -99,9 +99,7 @@ where
.collect::<Result<Vec<_>, _>>()?,
);
// 3. filtered_columns
let filtered_columns_evals: Vec<_> = repeat_with(|| builder.consume_intermediate_mle())
.take(self.aliased_results.len())
.collect();
let filtered_columns_evals = builder.consume_mle_evaluations(self.aliased_results.len());
assert!(filtered_columns_evals.len() == self.aliased_results.len());

let alpha = builder.consume_post_result_challenge();
Expand Down Expand Up @@ -262,8 +260,8 @@ pub(super) fn verify_filter<S: Scalar>(
) -> Result<(), ProofError> {
let c_fold_eval = alpha * fold_vals(beta, c_evals);
let d_fold_eval = alpha * fold_vals(beta, d_evals);
let c_star_eval = builder.consume_intermediate_mle();
let d_star_eval = builder.consume_intermediate_mle();
let c_star_eval = builder.consume_mle_evaluation();
let d_star_eval = builder.consume_mle_evaluation();

// sum c_star * s - d_star = 0
builder.produce_sumcheck_subpolynomial_evaluation(
Expand Down
18 changes: 7 additions & 11 deletions crates/proof-of-sql/src/sql/proof_plans/group_by_exec.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ use crate::{
};
use alloc::{boxed::Box, vec, vec::Vec};
use bumpalo::Bump;
use core::{iter, iter::repeat_with};
use core::iter;
use num_traits::{One, Zero};
use proof_of_sql_parser::Identifier;
use serde::{Deserialize, Serialize};
Expand Down Expand Up @@ -117,14 +117,10 @@ impl ProofPlan for GroupByExec {
})
.collect::<Result<Vec<_>, _>>()?;
// 3. filtered_columns
let group_by_result_columns_evals: Vec<_> =
repeat_with(|| builder.consume_intermediate_mle())
.take(self.group_by_exprs.len())
.collect();
let sum_result_columns_evals: Vec<_> = repeat_with(|| builder.consume_intermediate_mle())
.take(self.sum_expr.len())
.collect();
let count_column_eval = builder.consume_intermediate_mle();
let group_by_result_columns_evals =
builder.consume_mle_evaluations(self.group_by_exprs.len());
let sum_result_columns_evals = builder.consume_mle_evaluations(self.sum_expr.len());
let count_column_eval = builder.consume_mle_evaluation();

let alpha = builder.consume_post_result_challenge();
let beta = builder.consume_post_result_challenge();
Expand Down Expand Up @@ -358,8 +354,8 @@ fn verify_group_by<S: Scalar>(
// sum_out_fold = count_out + sum beta^(j+1) * sum_out[j]
let sum_out_fold_eval = count_out_eval + beta * fold_vals(beta, &sum_out_evals);

let g_in_star_eval = builder.consume_intermediate_mle();
let g_out_star_eval = builder.consume_intermediate_mle();
let g_in_star_eval = builder.consume_mle_evaluation();
let g_out_star_eval = builder.consume_mle_evaluation();

// sum g_in_star * sel_in * sum_in_fold - g_out_star * sum_out_fold = 0
builder.produce_sumcheck_subpolynomial_evaluation(
Expand Down
5 changes: 1 addition & 4 deletions crates/proof-of-sql/src/sql/proof_plans/projection_exec.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@ use crate::{
};
use alloc::vec::Vec;
use bumpalo::Bump;
use core::iter::repeat_with;
use serde::{Deserialize, Serialize};

/// Provable expressions for queries of the form
Expand Down Expand Up @@ -69,9 +68,7 @@ impl ProofPlan for ProjectionExec {
.verifier_evaluate(builder, accessor, one_eval)
})
.collect::<Result<Vec<_>, _>>()?;
let column_evals = repeat_with(|| builder.consume_intermediate_mle())
.take(self.aliased_results.len())
.collect::<Vec<_>>();
let column_evals = builder.consume_mle_evaluations(self.aliased_results.len());
Ok(TableEvaluation::new(column_evals, one_eval))
}

Expand Down
6 changes: 2 additions & 4 deletions crates/proof-of-sql/src/sql/proof_plans/slice_exec.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ use crate::{
};
use alloc::{boxed::Box, vec::Vec};
use bumpalo::Bump;
use core::iter::{repeat, repeat_with};
use core::iter::repeat;
use itertools::repeat_n;
use serde::{Deserialize, Serialize};

Expand Down Expand Up @@ -85,9 +85,7 @@ where
let max_one_eval = builder.consume_one_evaluation();
let selection_eval = max_one_eval - offset_one_eval;
// 3. filtered_columns
let filtered_columns_evals: Vec<_> = repeat_with(|| builder.consume_intermediate_mle())
.take(columns_evals.len())
.collect();
let filtered_columns_evals = builder.consume_mle_evaluations(columns_evals.len());
let alpha = builder.consume_post_result_challenge();
let beta = builder.consume_post_result_challenge();

Expand Down
9 changes: 3 additions & 6 deletions crates/proof-of-sql/src/sql/proof_plans/union_exec.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@ use crate::{
};
use alloc::{boxed::Box, vec, vec::Vec};
use bumpalo::Bump;
use core::iter::repeat_with;
use num_traits::{One, Zero};
use serde::{Deserialize, Serialize};

Expand Down Expand Up @@ -78,9 +77,7 @@ where
.iter()
.map(TableEvaluation::column_evals)
.collect::<Vec<_>>();
let output_column_evals: Vec<_> = repeat_with(|| builder.consume_intermediate_mle())
.take(self.schema.len())
.collect();
let output_column_evals = builder.consume_mle_evaluations(self.schema.len());
let input_one_evals = input_table_evals
.iter()
.map(TableEvaluation::one_eval)
Expand Down Expand Up @@ -199,7 +196,7 @@ fn verify_union<S: Scalar>(
.zip(input_one_evals)
.map(|(&input_eval, &input_one_eval)| {
let c_fold_eval = gamma * fold_vals(beta, input_eval);
let c_star_eval = builder.consume_intermediate_mle();
let c_star_eval = builder.consume_mle_evaluation();
// c_star + c_fold * c_star - input_ones = 0
builder.produce_sumcheck_subpolynomial_evaluation(
SumcheckSubpolynomialType::Identity,
Expand All @@ -210,7 +207,7 @@ fn verify_union<S: Scalar>(
.collect::<Vec<_>>();

let d_bar_fold_eval = gamma * fold_vals(beta, output_eval);
let d_star_eval = builder.consume_intermediate_mle();
let d_star_eval = builder.consume_mle_evaluation();

// d_star + d_bar_fold * d_star - output_ones = 0
builder.produce_sumcheck_subpolynomial_evaluation(
Expand Down

0 comments on commit 8dcc3cd

Please sign in to comment.