Skip to content

Commit

Permalink
feat: add membership check
Browse files Browse the repository at this point in the history
  • Loading branch information
iajoiner committed Dec 20, 2024
1 parent a74cda8 commit b72cbc1
Show file tree
Hide file tree
Showing 3 changed files with 316 additions and 0 deletions.
152 changes: 152 additions & 0 deletions crates/proof-of-sql/src/sql/proof_gadgets/membership_check.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,152 @@
use crate::{
base::{database::Column, proof::ProofError, scalar::Scalar, slice_ops},
sql::{
proof::{
FinalRoundBuilder, FirstRoundBuilder, SumcheckSubpolynomialType, VerificationBuilder,
},
proof_plans::{fold_columns, fold_vals},
},
};
use alloc::{boxed::Box, vec, vec::Vec};
use bumpalo::Bump;
use itertools::Itertools;
use num_traits::{One, Zero};

/// Perform first round evaluation of the membership check.
#[allow(dead_code)]
pub(crate) fn first_round_evaluate_membership_check<'a, S: Scalar>(
builder: &mut FirstRoundBuilder<'a, S>,
indexes: &[usize],
num_rows: usize,
alloc: &'a Bump,
) {
let multiplicity_map = indexes.into_iter().counts();
let multiplicities = (0..num_rows - 1)
.map(|i| multiplicity_map.get(&i).copied().unwrap_or(0) as i128)
.collect::<Vec<_>>();
let alloc_multiplicities = alloc.alloc_slice_copy(&multiplicities);
builder.produce_intermediate_mle(alloc_multiplicities as &[_]);
builder.request_post_result_challenges(2);
}

