Skip to content

Commit

Permalink
feat: add SliceExec
Browse files Browse the repository at this point in the history
  • Loading branch information
iajoiner authored and JayWhite2357 committed Dec 4, 2024
1 parent 078191c commit d96b015
Show file tree
Hide file tree
Showing 6 changed files with 662 additions and 5 deletions.
7 changes: 6 additions & 1 deletion crates/proof-of-sql/src/sql/proof_plans/dyn_proof_plan.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use super::{EmptyExec, FilterExec, GroupByExec, ProjectionExec, TableExec};
use super::{EmptyExec, FilterExec, GroupByExec, ProjectionExec, SliceExec, TableExec};
use crate::{
base::{
database::{ColumnField, ColumnRef, OwnedTable, Table, TableEvaluation, TableRef},
Expand Down Expand Up @@ -43,4 +43,9 @@ pub enum DynProofPlan {
/// SELECT <result_expr1>, ..., <result_exprN> FROM <table> WHERE <where_clause>
/// ```
Filter(FilterExec),
/// `ProofPlan` for queries of the form
/// ```ignore
/// <ProofPlan> LIMIT <fetch> [OFFSET <skip>]
/// ```
Slice(SliceExec),
}
4 changes: 2 additions & 2 deletions crates/proof-of-sql/src/sql/proof_plans/filter_exec.rs
Original file line number Diff line number Diff line change
Expand Up @@ -249,7 +249,7 @@ impl ProverEvaluate for FilterExec {
}

#[allow(clippy::unnecessary_wraps, clippy::too_many_arguments)]
fn verify_filter<S: Scalar>(
pub(crate) fn verify_filter<S: Scalar>(
builder: &mut VerificationBuilder<S>,
alpha: S,
beta: S,
Expand Down Expand Up @@ -286,7 +286,7 @@ fn verify_filter<S: Scalar>(
}

#[allow(clippy::too_many_arguments, clippy::many_single_char_names)]
pub(super) fn prove_filter<'a, S: Scalar + 'a>(
pub(crate) fn prove_filter<'a, S: Scalar + 'a>(
builder: &mut FinalRoundBuilder<'a, S>,
alloc: &'a Bump,
alpha: S,
Expand Down
8 changes: 7 additions & 1 deletion crates/proof-of-sql/src/sql/proof_plans/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,9 @@ mod projection_exec_test;
pub(crate) mod test_utility;

mod filter_exec;
pub(crate) use filter_exec::FilterExec;
#[cfg(test)]
pub(crate) use filter_exec::OstensibleFilterExec;
pub(crate) use filter_exec::{prove_filter, verify_filter, FilterExec};
#[cfg(all(test, feature = "blitzar"))]
mod filter_exec_test;
#[cfg(all(test, feature = "blitzar"))]
Expand All @@ -35,5 +35,11 @@ pub(crate) use group_by_exec::GroupByExec;
#[cfg(all(test, feature = "blitzar"))]
mod group_by_exec_test;

mod slice_exec;
#[allow(unused_imports)]
pub(crate) use slice_exec::SliceExec;
#[cfg(all(test, feature = "blitzar"))]
mod slice_exec_test;

mod dyn_proof_plan;
pub use dyn_proof_plan::DynProofPlan;
206 changes: 206 additions & 0 deletions crates/proof-of-sql/src/sql/proof_plans/slice_exec.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,206 @@
use super::{prove_filter, verify_filter, DynProofPlan};
use crate::{
base::{
database::{
filter_util::filter_columns, ColumnField, ColumnRef, OwnedTable, Table,
TableEvaluation, TableOptions, TableRef,
},
map::{IndexMap, IndexSet},
proof::ProofError,
scalar::Scalar,
},
sql::proof::{
CountBuilder, FinalRoundBuilder, FirstRoundBuilder, ProofPlan, ProverEvaluate,
VerificationBuilder,
},
};
use alloc::{boxed::Box, vec, vec::Vec};
use bumpalo::Bump;
use core::iter::repeat_with;
use serde::{Deserialize, Serialize};

/// `ProofPlan` for queries of the form
/// ```ignore
/// <ProofPlan> LIMIT <fetch> [OFFSET <skip>]
/// ```
#[derive(Debug, PartialEq, Serialize, Deserialize)]
pub struct SliceExec {
pub(super) input: Box<DynProofPlan>,
pub(super) skip: usize,
pub(super) fetch: Option<usize>,
}

/// Get the boolean slice selection from the number of rows, skip and fetch
fn get_slice_select(num_rows: usize, skip: usize, fetch: Option<usize>) -> Vec<bool> {
if let Some(fetch) = fetch {
let end = skip + fetch;
(0..num_rows).map(|i| i >= skip && i < end).collect()
} else {
(0..num_rows).map(|i| i >= skip).collect()
}
}

impl SliceExec {
/// Creates a new slice execution plan.
#[allow(dead_code)]
pub fn new(input: Box<DynProofPlan>, skip: usize, fetch: Option<usize>) -> Self {
Self { input, skip, fetch }
}
}

impl ProofPlan for SliceExec
where
SliceExec: ProverEvaluate,
{
fn count(&self, builder: &mut CountBuilder) -> Result<(), ProofError> {
self.input.count(builder)?;
builder.count_intermediate_mles(self.input.get_column_result_fields().len());
builder.count_intermediate_mles(2);
builder.count_subpolynomials(3);
builder.count_degree(3);
builder.count_post_result_challenges(2);
Ok(())
}

#[allow(unused_variables)]
fn verifier_evaluate<S: Scalar>(
&self,
builder: &mut VerificationBuilder<S>,
accessor: &IndexMap<ColumnRef, S>,
_result: Option<&OwnedTable<S>>,
one_eval_map: &IndexMap<TableRef, S>,
) -> Result<TableEvaluation<S>, ProofError> {
// 1. columns
// TODO: Make sure `GroupByExec` as self.input is supported
let input_table_eval =
self.input
.verifier_evaluate(builder, accessor, None, one_eval_map)?;
let output_one_eval = builder.consume_one_evaluation();
let columns_evals = input_table_eval.column_evals();
// 2. selection
// The selected range is (offset_index, max_index]
let offset_one_eval = builder.consume_one_evaluation();
let max_one_eval = builder.consume_one_evaluation();
let selection_eval = max_one_eval - offset_one_eval;
// 3. filtered_columns
let filtered_columns_evals: Vec<_> = repeat_with(|| builder.consume_intermediate_mle())
.take(columns_evals.len())
.collect();
let alpha = builder.consume_post_result_challenge();
let beta = builder.consume_post_result_challenge();

verify_filter(
builder,
alpha,
beta,
*input_table_eval.one_eval(),
output_one_eval,
columns_evals,
selection_eval,
&filtered_columns_evals,
)?;
Ok(TableEvaluation::new(
filtered_columns_evals,
output_one_eval,
))
}

fn get_column_result_fields(&self) -> Vec<ColumnField> {
self.input.get_column_result_fields()
}

fn get_column_references(&self) -> IndexSet<ColumnRef> {
self.input.get_column_references()
}

fn get_table_references(&self) -> IndexSet<TableRef> {
self.input.get_table_references()
}
}

impl ProverEvaluate for SliceExec {
#[tracing::instrument(name = "SliceExec::result_evaluate", level = "debug", skip_all)]
fn result_evaluate<'a, S: Scalar>(
&self,
alloc: &'a Bump,
table_map: &IndexMap<TableRef, Table<'a, S>>,
) -> (Table<'a, S>, Vec<usize>) {
// 1. columns
let (input, input_one_eval_lengths) = self.input.result_evaluate(alloc, table_map);
let input_length = input.num_rows();
let columns = input.columns().copied().collect::<Vec<_>>();
// 2. select
let select = get_slice_select(input_length, self.skip, self.fetch);
// The selected range is (offset_index, max_index]
let offset_index = self.skip.min(input_length);
let max_index = if let Some(fetch) = self.fetch {
(self.skip + fetch).min(input_length)
} else {
input_length
};
let output_length = max_index - offset_index;
// Compute filtered_columns
let (filtered_columns, _) = filter_columns(alloc, &columns, &select);
let res = Table::<'a, S>::try_from_iter_with_options(
self.get_column_result_fields()
.into_iter()
.map(|expr| expr.name())
.zip(filtered_columns),
TableOptions::new(Some(output_length)),
)
.expect("Failed to create table from iterator");
let mut one_eval_lengths = input_one_eval_lengths;
one_eval_lengths.extend(vec![output_length, offset_index, max_index]);
(res, one_eval_lengths)
}

fn first_round_evaluate(&self, builder: &mut FirstRoundBuilder) {
self.input.first_round_evaluate(builder);
builder.request_post_result_challenges(2);
}

#[tracing::instrument(name = "SliceExec::prover_evaluate", level = "debug", skip_all)]
#[allow(unused_variables)]
fn final_round_evaluate<'a, S: Scalar>(
&self,
builder: &mut FinalRoundBuilder<'a, S>,
alloc: &'a Bump,
table_map: &IndexMap<TableRef, Table<'a, S>>,
) -> Table<'a, S> {
// 1. columns
let input = self.input.final_round_evaluate(builder, alloc, table_map);
let columns = input.columns().copied().collect::<Vec<_>>();
// 2. select
let select = get_slice_select(input.num_rows(), self.skip, self.fetch);
let select_ref: &'a [_] = alloc.alloc_slice_copy(&select);
let output_length = select.iter().filter(|b| **b).count();
// Compute filtered_columns and indexes
let (filtered_columns, result_len) = filter_columns(alloc, &columns, &select);
// 3. Produce MLEs
filtered_columns.iter().copied().for_each(|column| {
builder.produce_intermediate_mle(column);
});
let alpha = builder.consume_post_result_challenge();
let beta = builder.consume_post_result_challenge();

prove_filter::<S>(
builder,
alloc,
alpha,
beta,
&columns,
select_ref,
&filtered_columns,
input.num_rows(),
result_len,
);
Table::<'a, S>::try_from_iter_with_options(
self.get_column_result_fields()
.into_iter()
.map(|expr| expr.name())
.zip(filtered_columns),
TableOptions::new(Some(output_length)),
)
.expect("Failed to create table from iterator")
}
}
Loading

0 comments on commit d96b015

Please sign in to comment.