Skip to content

Commit

Permalink
refactor: simplify result_evaluations calculation
Browse files Browse the repository at this point in the history
  • Loading branch information
JayWhite2357 committed Nov 15, 2024
1 parent b5a2a7b commit b817779
Show file tree
Hide file tree
Showing 4 changed files with 45 additions and 12 deletions.
19 changes: 19 additions & 0 deletions crates/proof-of-sql/src/base/database/owned_column.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ use crate::base::{
permutation::{Permutation, PermutationError},
},
scalar::Scalar,
slice_ops::inner_product_ref_cast,
};
use alloc::{
string::{String, ToString},
Expand Down Expand Up @@ -48,6 +49,24 @@ pub enum OwnedColumn<S: Scalar> {
}

impl<S: Scalar> OwnedColumn<S> {
/// Compute the inner product of the column with a vector of scalars.
pub(crate) fn inner_product(&self, vec: &[S]) -> S {
match self {
OwnedColumn::Boolean(col) => inner_product_ref_cast(col, vec),
OwnedColumn::TinyInt(col) => inner_product_ref_cast(col, vec),
OwnedColumn::SmallInt(col) => inner_product_ref_cast(col, vec),
OwnedColumn::Int(col) => inner_product_ref_cast(col, vec),
OwnedColumn::BigInt(col) | OwnedColumn::TimestampTZ(_, _, col) => {
inner_product_ref_cast(col, vec)
}
OwnedColumn::VarChar(col) => inner_product_ref_cast(col, vec),
OwnedColumn::Int128(col) => inner_product_ref_cast(col, vec),
OwnedColumn::Decimal75(_, _, col) | OwnedColumn::Scalar(col) => {
inner_product_ref_cast(col, vec)
}
}
}

/// Returns the length of the column.
#[must_use]
pub fn len(&self) -> usize {
Expand Down
11 changes: 10 additions & 1 deletion crates/proof-of-sql/src/base/database/owned_table.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
use super::OwnedColumn;
use crate::base::{map::IndexMap, scalar::Scalar};
use crate::base::{map::IndexMap, polynomial::compute_evaluation_vector, scalar::Scalar};
use proof_of_sql_parser::Identifier;
use serde::{Deserialize, Serialize};
use snafu::Snafu;
Expand Down Expand Up @@ -72,6 +72,15 @@ impl<S: Scalar> OwnedTable<S> {
pub fn column_names(&self) -> impl Iterator<Item = &Identifier> {
self.table.keys()
}

pub(crate) fn mle_evaluations(&self, evaluation_point: &[S]) -> Vec<S> {
let mut evaluation_vector = vec![S::ZERO; self.num_rows()];
compute_evaluation_vector(&mut evaluation_vector, evaluation_point);
self.table
.values()
.map(|column| column.inner_product(&evaluation_vector))
.collect()
}
}

// Note: we modify the default PartialEq for IndexMap to also check for column ordering.
Expand Down
12 changes: 12 additions & 0 deletions crates/proof-of-sql/src/base/slice_ops/inner_product.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,3 +14,15 @@ where
.map(|(&a, &b)| a * b)
.sum()
}

pub fn inner_product_ref_cast<F, T>(a: &[F], b: &[T]) -> T
where
for<'a> &'a F: Into<T>,
F: Send + Sync,
T: Sync + Send + Mul<Output = T> + Sum + Copy,
{
if_rayon!(a.par_iter().with_min_len(super::MIN_RAYON_LEN), a.iter())
.zip(b)
.map(|(a, b)| a.into() * *b)
.sum()
}
15 changes: 4 additions & 11 deletions crates/proof-of-sql/src/sql/proof/query_proof.rs
Original file line number Diff line number Diff line change
Expand Up @@ -173,13 +173,13 @@ impl<CP: CommitmentEvaluationProof> QueryProof<CP> {
result: &ProvableQueryResult,
setup: &CP::VerifierPublicSetup<'_>,
) -> QueryResult<CP::Scalar> {
let owned_table_result = result.to_owned_table(&expr.get_column_result_fields())?;

let (min_row_num, max_row_num) = get_index_range(accessor, expr.get_table_references());
let range_length = max_row_num - min_row_num;
let num_sumcheck_variables = cmp::max(log2_up(range_length), 1);
assert!(num_sumcheck_variables > 0);

let output_length = result.table_length();

// validate bit decompositions
for dist in &self.bit_distributions {
if !dist.is_valid() {
Expand Down Expand Up @@ -253,12 +253,10 @@ impl<CP: CommitmentEvaluationProof> QueryProof<CP> {
.take(self.pcs_proof_evaluations.len())
.collect();

let column_result_fields = expr.get_column_result_fields();

// pass over the provable AST to fill in the verification builder
let sumcheck_evaluations = SumcheckMleEvaluations::new(
range_length,
output_length,
owned_table_result.num_rows(),
&subclaim.evaluation_point,
&sumcheck_random_scalars,
&self.pcs_proof_evaluations,
Expand All @@ -271,7 +269,6 @@ impl<CP: CommitmentEvaluationProof> QueryProof<CP> {
&evaluation_random_scalars,
post_result_challenges,
);
let owned_table_result = result.to_owned_table(&column_result_fields[..])?;

let pcs_proof_commitments: Vec<_> = column_references
.iter()
Expand All @@ -289,11 +286,7 @@ impl<CP: CommitmentEvaluationProof> QueryProof<CP> {
Some(&owned_table_result),
)?;
// compute the evaluation of the result MLEs
let result_evaluations = result.evaluate(
&subclaim.evaluation_point,
output_length,
&column_result_fields[..],
)?;
let result_evaluations = owned_table_result.mle_evaluations(&subclaim.evaluation_point);
// check the evaluation of the result MLEs
if verifier_evaluations != result_evaluations {
Err(ProofError::VerificationError {
Expand Down

0 comments on commit b817779

Please sign in to comment.