Skip to content

Commit

Permalink
fix: fix padding vulnerability for filters
Browse files Browse the repository at this point in the history
  • Loading branch information
iajoiner committed Dec 10, 2024
1 parent e1823fd commit 95a21e5
Show file tree
Hide file tree
Showing 4 changed files with 53 additions and 18 deletions.
19 changes: 19 additions & 0 deletions crates/proof-of-sql/src/base/slice_ops/add_const.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
use crate::base::if_rayon;
use core::ops::AddAssign;
#[cfg(feature = "rayon")]
use rayon::iter::{IndexedParallelIterator, IntoParallelRefMutIterator, ParallelIterator};

/// This operation does `result[i] += to_add` for `i` in `0..result.len()`.
pub fn add_const<T, S>(result: &mut [T], to_add: S)
where
T: Send + Sync + AddAssign<T> + Copy,
S: Into<T> + Sync + Copy,
{
if_rayon!(
result.par_iter_mut().with_min_len(super::MIN_RAYON_LEN),
result.iter_mut()
)
.for_each(|res_i| {
*res_i += to_add.into();
});
}
9 changes: 9 additions & 0 deletions crates/proof-of-sql/src/base/slice_ops/add_const_test.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
use super::*;

#[test]
fn test_add_const() {
let mut a = vec![1, 2, 3, 4];
add_const(&mut a, 10);
let b = vec![1 + 10, 2 + 10, 3 + 10, 4 + 10];
assert_eq!(a, b);
}
4 changes: 4 additions & 0 deletions crates/proof-of-sql/src/base/slice_ops/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,9 @@
#[cfg(any(feature = "rayon", test))]
pub const MIN_RAYON_LEN: usize = 1 << 8;

mod add_const;
#[cfg(test)]
mod add_const_test;
mod inner_product;
#[cfg(test)]
mod inner_product_test;
Expand All @@ -16,6 +19,7 @@ mod slice_cast;
#[cfg(test)]
mod slice_cast_test;

pub use add_const::*;
pub use inner_product::*;
pub use mul_add_assign::*;
pub use slice_cast::*;
Expand Down
39 changes: 21 additions & 18 deletions crates/proof-of-sql/src/sql/proof_plans/filter_exec.rs
Original file line number Diff line number Diff line change
Expand Up @@ -260,8 +260,8 @@ pub(super) fn verify_filter<S: Scalar>(
s_eval: S,
d_evals: &[S],
) -> Result<(), ProofError> {
let c_fold_eval = alpha * one_eval + fold_vals(beta, c_evals);
let d_bar_fold_eval = alpha * one_eval + fold_vals(beta, d_evals);
let c_fold_eval = alpha * fold_vals(beta, c_evals);
let d_fold_eval = alpha * fold_vals(beta, d_evals);
let c_star_eval = builder.consume_intermediate_mle();
let d_star_eval = builder.consume_intermediate_mle();

Expand All @@ -271,16 +271,16 @@ pub(super) fn verify_filter<S: Scalar>(
c_star_eval * s_eval - d_star_eval,
);

// c_fold * c_star - input_ones = 0
// c_star + c_fold * c_star - input_ones = 0
builder.produce_sumcheck_subpolynomial_evaluation(
SumcheckSubpolynomialType::Identity,
c_fold_eval * c_star_eval - one_eval,
c_star_eval + c_fold_eval * c_star_eval - one_eval,
);

// d_bar_fold * d_star - chi = 0
// d_star + d_fold * d_star - chi = 0
builder.produce_sumcheck_subpolynomial_evaluation(
SumcheckSubpolynomialType::Identity,
d_bar_fold_eval * d_star_eval - chi_eval,
d_star_eval + d_fold_eval * d_star_eval - chi_eval,
);

Ok(())
Expand All @@ -299,19 +299,20 @@ pub(super) fn prove_filter<'a, S: Scalar + 'a>(
m: usize,
) {
let input_ones = alloc.alloc_slice_fill_copy(n, true);
let chi = alloc.alloc_slice_fill_copy(n, false);
chi[..m].fill(true);
let chi = alloc.alloc_slice_fill_copy(m, true);

let c_fold = alloc.alloc_slice_fill_copy(n, alpha);
fold_columns(c_fold, One::one(), beta, c);
let d_bar_fold = alloc.alloc_slice_fill_copy(n, alpha);
fold_columns(d_bar_fold, One::one(), beta, d);
let c_fold = alloc.alloc_slice_fill_copy(n, Zero::zero());
fold_columns(c_fold, alpha, beta, c);
let d_fold = alloc.alloc_slice_fill_copy(m, Zero::zero());
fold_columns(d_fold, alpha, beta, d);

let c_star = alloc.alloc_slice_copy(c_fold);
let d_star = alloc.alloc_slice_copy(d_bar_fold);
d_star[m..].fill(Zero::zero());
slice_ops::add_const::<S, S>(c_star, One::one());
slice_ops::batch_inversion(c_star);
slice_ops::batch_inversion(&mut d_star[..m]);

let d_star = alloc.alloc_slice_copy(d_fold);
slice_ops::add_const::<S, S>(d_star, One::one());
slice_ops::batch_inversion(d_star);

builder.produce_intermediate_mle(c_star as &[_]);
builder.produce_intermediate_mle(d_star as &[_]);
Expand All @@ -325,10 +326,11 @@ pub(super) fn prove_filter<'a, S: Scalar + 'a>(
],
);

// c_fold * c_star - input_ones = 0
// c_star + c_fold * c_star - input_ones = 0
builder.produce_sumcheck_subpolynomial(
SumcheckSubpolynomialType::Identity,
vec![
(S::one(), vec![Box::new(c_star as &[_])]),
(
S::one(),
vec![Box::new(c_star as &[_]), Box::new(c_fold as &[_])],
Expand All @@ -337,13 +339,14 @@ pub(super) fn prove_filter<'a, S: Scalar + 'a>(
],
);

// d_bar_fold * d_star - chi = 0
// d_star + d_fold * d_star - chi = 0
builder.produce_sumcheck_subpolynomial(
SumcheckSubpolynomialType::Identity,
vec![
(S::one(), vec![Box::new(d_star as &[_])]),
(
S::one(),
vec![Box::new(d_star as &[_]), Box::new(d_bar_fold as &[_])],
vec![Box::new(d_star as &[_]), Box::new(d_fold as &[_])],
),
(-S::one(), vec![Box::new(chi as &[_])]),
],
Expand Down

0 comments on commit 95a21e5

Please sign in to comment.