From 740e7361b93e8b328dd5f05a286644eb85fe0874 Mon Sep 17 00:00:00 2001 From: Ian Joiner <14581281+iajoiner@users.noreply.github.com> Date: Tue, 26 Nov 2024 10:19:52 -0500 Subject: [PATCH] feat: add `SliceExec` --- .../src/sql/proof_plans/dyn_proof_plan.rs | 7 +- .../src/sql/proof_plans/filter_exec.rs | 6 +- .../filter_exec_test_dishonest_prover.rs | 4 +- .../proof-of-sql/src/sql/proof_plans/mod.rs | 8 +- .../src/sql/proof_plans/slice_exec.rs | 192 ++++++++++++ .../src/sql/proof_plans/slice_exec_test.rs | 289 ++++++++++++++++++ .../src/sql/proof_plans/test_utility.rs | 6 +- 7 files changed, 505 insertions(+), 7 deletions(-) create mode 100644 crates/proof-of-sql/src/sql/proof_plans/slice_exec.rs create mode 100644 crates/proof-of-sql/src/sql/proof_plans/slice_exec_test.rs diff --git a/crates/proof-of-sql/src/sql/proof_plans/dyn_proof_plan.rs b/crates/proof-of-sql/src/sql/proof_plans/dyn_proof_plan.rs index e117ed8d3..dc988b9da 100644 --- a/crates/proof-of-sql/src/sql/proof_plans/dyn_proof_plan.rs +++ b/crates/proof-of-sql/src/sql/proof_plans/dyn_proof_plan.rs @@ -1,4 +1,4 @@ -use super::{EmptyExec, FilterExec, GroupByExec, ProjectionExec, TableExec}; +use super::{EmptyExec, FilterExec, GroupByExec, ProjectionExec, SliceExec, TableExec}; use crate::{ base::{ database::{ColumnField, ColumnRef, OwnedTable, Table, TableEvaluation, TableRef}, @@ -43,4 +43,9 @@ pub enum DynProofPlan { /// SELECT , ..., FROM WHERE /// ``` Filter(FilterExec), + /// `ProofPlan` for queries of the form + /// ```ignore + /// LIMIT [OFFSET ] + /// ``` + Slice(SliceExec), } diff --git a/crates/proof-of-sql/src/sql/proof_plans/filter_exec.rs b/crates/proof-of-sql/src/sql/proof_plans/filter_exec.rs index 9469772f3..b477a9d32 100644 --- a/crates/proof-of-sql/src/sql/proof_plans/filter_exec.rs +++ b/crates/proof-of-sql/src/sql/proof_plans/filter_exec.rs @@ -223,6 +223,7 @@ impl ProverEvaluate for FilterExec { let alpha = builder.consume_post_result_challenge(); let beta = builder.consume_post_result_challenge(); + builder.push_one_evaluation_length(table.num_rows()); prove_filter::( builder, @@ -247,7 +248,7 @@ impl ProverEvaluate for FilterExec { } #[allow(clippy::unnecessary_wraps, clippy::too_many_arguments)] -fn verify_filter( +pub(crate) fn verify_filter( builder: &mut VerificationBuilder, alpha: S, beta: S, @@ -284,7 +285,7 @@ fn verify_filter( } #[allow(clippy::too_many_arguments, clippy::many_single_char_names)] -pub(super) fn prove_filter<'a, S: Scalar + 'a>( +pub(crate) fn prove_filter<'a, S: Scalar + 'a>( builder: &mut FinalRoundBuilder<'a, S>, alloc: &'a Bump, alpha: S, @@ -295,7 +296,6 @@ pub(super) fn prove_filter<'a, S: Scalar + 'a>( n: usize, m: usize, ) { - builder.push_one_evaluation_length(n); builder.push_one_evaluation_length(m); let chi = alloc.alloc_slice_fill_copy(n, false); chi[..m].fill(true); diff --git a/crates/proof-of-sql/src/sql/proof_plans/filter_exec_test_dishonest_prover.rs b/crates/proof-of-sql/src/sql/proof_plans/filter_exec_test_dishonest_prover.rs index 69d37a017..21a9415c2 100644 --- a/crates/proof-of-sql/src/sql/proof_plans/filter_exec_test_dishonest_prover.rs +++ b/crates/proof-of-sql/src/sql/proof_plans/filter_exec_test_dishonest_prover.rs @@ -109,6 +109,8 @@ impl ProverEvaluate for DishonestFilterExec { let alpha = builder.consume_post_result_challenge(); let beta = builder.consume_post_result_challenge(); + let input_length = table.num_rows(); + builder.push_one_evaluation_length(input_length); prove_filter( builder, @@ -118,7 +120,7 @@ impl ProverEvaluate for DishonestFilterExec { &columns, selection, &filtered_columns, - table.num_rows(), + input_length, result_len, ); Table::<'a, S>::try_from_iter_with_options( diff --git a/crates/proof-of-sql/src/sql/proof_plans/mod.rs b/crates/proof-of-sql/src/sql/proof_plans/mod.rs index 98ddea04e..9f23adb47 100644 --- a/crates/proof-of-sql/src/sql/proof_plans/mod.rs +++ b/crates/proof-of-sql/src/sql/proof_plans/mod.rs @@ -16,9 +16,9 @@ mod projection_exec_test; pub(crate) mod test_utility; mod filter_exec; -pub(crate) use filter_exec::FilterExec; #[cfg(test)] pub(crate) use filter_exec::OstensibleFilterExec; +pub(crate) use filter_exec::{prove_filter, verify_filter, FilterExec}; #[cfg(all(test, feature = "blitzar"))] mod filter_exec_test; #[cfg(all(test, feature = "blitzar"))] @@ -35,5 +35,11 @@ pub(crate) use group_by_exec::GroupByExec; #[cfg(all(test, feature = "blitzar"))] mod group_by_exec_test; +mod slice_exec; +#[allow(unused_imports)] +pub(crate) use slice_exec::SliceExec; +#[cfg(all(test, feature = "blitzar"))] +mod slice_exec_test; + mod dyn_proof_plan; pub use dyn_proof_plan::DynProofPlan; diff --git a/crates/proof-of-sql/src/sql/proof_plans/slice_exec.rs b/crates/proof-of-sql/src/sql/proof_plans/slice_exec.rs new file mode 100644 index 000000000..963c38940 --- /dev/null +++ b/crates/proof-of-sql/src/sql/proof_plans/slice_exec.rs @@ -0,0 +1,192 @@ +use super::{prove_filter, verify_filter, DynProofPlan}; +use crate::{ + base::{ + database::{ + filter_util::filter_columns, ColumnField, ColumnRef, OwnedTable, Table, + TableEvaluation, TableOptions, TableRef, + }, + map::{IndexMap, IndexSet}, + proof::ProofError, + scalar::Scalar, + }, + sql::proof::{ + CountBuilder, FinalRoundBuilder, FirstRoundBuilder, ProofPlan, ProverEvaluate, + VerificationBuilder, + }, +}; +use alloc::{boxed::Box, vec::Vec}; +use bumpalo::Bump; +use core::iter::repeat_with; +use serde::{Deserialize, Serialize}; + +/// `ProofPlan` for queries of the form +/// ```ignore +/// LIMIT [OFFSET ] +/// ``` +#[derive(Debug, PartialEq, Serialize, Deserialize)] +pub struct SliceExec { + pub(super) input: Box, + pub(super) skip: usize, + pub(super) fetch: Option, +} + +/// Get the boolean slice selection from the number of rows, skip and fetch +fn get_slice_select(num_rows: usize, skip: usize, fetch: Option) -> Vec { + if let Some(fetch) = fetch { + let end = skip + fetch; + (0..num_rows).map(|i| i >= skip && i < end).collect() + } else { + (0..num_rows).map(|i| i >= skip).collect() + } +} + +impl SliceExec { + /// Creates a new slice execution plan. + #[allow(dead_code)] + pub fn new(input: Box, skip: usize, fetch: Option) -> Self { + Self { input, skip, fetch } + } +} + +impl ProofPlan for SliceExec +where + SliceExec: ProverEvaluate, +{ + fn count(&self, builder: &mut CountBuilder) -> Result<(), ProofError> { + self.input.count(builder)?; + builder.count_intermediate_mles(self.input.get_column_result_fields().len()); + builder.count_intermediate_mles(3); + builder.count_subpolynomials(3); + builder.count_degree(3); + builder.count_post_result_challenges(2); + Ok(()) + } + + #[allow(unused_variables)] + fn verifier_evaluate( + &self, + builder: &mut VerificationBuilder, + accessor: &IndexMap, + _result: Option<&OwnedTable>, + ) -> Result, ProofError> { + // 1. columns + // TODO: Make sure `GroupByExec` as self.input is supported + let input_table_eval = self.input.verifier_evaluate(builder, accessor, None)?; + let output_one_eval = builder.consume_one_evaluation(); + let columns_evals = input_table_eval.column_evals(); + // 2. selection + let selection_eval = builder.consume_intermediate_mle(); + // 3. filtered_columns + let filtered_columns_evals: Vec<_> = repeat_with(|| builder.consume_intermediate_mle()) + .take(columns_evals.len()) + .collect(); + let alpha = builder.consume_post_result_challenge(); + let beta = builder.consume_post_result_challenge(); + let output_one_eval = builder.consume_one_evaluation(); + + verify_filter( + builder, + alpha, + beta, + *input_table_eval.one_eval(), + output_one_eval, + columns_evals, + selection_eval, + &filtered_columns_evals, + )?; + Ok(TableEvaluation::new( + filtered_columns_evals, + output_one_eval, + )) + } + + fn get_column_result_fields(&self) -> Vec { + self.input.get_column_result_fields() + } + + fn get_column_references(&self) -> IndexSet { + self.input.get_column_references() + } + + fn get_table_references(&self) -> IndexSet { + self.input.get_table_references() + } +} + +impl ProverEvaluate for SliceExec { + #[tracing::instrument(name = "SliceExec::result_evaluate", level = "debug", skip_all)] + fn result_evaluate<'a, S: Scalar>( + &self, + alloc: &'a Bump, + table_map: &IndexMap>, + ) -> Table<'a, S> { + // 1. columns + let input = self.input.result_evaluate(alloc, table_map); + let columns = input.columns().copied().collect::>(); + // 2. select + let select = get_slice_select(input.num_rows(), self.skip, self.fetch); + let output_length = select.iter().filter(|b| **b).count(); + // Compute filtered_columns + let (filtered_columns, _) = filter_columns(alloc, &columns, &select); + Table::<'a, S>::try_from_iter_with_options( + self.get_column_result_fields() + .into_iter() + .map(|expr| expr.name()) + .zip(filtered_columns), + TableOptions::new(Some(output_length)), + ) + .expect("Failed to create table from iterator") + } + + fn first_round_evaluate(&self, builder: &mut FirstRoundBuilder) { + self.input.first_round_evaluate(builder); + builder.request_post_result_challenges(2); + } + + #[tracing::instrument(name = "SliceExec::prover_evaluate", level = "debug", skip_all)] + #[allow(unused_variables)] + fn final_round_evaluate<'a, S: Scalar>( + &self, + builder: &mut FinalRoundBuilder<'a, S>, + alloc: &'a Bump, + table_map: &IndexMap>, + ) -> Table<'a, S> { + // 1. columns + let input = self.input.final_round_evaluate(builder, alloc, table_map); + let columns = input.columns().copied().collect::>(); + // 2. select + let select = get_slice_select(input.num_rows(), self.skip, self.fetch); + let select_ref: &'a [_] = alloc.alloc_slice_copy(&select); + let output_length = select.iter().filter(|b| **b).count(); + + builder.produce_intermediate_mle(select_ref); + // Compute filtered_columns and indexes + let (filtered_columns, result_len) = filter_columns(alloc, &columns, &select); + // 3. Produce MLEs + filtered_columns.iter().copied().for_each(|column| { + builder.produce_intermediate_mle(column); + }); + let alpha = builder.consume_post_result_challenge(); + let beta = builder.consume_post_result_challenge(); + + prove_filter::( + builder, + alloc, + alpha, + beta, + &columns, + select_ref, + &filtered_columns, + input.num_rows(), + result_len, + ); + Table::<'a, S>::try_from_iter_with_options( + self.get_column_result_fields() + .into_iter() + .map(|expr| expr.name()) + .zip(filtered_columns), + TableOptions::new(Some(output_length)), + ) + .expect("Failed to create table from iterator") + } +} diff --git a/crates/proof-of-sql/src/sql/proof_plans/slice_exec_test.rs b/crates/proof-of-sql/src/sql/proof_plans/slice_exec_test.rs new file mode 100644 index 000000000..db22cd271 --- /dev/null +++ b/crates/proof-of-sql/src/sql/proof_plans/slice_exec_test.rs @@ -0,0 +1,289 @@ +use super::test_utility::*; +use crate::{ + base::{ + database::{ + owned_table_utility::*, table_utility::*, ColumnField, ColumnType, OwnedTable, + OwnedTableTestAccessor, TableTestAccessor, TestAccessor, + }, + map::{indexmap, IndexMap}, + math::decimal::Precision, + scalar::Curve25519Scalar, + }, + sql::{ + proof::{ + exercise_verification, ProvableQueryResult, ProverEvaluate, VerifiableQueryResult, + }, + proof_exprs::{test_utility::*, DynProofExpr}, + }, +}; +use blitzar::proof::InnerProductProof; +use bumpalo::Bump; + +#[test] +fn we_can_prove_and_get_the_correct_result_from_a_slice_exec() { + let data = owned_table([ + bigint("a", [1_i64, 2, 3, 4, 5]), + varchar("b", ["1", "2", "3", "4", "5"]), + ]); + let t = "sxt.t".parse().unwrap(); + let accessor = OwnedTableTestAccessor::::new_from_table(t, data, 0, ()); + let ast = slice_exec( + projection(cols_expr_plan(t, &["a", "b"], &accessor), tab(t)), + 1, + Some(2), + ); + let verifiable_res = VerifiableQueryResult::new(&ast, &accessor, &()); + exercise_verification(&verifiable_res, &ast, &accessor, t); + let res = verifiable_res.verify(&ast, &accessor, &()).unwrap().table; + let expected_res = owned_table([bigint("a", [2_i64, 3]), varchar("b", ["2", "3"])]); + assert_eq!(res, expected_res); +} + +#[test] +fn we_can_prove_and_get_the_correct_empty_result_from_a_slice_exec() { + let data = owned_table([ + bigint("a", [1_i64, 2, 3, 4, 5]), + varchar("b", ["1", "2", "3", "4", "5"]), + ]); + let t = "sxt.t".parse().unwrap(); + let accessor = OwnedTableTestAccessor::::new_from_table(t, data, 0, ()); + let where_clause: DynProofExpr = equal(column(t, "a", &accessor), const_int128(2)); + let ast = slice_exec( + filter( + cols_expr_plan(t, &["a", "b"], &accessor), + tab(t), + where_clause, + ), + 1, + Some(2), + ); + let verifiable_res = VerifiableQueryResult::new(&ast, &accessor, &()); + exercise_verification(&verifiable_res, &ast, &accessor, t); + let res = verifiable_res.verify(&ast, &accessor, &()).unwrap().table; + let expected_res = owned_table([bigint("a", [0_i64; 0]), varchar("b", [""; 0])]); + assert_eq!(res, expected_res); +} + +#[test] +fn we_can_get_an_empty_result_from_a_slice_on_an_empty_table_using_result_evaluate() { + let alloc = Bump::new(); + let data = table([ + borrowed_bigint("a", [0; 0], &alloc), + borrowed_bigint("b", [0; 0], &alloc), + borrowed_int128("c", [0; 0], &alloc), + borrowed_varchar("d", [""; 0], &alloc), + borrowed_scalar("e", [0; 0], &alloc), + ]); + let t = "sxt.t".parse().unwrap(); + let table_map = indexmap! { + t => data.clone() + }; + let mut accessor = TableTestAccessor::::new_empty_with_setup(()); + accessor.add_table(t, data, 0); + let where_clause: DynProofExpr = equal(column(t, "a", &accessor), const_int128(999)); + let expr = slice_exec( + filter( + cols_expr_plan(t, &["b", "c", "d", "e"], &accessor), + tab(t), + where_clause, + ), + 1, + Some(2), + ); + + let fields = &[ + ColumnField::new("b".parse().unwrap(), ColumnType::BigInt), + ColumnField::new("c".parse().unwrap(), ColumnType::Int128), + ColumnField::new("d".parse().unwrap(), ColumnType::VarChar), + ColumnField::new( + "e".parse().unwrap(), + ColumnType::Decimal75(Precision::new(75).unwrap(), 0), + ), + ]; + let res: OwnedTable = + ProvableQueryResult::from(expr.result_evaluate(&alloc, &table_map)) + .to_owned_table(fields) + .unwrap(); + let expected: OwnedTable = owned_table([ + bigint("b", [0; 0]), + int128("c", [0; 0]), + varchar("d", [""; 0]), + decimal75("e", 75, 0, [0; 0]), + ]); + + assert_eq!(res, expected); +} + +#[test] +fn we_can_get_an_empty_result_from_a_slice_using_result_evaluate() { + let alloc = Bump::new(); + let data = table([ + borrowed_bigint("a", [1, 4, 5, 2, 5], &alloc), + borrowed_bigint("b", [1, 2, 3, 4, 5], &alloc), + borrowed_int128("c", [1, 2, 3, 4, 5], &alloc), + borrowed_varchar("d", ["1", "2", "3", "4", "5"], &alloc), + borrowed_scalar("e", [1, 2, 3, 4, 5], &alloc), + ]); + let t = "sxt.t".parse().unwrap(); + let table_map = indexmap! { + t => data.clone() + }; + let mut accessor = TableTestAccessor::::new_empty_with_setup(()); + accessor.add_table(t, data, 0); + let where_clause: DynProofExpr = equal(column(t, "a", &accessor), const_int128(999)); + let expr = slice_exec( + filter( + cols_expr_plan(t, &["b", "c", "d", "e"], &accessor), + tab(t), + where_clause, + ), + 1, + Some(2), + ); + + let fields = &[ + ColumnField::new("b".parse().unwrap(), ColumnType::BigInt), + ColumnField::new("c".parse().unwrap(), ColumnType::Int128), + ColumnField::new("d".parse().unwrap(), ColumnType::VarChar), + ColumnField::new( + "e".parse().unwrap(), + ColumnType::Decimal75(Precision::new(1).unwrap(), 0), + ), + ]; + let res: OwnedTable = + ProvableQueryResult::from(expr.result_evaluate(&alloc, &table_map)) + .to_owned_table(fields) + .unwrap(); + let expected: OwnedTable = owned_table([ + bigint("b", [0; 0]), + int128("c", [0; 0]), + varchar("d", [""; 0]), + decimal75("e", 1, 0, [0; 0]), + ]); + + assert_eq!(res, expected); +} + +#[test] +fn we_can_get_no_columns_from_a_slice_with_empty_input_using_result_evaluate() { + let alloc = Bump::new(); + let data = table([ + borrowed_bigint("a", [1, 4, 5, 2, 5], &alloc), + borrowed_bigint("b", [1, 2, 3, 4, 5], &alloc), + borrowed_int128("c", [1, 2, 3, 4, 5], &alloc), + borrowed_varchar("d", ["1", "2", "3", "4", "5"], &alloc), + borrowed_scalar("e", [1, 2, 3, 4, 5], &alloc), + ]); + let t = "sxt.t".parse().unwrap(); + let table_map = indexmap! { + t => data.clone() + }; + let mut accessor = TableTestAccessor::::new_empty_with_setup(()); + accessor.add_table(t, data, 0); + let where_clause: DynProofExpr = equal(column(t, "a", &accessor), const_int128(5)); + let expr = slice_exec( + filter(cols_expr_plan(t, &[], &accessor), tab(t), where_clause), + 2, + None, + ); + let fields = &[]; + let res: OwnedTable = + ProvableQueryResult::from(expr.result_evaluate(&alloc, &table_map)) + .to_owned_table(fields) + .unwrap(); + let expected = OwnedTable::try_new(IndexMap::default()).unwrap(); + assert_eq!(res, expected); +} + +#[test] +fn we_can_get_the_correct_result_from_a_slice_using_result_evaluate() { + let alloc = Bump::new(); + let data = table([ + borrowed_bigint("a", [1, 4, 5, 2, 5], &alloc), + borrowed_bigint("b", [1, 2, 3, 4, 5], &alloc), + borrowed_int128("c", [1, 2, 3, 4, 5], &alloc), + borrowed_varchar("d", ["1", "2", "3", "4", "5"], &alloc), + borrowed_scalar("e", [1, 2, 3, 4, 5], &alloc), + ]); + let t = "sxt.t".parse().unwrap(); + let table_map = indexmap! { + t => data.clone() + }; + let mut accessor = TableTestAccessor::::new_empty_with_setup(()); + accessor.add_table(t, data, 0); + let where_clause: DynProofExpr = equal(column(t, "a", &accessor), const_int128(5)); + let expr = slice_exec( + filter( + cols_expr_plan(t, &["b", "c", "d", "e"], &accessor), + tab(t), + where_clause, + ), + 1, + None, + ); + let fields = &[ + ColumnField::new("b".parse().unwrap(), ColumnType::BigInt), + ColumnField::new("c".parse().unwrap(), ColumnType::Int128), + ColumnField::new("d".parse().unwrap(), ColumnType::VarChar), + ColumnField::new( + "e".parse().unwrap(), + ColumnType::Decimal75(Precision::new(1).unwrap(), 0), + ), + ]; + let res: OwnedTable = + ProvableQueryResult::from(expr.result_evaluate(&alloc, &table_map)) + .to_owned_table(fields) + .unwrap(); + let expected: OwnedTable = owned_table([ + bigint("b", [5]), + int128("c", [5]), + varchar("d", ["5"]), + decimal75("e", 1, 0, [5]), + ]); + assert_eq!(res, expected); +} + +#[test] +fn we_can_prove_a_slice_exec() { + let data = owned_table([ + bigint("a", [101, 105, 105, 105, 105]), + bigint("b", [1, 2, 3, 4, 7]), + int128("c", [1, 3, 3, 4, 5]), + varchar("d", ["1", "2", "3", "4", "5"]), + scalar("e", [1, 2, 3, 4, 5]), + ]); + let t = "sxt.t".parse().unwrap(); + let mut accessor = OwnedTableTestAccessor::::new_empty_with_setup(()); + accessor.add_table(t, data, 0); + let expr = slice_exec( + filter( + vec![ + col_expr_plan(t, "b", &accessor), + col_expr_plan(t, "c", &accessor), + col_expr_plan(t, "d", &accessor), + col_expr_plan(t, "e", &accessor), + aliased_plan(const_int128(105), "const"), + aliased_plan( + equal(column(t, "b", &accessor), column(t, "c", &accessor)), + "bool", + ), + ], + tab(t), + equal(column(t, "a", &accessor), const_int128(105)), + ), + 2, + Some(1), + ); + let res = VerifiableQueryResult::new(&expr, &accessor, &()); + exercise_verification(&res, &expr, &accessor, t); + let res = res.verify(&expr, &accessor, &()).unwrap().table; + let expected = owned_table([ + bigint("b", [4]), + int128("c", [4]), + varchar("d", ["4"]), + scalar("e", [4]), + int128("const", [105]), + boolean("bool", [true]), + ]); + assert_eq!(res, expected); +} diff --git a/crates/proof-of-sql/src/sql/proof_plans/test_utility.rs b/crates/proof-of-sql/src/sql/proof_plans/test_utility.rs index b22a61846..93f27db53 100644 --- a/crates/proof-of-sql/src/sql/proof_plans/test_utility.rs +++ b/crates/proof-of-sql/src/sql/proof_plans/test_utility.rs @@ -1,4 +1,4 @@ -use super::{DynProofPlan, FilterExec, GroupByExec, ProjectionExec, TableExec}; +use super::{DynProofPlan, FilterExec, GroupByExec, ProjectionExec, SliceExec, TableExec}; use crate::{ base::database::{ColumnField, TableRef}, sql::proof_exprs::{AliasedDynProofExpr, ColumnExpr, DynProofExpr, TableExpr}, @@ -38,3 +38,7 @@ pub fn group_by( where_clause, )) } + +pub fn slice_exec(input: DynProofPlan, skip: usize, fetch: Option) -> DynProofPlan { + DynProofPlan::Slice(SliceExec::new(Box::new(input), skip, fetch)) +}