Skip to content

Commit

Permalink
feat: add SliceExec (#379)
Browse files Browse the repository at this point in the history
Please be sure to look over the pull request guidelines here:
https://github.com/spaceandtimelabs/sxt-proof-of-sql/blob/main/CONTRIBUTING.md#submit-pr.

# Please go through the following checklist
- [x] The PR title and commit messages adhere to guidelines here:
https://github.com/spaceandtimelabs/sxt-proof-of-sql/blob/main/CONTRIBUTING.md.
In particular `!` is used if and only if at least one breaking change
has been introduced.
- [x] I have run the ci check script with `source
scripts/run_ci_checks.sh`.
- The following upstream PRs have been merged:
  - [x] #381
  - [x] #401
  - [x] #404

# Rationale for this change
This PR replaces #121 and is designed to test whether our `ProofPlan`s
are truly composable now.
<!--
Why are you proposing this change? If this is already explained clearly
in the linked issue then this section is not needed.
Explaining clearly why changes are proposed helps reviewers understand
your changes and offer better suggestions for fixes.

 Example:
 Add `NestedLoopJoinExec`.
 Closes #345.

Since we added `HashJoinExec` in #323 it has been possible to do
provable inner joins. However performance is not satisfactory in some
cases. Hence we need to fix the problem by implement
`NestedLoopJoinExec` and speed up the code
 for `HashJoinExec`.
-->

# What changes are included in this PR?
- add `SliceExec`.
<!--
There is no need to duplicate the description in the ticket here but it
is sometimes worth providing a summary of the individual changes in this
PR.

Example:
- Add `NestedLoopJoinExec`.
- Speed up `HashJoinExec`.
- Route joins to `NestedLoopJoinExec` if the outer input is sufficiently
small.
-->

# Are these changes tested?
<!--
We typically require tests for all PRs in order to:
1. Prevent the code from being accidentally broken by subsequent changes
2. Serve as another way to document the expected behavior of the code

If tests are not included in your PR, please explain why (for example,
are they covered by existing tests)?

Example:
Yes.
-->
Yes.
  • Loading branch information
iajoiner authored Dec 4, 2024
2 parents 078191c + 383320c commit 3eb74b3
Show file tree
Hide file tree
Showing 7 changed files with 785 additions and 3 deletions.
3 changes: 3 additions & 0 deletions crates/proof-of-sql/src/base/proof/error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,4 +6,7 @@ pub enum ProofError {
#[snafu(display("Verification error: {error}"))]
/// This error occurs when a proof failed to verify.
VerificationError { error: &'static str },
/// This error occurs when a query plan is not supported.
#[snafu(display("Unsupported query plan: {error}"))]
UnsupportedQueryPlan { error: &'static str },
}
7 changes: 6 additions & 1 deletion crates/proof-of-sql/src/sql/proof_plans/dyn_proof_plan.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use super::{EmptyExec, FilterExec, GroupByExec, ProjectionExec, TableExec};
use super::{EmptyExec, FilterExec, GroupByExec, ProjectionExec, SliceExec, TableExec};
use crate::{
base::{
database::{ColumnField, ColumnRef, OwnedTable, Table, TableEvaluation, TableRef},
Expand Down Expand Up @@ -43,4 +43,9 @@ pub enum DynProofPlan {
/// SELECT <result_expr1>, ..., <result_exprN> FROM <table> WHERE <where_clause>
/// ```
Filter(FilterExec),
/// `ProofPlan` for queries of the form
/// ```ignore
/// <ProofPlan> LIMIT <fetch> [OFFSET <skip>]
/// ```
Slice(SliceExec),
}
2 changes: 1 addition & 1 deletion crates/proof-of-sql/src/sql/proof_plans/filter_exec.rs
Original file line number Diff line number Diff line change
Expand Up @@ -249,7 +249,7 @@ impl ProverEvaluate for FilterExec {
}

#[allow(clippy::unnecessary_wraps, clippy::too_many_arguments)]
fn verify_filter<S: Scalar>(
pub(super) fn verify_filter<S: Scalar>(
builder: &mut VerificationBuilder<S>,
alpha: S,
beta: S,
Expand Down
6 changes: 6 additions & 0 deletions crates/proof-of-sql/src/sql/proof_plans/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -35,5 +35,11 @@ pub(crate) use group_by_exec::GroupByExec;
#[cfg(all(test, feature = "blitzar"))]
mod group_by_exec_test;

mod slice_exec;
#[allow(unused_imports)]
pub(crate) use slice_exec::SliceExec;
#[cfg(all(test, feature = "blitzar"))]
mod slice_exec_test;

mod dyn_proof_plan;
pub use dyn_proof_plan::DynProofPlan;
212 changes: 212 additions & 0 deletions crates/proof-of-sql/src/sql/proof_plans/slice_exec.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,212 @@
use super::{
filter_exec::{prove_filter, verify_filter},
DynProofPlan,
};
use crate::{
base::{
database::{
filter_util::filter_columns, ColumnField, ColumnRef, OwnedTable, Table,
TableEvaluation, TableOptions, TableRef,
},
map::{IndexMap, IndexSet},
proof::ProofError,
scalar::Scalar,
},
sql::proof::{
CountBuilder, FinalRoundBuilder, FirstRoundBuilder, ProofPlan, ProverEvaluate,
VerificationBuilder,
},
};
use alloc::{boxed::Box, vec, vec::Vec};
use bumpalo::Bump;
use core::iter::{repeat, repeat_with};
use itertools::repeat_n;
use serde::{Deserialize, Serialize};

/// `ProofPlan` for queries of the form
/// ```ignore
/// <ProofPlan> LIMIT <fetch> [OFFSET <skip>]
/// ```
#[derive(Debug, PartialEq, Serialize, Deserialize)]
pub struct SliceExec {
pub(super) input: Box<DynProofPlan>,
pub(super) skip: usize,
pub(super) fetch: Option<usize>,
}

/// Get the boolean slice selection from the number of rows, skip and fetch
fn get_slice_select(num_rows: usize, skip: usize, fetch: Option<usize>) -> Vec<bool> {
repeat_n(false, skip)
.chain(repeat_n(true, fetch.unwrap_or(num_rows)))
.chain(repeat(false))
.take(num_rows)
.collect()
}

impl SliceExec {
/// Creates a new slice execution plan.
#[allow(dead_code)]
pub fn new(input: Box<DynProofPlan>, skip: usize, fetch: Option<usize>) -> Self {
Self { input, skip, fetch }
}
}

impl ProofPlan for SliceExec
where
SliceExec: ProverEvaluate,
{
fn count(&self, builder: &mut CountBuilder) -> Result<(), ProofError> {
self.input.count(builder)?;
builder.count_intermediate_mles(self.input.get_column_result_fields().len());
builder.count_intermediate_mles(2);
builder.count_subpolynomials(3);
builder.count_degree(3);
builder.count_post_result_challenges(2);
Ok(())
}

#[allow(unused_variables)]
fn verifier_evaluate<S: Scalar>(
&self,
builder: &mut VerificationBuilder<S>,
accessor: &IndexMap<ColumnRef, S>,
_result: Option<&OwnedTable<S>>,
one_eval_map: &IndexMap<TableRef, S>,
) -> Result<TableEvaluation<S>, ProofError> {
// 1. columns
// We do not support `GroupByExec` as input for now
if matches!(*self.input, DynProofPlan::GroupBy(_)) {
return Err(ProofError::UnsupportedQueryPlan {
error: "GroupByExec as input for another plan is not supported",
});
}
let input_table_eval =
self.input
.verifier_evaluate(builder, accessor, None, one_eval_map)?;
let output_one_eval = builder.consume_one_evaluation();
let columns_evals = input_table_eval.column_evals();
// 2. selection
// The selected range is (offset_index, max_index]
let offset_one_eval = builder.consume_one_evaluation();
let max_one_eval = builder.consume_one_evaluation();
let selection_eval = max_one_eval - offset_one_eval;
// 3. filtered_columns
let filtered_columns_evals: Vec<_> = repeat_with(|| builder.consume_intermediate_mle())
.take(columns_evals.len())
.collect();
let alpha = builder.consume_post_result_challenge();
let beta = builder.consume_post_result_challenge();

verify_filter(
builder,
alpha,
beta,
*input_table_eval.one_eval(),
output_one_eval,
columns_evals,
selection_eval,
&filtered_columns_evals,
)?;
Ok(TableEvaluation::new(
filtered_columns_evals,
output_one_eval,
))
}

fn get_column_result_fields(&self) -> Vec<ColumnField> {
self.input.get_column_result_fields()
}

fn get_column_references(&self) -> IndexSet<ColumnRef> {
self.input.get_column_references()
}

fn get_table_references(&self) -> IndexSet<TableRef> {
self.input.get_table_references()
}
}

impl ProverEvaluate for SliceExec {
#[tracing::instrument(name = "SliceExec::first_round_evaluate", level = "debug", skip_all)]
fn first_round_evaluate<'a, S: Scalar>(
&self,
builder: &mut FirstRoundBuilder,
alloc: &'a Bump,
table_map: &IndexMap<TableRef, Table<'a, S>>,
) -> (Table<'a, S>, Vec<usize>) {
// 1. columns
let (input, input_one_eval_lengths) =
self.input.first_round_evaluate(builder, alloc, table_map);
let input_length = input.num_rows();
let columns = input.columns().copied().collect::<Vec<_>>();
// 2. select
let select = get_slice_select(input_length, self.skip, self.fetch);
// The selected range is (offset_index, max_index]
let offset_index = self.skip.min(input_length);
let max_index = if let Some(fetch) = self.fetch {
(self.skip + fetch).min(input_length)
} else {
input_length
};
let output_length = max_index - offset_index;
// Compute filtered_columns
let (filtered_columns, _) = filter_columns(alloc, &columns, &select);
let res = Table::<'a, S>::try_from_iter_with_options(
self.get_column_result_fields()
.into_iter()
.map(|expr| expr.name())
.zip(filtered_columns),
TableOptions::new(Some(output_length)),
)
.expect("Failed to create table from iterator");
let mut one_eval_lengths = input_one_eval_lengths;
one_eval_lengths.extend(vec![output_length, offset_index, max_index]);
builder.request_post_result_challenges(2);
(res, one_eval_lengths)
}

#[tracing::instrument(name = "SliceExec::prover_evaluate", level = "debug", skip_all)]
#[allow(unused_variables)]
fn final_round_evaluate<'a, S: Scalar>(
&self,
builder: &mut FinalRoundBuilder<'a, S>,
alloc: &'a Bump,
table_map: &IndexMap<TableRef, Table<'a, S>>,
) -> Table<'a, S> {
// 1. columns
let input = self.input.final_round_evaluate(builder, alloc, table_map);
let columns = input.columns().copied().collect::<Vec<_>>();
// 2. select
let select = get_slice_select(input.num_rows(), self.skip, self.fetch);
let select_ref: &'a [_] = alloc.alloc_slice_copy(&select);
let output_length = select.iter().filter(|b| **b).count();
// Compute filtered_columns and indexes
let (filtered_columns, result_len) = filter_columns(alloc, &columns, &select);
// 3. Produce MLEs
filtered_columns.iter().copied().for_each(|column| {
builder.produce_intermediate_mle(column);
});
let alpha = builder.consume_post_result_challenge();
let beta = builder.consume_post_result_challenge();

prove_filter::<S>(
builder,
alloc,
alpha,
beta,
&columns,
select_ref,
&filtered_columns,
input.num_rows(),
result_len,
);
Table::<'a, S>::try_from_iter_with_options(
self.get_column_result_fields()
.into_iter()
.map(|expr| expr.name())
.zip(filtered_columns),
TableOptions::new(Some(output_length)),
)
.expect("Failed to create table from iterator")
}
}
Loading

0 comments on commit 3eb74b3

Please sign in to comment.