Skip to content

Commit

Permalink
fix: fix group by vulnerability (#426)
Browse files Browse the repository at this point in the history
  • Loading branch information
iajoiner authored Dec 11, 2024
2 parents 5b3c440 + 080e36f commit 8683e3c
Showing 1 changed file with 45 additions and 45 deletions.
90 changes: 45 additions & 45 deletions crates/proof-of-sql/src/sql/proof_plans/group_by_exec.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ use crate::{
use alloc::{boxed::Box, vec, vec::Vec};
use bumpalo::Bump;
use core::{iter, iter::repeat_with};
use num_traits::One;
use num_traits::{One, Zero};
use proof_of_sql_parser::Identifier;
use serde::{Deserialize, Serialize};

Expand Down Expand Up @@ -128,12 +128,14 @@ impl ProofPlan for GroupByExec {

let alpha = builder.consume_post_result_challenge();
let beta = builder.consume_post_result_challenge();
let output_one_eval = builder.consume_one_evaluation();

verify_group_by(
builder,
alpha,
beta,
input_one_eval,
output_one_eval,
(group_by_evals, aggregate_evals, where_eval),
(
group_by_result_columns_evals.clone(),
Expand Down Expand Up @@ -171,7 +173,6 @@ impl ProofPlan for GroupByExec {
.chain(sum_result_columns_evals)
.chain(iter::once(count_column_eval))
.collect::<Vec<_>>();
let output_one_eval = builder.consume_one_evaluation();
Ok(TableEvaluation::new(column_evals, output_one_eval))
}

Expand Down Expand Up @@ -343,38 +344,39 @@ fn verify_group_by<S: Scalar>(
builder: &mut VerificationBuilder<S>,
alpha: S,
beta: S,
one_eval: S,
input_one_eval: S,
output_one_eval: S,
(g_in_evals, sum_in_evals, sel_in_eval): (Vec<S>, Vec<S>, S),
(g_out_evals, sum_out_evals, count_out_eval): (Vec<S>, Vec<S>, S),
) -> Result<(), ProofError> {
// g_in_fold = alpha + sum beta^j * g_in[j]
let g_in_fold_eval = alpha * one_eval + fold_vals(beta, &g_in_evals);
// g_out_bar_fold = alpha + sum beta^j * g_out_bar[j]
let g_out_bar_fold_eval = alpha * one_eval + fold_vals(beta, &g_out_evals);
// sum_in_fold = 1 + sum beta^(j+1) * sum_in[j]
let sum_in_fold_eval = one_eval + beta * fold_vals(beta, &sum_in_evals);
// sum_out_bar_fold = count_out_bar + sum beta^(j+1) * sum_out_bar[j]
let sum_out_bar_fold_eval = count_out_eval + beta * fold_vals(beta, &sum_out_evals);
// g_in_fold = alpha * sum beta^j * g_in[j]
let g_in_fold_eval = alpha * fold_vals(beta, &g_in_evals);
// g_out_fold = alpha * sum beta^j * g_out[j]
let g_out_fold_eval = alpha * fold_vals(beta, &g_out_evals);
// sum_in_fold = input_ones + sum beta^(j+1) * sum_in[j]
let sum_in_fold_eval = input_one_eval + beta * fold_vals(beta, &sum_in_evals);
// sum_out_fold = count_out + sum beta^(j+1) * sum_out[j]
let sum_out_fold_eval = count_out_eval + beta * fold_vals(beta, &sum_out_evals);

let g_in_star_eval = builder.consume_intermediate_mle();
let g_out_star_eval = builder.consume_intermediate_mle();

// sum g_in_star * sel_in * sum_in_fold - g_out_star * sum_out_bar_fold = 0
// sum g_in_star * sel_in * sum_in_fold - g_out_star * sum_out_fold = 0
builder.produce_sumcheck_subpolynomial_evaluation(
SumcheckSubpolynomialType::ZeroSum,
g_in_star_eval * sel_in_eval * sum_in_fold_eval - g_out_star_eval * sum_out_bar_fold_eval,
g_in_star_eval * sel_in_eval * sum_in_fold_eval - g_out_star_eval * sum_out_fold_eval,
);

// g_in_star * g_in_fold - input_ones = 0
// g_in_star + g_in_star * g_in_fold - input_ones = 0
builder.produce_sumcheck_subpolynomial_evaluation(
SumcheckSubpolynomialType::Identity,
g_in_star_eval * g_in_fold_eval - one_eval,
g_in_star_eval + g_in_star_eval * g_in_fold_eval - input_one_eval,
);

// g_out_star * g_out_bar_fold - input_ones = 0
// g_out_star + g_out_star * g_out_fold - output_ones = 0
builder.produce_sumcheck_subpolynomial_evaluation(
SumcheckSubpolynomialType::Identity,
g_out_star_eval * g_out_bar_fold_eval - one_eval,
g_out_star_eval + g_out_star_eval * g_out_fold_eval - output_one_eval,
);

Ok(())
Expand All @@ -393,39 +395,41 @@ pub fn prove_group_by<'a, S: Scalar>(
(g_out, sum_out, count_out): (&[Column<S>], &[&'a [S]], &'a [i64]),
n: usize,
) {
let m_out = count_out.len();
let m = count_out.len();
let input_ones = alloc.alloc_slice_fill_copy(n, true);
let output_ones = alloc.alloc_slice_fill_copy(m, true);

// g_in_fold = alpha + sum beta^j * g_in[j]
let g_in_fold = alloc.alloc_slice_fill_copy(n, alpha);
fold_columns(g_in_fold, One::one(), beta, g_in);
// g_in_fold = alpha * sum beta^j * g_in[j]
let g_in_fold = alloc.alloc_slice_fill_copy(n, Zero::zero());
fold_columns(g_in_fold, alpha, beta, g_in);

// g_out_bar_fold = alpha + sum beta^j * g_out_bar[j]
let g_out_bar_fold = alloc.alloc_slice_fill_copy(n, alpha);
fold_columns(g_out_bar_fold, One::one(), beta, g_out);
// g_out_fold = alpha * sum beta^j * g_out[j]
let g_out_fold = alloc.alloc_slice_fill_copy(m, Zero::zero());
fold_columns(g_out_fold, alpha, beta, g_out);

// sum_in_fold = 1 + sum beta^(j+1) * sum_in[j]
let sum_in_fold = alloc.alloc_slice_fill_copy(n, One::one());
fold_columns(sum_in_fold, beta, beta, sum_in);

// sum_out_bar_fold = count_out_bar + sum beta^(j+1) * sum_out_bar[j]
let sum_out_bar_fold = alloc.alloc_slice_fill_default(n);
slice_ops::slice_cast_mut(count_out, sum_out_bar_fold);
fold_columns(sum_out_bar_fold, beta, beta, sum_out);
// sum_out_fold = count_out + sum beta^(j+1) * sum_out[j]
let sum_out_fold = alloc.alloc_slice_fill_default(m);
slice_ops::slice_cast_mut(count_out, sum_out_fold);
fold_columns(sum_out_fold, beta, beta, sum_out);

// g_in_star = g_in_fold^(-1)
// g_in_star = (1 + g_in_fold)^(-1)
let g_in_star = alloc.alloc_slice_copy(g_in_fold);
slice_ops::add_const::<S, S>(g_in_star, One::one());
slice_ops::batch_inversion(g_in_star);

// g_out_star = g_out_bar_fold^(-1), which is simply alpha^(-1) when beyond the output length
let g_out_star = alloc.alloc_slice_copy(g_out_bar_fold);
g_out_star[m_out..].fill(alpha.inv().expect("alpha should never be 0"));
slice_ops::batch_inversion(&mut g_out_star[..m_out]);
// g_out_star = (1 + g_out_fold)^(-1)
let g_out_star = alloc.alloc_slice_copy(g_out_fold);
slice_ops::add_const::<S, S>(g_out_star, One::one());
slice_ops::batch_inversion(g_out_star);

builder.produce_intermediate_mle(g_in_star as &[_]);
builder.produce_intermediate_mle(g_out_star as &[_]);

// sum g_in_star * sel_in * sum_in_fold - g_out_star * sum_out_bar_fold = 0
// sum g_in_star * sel_in * sum_in_fold - g_out_star * sum_out_fold = 0
builder.produce_sumcheck_subpolynomial(
SumcheckSubpolynomialType::ZeroSum,
vec![
Expand All @@ -439,18 +443,16 @@ pub fn prove_group_by<'a, S: Scalar>(
),
(
-S::one(),
vec![
Box::new(g_out_star as &[_]),
Box::new(sum_out_bar_fold as &[_]),
],
vec![Box::new(g_out_star as &[_]), Box::new(sum_out_fold as &[_])],
),
],
);

// g_in_star * g_in_fold - input_ones = 0
// g_in_star + g_in_star * g_in_fold - input_ones = 0
builder.produce_sumcheck_subpolynomial(
SumcheckSubpolynomialType::Identity,
vec![
(S::one(), vec![Box::new(g_in_star as &[_])]),
(
S::one(),
vec![Box::new(g_in_star as &[_]), Box::new(g_in_fold as &[_])],
Expand All @@ -459,18 +461,16 @@ pub fn prove_group_by<'a, S: Scalar>(
],
);

// g_out_star * g_out_bar_fold - input_ones = 0
// g_out_star + g_out_star * g_out_fold - output_ones = 0
builder.produce_sumcheck_subpolynomial(
SumcheckSubpolynomialType::Identity,
vec![
(S::one(), vec![Box::new(g_out_star as &[_])]),
(
S::one(),
vec![
Box::new(g_out_star as &[_]),
Box::new(g_out_bar_fold as &[_]),
],
vec![Box::new(g_out_star as &[_]), Box::new(g_out_fold as &[_])],
),
(-S::one(), vec![Box::new(input_ones as &[_])]),
(-S::one(), vec![Box::new(output_ones as &[_])]),
],
);
}

0 comments on commit 8683e3c

Please sign in to comment.