From 827a8e2a151c2179e80f94ae0d3c6a10305b85d4 Mon Sep 17 00:00:00 2001 From: Jay White Date: Sat, 16 Nov 2024 15:44:39 -0500 Subject: [PATCH] refactor: simplify `result_evaluations` calculation --- .../src/base/database/owned_column.rs | 19 +++++++++++++++++++ .../src/base/database/owned_table.rs | 12 +++++++++++- .../src/base/slice_ops/inner_product.rs | 12 ++++++++++++ .../proof-of-sql/src/sql/proof/query_proof.rs | 15 ++++----------- 4 files changed, 46 insertions(+), 12 deletions(-) diff --git a/crates/proof-of-sql/src/base/database/owned_column.rs b/crates/proof-of-sql/src/base/database/owned_column.rs index afe30a142..26a5eaffb 100644 --- a/crates/proof-of-sql/src/base/database/owned_column.rs +++ b/crates/proof-of-sql/src/base/database/owned_column.rs @@ -9,6 +9,7 @@ use crate::base::{ permutation::{Permutation, PermutationError}, }, scalar::Scalar, + slice_ops::inner_product_ref_cast, }; use alloc::{ string::{String, ToString}, @@ -48,6 +49,24 @@ pub enum OwnedColumn { } impl OwnedColumn { + /// Compute the inner product of the column with a vector of scalars. + pub(crate) fn inner_product(&self, vec: &[S]) -> S { + match self { + OwnedColumn::Boolean(col) => inner_product_ref_cast(col, vec), + OwnedColumn::TinyInt(col) => inner_product_ref_cast(col, vec), + OwnedColumn::SmallInt(col) => inner_product_ref_cast(col, vec), + OwnedColumn::Int(col) => inner_product_ref_cast(col, vec), + OwnedColumn::BigInt(col) | OwnedColumn::TimestampTZ(_, _, col) => { + inner_product_ref_cast(col, vec) + } + OwnedColumn::VarChar(col) => inner_product_ref_cast(col, vec), + OwnedColumn::Int128(col) => inner_product_ref_cast(col, vec), + OwnedColumn::Decimal75(_, _, col) | OwnedColumn::Scalar(col) => { + inner_product_ref_cast(col, vec) + } + } + } + /// Returns the length of the column. #[must_use] pub fn len(&self) -> usize { diff --git a/crates/proof-of-sql/src/base/database/owned_table.rs b/crates/proof-of-sql/src/base/database/owned_table.rs index 82f3d1d35..eb4dfdd31 100644 --- a/crates/proof-of-sql/src/base/database/owned_table.rs +++ b/crates/proof-of-sql/src/base/database/owned_table.rs @@ -1,5 +1,6 @@ use super::OwnedColumn; -use crate::base::{map::IndexMap, scalar::Scalar}; +use crate::base::{map::IndexMap, polynomial::compute_evaluation_vector, scalar::Scalar}; +use alloc::{vec, vec::Vec}; use proof_of_sql_parser::Identifier; use serde::{Deserialize, Serialize}; use snafu::Snafu; @@ -72,6 +73,15 @@ impl OwnedTable { pub fn column_names(&self) -> impl Iterator { self.table.keys() } + + pub(crate) fn mle_evaluations(&self, evaluation_point: &[S]) -> Vec { + let mut evaluation_vector = vec![S::ZERO; self.num_rows()]; + compute_evaluation_vector(&mut evaluation_vector, evaluation_point); + self.table + .values() + .map(|column| column.inner_product(&evaluation_vector)) + .collect() + } } // Note: we modify the default PartialEq for IndexMap to also check for column ordering. diff --git a/crates/proof-of-sql/src/base/slice_ops/inner_product.rs b/crates/proof-of-sql/src/base/slice_ops/inner_product.rs index 0c54ac749..b9ca9af96 100644 --- a/crates/proof-of-sql/src/base/slice_ops/inner_product.rs +++ b/crates/proof-of-sql/src/base/slice_ops/inner_product.rs @@ -14,3 +14,15 @@ where .map(|(&a, &b)| a * b) .sum() } + +pub fn inner_product_ref_cast(a: &[F], b: &[T]) -> T +where + for<'a> &'a F: Into, + F: Send + Sync, + T: Sync + Send + Mul + Sum + Copy, +{ + if_rayon!(a.par_iter().with_min_len(super::MIN_RAYON_LEN), a.iter()) + .zip(b) + .map(|(a, b)| a.into() * *b) + .sum() +} diff --git a/crates/proof-of-sql/src/sql/proof/query_proof.rs b/crates/proof-of-sql/src/sql/proof/query_proof.rs index 2b406d8ce..5df8c68b9 100644 --- a/crates/proof-of-sql/src/sql/proof/query_proof.rs +++ b/crates/proof-of-sql/src/sql/proof/query_proof.rs @@ -173,13 +173,13 @@ impl QueryProof { result: &ProvableQueryResult, setup: &CP::VerifierPublicSetup<'_>, ) -> QueryResult { + let owned_table_result = result.to_owned_table(&expr.get_column_result_fields())?; + let (min_row_num, max_row_num) = get_index_range(accessor, expr.get_table_references()); let range_length = max_row_num - min_row_num; let num_sumcheck_variables = cmp::max(log2_up(range_length), 1); assert!(num_sumcheck_variables > 0); - let output_length = result.table_length(); - // validate bit decompositions for dist in &self.bit_distributions { if !dist.is_valid() { @@ -253,12 +253,10 @@ impl QueryProof { .take(self.pcs_proof_evaluations.len()) .collect(); - let column_result_fields = expr.get_column_result_fields(); - // pass over the provable AST to fill in the verification builder let sumcheck_evaluations = SumcheckMleEvaluations::new( range_length, - output_length, + owned_table_result.num_rows(), &subclaim.evaluation_point, &sumcheck_random_scalars, &self.pcs_proof_evaluations, @@ -271,7 +269,6 @@ impl QueryProof { &evaluation_random_scalars, post_result_challenges, ); - let owned_table_result = result.to_owned_table(&column_result_fields[..])?; let pcs_proof_commitments: Vec<_> = column_references .iter() @@ -289,11 +286,7 @@ impl QueryProof { Some(&owned_table_result), )?; // compute the evaluation of the result MLEs - let result_evaluations = result.evaluate( - &subclaim.evaluation_point, - output_length, - &column_result_fields[..], - )?; + let result_evaluations = owned_table_result.mle_evaluations(&subclaim.evaluation_point); // check the evaluation of the result MLEs if verifier_evaluations != result_evaluations { Err(ProofError::VerificationError {