Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor: simplify result_evaluations calculation #373

Merged
merged 1 commit into from
Nov 17, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 19 additions & 0 deletions crates/proof-of-sql/src/base/database/owned_column.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ use crate::base::{
permutation::{Permutation, PermutationError},
},
scalar::Scalar,
slice_ops::inner_product_ref_cast,
};
use alloc::{
string::{String, ToString},
Expand Down Expand Up @@ -48,6 +49,24 @@ pub enum OwnedColumn<S: Scalar> {
}

impl<S: Scalar> OwnedColumn<S> {
/// Compute the inner product of the column with a vector of scalars.
pub(crate) fn inner_product(&self, vec: &[S]) -> S {
match self {
OwnedColumn::Boolean(col) => inner_product_ref_cast(col, vec),
OwnedColumn::TinyInt(col) => inner_product_ref_cast(col, vec),
OwnedColumn::SmallInt(col) => inner_product_ref_cast(col, vec),
OwnedColumn::Int(col) => inner_product_ref_cast(col, vec),
OwnedColumn::BigInt(col) | OwnedColumn::TimestampTZ(_, _, col) => {
inner_product_ref_cast(col, vec)
}
OwnedColumn::VarChar(col) => inner_product_ref_cast(col, vec),
OwnedColumn::Int128(col) => inner_product_ref_cast(col, vec),
OwnedColumn::Decimal75(_, _, col) | OwnedColumn::Scalar(col) => {
inner_product_ref_cast(col, vec)
}
}
}

/// Returns the length of the column.
#[must_use]
pub fn len(&self) -> usize {
Expand Down
12 changes: 11 additions & 1 deletion crates/proof-of-sql/src/base/database/owned_table.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
use super::OwnedColumn;
use crate::base::{map::IndexMap, scalar::Scalar};
use crate::base::{map::IndexMap, polynomial::compute_evaluation_vector, scalar::Scalar};
use alloc::{vec, vec::Vec};
use proof_of_sql_parser::Identifier;
use serde::{Deserialize, Serialize};
use snafu::Snafu;
Expand Down Expand Up @@ -72,6 +73,15 @@ impl<S: Scalar> OwnedTable<S> {
pub fn column_names(&self) -> impl Iterator<Item = &Identifier> {
self.table.keys()
}

pub(crate) fn mle_evaluations(&self, evaluation_point: &[S]) -> Vec<S> {
let mut evaluation_vector = vec![S::ZERO; self.num_rows()];
compute_evaluation_vector(&mut evaluation_vector, evaluation_point);
self.table
.values()
.map(|column| column.inner_product(&evaluation_vector))
.collect()
}
}

// Note: we modify the default PartialEq for IndexMap to also check for column ordering.
Expand Down
12 changes: 12 additions & 0 deletions crates/proof-of-sql/src/base/slice_ops/inner_product.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,3 +14,15 @@ where
.map(|(&a, &b)| a * b)
.sum()
}

pub fn inner_product_ref_cast<F, T>(a: &[F], b: &[T]) -> T
where
for<'a> &'a F: Into<T>,
F: Send + Sync,
T: Sync + Send + Mul<Output = T> + Sum + Copy,
{
if_rayon!(a.par_iter().with_min_len(super::MIN_RAYON_LEN), a.iter())
.zip(b)
.map(|(a, b)| a.into() * *b)
.sum()
}
15 changes: 4 additions & 11 deletions crates/proof-of-sql/src/sql/proof/query_proof.rs
Original file line number Diff line number Diff line change
Expand Up @@ -173,13 +173,13 @@ impl<CP: CommitmentEvaluationProof> QueryProof<CP> {
result: &ProvableQueryResult,
setup: &CP::VerifierPublicSetup<'_>,
) -> QueryResult<CP::Scalar> {
let owned_table_result = result.to_owned_table(&expr.get_column_result_fields())?;

let (min_row_num, max_row_num) = get_index_range(accessor, expr.get_table_references());
let range_length = max_row_num - min_row_num;
let num_sumcheck_variables = cmp::max(log2_up(range_length), 1);
assert!(num_sumcheck_variables > 0);

let output_length = result.table_length();

// validate bit decompositions
for dist in &self.bit_distributions {
if !dist.is_valid() {
Expand Down Expand Up @@ -253,12 +253,10 @@ impl<CP: CommitmentEvaluationProof> QueryProof<CP> {
.take(self.pcs_proof_evaluations.len())
.collect();

let column_result_fields = expr.get_column_result_fields();

// pass over the provable AST to fill in the verification builder
let sumcheck_evaluations = SumcheckMleEvaluations::new(
range_length,
output_length,
owned_table_result.num_rows(),
&subclaim.evaluation_point,
&sumcheck_random_scalars,
&self.pcs_proof_evaluations,
Expand All @@ -271,7 +269,6 @@ impl<CP: CommitmentEvaluationProof> QueryProof<CP> {
&evaluation_random_scalars,
post_result_challenges,
);
let owned_table_result = result.to_owned_table(&column_result_fields[..])?;

let pcs_proof_commitments: Vec<_> = column_references
.iter()
Expand All @@ -289,11 +286,7 @@ impl<CP: CommitmentEvaluationProof> QueryProof<CP> {
Some(&owned_table_result),
)?;
// compute the evaluation of the result MLEs
let result_evaluations = result.evaluate(
&subclaim.evaluation_point,
output_length,
&column_result_fields[..],
)?;
let result_evaluations = owned_table_result.mle_evaluations(&subclaim.evaluation_point);
// check the evaluation of the result MLEs
if verifier_evaluations != result_evaluations {
Err(ProofError::VerificationError {
Expand Down
Loading