Skip to content

Commit

Permalink
fix: fix filter vulnerability (#423)
Browse files Browse the repository at this point in the history
Please be sure to look over the pull request guidelines here:
https://github.com/spaceandtimelabs/sxt-proof-of-sql/blob/main/CONTRIBUTING.md#submit-pr.

# Please go through the following checklist
- [x] The PR title and commit messages adhere to guidelines here:
https://github.com/spaceandtimelabs/sxt-proof-of-sql/blob/main/CONTRIBUTING.md.
In particular `!` is used if and only if at least one breaking change
has been introduced.
- [x] I have run the ci check script with `source
scripts/run_ci_checks.sh`.

# Rationale for this change
Recently @JayWhite2357 found a vulnerability in our `ProofPlan`s with
nontrivial proofs related to range length. Now we are implementing a
fix.
<!--
Why are you proposing this change? If this is already explained clearly
in the linked issue then this section is not needed.
Explaining clearly why changes are proposed helps reviewers understand
your changes and offer better suggestions for fixes.

 Example:
 Add `NestedLoopJoinExec`.
 Closes #345.

Since we added `HashJoinExec` in #323 it has been possible to do
provable inner joins. However performance is not satisfactory in some
cases. Hence we need to fix the problem by implement
`NestedLoopJoinExec` and speed up the code
 for `HashJoinExec`.
-->

# What changes are included in this PR?
- add `range_length_one_evaluation` to `SumcheckMleEvaluations`
- add `range_length` to `FinalRoundBuilder`
- fix the bug in the filter proof.
<!--
There is no need to duplicate the description in the ticket here but it
is sometimes worth providing a summary of the individual changes in this
PR.

Example:
- Add `NestedLoopJoinExec`.
- Speed up `HashJoinExec`.
- Route joins to `NestedLoopJoinExec` if the outer input is sufficiently
small.
-->

# Are these changes tested?
<!--
We typically require tests for all PRs in order to:
1. Prevent the code from being accidentally broken by subsequent changes
2. Serve as another way to document the expected behavior of the code

If tests are not included in your PR, please explain why (for example,
are they covered by existing tests)?

Example:
Yes.
-->
Existing tests should pass.
  • Loading branch information
iajoiner authored Dec 10, 2024
2 parents e1823fd + 95a21e5 commit cb4a93c
Show file tree
Hide file tree
Showing 4 changed files with 53 additions and 18 deletions.
19 changes: 19 additions & 0 deletions crates/proof-of-sql/src/base/slice_ops/add_const.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
use crate::base::if_rayon;
use core::ops::AddAssign;
#[cfg(feature = "rayon")]
use rayon::iter::{IndexedParallelIterator, IntoParallelRefMutIterator, ParallelIterator};

/// This operation does `result[i] += to_add` for `i` in `0..result.len()`.
pub fn add_const<T, S>(result: &mut [T], to_add: S)
where
T: Send + Sync + AddAssign<T> + Copy,
S: Into<T> + Sync + Copy,
{
if_rayon!(
result.par_iter_mut().with_min_len(super::MIN_RAYON_LEN),
result.iter_mut()
)
.for_each(|res_i| {
*res_i += to_add.into();
});
}
9 changes: 9 additions & 0 deletions crates/proof-of-sql/src/base/slice_ops/add_const_test.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
use super::*;

#[test]
fn test_add_const() {
let mut a = vec![1, 2, 3, 4];
add_const(&mut a, 10);
let b = vec![1 + 10, 2 + 10, 3 + 10, 4 + 10];
assert_eq!(a, b);
}
4 changes: 4 additions & 0 deletions crates/proof-of-sql/src/base/slice_ops/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,9 @@
#[cfg(any(feature = "rayon", test))]
pub const MIN_RAYON_LEN: usize = 1 << 8;

mod add_const;
#[cfg(test)]
mod add_const_test;
mod inner_product;
#[cfg(test)]
mod inner_product_test;
Expand All @@ -16,6 +19,7 @@ mod slice_cast;
#[cfg(test)]
mod slice_cast_test;

pub use add_const::*;
pub use inner_product::*;
pub use mul_add_assign::*;
pub use slice_cast::*;
Expand Down
39 changes: 21 additions & 18 deletions crates/proof-of-sql/src/sql/proof_plans/filter_exec.rs
Original file line number Diff line number Diff line change
Expand Up @@ -260,8 +260,8 @@ pub(super) fn verify_filter<S: Scalar>(
s_eval: S,
d_evals: &[S],
) -> Result<(), ProofError> {
let c_fold_eval = alpha * one_eval + fold_vals(beta, c_evals);
let d_bar_fold_eval = alpha * one_eval + fold_vals(beta, d_evals);
let c_fold_eval = alpha * fold_vals(beta, c_evals);
let d_fold_eval = alpha * fold_vals(beta, d_evals);
let c_star_eval = builder.consume_intermediate_mle();
let d_star_eval = builder.consume_intermediate_mle();

Expand All @@ -271,16 +271,16 @@ pub(super) fn verify_filter<S: Scalar>(
c_star_eval * s_eval - d_star_eval,
);

// c_fold * c_star - input_ones = 0
// c_star + c_fold * c_star - input_ones = 0
builder.produce_sumcheck_subpolynomial_evaluation(
SumcheckSubpolynomialType::Identity,
c_fold_eval * c_star_eval - one_eval,
c_star_eval + c_fold_eval * c_star_eval - one_eval,
);

// d_bar_fold * d_star - chi = 0
// d_star + d_fold * d_star - chi = 0
builder.produce_sumcheck_subpolynomial_evaluation(
SumcheckSubpolynomialType::Identity,
d_bar_fold_eval * d_star_eval - chi_eval,
d_star_eval + d_fold_eval * d_star_eval - chi_eval,
);

Ok(())
Expand All @@ -299,19 +299,20 @@ pub(super) fn prove_filter<'a, S: Scalar + 'a>(
m: usize,
) {
let input_ones = alloc.alloc_slice_fill_copy(n, true);
let chi = alloc.alloc_slice_fill_copy(n, false);
chi[..m].fill(true);
let chi = alloc.alloc_slice_fill_copy(m, true);

let c_fold = alloc.alloc_slice_fill_copy(n, alpha);
fold_columns(c_fold, One::one(), beta, c);
let d_bar_fold = alloc.alloc_slice_fill_copy(n, alpha);
fold_columns(d_bar_fold, One::one(), beta, d);
let c_fold = alloc.alloc_slice_fill_copy(n, Zero::zero());
fold_columns(c_fold, alpha, beta, c);
let d_fold = alloc.alloc_slice_fill_copy(m, Zero::zero());
fold_columns(d_fold, alpha, beta, d);

let c_star = alloc.alloc_slice_copy(c_fold);
let d_star = alloc.alloc_slice_copy(d_bar_fold);
d_star[m..].fill(Zero::zero());
slice_ops::add_const::<S, S>(c_star, One::one());
slice_ops::batch_inversion(c_star);
slice_ops::batch_inversion(&mut d_star[..m]);

let d_star = alloc.alloc_slice_copy(d_fold);
slice_ops::add_const::<S, S>(d_star, One::one());
slice_ops::batch_inversion(d_star);

builder.produce_intermediate_mle(c_star as &[_]);
builder.produce_intermediate_mle(d_star as &[_]);
Expand All @@ -325,10 +326,11 @@ pub(super) fn prove_filter<'a, S: Scalar + 'a>(
],
);

// c_fold * c_star - input_ones = 0
// c_star + c_fold * c_star - input_ones = 0
builder.produce_sumcheck_subpolynomial(
SumcheckSubpolynomialType::Identity,
vec![
(S::one(), vec![Box::new(c_star as &[_])]),
(
S::one(),
vec![Box::new(c_star as &[_]), Box::new(c_fold as &[_])],
Expand All @@ -337,13 +339,14 @@ pub(super) fn prove_filter<'a, S: Scalar + 'a>(
],
);

// d_bar_fold * d_star - chi = 0
// d_star + d_fold * d_star - chi = 0
builder.produce_sumcheck_subpolynomial(
SumcheckSubpolynomialType::Identity,
vec![
(S::one(), vec![Box::new(d_star as &[_])]),
(
S::one(),
vec![Box::new(d_star as &[_]), Box::new(d_bar_fold as &[_])],
vec![Box::new(d_star as &[_]), Box::new(d_fold as &[_])],
),
(-S::one(), vec![Box::new(chi as &[_])]),
],
Expand Down

0 comments on commit cb4a93c

Please sign in to comment.