Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor!: move table_length to FirstRoundBuilder #380

Merged
merged 3 commits into from
Nov 22, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 6 additions & 29 deletions crates/proof-of-sql/src/sql/proof/final_round_builder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,10 @@ use crate::base::{
polynomial::{CompositePolynomial, MultilinearExtension},
scalar::Scalar,
};
use alloc::{boxed::Box, vec, vec::Vec};
use num_traits::Zero;
use alloc::{boxed::Box, vec::Vec};

/// Track components used to form a query's proof
pub struct FinalRoundBuilder<'a, S: Scalar> {
table_length: usize,
num_sumcheck_variables: usize,
bit_distributions: Vec<BitDistribution>,
commitment_descriptor: Vec<CommittableColumn<'a>>,
Expand All @@ -30,13 +28,8 @@ pub struct FinalRoundBuilder<'a, S: Scalar> {
}

impl<'a, S: Scalar> FinalRoundBuilder<'a, S> {
pub fn new(
table_length: usize,
num_sumcheck_variables: usize,
post_result_challenges: Vec<S>,
) -> Self {
pub fn new(num_sumcheck_variables: usize, post_result_challenges: Vec<S>) -> Self {
Self {
table_length,
num_sumcheck_variables,
bit_distributions: Vec::new(),
commitment_descriptor: Vec::new(),
Expand All @@ -46,10 +39,6 @@ impl<'a, S: Scalar> FinalRoundBuilder<'a, S> {
}
}

pub fn table_length(&self) -> usize {
self.table_length
}

pub fn num_sumcheck_variables(&self) -> usize {
self.num_sumcheck_variables
}
Expand All @@ -58,6 +47,10 @@ impl<'a, S: Scalar> FinalRoundBuilder<'a, S> {
self.sumcheck_subpolynomials.len()
}

pub fn pcs_proof_mles(&self) -> &[Box<dyn MultilinearExtension<S> + 'a>] {
&self.pcs_proof_mles
}

/// Produce a bit distribution that describes which bits are constant
/// and which bits varying in a column of data
pub fn produce_bit_distribution(&mut self, dist: BitDistribution) {
Expand Down Expand Up @@ -152,22 +145,6 @@ impl<'a, S: Scalar> FinalRoundBuilder<'a, S> {
res
}

/// Given random multipliers, multiply and add together all of the MLEs used in sumcheck except
/// for those that correspond to result columns sent to the verifier.
#[tracing::instrument(
name = "FinalRoundBuilder::fold_pcs_proof_mles",
level = "debug",
skip_all
)]
pub fn fold_pcs_proof_mles(&self, multipliers: &[S]) -> Vec<S> {
assert_eq!(multipliers.len(), self.pcs_proof_mles.len());
let mut res = vec![Zero::zero(); self.table_length];
for (multiplier, evaluator) in multipliers.iter().zip(self.pcs_proof_mles.iter()) {
evaluator.mul_add(&mut res, multiplier);
}
res
}

pub fn bit_distributions(&self) -> &[BitDistribution] {
&self.bit_distributions
}
Expand Down
25 changes: 4 additions & 21 deletions crates/proof-of-sql/src/sql/proof/final_round_builder_test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ use num_traits::{One, Zero};
fn we_can_compute_commitments_for_intermediate_mles_using_a_zero_offset() {
let mle1 = [1, 2];
let mle2 = [10i64, 20];
let mut builder = FinalRoundBuilder::<Curve25519Scalar>::new(2, 1, Vec::new());
let mut builder = FinalRoundBuilder::<Curve25519Scalar>::new(1, Vec::new());
builder.produce_anchored_mle(&mle1);
builder.produce_intermediate_mle(&mle2[..]);
let offset_generators = 0_usize;
Expand All @@ -41,7 +41,7 @@ fn we_can_compute_commitments_for_intermediate_mles_using_a_zero_offset() {
fn we_can_compute_commitments_for_intermediate_mles_using_a_non_zero_offset() {
let mle1 = [1, 2];
let mle2 = [10i64, 20];
let mut builder = FinalRoundBuilder::<Curve25519Scalar>::new(2, 1, Vec::new());
let mut builder = FinalRoundBuilder::<Curve25519Scalar>::new(1, Vec::new());
builder.produce_anchored_mle(&mle1);
builder.produce_intermediate_mle(&mle2[..]);
let offset_generators = 123_usize;
Expand All @@ -60,7 +60,7 @@ fn we_can_compute_commitments_for_intermediate_mles_using_a_non_zero_offset() {
fn we_can_evaluate_pcs_proof_mles() {
let mle1 = [1, 2];
let mle2 = [10i64, 20];
let mut builder = FinalRoundBuilder::new(2, 1, Vec::new());
let mut builder = FinalRoundBuilder::new(1, Vec::new());
builder.produce_anchored_mle(&mle1);
builder.produce_intermediate_mle(&mle2[..]);
let evaluation_vec = [
Expand All @@ -80,7 +80,7 @@ fn we_can_form_an_aggregated_sumcheck_polynomial() {
let mle1 = [1, 2, -1];
let mle2 = [10i64, 20, 100, 30];
let mle3 = [2000i64, 3000, 5000, 7000];
let mut builder = FinalRoundBuilder::new(4, 2, Vec::new());
let mut builder = FinalRoundBuilder::new(2, Vec::new());
builder.produce_anchored_mle(&mle1);
builder.produce_intermediate_mle(&mle2[..]);
builder.produce_intermediate_mle(&mle3[..]);
Expand Down Expand Up @@ -166,26 +166,9 @@ fn we_can_form_the_provable_query_result() {
assert_eq!(res, expected_res);
}

#[test]
fn we_can_fold_pcs_proof_mles() {
let mle1 = [1, 2];
let mle2 = [10i64, 20];
let mut builder = FinalRoundBuilder::new(2, 1, Vec::new());
builder.produce_anchored_mle(&mle1);
builder.produce_intermediate_mle(&mle2[..]);
let multipliers = [Curve25519Scalar::from(100u64), Curve25519Scalar::from(2u64)];
let z = builder.fold_pcs_proof_mles(&multipliers);
let expected_z = [
Curve25519Scalar::from(120u64),
Curve25519Scalar::from(240u64),
];
assert_eq!(z, expected_z);
}

#[test]
fn we_can_consume_post_result_challenges_in_proof_builder() {
let mut builder = FinalRoundBuilder::new(
0,
0,
vec![
Curve25519Scalar::from(123),
Expand Down
24 changes: 17 additions & 7 deletions crates/proof-of-sql/src/sql/proof/first_round_builder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,19 +5,29 @@ pub struct FirstRoundBuilder {
/// the prover after the prover sends the result, but before the prover
/// send commitments to the intermediate witness columns.
num_post_result_challenges: usize,
}

impl Default for FirstRoundBuilder {
fn default() -> Self {
Self::new()
}
/// Used to determine the indices of generators we use
range_length: usize,
}

impl FirstRoundBuilder {
/// Create a new result builder for a table with the given length. For multi table queries, this will likely need to change.
pub fn new() -> Self {
pub fn new(range_length: usize) -> Self {
Self {
num_post_result_challenges: 0,
range_length,
}
}

pub fn range_length(&self) -> usize {
self.range_length
}

/// Used if a `ProofPlan` can cause output `table_length` to be larger
/// than the largest of the input ones e.g. unions and joins since it will
/// force us to update `range_length`.
pub fn update_range_length(&mut self, table_length: usize) {
if table_length > self.range_length {
self.range_length = table_length;
}
}

Expand Down
37 changes: 21 additions & 16 deletions crates/proof-of-sql/src/sql/proof/query_proof.rs
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,8 @@ pub struct QueryProof<CP: CommitmentEvaluationProof> {
pub pcs_proof_evaluations: Vec<CP::Scalar>,
/// Inner product proof of the MLEs' evaluations
pub evaluation_proof: CP,
/// Length of the range of generators we use
pub range_length: usize,
}

impl<CP: CommitmentEvaluationProof> QueryProof<CP> {
Expand All @@ -71,10 +73,7 @@ impl<CP: CommitmentEvaluationProof> QueryProof<CP> {
setup: &CP::ProverPublicSetup<'_>,
) -> (Self, ProvableQueryResult) {
let (min_row_num, max_row_num) = get_index_range(accessor, expr.get_table_references());
let range_length = max_row_num - min_row_num;
let num_sumcheck_variables = cmp::max(log2_up(range_length), 1);
assert!(num_sumcheck_variables > 0);

let initial_range_length = max_row_num - min_row_num;
let alloc = Bump::new();

let total_col_refs = expr.get_column_references();
Expand All @@ -95,8 +94,11 @@ impl<CP: CommitmentEvaluationProof> QueryProof<CP> {
let provable_result = expr.result_evaluate(&alloc, &table_map).into();

// Prover First Round
let mut first_round_builder = FirstRoundBuilder::new();
let mut first_round_builder = FirstRoundBuilder::new(initial_range_length);
expr.first_round_evaluate(&mut first_round_builder);
let range_length = first_round_builder.range_length();
let num_sumcheck_variables = cmp::max(log2_up(range_length), 1);
assert!(num_sumcheck_variables > 0);

// construct a transcript for the proof
let mut transcript: Keccak256Transcript =
Expand All @@ -112,8 +114,7 @@ impl<CP: CommitmentEvaluationProof> QueryProof<CP> {
.take(first_round_builder.num_post_result_challenges())
.collect();

let mut builder =
FinalRoundBuilder::new(range_length, num_sumcheck_variables, post_result_challenges);
let mut builder = FinalRoundBuilder::new(num_sumcheck_variables, post_result_challenges);

for col_ref in total_col_refs {
builder.produce_anchored_mle(accessor.get_column(col_ref));
Expand Down Expand Up @@ -159,7 +160,12 @@ impl<CP: CommitmentEvaluationProof> QueryProof<CP> {
core::iter::repeat_with(|| transcript.scalar_challenge_as_be())
.take(pcs_proof_evaluations.len())
.collect();
let folded_mle = builder.fold_pcs_proof_mles(&random_scalars);

assert_eq!(random_scalars.len(), builder.pcs_proof_mles().len());
let mut folded_mle = vec![Zero::zero(); range_length];
for (multiplier, evaluator) in random_scalars.iter().zip(builder.pcs_proof_mles().iter()) {
evaluator.mul_add(&mut folded_mle, multiplier);
}

// finally, form the inner product proof of the MLEs' evaluations
let evaluation_proof = CP::new(
Expand All @@ -176,6 +182,7 @@ impl<CP: CommitmentEvaluationProof> QueryProof<CP> {
sumcheck_proof,
pcs_proof_evaluations,
evaluation_proof,
range_length,
};
(proof, provable_result)
}
Expand All @@ -190,10 +197,8 @@ impl<CP: CommitmentEvaluationProof> QueryProof<CP> {
setup: &CP::VerifierPublicSetup<'_>,
) -> QueryResult<CP::Scalar> {
let owned_table_result = result.to_owned_table(&expr.get_column_result_fields())?;

let (min_row_num, max_row_num) = get_index_range(accessor, expr.get_table_references());
let range_length = max_row_num - min_row_num;
let num_sumcheck_variables = cmp::max(log2_up(range_length), 1);
let (min_row_num, _) = get_index_range(accessor, expr.get_table_references());
let num_sumcheck_variables = cmp::max(log2_up(self.range_length), 1);
assert!(num_sumcheck_variables > 0);

// validate bit decompositions
Expand Down Expand Up @@ -222,7 +227,7 @@ impl<CP: CommitmentEvaluationProof> QueryProof<CP> {

// construct a transcript for the proof
let mut transcript: Keccak256Transcript =
make_transcript(expr, result, range_length, min_row_num);
make_transcript(expr, result, self.range_length, min_row_num);

// These are the challenges that will be consumed by the proof
// Specifically, these are the challenges that the verifier sends to
Expand All @@ -244,7 +249,7 @@ impl<CP: CommitmentEvaluationProof> QueryProof<CP> {
.take(num_random_scalars)
.collect();
let sumcheck_random_scalars =
SumcheckRandomScalars::new(&random_scalars, range_length, num_sumcheck_variables);
SumcheckRandomScalars::new(&random_scalars, self.range_length, num_sumcheck_variables);

// verify sumcheck up to the evaluation check
let poly_info = CompositePolynomialInfo {
Expand All @@ -271,7 +276,7 @@ impl<CP: CommitmentEvaluationProof> QueryProof<CP> {

// pass over the provable AST to fill in the verification builder
let sumcheck_evaluations = SumcheckMleEvaluations::new(
range_length,
self.range_length,
owned_table_result.num_rows(),
&subclaim.evaluation_point,
&sumcheck_random_scalars,
Expand Down Expand Up @@ -327,7 +332,7 @@ impl<CP: CommitmentEvaluationProof> QueryProof<CP> {
&product,
&subclaim.evaluation_point,
min_row_num as u64,
range_length,
self.range_length,
setup,
)
.map_err(|_e| ProofError::VerificationError {
Expand Down
10 changes: 7 additions & 3 deletions crates/proof-of-sql/src/sql/proof_exprs/equals_expr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,12 @@ impl ProofExpr for EqualsExpr {
let rhs_scale = self.rhs.data_type().scale().unwrap_or(0);
let res = scale_and_subtract(alloc, lhs_column, rhs_column, lhs_scale, rhs_scale, true)
.expect("Failed to scale and subtract");
Column::Boolean(prover_evaluate_equals_zero(builder, alloc, res))
Column::Boolean(prover_evaluate_equals_zero(
table.num_rows(),
builder,
alloc,
res,
))
}

fn verifier_evaluate<S: Scalar>(
Expand Down Expand Up @@ -103,12 +108,11 @@ pub fn result_evaluate_equals_zero<'a, S: Scalar>(
}

pub fn prover_evaluate_equals_zero<'a, S: Scalar>(
table_length: usize,
builder: &mut FinalRoundBuilder<'a, S>,
alloc: &'a Bump,
lhs: &'a [S],
) -> &'a [bool] {
let table_length = builder.table_length();

// lhs_pseudo_inv
let lhs_pseudo_inv = alloc.alloc_slice_copy(lhs);
slice_ops::batch_inversion(lhs_pseudo_inv);
Expand Down
2 changes: 1 addition & 1 deletion crates/proof-of-sql/src/sql/proof_exprs/inequality_expr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@ impl ProofExpr for InequalityExpr {
};

// diff == 0
let equals_zero = prover_evaluate_equals_zero(builder, alloc, diff);
let equals_zero = prover_evaluate_equals_zero(table.num_rows(), builder, alloc, diff);

// sign(diff) == -1
let sign = prover_evaluate_sign(
Expand Down
4 changes: 2 additions & 2 deletions crates/proof-of-sql/src/sql/proof_exprs/sign_expr_test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ fn prover_evaluation_generates_the_bit_distribution_of_a_constant_column() {
let dist = BitDistribution::new::<Curve25519Scalar, _>(&data);
let alloc = Bump::new();
let data: Vec<Curve25519Scalar> = data.into_iter().map(Curve25519Scalar::from).collect();
let mut builder = FinalRoundBuilder::new(3, 2, Vec::new());
let mut builder = FinalRoundBuilder::new(2, Vec::new());
let sign = prover_evaluate_sign(&mut builder, &alloc, &data, false);
assert_eq!(sign, [false; 3]);
assert_eq!(builder.bit_distributions(), [dist]);
Expand All @@ -27,7 +27,7 @@ fn prover_evaluation_generates_the_bit_distribution_of_a_negative_constant_colum
let dist = BitDistribution::new::<Curve25519Scalar, _>(&data);
let alloc = Bump::new();
let data: Vec<Curve25519Scalar> = data.into_iter().map(Curve25519Scalar::from).collect();
let mut builder = FinalRoundBuilder::new(3, 2, Vec::new());
let mut builder = FinalRoundBuilder::new(2, Vec::new());
let sign = prover_evaluate_sign(&mut builder, &alloc, &data, false);
assert_eq!(sign, [true; 3]);
assert_eq!(builder.bit_distributions(), [dist]);
Expand Down
3 changes: 2 additions & 1 deletion crates/proof-of-sql/src/sql/proof_plans/filter_exec.rs
Original file line number Diff line number Diff line change
Expand Up @@ -218,6 +218,7 @@ impl ProverEvaluate for FilterExec {
&columns,
selection,
&filtered_columns,
table.num_rows(),
result_len,
);
Table::<'a, S>::try_from_iter_with_options(
Expand Down Expand Up @@ -278,9 +279,9 @@ pub(super) fn prove_filter<'a, S: Scalar + 'a>(
c: &[Column<S>],
s: &'a [bool],
d: &[Column<S>],
n: usize,
m: usize,
) {
let n = builder.table_length();
let chi = alloc.alloc_slice_fill_copy(n, false);
chi[..m].fill(true);

Expand Down
Loading
Loading