/// Perform final round evaluation of the membership check.
#[allow(dead_code)]
pub(crate) fn final_round_evaluate_membership_check<'a, S: Scalar>(
builder: &mut FinalRoundBuilder<'a, S>,
alloc: &'a Bump,
alpha: S,
beta: S,
columns: &[Column<'a, S>],
candidate_subset: &[Column<'a, S>],
indexes: &[usize],
num_rows: usize,
candidate_num_rows: usize,
) {
// 1. Get multiplicity of each index
let multiplicity_map = indexes.into_iter().counts();
let multiplicities = (0..num_rows - 1)
.map(|i| multiplicity_map.get(&i).copied().unwrap_or(0) as i128)
.collect::<Vec<_>>();
let alloc_multiplicities = alloc.alloc_slice_copy(&multiplicities);
builder.produce_intermediate_mle(alloc_multiplicities as &[_]);
// 2. Fold the columns
let input_ones = alloc.alloc_slice_fill_copy(num_rows, true);
let candidate_ones = alloc.alloc_slice_fill_copy(candidate_num_rows, true);

let c_fold = alloc.alloc_slice_fill_copy(num_rows, Zero::zero());
fold_columns(c_fold, alpha, beta, columns);
let d_fold = alloc.alloc_slice_fill_copy(candidate_num_rows, Zero::zero());
fold_columns(d_fold, alpha, beta, candidate_subset);

let c_star = alloc.alloc_slice_copy(c_fold);
slice_ops::add_const::<S, S>(c_star, One::one());
slice_ops::batch_inversion(c_star);

let d_star = alloc.alloc_slice_copy(d_fold);
slice_ops::add_const::<S, S>(d_star, One::one());
slice_ops::batch_inversion(d_star);

builder.produce_intermediate_mle(c_star as &[_]);
builder.produce_intermediate_mle(d_star as &[_]);

// sum c_star * multiplicities - d_star = 0
builder.produce_sumcheck_subpolynomial(
SumcheckSubpolynomialType::ZeroSum,
vec![
(
S::one(),
vec![
Box::new(c_star as &[_]),
Box::new(alloc_multiplicities as &[_]),
],
),
(-S::one(), vec![Box::new(d_star as &[_])]),
],
);

// c_star + c_fold * c_star - input_ones = 0
builder.produce_sumcheck_subpolynomial(
SumcheckSubpolynomialType::Identity,
vec![
(S::one(), vec![Box::new(c_star as &[_])]),
(
S::one(),
vec![Box::new(c_star as &[_]), Box::new(c_fold as &[_])],
),
(-S::one(), vec![Box::new(input_ones as &[_])]),
],
);

// d_star + d_fold * d_star - candidate_ones = 0
builder.produce_sumcheck_subpolynomial(
SumcheckSubpolynomialType::Identity,
vec![
(S::one(), vec![Box::new(d_star as &[_])]),
(
S::one(),
vec![Box::new(d_star as &[_]), Box::new(d_fold as &[_])],
),
(-S::one(), vec![Box::new(candidate_ones as &[_])]),
],
);
}

#[allow(dead_code)]
pub(crate) fn verify_membership_check<S: Scalar>(
builder: &mut VerificationBuilder<S>,
alpha: S,
beta: S,
input_one_eval: S,
candidate_one_eval: S,
column_evals: &[S],
candidate_evals: &[S],
multiplicity_eval: S,
) -> Result<(), ProofError> {
let c_fold_eval = alpha * fold_vals(beta, column_evals);
let d_fold_eval = alpha * fold_vals(beta, candidate_evals);
let c_star_eval = builder.try_consume_final_round_mle_evaluation()?;
let d_star_eval = builder.try_consume_final_round_mle_evaluation()?;

// sum c_star * multiplicities - d_star = 0
builder.try_produce_sumcheck_subpolynomial_evaluation(
SumcheckSubpolynomialType::ZeroSum,
c_star_eval * multiplicity_eval - d_star_eval,
2,
)?;

// c_star + c_fold * c_star - input_ones = 0
builder.try_produce_sumcheck_subpolynomial_evaluation(
SumcheckSubpolynomialType::Identity,
c_star_eval + c_fold_eval * c_star_eval - input_one_eval,
2,
)?;

// d_star + d_fold * d_star - candidate_ones = 0
builder.try_produce_sumcheck_subpolynomial_evaluation(
SumcheckSubpolynomialType::Identity,
d_star_eval + d_fold_eval * d_star_eval - candidate_one_eval,
2,
)?;

Ok(())
}
161 changes: 161 additions & 0 deletions crates/proof-of-sql/src/sql/proof_gadgets/membership_check_test.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,161 @@
use super::range_check::{final_round_evaluate_range_check, verifier_evaluate_range_check};
use crate::{
base::{
database::{ColumnField, ColumnRef, OwnedTable, Table, TableEvaluation, TableRef},
map::{indexset, IndexMap, IndexSet},
proof::ProofError,
scalar::Scalar,
},
sql::proof::{
FinalRoundBuilder, FirstRoundBuilder, ProofPlan, ProverEvaluate, VerificationBuilder,
},
};
use bumpalo::Bump;
use serde::Serialize;

#[derive(Debug, Serialize)]
pub struct MembershipCheckTestPlan<'a, S: Scalar> {
pub columns: Vec<Column<'a, S>>,
pub candidate_subset: Vec<Column<'a, S>>,
}

impl<'a, S: Scalar> ProverEvaluate for MembershipCheckTestPlan<'a, S> {
#[doc = " Evaluate the query, modify `FirstRoundBuilder` and return the result."]
fn first_round_evaluate<'a, S: Scalar>(
&self,
builder: &mut FirstRoundBuilder<'a, S>,
_alloc: &'a Bump,
table_map: &IndexMap<TableRef, Table<'a, S>>,
) -> Table<'a, S> {

builder.request_post_result_challenges(2);
table_map[&self.column.table_ref()].clone()
}

// extract data to test on from here, feed it into range check
fn final_round_evaluate<'a, S: Scalar>(
&self,
builder: &mut FinalRoundBuilder<'a, S>,
alloc: &'a Bump,
table_map: &IndexMap<TableRef, Table<'a, S>>,
) -> Table<'a, S> {
// Get the table from the map using the table reference
let table: &Table<'a, S> = table_map
.get(&self.column.table_ref())
.expect("Table not found");

let scalars = table
.inner_table()
.get(&self.column.column_id())
.expect("Column not found in table")
.as_scalar()
.expect("Failed to convert column to scalar");
final_round_evaluate_range_check(builder, scalars, 256, alloc);
table.clone()
}
}

