Skip to content

Commit

Permalink
refactor!: generalize one eval to arbitrary lengths
Browse files Browse the repository at this point in the history
  • Loading branch information
iajoiner committed Nov 26, 2024
1 parent e04dd5c commit 5ff0e59
Show file tree
Hide file tree
Showing 30 changed files with 317 additions and 121 deletions.
3 changes: 3 additions & 0 deletions crates/proof-of-sql/src/base/database/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,9 @@ pub use table::{Table, TableOptions};
mod table_test;
pub mod table_utility;

mod table_evaluation;
pub use table_evaluation::TableEvaluation;

/// TODO: add docs
pub(crate) mod expression_evaluation;
mod expression_evaluation_error;
Expand Down
34 changes: 34 additions & 0 deletions crates/proof-of-sql/src/base/database/table_evaluation.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
use crate::base::scalar::Scalar;
use alloc::vec::Vec;

/// The result of evaluating a table
#[derive(Debug, Eq, PartialEq, Clone)]
pub struct TableEvaluation<S: Scalar> {
/// Evaluation of each column in the table
column_evals: Vec<S>,
/// Evaluation of an all-one column with the same length as the table
one_eval: S,
}

impl<S: Scalar> TableEvaluation<S> {
/// Creates a new [`TableEvaluation`].
#[must_use]
pub fn new(column_evals: Vec<S>, one_eval: S) -> Self {
Self {
column_evals,
one_eval,
}
}

/// Returns the evaluation of each column in the table.
#[must_use]
pub fn column_evals(&self) -> &[S] {
&self.column_evals
}

/// Returns the evaluation of an all-one column with the same length as the table.
#[must_use]
pub fn one_eval(&self) -> &S {
&self.one_eval
}
}
11 changes: 11 additions & 0 deletions crates/proof-of-sql/src/sql/proof/final_round_builder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ pub struct FinalRoundBuilder<'a, S: Scalar> {
commitment_descriptor: Vec<CommittableColumn<'a>>,
pcs_proof_mles: Vec<Box<dyn MultilinearExtension<S> + 'a>>,
sumcheck_subpolynomials: Vec<SumcheckSubpolynomial<'a, S>>,
one_evaluation_lengths: Vec<usize>,
/// The challenges used in creation of the constraints in the proof.
/// Specifically, these are the challenges that the verifier sends to
/// the prover after the prover sends the result, but before the prover
Expand All @@ -36,6 +37,7 @@ impl<'a, S: Scalar> FinalRoundBuilder<'a, S> {
pcs_proof_mles: Vec::new(),
sumcheck_subpolynomials: Vec::new(),
post_result_challenges,
one_evaluation_lengths: Vec::new(),
}
}

Expand All @@ -51,6 +53,15 @@ impl<'a, S: Scalar> FinalRoundBuilder<'a, S> {
&self.pcs_proof_mles
}

pub fn one_evaluation_lengths(&self) -> &[usize] {
&self.one_evaluation_lengths
}

/// Whenever we need to evaluate a column of 1s with a given length, we push the length
pub fn push_one_evaluation_length(&mut self, length: usize) {
self.one_evaluation_lengths.push(length);
}

