-
Notifications
You must be signed in to change notification settings - Fork 126
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Please be sure to look over the pull request guidelines here: https://github.com/spaceandtimelabs/sxt-proof-of-sql/blob/main/CONTRIBUTING.md#submit-pr. # Please go through the following checklist - [x] The PR title and commit messages adhere to guidelines here: https://github.com/spaceandtimelabs/sxt-proof-of-sql/blob/main/CONTRIBUTING.md. In particular `!` is used if and only if at least one breaking change has been introduced. - [x] I have run the ci check script with `source scripts/run_ci_checks.sh`. - The following upstream PRs have been merged: - [x] #381 - [x] #401 - [x] #404 # Rationale for this change This PR replaces #121 and is designed to test whether our `ProofPlan`s are truly composable now. <!-- Why are you proposing this change? If this is already explained clearly in the linked issue then this section is not needed. Explaining clearly why changes are proposed helps reviewers understand your changes and offer better suggestions for fixes. Example: Add `NestedLoopJoinExec`. Closes #345. Since we added `HashJoinExec` in #323 it has been possible to do provable inner joins. However performance is not satisfactory in some cases. Hence we need to fix the problem by implement `NestedLoopJoinExec` and speed up the code for `HashJoinExec`. --> # What changes are included in this PR? - add `SliceExec`. <!-- There is no need to duplicate the description in the ticket here but it is sometimes worth providing a summary of the individual changes in this PR. Example: - Add `NestedLoopJoinExec`. - Speed up `HashJoinExec`. - Route joins to `NestedLoopJoinExec` if the outer input is sufficiently small. --> # Are these changes tested? <!-- We typically require tests for all PRs in order to: 1. Prevent the code from being accidentally broken by subsequent changes 2. Serve as another way to document the expected behavior of the code If tests are not included in your PR, please explain why (for example, are they covered by existing tests)? Example: Yes. --> Yes.
- Loading branch information
Showing
7 changed files
with
785 additions
and
3 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,212 @@ | ||
use super::{ | ||
filter_exec::{prove_filter, verify_filter}, | ||
DynProofPlan, | ||
}; | ||
use crate::{ | ||
base::{ | ||
database::{ | ||
filter_util::filter_columns, ColumnField, ColumnRef, OwnedTable, Table, | ||
TableEvaluation, TableOptions, TableRef, | ||
}, | ||
map::{IndexMap, IndexSet}, | ||
proof::ProofError, | ||
scalar::Scalar, | ||
}, | ||
sql::proof::{ | ||
CountBuilder, FinalRoundBuilder, FirstRoundBuilder, ProofPlan, ProverEvaluate, | ||
VerificationBuilder, | ||
}, | ||
}; | ||
use alloc::{boxed::Box, vec, vec::Vec}; | ||
use bumpalo::Bump; | ||
use core::iter::{repeat, repeat_with}; | ||
use itertools::repeat_n; | ||
use serde::{Deserialize, Serialize}; | ||
|
||
/// `ProofPlan` for queries of the form | ||
/// ```ignore | ||
/// <ProofPlan> LIMIT <fetch> [OFFSET <skip>] | ||
/// ``` | ||
#[derive(Debug, PartialEq, Serialize, Deserialize)] | ||
pub struct SliceExec { | ||
pub(super) input: Box<DynProofPlan>, | ||
pub(super) skip: usize, | ||
pub(super) fetch: Option<usize>, | ||
} | ||
|
||
/// Get the boolean slice selection from the number of rows, skip and fetch | ||
fn get_slice_select(num_rows: usize, skip: usize, fetch: Option<usize>) -> Vec<bool> { | ||
repeat_n(false, skip) | ||
.chain(repeat_n(true, fetch.unwrap_or(num_rows))) | ||
.chain(repeat(false)) | ||
.take(num_rows) | ||
.collect() | ||
} | ||
|
||
impl SliceExec { | ||
/// Creates a new slice execution plan. | ||
#[allow(dead_code)] | ||
pub fn new(input: Box<DynProofPlan>, skip: usize, fetch: Option<usize>) -> Self { | ||
Self { input, skip, fetch } | ||
} | ||
} | ||
|
||
impl ProofPlan for SliceExec | ||
where | ||
SliceExec: ProverEvaluate, | ||
{ | ||
fn count(&self, builder: &mut CountBuilder) -> Result<(), ProofError> { | ||
self.input.count(builder)?; | ||
builder.count_intermediate_mles(self.input.get_column_result_fields().len()); | ||
builder.count_intermediate_mles(2); | ||
builder.count_subpolynomials(3); | ||
builder.count_degree(3); | ||
builder.count_post_result_challenges(2); | ||
Ok(()) | ||
} | ||
|
||
#[allow(unused_variables)] | ||
fn verifier_evaluate<S: Scalar>( | ||
&self, | ||
builder: &mut VerificationBuilder<S>, | ||
accessor: &IndexMap<ColumnRef, S>, | ||
_result: Option<&OwnedTable<S>>, | ||
one_eval_map: &IndexMap<TableRef, S>, | ||
) -> Result<TableEvaluation<S>, ProofError> { | ||
// 1. columns | ||
// We do not support `GroupByExec` as input for now | ||
if matches!(*self.input, DynProofPlan::GroupBy(_)) { | ||
return Err(ProofError::UnsupportedQueryPlan { | ||
error: "GroupByExec as input for another plan is not supported", | ||
}); | ||
} | ||
let input_table_eval = | ||
self.input | ||
.verifier_evaluate(builder, accessor, None, one_eval_map)?; | ||
let output_one_eval = builder.consume_one_evaluation(); | ||
let columns_evals = input_table_eval.column_evals(); | ||
// 2. selection | ||
// The selected range is (offset_index, max_index] | ||
let offset_one_eval = builder.consume_one_evaluation(); | ||
let max_one_eval = builder.consume_one_evaluation(); | ||
let selection_eval = max_one_eval - offset_one_eval; | ||
// 3. filtered_columns | ||
let filtered_columns_evals: Vec<_> = repeat_with(|| builder.consume_intermediate_mle()) | ||
.take(columns_evals.len()) | ||
.collect(); | ||
let alpha = builder.consume_post_result_challenge(); | ||
let beta = builder.consume_post_result_challenge(); | ||
|
||
verify_filter( | ||
builder, | ||
alpha, | ||
beta, | ||
*input_table_eval.one_eval(), | ||
output_one_eval, | ||
columns_evals, | ||
selection_eval, | ||
&filtered_columns_evals, | ||
)?; | ||
Ok(TableEvaluation::new( | ||
filtered_columns_evals, | ||
output_one_eval, | ||
)) | ||
} | ||
|
||
fn get_column_result_fields(&self) -> Vec<ColumnField> { | ||
self.input.get_column_result_fields() | ||
} | ||
|
||
fn get_column_references(&self) -> IndexSet<ColumnRef> { | ||
self.input.get_column_references() | ||
} | ||
|
||
fn get_table_references(&self) -> IndexSet<TableRef> { | ||
self.input.get_table_references() | ||
} | ||
} | ||
|
||
impl ProverEvaluate for SliceExec { | ||
#[tracing::instrument(name = "SliceExec::first_round_evaluate", level = "debug", skip_all)] | ||
fn first_round_evaluate<'a, S: Scalar>( | ||
&self, | ||
builder: &mut FirstRoundBuilder, | ||
alloc: &'a Bump, | ||
table_map: &IndexMap<TableRef, Table<'a, S>>, | ||
) -> (Table<'a, S>, Vec<usize>) { | ||
// 1. columns | ||
let (input, input_one_eval_lengths) = | ||
self.input.first_round_evaluate(builder, alloc, table_map); | ||
let input_length = input.num_rows(); | ||
let columns = input.columns().copied().collect::<Vec<_>>(); | ||
// 2. select | ||
let select = get_slice_select(input_length, self.skip, self.fetch); | ||
// The selected range is (offset_index, max_index] | ||
let offset_index = self.skip.min(input_length); | ||
let max_index = if let Some(fetch) = self.fetch { | ||
(self.skip + fetch).min(input_length) | ||
} else { | ||
input_length | ||
}; | ||
let output_length = max_index - offset_index; | ||
// Compute filtered_columns | ||
let (filtered_columns, _) = filter_columns(alloc, &columns, &select); | ||
let res = Table::<'a, S>::try_from_iter_with_options( | ||
self.get_column_result_fields() | ||
.into_iter() | ||
.map(|expr| expr.name()) | ||
.zip(filtered_columns), | ||
TableOptions::new(Some(output_length)), | ||
) | ||
.expect("Failed to create table from iterator"); | ||
let mut one_eval_lengths = input_one_eval_lengths; | ||
one_eval_lengths.extend(vec![output_length, offset_index, max_index]); | ||
builder.request_post_result_challenges(2); | ||
(res, one_eval_lengths) | ||
} | ||
|
||
#[tracing::instrument(name = "SliceExec::prover_evaluate", level = "debug", skip_all)] | ||
#[allow(unused_variables)] | ||
fn final_round_evaluate<'a, S: Scalar>( | ||
&self, | ||
builder: &mut FinalRoundBuilder<'a, S>, | ||
alloc: &'a Bump, | ||
table_map: &IndexMap<TableRef, Table<'a, S>>, | ||
) -> Table<'a, S> { | ||
// 1. columns | ||
let input = self.input.final_round_evaluate(builder, alloc, table_map); | ||
let columns = input.columns().copied().collect::<Vec<_>>(); | ||
// 2. select | ||
let select = get_slice_select(input.num_rows(), self.skip, self.fetch); | ||
let select_ref: &'a [_] = alloc.alloc_slice_copy(&select); | ||
let output_length = select.iter().filter(|b| **b).count(); | ||
// Compute filtered_columns and indexes | ||
let (filtered_columns, result_len) = filter_columns(alloc, &columns, &select); | ||
// 3. Produce MLEs | ||
filtered_columns.iter().copied().for_each(|column| { | ||
builder.produce_intermediate_mle(column); | ||
}); | ||
let alpha = builder.consume_post_result_challenge(); | ||
let beta = builder.consume_post_result_challenge(); | ||
|
||
prove_filter::<S>( | ||
builder, | ||
alloc, | ||
alpha, | ||
beta, | ||
&columns, | ||
select_ref, | ||
&filtered_columns, | ||
input.num_rows(), | ||
result_len, | ||
); | ||
Table::<'a, S>::try_from_iter_with_options( | ||
self.get_column_result_fields() | ||
.into_iter() | ||
.map(|expr| expr.name()) | ||
.zip(filtered_columns), | ||
TableOptions::new(Some(output_length)), | ||
) | ||
.expect("Failed to create table from iterator") | ||
} | ||
} |
Oops, something went wrong.