diff --git a/crates/proof-of-sql/src/sql/proof_plans/group_by_exec.rs b/crates/proof-of-sql/src/sql/proof_plans/group_by_exec.rs index 2d8fb8d90..e2a2d9163 100644 --- a/crates/proof-of-sql/src/sql/proof_plans/group_by_exec.rs +++ b/crates/proof-of-sql/src/sql/proof_plans/group_by_exec.rs @@ -23,7 +23,7 @@ use crate::{ use alloc::{boxed::Box, vec, vec::Vec}; use bumpalo::Bump; use core::{iter, iter::repeat_with}; -use num_traits::One; +use num_traits::{One, Zero}; use proof_of_sql_parser::Identifier; use serde::{Deserialize, Serialize}; @@ -128,12 +128,14 @@ impl ProofPlan for GroupByExec { let alpha = builder.consume_post_result_challenge(); let beta = builder.consume_post_result_challenge(); + let output_one_eval = builder.consume_one_evaluation(); verify_group_by( builder, alpha, beta, input_one_eval, + output_one_eval, (group_by_evals, aggregate_evals, where_eval), ( group_by_result_columns_evals.clone(), @@ -171,7 +173,6 @@ impl ProofPlan for GroupByExec { .chain(sum_result_columns_evals) .chain(iter::once(count_column_eval)) .collect::>(); - let output_one_eval = builder.consume_one_evaluation(); Ok(TableEvaluation::new(column_evals, output_one_eval)) } @@ -343,38 +344,39 @@ fn verify_group_by( builder: &mut VerificationBuilder, alpha: S, beta: S, - one_eval: S, + input_one_eval: S, + output_one_eval: S, (g_in_evals, sum_in_evals, sel_in_eval): (Vec, Vec, S), (g_out_evals, sum_out_evals, count_out_eval): (Vec, Vec, S), ) -> Result<(), ProofError> { - // g_in_fold = alpha + sum beta^j * g_in[j] - let g_in_fold_eval = alpha * one_eval + fold_vals(beta, &g_in_evals); - // g_out_bar_fold = alpha + sum beta^j * g_out_bar[j] - let g_out_bar_fold_eval = alpha * one_eval + fold_vals(beta, &g_out_evals); - // sum_in_fold = 1 + sum beta^(j+1) * sum_in[j] - let sum_in_fold_eval = one_eval + beta * fold_vals(beta, &sum_in_evals); - // sum_out_bar_fold = count_out_bar + sum beta^(j+1) * sum_out_bar[j] - let sum_out_bar_fold_eval = count_out_eval + beta * fold_vals(beta, &sum_out_evals); + // g_in_fold = alpha * sum beta^j * g_in[j] + let g_in_fold_eval = alpha * fold_vals(beta, &g_in_evals); + // g_out_fold = alpha * sum beta^j * g_out[j] + let g_out_fold_eval = alpha * fold_vals(beta, &g_out_evals); + // sum_in_fold = input_ones + sum beta^(j+1) * sum_in[j] + let sum_in_fold_eval = input_one_eval + beta * fold_vals(beta, &sum_in_evals); + // sum_out_fold = count_out + sum beta^(j+1) * sum_out[j] + let sum_out_fold_eval = count_out_eval + beta * fold_vals(beta, &sum_out_evals); let g_in_star_eval = builder.consume_intermediate_mle(); let g_out_star_eval = builder.consume_intermediate_mle(); - // sum g_in_star * sel_in * sum_in_fold - g_out_star * sum_out_bar_fold = 0 + // sum g_in_star * sel_in * sum_in_fold - g_out_star * sum_out_fold = 0 builder.produce_sumcheck_subpolynomial_evaluation( SumcheckSubpolynomialType::ZeroSum, - g_in_star_eval * sel_in_eval * sum_in_fold_eval - g_out_star_eval * sum_out_bar_fold_eval, + g_in_star_eval * sel_in_eval * sum_in_fold_eval - g_out_star_eval * sum_out_fold_eval, ); - // g_in_star * g_in_fold - input_ones = 0 + // g_in_star + g_in_star * g_in_fold - input_ones = 0 builder.produce_sumcheck_subpolynomial_evaluation( SumcheckSubpolynomialType::Identity, - g_in_star_eval * g_in_fold_eval - one_eval, + g_in_star_eval + g_in_star_eval * g_in_fold_eval - input_one_eval, ); - // g_out_star * g_out_bar_fold - input_ones = 0 + // g_out_star + g_out_star * g_out_fold - output_ones = 0 builder.produce_sumcheck_subpolynomial_evaluation( SumcheckSubpolynomialType::Identity, - g_out_star_eval * g_out_bar_fold_eval - one_eval, + g_out_star_eval + g_out_star_eval * g_out_fold_eval - output_one_eval, ); Ok(()) @@ -393,39 +395,41 @@ pub fn prove_group_by<'a, S: Scalar>( (g_out, sum_out, count_out): (&[Column], &[&'a [S]], &'a [i64]), n: usize, ) { - let m_out = count_out.len(); + let m = count_out.len(); let input_ones = alloc.alloc_slice_fill_copy(n, true); + let output_ones = alloc.alloc_slice_fill_copy(m, true); - // g_in_fold = alpha + sum beta^j * g_in[j] - let g_in_fold = alloc.alloc_slice_fill_copy(n, alpha); - fold_columns(g_in_fold, One::one(), beta, g_in); + // g_in_fold = alpha * sum beta^j * g_in[j] + let g_in_fold = alloc.alloc_slice_fill_copy(n, Zero::zero()); + fold_columns(g_in_fold, alpha, beta, g_in); - // g_out_bar_fold = alpha + sum beta^j * g_out_bar[j] - let g_out_bar_fold = alloc.alloc_slice_fill_copy(n, alpha); - fold_columns(g_out_bar_fold, One::one(), beta, g_out); + // g_out_fold = alpha * sum beta^j * g_out[j] + let g_out_fold = alloc.alloc_slice_fill_copy(m, Zero::zero()); + fold_columns(g_out_fold, alpha, beta, g_out); // sum_in_fold = 1 + sum beta^(j+1) * sum_in[j] let sum_in_fold = alloc.alloc_slice_fill_copy(n, One::one()); fold_columns(sum_in_fold, beta, beta, sum_in); - // sum_out_bar_fold = count_out_bar + sum beta^(j+1) * sum_out_bar[j] - let sum_out_bar_fold = alloc.alloc_slice_fill_default(n); - slice_ops::slice_cast_mut(count_out, sum_out_bar_fold); - fold_columns(sum_out_bar_fold, beta, beta, sum_out); + // sum_out_fold = count_out + sum beta^(j+1) * sum_out[j] + let sum_out_fold = alloc.alloc_slice_fill_default(m); + slice_ops::slice_cast_mut(count_out, sum_out_fold); + fold_columns(sum_out_fold, beta, beta, sum_out); - // g_in_star = g_in_fold^(-1) + // g_in_star = (1 + g_in_fold)^(-1) let g_in_star = alloc.alloc_slice_copy(g_in_fold); + slice_ops::add_const::(g_in_star, One::one()); slice_ops::batch_inversion(g_in_star); - // g_out_star = g_out_bar_fold^(-1), which is simply alpha^(-1) when beyond the output length - let g_out_star = alloc.alloc_slice_copy(g_out_bar_fold); - g_out_star[m_out..].fill(alpha.inv().expect("alpha should never be 0")); - slice_ops::batch_inversion(&mut g_out_star[..m_out]); + // g_out_star = (1 + g_out_fold)^(-1) + let g_out_star = alloc.alloc_slice_copy(g_out_fold); + slice_ops::add_const::(g_out_star, One::one()); + slice_ops::batch_inversion(g_out_star); builder.produce_intermediate_mle(g_in_star as &[_]); builder.produce_intermediate_mle(g_out_star as &[_]); - // sum g_in_star * sel_in * sum_in_fold - g_out_star * sum_out_bar_fold = 0 + // sum g_in_star * sel_in * sum_in_fold - g_out_star * sum_out_fold = 0 builder.produce_sumcheck_subpolynomial( SumcheckSubpolynomialType::ZeroSum, vec![ @@ -439,18 +443,16 @@ pub fn prove_group_by<'a, S: Scalar>( ), ( -S::one(), - vec![ - Box::new(g_out_star as &[_]), - Box::new(sum_out_bar_fold as &[_]), - ], + vec![Box::new(g_out_star as &[_]), Box::new(sum_out_fold as &[_])], ), ], ); - // g_in_star * g_in_fold - input_ones = 0 + // g_in_star + g_in_star * g_in_fold - input_ones = 0 builder.produce_sumcheck_subpolynomial( SumcheckSubpolynomialType::Identity, vec![ + (S::one(), vec![Box::new(g_in_star as &[_])]), ( S::one(), vec![Box::new(g_in_star as &[_]), Box::new(g_in_fold as &[_])], @@ -459,18 +461,16 @@ pub fn prove_group_by<'a, S: Scalar>( ], ); - // g_out_star * g_out_bar_fold - input_ones = 0 + // g_out_star + g_out_star * g_out_fold - output_ones = 0 builder.produce_sumcheck_subpolynomial( SumcheckSubpolynomialType::Identity, vec![ + (S::one(), vec![Box::new(g_out_star as &[_])]), ( S::one(), - vec![ - Box::new(g_out_star as &[_]), - Box::new(g_out_bar_fold as &[_]), - ], + vec![Box::new(g_out_star as &[_]), Box::new(g_out_fold as &[_])], ), - (-S::one(), vec![Box::new(input_ones as &[_])]), + (-S::one(), vec![Box::new(output_ones as &[_])]), ], ); }