/// Produce a bit distribution that describes which bits are constant
/// and which bits varying in a column of data
pub fn produce_bit_distribution(&mut self, dist: BitDistribution) {
Expand Down
5 changes: 3 additions & 2 deletions crates/proof-of-sql/src/sql/proof/proof_plan.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use super::{CountBuilder, FinalRoundBuilder, FirstRoundBuilder, VerificationBuilder};
use crate::base::{
database::{ColumnField, ColumnRef, OwnedTable, Table, TableRef},
database::{ColumnField, ColumnRef, OwnedTable, Table, TableEvaluation, TableRef},
map::{IndexMap, IndexSet},
proof::ProofError,
scalar::Scalar,
Expand All @@ -21,7 +21,8 @@ pub trait ProofPlan: Debug + Send + Sync + ProverEvaluate {
builder: &mut VerificationBuilder<S>,
accessor: &IndexMap<ColumnRef, S>,
result: Option<&OwnedTable<S>>,
) -> Result<Vec<S>, ProofError>;
one_eval_map: &IndexMap<TableRef, S>,
) -> Result<TableEvaluation<S>, ProofError>;

/// Return all the result column fields
fn get_column_result_fields(&self) -> Vec<ColumnField>;
Expand Down
28 changes: 25 additions & 3 deletions crates/proof-of-sql/src/sql/proof/query_proof.rs
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,8 @@ fn get_index_range(
pub struct QueryProof<CP: CommitmentEvaluationProof> {
/// Bit distributions
pub bit_distributions: Vec<BitDistribution>,
/// One evaluation lengths
pub one_evaluation_lengths: Vec<usize>,
/// Commitments
pub commitments: Vec<CP::Commitment>,
/// Sumcheck Proof
Expand Down Expand Up @@ -178,6 +180,7 @@ impl<CP: CommitmentEvaluationProof> QueryProof<CP> {

let proof = Self {
bit_distributions: builder.bit_distributions().to_vec(),
one_evaluation_lengths: builder.one_evaluation_lengths().to_vec(),
commitments,
sumcheck_proof,
pcs_proof_evaluations,
Expand All @@ -197,7 +200,8 @@ impl<CP: CommitmentEvaluationProof> QueryProof<CP> {
setup: &CP::VerifierPublicSetup<'_>,
) -> QueryResult<CP::Scalar> {
let owned_table_result = result.to_owned_table(&expr.get_column_result_fields())?;
let (min_row_num, _) = get_index_range(accessor, expr.get_table_references());
let table_refs = expr.get_table_references();
let (min_row_num, _) = get_index_range(accessor, table_refs.clone());
let num_sumcheck_variables = cmp::max(log2_up(self.range_length), 1);
assert!(num_sumcheck_variables > 0);

Expand Down Expand Up @@ -274,21 +278,38 @@ impl<CP: CommitmentEvaluationProof> QueryProof<CP> {
.take(self.pcs_proof_evaluations.len())
.collect();

// Always prepend input lengths to the one evaluation lengths
let table_length_map = table_refs
.iter()
.map(|table_ref| (table_ref, accessor.get_length(*table_ref)))
.collect::<IndexMap<_, _>>();

let one_evaluation_lengths = table_length_map
.values()
.chain(self.one_evaluation_lengths.clone().iter())
.copied()
.collect::<Vec<_>>();

// pass over the provable AST to fill in the verification builder
let sumcheck_evaluations = SumcheckMleEvaluations::new(
self.range_length,
owned_table_result.num_rows(),
&one_evaluation_lengths,
&subclaim.evaluation_point,
&sumcheck_random_scalars,
&self.pcs_proof_evaluations,
);
let one_eval_map: IndexMap<TableRef, CP::Scalar> = table_length_map
.iter()
.map(|(table_ref, length)| (**table_ref, sumcheck_evaluations.one_evaluations[length]))
.collect();
let mut builder = VerificationBuilder::new(
min_row_num,
sumcheck_evaluations,
&self.bit_distributions,
sumcheck_random_scalars.subpolynomial_multipliers,
&evaluation_random_scalars,
post_result_challenges,
self.one_evaluation_lengths.clone(),
);

let pcs_proof_commitments: Vec<_> = column_references
Expand All @@ -305,11 +326,12 @@ impl<CP: CommitmentEvaluationProof> QueryProof<CP> {
&mut builder,
&evaluation_accessor,
Some(&owned_table_result),
&one_eval_map,
)?;
// compute the evaluation of the result MLEs
let result_evaluations = owned_table_result.mle_evaluations(&subclaim.evaluation_point);
// check the evaluation of the result MLEs
if verifier_evaluations != result_evaluations {
if verifier_evaluations.column_evals() != result_evaluations {
Err(ProofError::VerificationError {
error: "result evaluation check failed",
})?;
Expand Down
38 changes: 29 additions & 9 deletions crates/proof-of-sql/src/sql/proof/query_proof_test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ use crate::{
owned_table_utility::{bigint, owned_table},
table_utility::*,
ColumnField, ColumnRef, ColumnType, OwnedTable, OwnedTableTestAccessor, Table,
TableRef,
TableEvaluation, TableRef,
},
map::{indexset, IndexMap, IndexSet},
proof::ProofError,
Expand Down Expand Up @@ -65,6 +65,7 @@ impl ProverEvaluate for TrivialTestProofPlan {
SumcheckSubpolynomialType::Identity,
vec![(S::ONE, vec![Box::new(col as &[_])])],
);
builder.push_one_evaluation_length(self.length);
table([borrowed_bigint(
"a1",
vec![self.column_fill_value; self.length],
Expand All @@ -85,13 +86,17 @@ impl ProofPlan for TrivialTestProofPlan {
builder: &mut VerificationBuilder<S>,
_accessor: &IndexMap<ColumnRef, S>,
_result: Option<&OwnedTable<S>>,
) -> Result<Vec<S>, ProofError> {
_one_eval_map: &IndexMap<TableRef, S>,
) -> Result<TableEvaluation<S>, ProofError> {
assert_eq!(builder.consume_intermediate_mle(), S::ZERO);
builder.produce_sumcheck_subpolynomial_evaluation(
&SumcheckSubpolynomialType::ZeroSum,
S::from(self.evaluation),
);
Ok(vec![S::ZERO])
Ok(TableEvaluation::new(
vec![S::ZERO],
builder.consume_one_evaluation(),
))
}
///
/// # Panics
Expand Down Expand Up @@ -246,6 +251,7 @@ impl ProverEvaluate for SquareTestProofPlan {
(-S::ONE, vec![Box::new(x), Box::new(x)]),
],
);
builder.push_one_evaluation_length(2);
table([borrowed_bigint("a1", self.res, alloc)])
}
}
Expand All @@ -261,7 +267,8 @@ impl ProofPlan for SquareTestProofPlan {
builder: &mut VerificationBuilder<S>,
accessor: &IndexMap<ColumnRef, S>,
_result: Option<&OwnedTable<S>>,
) -> Result<Vec<S>, ProofError> {
_one_eval_map: &IndexMap<TableRef, S>,
) -> Result<TableEvaluation<S>, ProofError> {
let x_eval = S::from(self.anchored_commit_multiplier)
* *accessor
.get(&ColumnRef::new(
Expand All @@ -275,7 +282,10 @@ impl ProofPlan for SquareTestProofPlan {
&SumcheckSubpolynomialType::Identity,
res_eval - x_eval * x_eval,
);
Ok(vec![res_eval])
Ok(TableEvaluation::new(
vec![res_eval],
builder.consume_one_evaluation(),
))
}
fn get_column_result_fields(&self) -> Vec<ColumnField> {
vec![ColumnField::new("a1".parse().unwrap(), ColumnType::BigInt)]
Expand Down Expand Up @@ -436,6 +446,7 @@ impl ProverEvaluate for DoubleSquareTestProofPlan {
],
);
builder.produce_intermediate_mle(res);
builder.push_one_evaluation_length(2);
table([borrowed_bigint("a1", self.res, alloc)])
}
}
Expand All @@ -451,7 +462,8 @@ impl ProofPlan for DoubleSquareTestProofPlan {
builder: &mut VerificationBuilder<S>,
accessor: &IndexMap<ColumnRef, S>,
_result: Option<&OwnedTable<S>>,
) -> Result<Vec<S>, ProofError> {
_one_eval_map: &IndexMap<TableRef, S>,
) -> Result<TableEvaluation<S>, ProofError> {
let x_eval = *accessor
.get(&ColumnRef::new(
"sxt.test".parse().unwrap(),
Expand All @@ -473,7 +485,10 @@ impl ProofPlan for DoubleSquareTestProofPlan {
&SumcheckSubpolynomialType::Identity,
res_eval - z_eval * z_eval,
);
Ok(vec![res_eval])
Ok(TableEvaluation::new(
vec![res_eval],
builder.consume_one_evaluation(),
))
}
fn get_column_result_fields(&self) -> Vec<ColumnField> {
vec![ColumnField::new("a1".parse().unwrap(), ColumnType::BigInt)]
Expand Down Expand Up @@ -634,6 +649,7 @@ impl ProverEvaluate for ChallengeTestProofPlan {
(-alpha, vec![Box::new(x), Box::new(x)]),
],
);
builder.push_one_evaluation_length(2);
table([borrowed_bigint("a1", [9, 25], alloc)])
}
}
Expand All @@ -650,7 +666,8 @@ impl ProofPlan for ChallengeTestProofPlan {
builder: &mut VerificationBuilder<S>,
accessor: &IndexMap<ColumnRef, S>,
_result: Option<&OwnedTable<S>>,
) -> Result<Vec<S>, ProofError> {
_one_eval_map: &IndexMap<TableRef, S>,
) -> Result<TableEvaluation<S>, ProofError> {
let alpha = builder.consume_post_result_challenge();
let _beta = builder.consume_post_result_challenge();
let x_eval = *accessor
Expand All @@ -665,7 +682,10 @@ impl ProofPlan for ChallengeTestProofPlan {
&SumcheckSubpolynomialType::Identity,
alpha * res_eval - alpha * x_eval * x_eval,
);
Ok(vec![res_eval])
Ok(TableEvaluation::new(
vec![res_eval],
builder.consume_one_evaluation(),
))
}
fn get_column_result_fields(&self) -> Vec<ColumnField> {
vec![ColumnField::new("a1".parse().unwrap(), ColumnType::BigInt)]
Expand Down
47 changes: 23 additions & 24 deletions crates/proof-of-sql/src/sql/proof/sumcheck_mle_evaluations.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
use super::SumcheckRandomScalars;
use crate::base::{
map::{IndexMap, IndexSet},
polynomial::{
compute_truncated_lagrange_basis_inner_product, compute_truncated_lagrange_basis_sum,
},
Expand All @@ -9,22 +10,14 @@ use crate::base::{
/// Evaluations for different MLEs at the random point chosen for sumcheck
#[derive(Default)]
pub struct SumcheckMleEvaluations<'a, S: Scalar> {
/// The length of the input table for a basic filter. When we support more complex queries, this may need to split.
pub input_length: usize,
/// The length of the output table for a basic filter. When we support more complex queries, this may need to split.
pub output_length: usize,
/// The number of sumcheck variables.
pub num_sumcheck_variables: usize,
/// The evaluation (at the random point generated by sumcheck) of an MLE `{x_i}` where
/// `x_i = 1` if `i < input_length;`
/// `x_i = 1` if `i < length;`
/// = 0, otherwise
pub input_one_evaluation: S,

/// The evaluation (at the random point generated by sumcheck) of an MLE `{x_i}` where
/// `x_i = 1` if `i < output_length;`
/// = 0, otherwise
pub output_one_evaluation: S,

pub one_evaluations: IndexMap<usize, S>,
/// The evaluation (at the random point generated by sumcheck) of the MLE formed from all ones with length 1.
pub singleton_one_evaluation: S,
/// The evaluation (at the random point generated by sumcheck) of the MLE formed from entrywise random scalars.
///
/// This is used within sumcheck to establish that a given expression
Expand All @@ -46,8 +39,8 @@ impl<'a, S: Scalar> SumcheckMleEvaluations<'a, S> {
/// - `sumcheck_random_scalars` - the random scalars used to batch the evaluations that are proven via IPA
/// - `pcs_proof_evaluations` - the evaluations of the MLEs that are proven via IPA
pub fn new(
input_length: usize,
output_length: usize,
range_length: usize,
one_evaluation_lengths: &[usize],
evaluation_point: &[S],
sumcheck_random_scalars: &SumcheckRandomScalars<S>,
pcs_proof_evaluations: &'a [S],
Expand All @@ -56,22 +49,28 @@ impl<'a, S: Scalar> SumcheckMleEvaluations<'a, S> {
evaluation_point.len(),
sumcheck_random_scalars.entrywise_point.len()
);
assert_eq!(input_length, sumcheck_random_scalars.table_length);
assert_eq!(range_length, sumcheck_random_scalars.table_length);
let random_evaluation = compute_truncated_lagrange_basis_inner_product(
input_length,
range_length,
evaluation_point,
sumcheck_random_scalars.entrywise_point,
);
let input_one_evaluation =
compute_truncated_lagrange_basis_sum(input_length, evaluation_point);
let output_one_evaluation =
compute_truncated_lagrange_basis_sum(output_length, evaluation_point);
let unique_one_evaluation_lengths: IndexSet<usize> =
one_evaluation_lengths.iter().copied().collect();
let one_evaluations = unique_one_evaluation_lengths
.iter()
.map(|&length| {
(
length,
compute_truncated_lagrange_basis_sum(length, evaluation_point),
)
})
.collect();
let singleton_one_evaluation = compute_truncated_lagrange_basis_sum(1, evaluation_point);
Self {
input_length,
output_length,
num_sumcheck_variables: evaluation_point.len(),
input_one_evaluation,
output_one_evaluation,
one_evaluations,
singleton_one_evaluation,
random_evaluation,
pcs_proof_evaluations,
}
Expand Down
Loading

0 comments on commit 5ff0e59

Please sign in to comment.