From 557c669711d18373e74bc16eff4ee594b929da01 Mon Sep 17 00:00:00 2001 From: Ian Joiner <14581281+iajoiner@users.noreply.github.com> Date: Mon, 18 Nov 2024 18:37:24 -0500 Subject: [PATCH] feat: add `SliceExec` --- .../src/sql/proof_plans/dyn_proof_plan.rs | 7 +- .../src/sql/proof_plans/filter_exec.rs | 4 +- .../proof-of-sql/src/sql/proof_plans/mod.rs | 7 +- .../src/sql/proof_plans/slice_exec.rs | 182 +++++++++++ .../src/sql/proof_plans/slice_exec_test.rs | 297 ++++++++++++++++++ .../src/sql/proof_plans/test_utility.rs | 7 +- 6 files changed, 499 insertions(+), 5 deletions(-) create mode 100644 crates/proof-of-sql/src/sql/proof_plans/slice_exec.rs create mode 100644 crates/proof-of-sql/src/sql/proof_plans/slice_exec_test.rs diff --git a/crates/proof-of-sql/src/sql/proof_plans/dyn_proof_plan.rs b/crates/proof-of-sql/src/sql/proof_plans/dyn_proof_plan.rs index e402708aa..912071da1 100644 --- a/crates/proof-of-sql/src/sql/proof_plans/dyn_proof_plan.rs +++ b/crates/proof-of-sql/src/sql/proof_plans/dyn_proof_plan.rs @@ -1,4 +1,4 @@ -use super::{EmptyExec, FilterExec, GroupByExec, ProjectionExec, TableExec}; +use super::{EmptyExec, FilterExec, GroupByExec, ProjectionExec, SliceExec, TableExec}; use crate::{ base::{ database::{ColumnField, ColumnRef, OwnedTable, Table, TableRef}, @@ -43,4 +43,9 @@ pub enum DynProofPlan { /// SELECT , ..., FROM WHERE /// ``` Filter(FilterExec), + /// `ProofPlan` for queries of the form + /// ```ignore + /// LIMIT [OFFSET ] + /// ``` + Slice(SliceExec), } diff --git a/crates/proof-of-sql/src/sql/proof_plans/filter_exec.rs b/crates/proof-of-sql/src/sql/proof_plans/filter_exec.rs index 4040ff225..4829d8c47 100644 --- a/crates/proof-of-sql/src/sql/proof_plans/filter_exec.rs +++ b/crates/proof-of-sql/src/sql/proof_plans/filter_exec.rs @@ -233,7 +233,7 @@ impl ProverEvaluate for FilterExec { } #[allow(clippy::unnecessary_wraps)] -fn verify_filter( +pub(crate) fn verify_filter( builder: &mut VerificationBuilder, alpha: S, beta: S, @@ -271,7 +271,7 @@ fn verify_filter( } #[allow(clippy::too_many_arguments, clippy::many_single_char_names)] -pub(super) fn prove_filter<'a, S: Scalar + 'a>( +pub(crate) fn prove_filter<'a, S: Scalar + 'a>( builder: &mut FinalRoundBuilder<'a, S>, alloc: &'a Bump, alpha: S, diff --git a/crates/proof-of-sql/src/sql/proof_plans/mod.rs b/crates/proof-of-sql/src/sql/proof_plans/mod.rs index 98ddea04e..f468204bd 100644 --- a/crates/proof-of-sql/src/sql/proof_plans/mod.rs +++ b/crates/proof-of-sql/src/sql/proof_plans/mod.rs @@ -16,14 +16,19 @@ mod projection_exec_test; pub(crate) mod test_utility; mod filter_exec; -pub(crate) use filter_exec::FilterExec; #[cfg(test)] pub(crate) use filter_exec::OstensibleFilterExec; +pub(crate) use filter_exec::{prove_filter, verify_filter, FilterExec}; #[cfg(all(test, feature = "blitzar"))] mod filter_exec_test; #[cfg(all(test, feature = "blitzar"))] mod filter_exec_test_dishonest_prover; +mod slice_exec; +pub(crate) use slice_exec::SliceExec; +#[cfg(all(test, feature = "blitzar"))] +mod slice_exec_test; + mod fold_util; pub(crate) use fold_util::{fold_columns, fold_vals}; #[cfg(test)] diff --git a/crates/proof-of-sql/src/sql/proof_plans/slice_exec.rs b/crates/proof-of-sql/src/sql/proof_plans/slice_exec.rs new file mode 100644 index 000000000..5ed7d1ad7 --- /dev/null +++ b/crates/proof-of-sql/src/sql/proof_plans/slice_exec.rs @@ -0,0 +1,182 @@ +use super::{prove_filter, verify_filter, DynProofPlan}; +use crate::{ + base::{ + database::{ + filter_util::filter_columns, ColumnField, ColumnRef, OwnedTable, Table, TableOptions, + TableRef, + }, + map::{IndexMap, IndexSet}, + proof::ProofError, + scalar::Scalar, + }, + sql::proof::{ + CountBuilder, FinalRoundBuilder, FirstRoundBuilder, ProofPlan, ProverEvaluate, + VerificationBuilder, + }, +}; +use alloc::{boxed::Box, vec::Vec}; +use bumpalo::Bump; +use core::iter::repeat_with; +use serde::{Deserialize, Serialize}; + +/// `ProofPlan` for queries of the form +/// ```ignore +/// LIMIT [OFFSET ] +/// ``` +#[derive(Debug, PartialEq, Serialize, Deserialize)] +pub struct SliceExec { + pub(super) input: Box, + pub(super) skip: usize, + pub(super) fetch: Option, +} + +/// Get the boolean slice selection from the number of rows, skip and fetch +fn get_slice_select(num_rows: usize, skip: usize, fetch: Option) -> Vec { + if let Some(fetch) = fetch { + let end = skip + fetch; + (0..num_rows).map(|i| i >= skip && i < end).collect() + } else { + (0..num_rows).map(|i| i >= skip).collect() + } +} + +impl SliceExec { + /// Creates a new slice execution plan. + pub fn new(input: Box, skip: usize, fetch: Option) -> Self { + Self { input, skip, fetch } + } +} + +impl ProofPlan for SliceExec +where + SliceExec: ProverEvaluate, +{ + fn count(&self, builder: &mut CountBuilder) -> Result<(), ProofError> { + self.input.count(builder)?; + builder.count_intermediate_mles(self.input.get_column_result_fields().len()); + builder.count_intermediate_mles(3); + builder.count_subpolynomials(3); + builder.count_degree(3); + builder.count_post_result_challenges(2); + Ok(()) + } + + #[allow(unused_variables)] + fn verifier_evaluate( + &self, + builder: &mut VerificationBuilder, + accessor: &IndexMap, + _result: Option<&OwnedTable>, + ) -> Result, ProofError> { + // 1. columns + // TODO: Make sure `GroupByExec` as self.input is supported + let columns_evals = self.input.verifier_evaluate(builder, accessor, None)?; + // 2. selection + let selection_eval = builder.consume_intermediate_mle(); + // 3. filtered_columns + let filtered_columns_evals: Vec<_> = repeat_with(|| builder.consume_intermediate_mle()) + .take(columns_evals.len()) + .collect(); + let alpha = builder.consume_post_result_challenge(); + let beta = builder.consume_post_result_challenge(); + + verify_filter( + builder, + alpha, + beta, + &columns_evals, + selection_eval, + &filtered_columns_evals, + )?; + Ok(filtered_columns_evals) + } + + fn get_column_result_fields(&self) -> Vec { + self.input.get_column_result_fields() + } + + fn get_column_references(&self) -> IndexSet { + self.input.get_column_references() + } + + fn get_table_references(&self) -> IndexSet { + self.input.get_table_references() + } +} + +impl ProverEvaluate for SliceExec { + #[tracing::instrument(name = "SliceExec::result_evaluate", level = "debug", skip_all)] + fn result_evaluate<'a, S: Scalar>( + &self, + alloc: &'a Bump, + table_map: &IndexMap>, + ) -> Table<'a, S> { + // 1. columns + let input = self.input.result_evaluate(alloc, table_map); + let columns = input.columns().copied().collect::>(); + // 2. select + let select = get_slice_select(input.num_rows(), self.skip, self.fetch); + let output_length = select.iter().filter(|b| **b).count(); + // Compute filtered_columns + let (filtered_columns, _) = filter_columns(alloc, &columns, &select); + Table::<'a, S>::try_from_iter_with_options( + self.get_column_result_fields() + .into_iter() + .map(|expr| expr.name()) + .zip(filtered_columns), + TableOptions::new(Some(output_length)), + ) + .expect("Failed to create table from iterator") + } + + fn first_round_evaluate(&self, builder: &mut FirstRoundBuilder) { + self.input.first_round_evaluate(builder); + builder.request_post_result_challenges(2); + } + + #[tracing::instrument(name = "SliceExec::prover_evaluate", level = "debug", skip_all)] + #[allow(unused_variables)] + fn final_round_evaluate<'a, S: Scalar>( + &self, + builder: &mut FinalRoundBuilder<'a, S>, + alloc: &'a Bump, + table_map: &IndexMap>, + ) -> Table<'a, S> { + // 1. columns + let input = self.input.final_round_evaluate(builder, alloc, table_map); + let columns = input.columns().copied().collect::>(); + // 2. select + let select = get_slice_select(input.num_rows(), self.skip, self.fetch); + let select_ref: &'a [_] = alloc.alloc_slice_copy(&select); + let output_length = select.iter().filter(|b| **b).count(); + + builder.produce_intermediate_mle(select_ref); + // Compute filtered_columns and indexes + let (filtered_columns, result_len) = filter_columns(alloc, &columns, &select); + // 3. Produce MLEs + for column in &filtered_columns { + builder.produce_intermediate_mle(column.as_scalar(alloc)); + } + let alpha = builder.consume_post_result_challenge(); + let beta = builder.consume_post_result_challenge(); + + prove_filter::( + builder, + alloc, + alpha, + beta, + &columns, + select_ref, + &filtered_columns, + result_len, + ); + Table::<'a, S>::try_from_iter_with_options( + self.get_column_result_fields() + .into_iter() + .map(|expr| expr.name()) + .zip(filtered_columns), + TableOptions::new(Some(output_length)), + ) + .expect("Failed to create table from iterator") + } +} diff --git a/crates/proof-of-sql/src/sql/proof_plans/slice_exec_test.rs b/crates/proof-of-sql/src/sql/proof_plans/slice_exec_test.rs new file mode 100644 index 000000000..d3a0703c7 --- /dev/null +++ b/crates/proof-of-sql/src/sql/proof_plans/slice_exec_test.rs @@ -0,0 +1,297 @@ +use super::test_utility::*; +use crate::{ + base::{ + database::{ + owned_table_utility::*, table_utility::*, ColumnField, ColumnType, OwnedTable, + OwnedTableTestAccessor, TableTestAccessor, TestAccessor, + }, + map::{indexmap, IndexMap}, + math::decimal::Precision, + scalar::Curve25519Scalar, + }, + sql::{ + proof::{ + exercise_verification, FirstRoundBuilder, ProvableQueryResult, ProverEvaluate, + VerifiableQueryResult, + }, + proof_exprs::{test_utility::*, DynProofExpr}, + }, +}; +use blitzar::proof::InnerProductProof; +use bumpalo::Bump; + +#[test] +fn we_can_prove_and_get_the_correct_result_from_a_slice_exec() { + let data = owned_table([ + bigint("a", [1_i64, 2, 3, 4, 5]), + varchar("b", ["1", "2", "3", "4", "5"]), + ]); + let t = "sxt.t".parse().unwrap(); + let accessor = OwnedTableTestAccessor::::new_from_table(t, data, 0, ()); + let ast = slice_exec( + projection(cols_expr_plan(t, &["a", "b"], &accessor), tab(t)), + 1, + Some(2), + ); + let verifiable_res = VerifiableQueryResult::new(&ast, &accessor, &()); + exercise_verification(&verifiable_res, &ast, &accessor, t); + let res = verifiable_res.verify(&ast, &accessor, &()).unwrap().table; + let expected_res = owned_table([bigint("a", [2_i64, 3]), varchar("b", ["2", "3"])]); + assert_eq!(res, expected_res); +} + +#[test] +fn we_can_prove_and_get_the_correct_empty_result_from_a_slice_exec() { + let data = owned_table([ + bigint("a", [1_i64, 2, 3, 4, 5]), + varchar("b", ["1", "2", "3", "4", "5"]), + ]); + let t = "sxt.t".parse().unwrap(); + let accessor = OwnedTableTestAccessor::::new_from_table(t, data, 0, ()); + let where_clause: DynProofExpr = equal(column(t, "a", &accessor), const_int128(2)); + let ast = slice_exec( + filter( + cols_expr_plan(t, &["a", "b"], &accessor), + tab(t), + where_clause, + ), + 1, + Some(2), + ); + let verifiable_res = VerifiableQueryResult::new(&ast, &accessor, &()); + exercise_verification(&verifiable_res, &ast, &accessor, t); + let res = verifiable_res.verify(&ast, &accessor, &()).unwrap().table; + let expected_res = owned_table([bigint("a", [0_i64; 0]), varchar("b", [""; 0])]); + assert_eq!(res, expected_res); +} + +#[test] +fn we_can_get_an_empty_result_from_a_slice_on_an_empty_table_using_result_evaluate() { + let alloc = Bump::new(); + let data = table([ + borrowed_bigint("a", [0; 0], &alloc), + borrowed_bigint("b", [0; 0], &alloc), + borrowed_int128("c", [0; 0], &alloc), + borrowed_varchar("d", [""; 0], &alloc), + borrowed_scalar("e", [0; 0], &alloc), + ]); + let t = "sxt.t".parse().unwrap(); + let table_map = indexmap! { + t => data.clone() + }; + let mut accessor = TableTestAccessor::::new_empty_with_setup(()); + accessor.add_table(t, data, 0); + let where_clause: DynProofExpr = equal(column(t, "a", &accessor), const_int128(999)); + let expr = slice_exec( + filter( + cols_expr_plan(t, &["b", "c", "d", "e"], &accessor), + tab(t), + where_clause, + ), + 1, + Some(2), + ); + + let mut builder = FirstRoundBuilder::new(); + expr.first_round_evaluate(&mut builder); + let fields = &[ + ColumnField::new("b".parse().unwrap(), ColumnType::BigInt), + ColumnField::new("c".parse().unwrap(), ColumnType::Int128), + ColumnField::new("d".parse().unwrap(), ColumnType::VarChar), + ColumnField::new( + "e".parse().unwrap(), + ColumnType::Decimal75(Precision::new(75).unwrap(), 0), + ), + ]; + let res: OwnedTable = + ProvableQueryResult::from(expr.result_evaluate(&alloc, &table_map)) + .to_owned_table(fields) + .unwrap(); + let expected: OwnedTable = owned_table([ + bigint("b", [0; 0]), + int128("c", [0; 0]), + varchar("d", [""; 0]), + decimal75("e", 75, 0, [0; 0]), + ]); + + assert_eq!(res, expected); +} + +#[test] +fn we_can_get_an_empty_result_from_a_slice_using_result_evaluate() { + let alloc = Bump::new(); + let data = table([ + borrowed_bigint("a", [1, 4, 5, 2, 5], &alloc), + borrowed_bigint("b", [1, 2, 3, 4, 5], &alloc), + borrowed_int128("c", [1, 2, 3, 4, 5], &alloc), + borrowed_varchar("d", ["1", "2", "3", "4", "5"], &alloc), + borrowed_scalar("e", [1, 2, 3, 4, 5], &alloc), + ]); + let t = "sxt.t".parse().unwrap(); + let table_map = indexmap! { + t => data.clone() + }; + let mut accessor = TableTestAccessor::::new_empty_with_setup(()); + accessor.add_table(t, data, 0); + let where_clause: DynProofExpr = equal(column(t, "a", &accessor), const_int128(999)); + let expr = slice_exec( + filter( + cols_expr_plan(t, &["b", "c", "d", "e"], &accessor), + tab(t), + where_clause, + ), + 1, + Some(2), + ); + let mut builder = FirstRoundBuilder::new(); + expr.first_round_evaluate(&mut builder); + let fields = &[ + ColumnField::new("b".parse().unwrap(), ColumnType::BigInt), + ColumnField::new("c".parse().unwrap(), ColumnType::Int128), + ColumnField::new("d".parse().unwrap(), ColumnType::VarChar), + ColumnField::new( + "e".parse().unwrap(), + ColumnType::Decimal75(Precision::new(1).unwrap(), 0), + ), + ]; + let res: OwnedTable = + ProvableQueryResult::from(expr.result_evaluate(&alloc, &table_map)) + .to_owned_table(fields) + .unwrap(); + let expected: OwnedTable = owned_table([ + bigint("b", [0; 0]), + int128("c", [0; 0]), + varchar("d", [""; 0]), + decimal75("e", 1, 0, [0; 0]), + ]); + + assert_eq!(res, expected); +} + +#[test] +fn we_can_get_no_columns_from_a_slice_with_empty_input_using_result_evaluate() { + let alloc = Bump::new(); + let data = table([ + borrowed_bigint("a", [1, 4, 5, 2, 5], &alloc), + borrowed_bigint("b", [1, 2, 3, 4, 5], &alloc), + borrowed_int128("c", [1, 2, 3, 4, 5], &alloc), + borrowed_varchar("d", ["1", "2", "3", "4", "5"], &alloc), + borrowed_scalar("e", [1, 2, 3, 4, 5], &alloc), + ]); + let t = "sxt.t".parse().unwrap(); + let table_map = indexmap! { + t => data.clone() + }; + let mut accessor = TableTestAccessor::::new_empty_with_setup(()); + accessor.add_table(t, data, 0); + let where_clause: DynProofExpr = equal(column(t, "a", &accessor), const_int128(5)); + let expr = slice_exec( + filter(cols_expr_plan(t, &[], &accessor), tab(t), where_clause), + 2, + None, + ); + let mut builder = FirstRoundBuilder::new(); + expr.first_round_evaluate(&mut builder); + let fields = &[]; + let res: OwnedTable = + ProvableQueryResult::from(expr.result_evaluate(&alloc, &table_map)) + .to_owned_table(fields) + .unwrap(); + let expected = OwnedTable::try_new(IndexMap::default()).unwrap(); + assert_eq!(res, expected); +} + +#[test] +fn we_can_get_the_correct_result_from_a_slice_using_result_evaluate() { + let alloc = Bump::new(); + let data = table([ + borrowed_bigint("a", [1, 4, 5, 2, 5], &alloc), + borrowed_bigint("b", [1, 2, 3, 4, 5], &alloc), + borrowed_int128("c", [1, 2, 3, 4, 5], &alloc), + borrowed_varchar("d", ["1", "2", "3", "4", "5"], &alloc), + borrowed_scalar("e", [1, 2, 3, 4, 5], &alloc), + ]); + let t = "sxt.t".parse().unwrap(); + let table_map = indexmap! { + t => data.clone() + }; + let mut accessor = TableTestAccessor::::new_empty_with_setup(()); + accessor.add_table(t, data, 0); + let where_clause: DynProofExpr = equal(column(t, "a", &accessor), const_int128(5)); + let expr = slice_exec( + filter( + cols_expr_plan(t, &["b", "c", "d", "e"], &accessor), + tab(t), + where_clause, + ), + 1, + None, + ); + let mut builder = FirstRoundBuilder::new(); + expr.first_round_evaluate(&mut builder); + let fields = &[ + ColumnField::new("b".parse().unwrap(), ColumnType::BigInt), + ColumnField::new("c".parse().unwrap(), ColumnType::Int128), + ColumnField::new("d".parse().unwrap(), ColumnType::VarChar), + ColumnField::new( + "e".parse().unwrap(), + ColumnType::Decimal75(Precision::new(1).unwrap(), 0), + ), + ]; + let res: OwnedTable = + ProvableQueryResult::from(expr.result_evaluate(&alloc, &table_map)) + .to_owned_table(fields) + .unwrap(); + let expected: OwnedTable = owned_table([ + bigint("b", [5]), + int128("c", [5]), + varchar("d", ["5"]), + decimal75("e", 1, 0, [5]), + ]); + assert_eq!(res, expected); +} + +#[test] +fn we_can_prove_a_slice_exec() { + let data = owned_table([ + bigint("a", [101, 105, 105, 105, 105]), + bigint("b", [1, 2, 3, 4, 7]), + int128("c", [1, 3, 3, 4, 5]), + varchar("d", ["1", "2", "3", "4", "5"]), + scalar("e", [1, 2, 3, 4, 5]), + ]); + let t = "sxt.t".parse().unwrap(); + let mut accessor = OwnedTableTestAccessor::::new_empty_with_setup(()); + accessor.add_table(t, data, 0); + let expr = slice_exec( + filter( + vec![ + col_expr_plan(t, "b", &accessor), + col_expr_plan(t, "c", &accessor), + col_expr_plan(t, "d", &accessor), + col_expr_plan(t, "e", &accessor), + aliased_plan(const_int128(105), "const"), + aliased_plan( + equal(column(t, "b", &accessor), column(t, "c", &accessor)), + "bool", + ), + ], + tab(t), + equal(column(t, "a", &accessor), const_int128(105)), + ), + 2, + Some(1), + ); + let res = VerifiableQueryResult::new(&expr, &accessor, &()); + exercise_verification(&res, &expr, &accessor, t); + let res = res.verify(&expr, &accessor, &()).unwrap().table; + let expected = owned_table([ + bigint("b", [4]), + int128("c", [4]), + varchar("d", ["4"]), + scalar("e", [4]), + int128("const", [105]), + boolean("bool", [true]), + ]); + assert_eq!(res, expected); +} diff --git a/crates/proof-of-sql/src/sql/proof_plans/test_utility.rs b/crates/proof-of-sql/src/sql/proof_plans/test_utility.rs index b22a61846..33f35481a 100644 --- a/crates/proof-of-sql/src/sql/proof_plans/test_utility.rs +++ b/crates/proof-of-sql/src/sql/proof_plans/test_utility.rs @@ -1,8 +1,9 @@ -use super::{DynProofPlan, FilterExec, GroupByExec, ProjectionExec, TableExec}; +use super::{DynProofPlan, FilterExec, GroupByExec, ProjectionExec, SliceExec, TableExec}; use crate::{ base::database::{ColumnField, TableRef}, sql::proof_exprs::{AliasedDynProofExpr, ColumnExpr, DynProofExpr, TableExpr}, }; +use alloc::boxed::Box; pub fn table_exec(table_ref: TableRef, schema: Vec) -> DynProofPlan { DynProofPlan::Table(TableExec::new(table_ref, schema)) @@ -38,3 +39,7 @@ pub fn group_by( where_clause, )) } + +pub fn slice_exec(input: DynProofPlan, skip: usize, fetch: Option) -> DynProofPlan { + DynProofPlan::Slice(SliceExec::new(Box::new(input), skip, fetch)) +}