diff --git a/crates/proof-of-sql/src/sql/proof_gadgets/membership_check.rs b/crates/proof-of-sql/src/sql/proof_gadgets/membership_check.rs new file mode 100644 index 000000000..071954da7 --- /dev/null +++ b/crates/proof-of-sql/src/sql/proof_gadgets/membership_check.rs @@ -0,0 +1,152 @@ +use crate::{ + base::{database::Column, proof::ProofError, scalar::Scalar, slice_ops}, + sql::{ + proof::{ + FinalRoundBuilder, FirstRoundBuilder, SumcheckSubpolynomialType, VerificationBuilder, + }, + proof_plans::{fold_columns, fold_vals}, + }, +}; +use alloc::{boxed::Box, vec, vec::Vec}; +use bumpalo::Bump; +use itertools::Itertools; +use num_traits::{One, Zero}; + +/// Perform first round evaluation of the membership check. +#[allow(dead_code)] +pub(crate) fn first_round_evaluate_membership_check<'a, S: Scalar>( + builder: &mut FirstRoundBuilder<'a, S>, + indexes: &[usize], + num_rows: usize, + alloc: &'a Bump, +) { + let multiplicity_map = indexes.into_iter().counts(); + let multiplicities = (0..num_rows - 1) + .map(|i| multiplicity_map.get(&i).copied().unwrap_or(0) as i128) + .collect::>(); + let alloc_multiplicities = alloc.alloc_slice_copy(&multiplicities); + builder.produce_intermediate_mle(alloc_multiplicities as &[_]); + builder.request_post_result_challenges(2); +} + +/// Perform final round evaluation of the membership check. +#[allow(dead_code)] +pub(crate) fn final_round_evaluate_membership_check<'a, S: Scalar>( + builder: &mut FinalRoundBuilder<'a, S>, + alloc: &'a Bump, + alpha: S, + beta: S, + columns: &[Column<'a, S>], + candidate_subset: &[Column<'a, S>], + indexes: &[usize], + num_rows: usize, + candidate_num_rows: usize, +) { + // 1. Get multiplicity of each index + let multiplicity_map = indexes.into_iter().counts(); + let multiplicities = (0..num_rows - 1) + .map(|i| multiplicity_map.get(&i).copied().unwrap_or(0) as i128) + .collect::>(); + let alloc_multiplicities = alloc.alloc_slice_copy(&multiplicities); + builder.produce_intermediate_mle(alloc_multiplicities as &[_]); + // 2. Fold the columns + let input_ones = alloc.alloc_slice_fill_copy(num_rows, true); + let candidate_ones = alloc.alloc_slice_fill_copy(candidate_num_rows, true); + + let c_fold = alloc.alloc_slice_fill_copy(num_rows, Zero::zero()); + fold_columns(c_fold, alpha, beta, columns); + let d_fold = alloc.alloc_slice_fill_copy(candidate_num_rows, Zero::zero()); + fold_columns(d_fold, alpha, beta, candidate_subset); + + let c_star = alloc.alloc_slice_copy(c_fold); + slice_ops::add_const::(c_star, One::one()); + slice_ops::batch_inversion(c_star); + + let d_star = alloc.alloc_slice_copy(d_fold); + slice_ops::add_const::(d_star, One::one()); + slice_ops::batch_inversion(d_star); + + builder.produce_intermediate_mle(c_star as &[_]); + builder.produce_intermediate_mle(d_star as &[_]); + + // sum c_star * multiplicities - d_star = 0 + builder.produce_sumcheck_subpolynomial( + SumcheckSubpolynomialType::ZeroSum, + vec![ + ( + S::one(), + vec![ + Box::new(c_star as &[_]), + Box::new(alloc_multiplicities as &[_]), + ], + ), + (-S::one(), vec![Box::new(d_star as &[_])]), + ], + ); + + // c_star + c_fold * c_star - input_ones = 0 + builder.produce_sumcheck_subpolynomial( + SumcheckSubpolynomialType::Identity, + vec![ + (S::one(), vec![Box::new(c_star as &[_])]), + ( + S::one(), + vec![Box::new(c_star as &[_]), Box::new(c_fold as &[_])], + ), + (-S::one(), vec![Box::new(input_ones as &[_])]), + ], + ); + + // d_star + d_fold * d_star - candidate_ones = 0 + builder.produce_sumcheck_subpolynomial( + SumcheckSubpolynomialType::Identity, + vec![ + (S::one(), vec![Box::new(d_star as &[_])]), + ( + S::one(), + vec![Box::new(d_star as &[_]), Box::new(d_fold as &[_])], + ), + (-S::one(), vec![Box::new(candidate_ones as &[_])]), + ], + ); +} + +#[allow(dead_code)] +pub(crate) fn verify_membership_check( + builder: &mut VerificationBuilder, + alpha: S, + beta: S, + input_one_eval: S, + candidate_one_eval: S, + column_evals: &[S], + candidate_evals: &[S], + multiplicity_eval: S, +) -> Result<(), ProofError> { + let c_fold_eval = alpha * fold_vals(beta, column_evals); + let d_fold_eval = alpha * fold_vals(beta, candidate_evals); + let c_star_eval = builder.try_consume_final_round_mle_evaluation()?; + let d_star_eval = builder.try_consume_final_round_mle_evaluation()?; + + // sum c_star * multiplicities - d_star = 0 + builder.try_produce_sumcheck_subpolynomial_evaluation( + SumcheckSubpolynomialType::ZeroSum, + c_star_eval * multiplicity_eval - d_star_eval, + 2, + )?; + + // c_star + c_fold * c_star - input_ones = 0 + builder.try_produce_sumcheck_subpolynomial_evaluation( + SumcheckSubpolynomialType::Identity, + c_star_eval + c_fold_eval * c_star_eval - input_one_eval, + 2, + )?; + + // d_star + d_fold * d_star - candidate_ones = 0 + builder.try_produce_sumcheck_subpolynomial_evaluation( + SumcheckSubpolynomialType::Identity, + d_star_eval + d_fold_eval * d_star_eval - candidate_one_eval, + 2, + )?; + + Ok(()) +} diff --git a/crates/proof-of-sql/src/sql/proof_gadgets/membership_check_test.rs b/crates/proof-of-sql/src/sql/proof_gadgets/membership_check_test.rs new file mode 100644 index 000000000..031f97582 --- /dev/null +++ b/crates/proof-of-sql/src/sql/proof_gadgets/membership_check_test.rs @@ -0,0 +1,163 @@ +use super::range_check::{final_round_evaluate_range_check, verifier_evaluate_range_check}; +use crate::{ + base::{ + database::{ColumnField, ColumnRef, OwnedTable, Table, TableEvaluation, TableRef}, + map::{indexset, IndexMap, IndexSet}, + proof::ProofError, + scalar::Scalar, + }, + sql::proof::{ + FinalRoundBuilder, FirstRoundBuilder, ProofPlan, ProverEvaluate, VerificationBuilder, + }, +}; +use bumpalo::Bump; +use serde::Serialize; + +#[derive(Debug, Serialize)] +pub struct MembershipCheckTestPlan<'a, S: Scalar> { + pub columns: Vec>, + pub candidate_subset: Vec>, +} + +impl<'a, S: Scalar> ProverEvaluate for MembershipCheckTestPlan<'a, S> { + #[doc = " Evaluate the query, modify `FirstRoundBuilder` and return the result."] + fn first_round_evaluate<'a, S: Scalar>( + &self, + builder: &mut FirstRoundBuilder<'a, S>, + _alloc: &'a Bump, + table_map: &IndexMap>, + ) -> Table<'a, S> { + candidate_subset.iter().for_each(|column| { + builder.produce_intermediate_mle(column.as_slice()); + }); + builder.request_post_result_challenges(2); + table_map[&self.column.table_ref()].clone() + } + + // extract data to test on from here, feed it into range check + fn final_round_evaluate<'a, S: Scalar>( + &self, + builder: &mut FinalRoundBuilder<'a, S>, + alloc: &'a Bump, + table_map: &IndexMap>, + ) -> Table<'a, S> { + // Get the table from the map using the table reference + let table: &Table<'a, S> = table_map + .get(&self.column.table_ref()) + .expect("Table not found"); + + let scalars = table + .inner_table() + .get(&self.column.column_id()) + .expect("Column not found in table") + .as_scalar() + .expect("Failed to convert column to scalar"); + final_round_evaluate_range_check(builder, scalars, 256, alloc); + table.clone() + } +} + +impl ProofPlan for MembershipCheckTestPlan { + fn get_column_result_fields(&self) -> Vec { + vec![ColumnField::new( + self.column.column_id(), + *self.column.column_type(), + )] + } + + fn get_column_references(&self) -> IndexSet { + indexset! {self.column.clone()} + } + + #[doc = " Return all the tables referenced in the Query"] + fn get_table_references(&self) -> IndexSet { + indexset! {self.column.table_ref()} + } + + #[doc = " Form components needed to verify and proof store into `VerificationBuilder`"] + fn verifier_evaluate( + &self, + builder: &mut VerificationBuilder, + accessor: &IndexMap, + _result: Option<&OwnedTable>, + one_eval_map: &IndexMap, + ) -> Result, ProofError> { + let input_column_eval = accessor[&self.column]; + let input_ones_eval = one_eval_map[&self.column.table_ref()]; + + verifier_evaluate_range_check(builder, input_ones_eval, input_column_eval)?; + + Ok(TableEvaluation::new( + vec![accessor[&self.column]], + one_eval_map[&self.column.table_ref()], + )) + } +} + +#[cfg(all(test, feature = "blitzar"))] +mod tests { + use super::*; + use crate::{ + base::database::{ + owned_table_utility::{owned_table, scalar}, + ColumnRef, ColumnType, OwnedTableTestAccessor, + }, + sql::proof::VerifiableQueryResult, + }; + use blitzar::proof::InnerProductProof; + + #[test] + #[should_panic( + expected = "Range check failed, column contains values outside of the selected range" + )] + fn we_cannot_successfully_verify_invalid_range() { + let data = owned_table([scalar("a", -2..254)]); + let t = "sxt.t".parse().unwrap(); + let accessor = OwnedTableTestAccessor::::new_from_table(t, data, 0, ()); + let ast = MembershipCheckTestPlan { + column: ColumnRef::new(t, "a".into(), ColumnType::Scalar), + }; + let verifiable_res = VerifiableQueryResult::::new(&ast, &accessor, &()); + let _ = verifiable_res.verify(&ast, &accessor, &()); + } + + #[test] + fn we_can_prove_a_range_check_with_range_0_to_256() { + let data = owned_table([scalar("a", 0..256)]); + let t = "sxt.t".parse().unwrap(); + let accessor = OwnedTableTestAccessor::::new_from_table(t, data, 0, ()); + let ast = MembershipCheckTestPlan { + column: ColumnRef::new(t, "a".into(), ColumnType::Scalar), + }; + let verifiable_res = VerifiableQueryResult::::new(&ast, &accessor, &()); + let res: Result< + crate::sql::proof::QueryData>, + crate::sql::proof::QueryError, + > = verifiable_res.verify(&ast, &accessor, &()); + + if let Err(e) = res { + panic!("Verification failed: {e}"); + } + assert!(res.is_ok()); + } + + #[test] + fn we_can_prove_a_range_check_with_range_1000_to_1256() { + let data = owned_table([scalar("a", 1000..1256)]); + let t = "sxt.t".parse().unwrap(); + let accessor = OwnedTableTestAccessor::::new_from_table(t, data, 0, ()); + let ast = MembershipCheckTestPlan { + column: ColumnRef::new(t, "a".into(), ColumnType::Scalar), + }; + let verifiable_res = VerifiableQueryResult::::new(&ast, &accessor, &()); + let res: Result< + crate::sql::proof::QueryData>, + crate::sql::proof::QueryError, + > = verifiable_res.verify(&ast, &accessor, &()); + + if let Err(e) = res { + panic!("Verification failed: {e}"); + } + assert!(res.is_ok()); + } +} diff --git a/crates/proof-of-sql/src/sql/proof_gadgets/mod.rs b/crates/proof-of-sql/src/sql/proof_gadgets/mod.rs index 0ec348767..67a8a9ed0 100644 --- a/crates/proof-of-sql/src/sql/proof_gadgets/mod.rs +++ b/crates/proof-of-sql/src/sql/proof_gadgets/mod.rs @@ -3,6 +3,9 @@ mod bitwise_verification; use bitwise_verification::{verify_constant_abs_decomposition, verify_constant_sign_decomposition}; #[cfg(test)] mod bitwise_verification_test; +mod membership_check; +#[cfg(test)] +mod membership_check_test; mod sign_expr; pub(crate) use sign_expr::{prover_evaluate_sign, result_evaluate_sign, verifier_evaluate_sign}; pub mod range_check;