impl ProofPlan for MembershipCheckTestPlan {
fn get_column_result_fields(&self) -> Vec<ColumnField> {
vec![ColumnField::new(
self.column.column_id(),
*self.column.column_type(),
)]
}

fn get_column_references(&self) -> IndexSet<ColumnRef> {
indexset! {self.column.clone()}
}

#[doc = " Return all the tables referenced in the Query"]
fn get_table_references(&self) -> IndexSet<TableRef> {
indexset! {self.column.table_ref()}
}

#[doc = " Form components needed to verify and proof store into `VerificationBuilder`"]
fn verifier_evaluate<S: Scalar>(
&self,
builder: &mut VerificationBuilder<S>,
accessor: &IndexMap<ColumnRef, S>,
_result: Option<&OwnedTable<S>>,
one_eval_map: &IndexMap<TableRef, S>,
) -> Result<TableEvaluation<S>, ProofError> {
let input_column_eval = accessor[&self.column];
let input_ones_eval = one_eval_map[&self.column.table_ref()];

verifier_evaluate_range_check(builder, input_ones_eval, input_column_eval)?;

Ok(TableEvaluation::new(
vec![accessor[&self.column]],
one_eval_map[&self.column.table_ref()],
))
}
}

#[cfg(all(test, feature = "blitzar"))]
mod tests {
use super::*;
use crate::{
base::database::{
owned_table_utility::{owned_table, scalar},
ColumnRef, ColumnType, OwnedTableTestAccessor,
},
sql::proof::VerifiableQueryResult,
};
use blitzar::proof::InnerProductProof;

#[test]
#[should_panic(
expected = "Range check failed, column contains values outside of the selected range"
)]
fn we_cannot_successfully_verify_invalid_range() {
let data = owned_table([scalar("a", -2..254)]);
let t = "sxt.t".parse().unwrap();
let accessor = OwnedTableTestAccessor::<InnerProductProof>::new_from_table(t, data, 0, ());
let ast = MembershipCheckTestPlan {
column: ColumnRef::new(t, "a".into(), ColumnType::Scalar),
};
let verifiable_res = VerifiableQueryResult::<InnerProductProof>::new(&ast, &accessor, &());
let _ = verifiable_res.verify(&ast, &accessor, &());
}

#[test]
fn we_can_prove_a_range_check_with_range_0_to_256() {
let data = owned_table([scalar("a", 0..256)]);
let t = "sxt.t".parse().unwrap();
let accessor = OwnedTableTestAccessor::<InnerProductProof>::new_from_table(t, data, 0, ());
let ast = MembershipCheckTestPlan {
column: ColumnRef::new(t, "a".into(), ColumnType::Scalar),
};
let verifiable_res = VerifiableQueryResult::<InnerProductProof>::new(&ast, &accessor, &());
let res: Result<
crate::sql::proof::QueryData<crate::base::scalar::MontScalar<ark_curve25519::FrConfig>>,
crate::sql::proof::QueryError,
> = verifiable_res.verify(&ast, &accessor, &());

if let Err(e) = res {
panic!("Verification failed: {e}");
}
assert!(res.is_ok());
}

#[test]
fn we_can_prove_a_range_check_with_range_1000_to_1256() {
let data = owned_table([scalar("a", 1000..1256)]);
let t = "sxt.t".parse().unwrap();
let accessor = OwnedTableTestAccessor::<InnerProductProof>::new_from_table(t, data, 0, ());
let ast = MembershipCheckTestPlan {
column: ColumnRef::new(t, "a".into(), ColumnType::Scalar),
};
let verifiable_res = VerifiableQueryResult::<InnerProductProof>::new(&ast, &accessor, &());
let res: Result<
crate::sql::proof::QueryData<crate::base::scalar::MontScalar<ark_curve25519::FrConfig>>,
crate::sql::proof::QueryError,
> = verifiable_res.verify(&ast, &accessor, &());

if let Err(e) = res {
panic!("Verification failed: {e}");
}
assert!(res.is_ok());
}
}
3 changes: 3 additions & 0 deletions crates/proof-of-sql/src/sql/proof_gadgets/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,9 @@ mod bitwise_verification;
use bitwise_verification::{verify_constant_abs_decomposition, verify_constant_sign_decomposition};
#[cfg(test)]
mod bitwise_verification_test;
mod membership_check;
#[cfg(test)]
mod membership_check_test;
mod sign_expr;
pub(crate) use sign_expr::{prover_evaluate_sign, result_evaluate_sign, verifier_evaluate_sign};
pub mod range_check;
Expand Down

0 comments on commit b72cbc1

Please sign in to comment.