Skip to content

Commit

Permalink
fix: fix id issue
Browse files Browse the repository at this point in the history
  • Loading branch information
iajoiner committed Dec 10, 2024
1 parent 2ecce43 commit 4a13d33
Show file tree
Hide file tree
Showing 5 changed files with 26 additions and 12 deletions.
22 changes: 11 additions & 11 deletions crates/proof-of-sql/src/base/polynomial/multilinear_extension.rs
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
use crate::base::{database::Column, if_rayon, scalar::Scalar, slice_ops};
use alloc::{rc::Rc, vec::Vec};
use core::ffi::c_void;
use core::{ffi::c_void, fmt::Debug};
use num_traits::Zero;
#[cfg(feature = "rayon")]
use rayon::iter::{IntoParallelRefIterator, ParallelIterator};

/// Interface for operating on multilinear extension's in-place
pub trait MultilinearExtension<S: Scalar> {
pub trait MultilinearExtension<S: Scalar>: Debug {
/// Given an evaluation vector, compute the evaluation of the multilinear
/// extension
fn inner_product(&self, evaluation_vec: &[S]) -> S;
Expand All @@ -18,7 +18,7 @@ pub trait MultilinearExtension<S: Scalar> {
fn to_sumcheck_term(&self, num_vars: usize) -> Rc<Vec<S>>;

/// pointer to identify the slice forming the MLE
fn id(&self) -> *const c_void;
fn id(&self) -> (*const c_void, usize);

#[cfg(test)]
/// Given an evaluation point, compute the evaluation of the multilinear
Expand All @@ -30,7 +30,7 @@ pub trait MultilinearExtension<S: Scalar> {
}
}

impl<'a, T: Sync, S: Scalar> MultilinearExtension<S> for &'a [T]
impl<'a, T: Sync + Debug, S: Scalar> MultilinearExtension<S> for &'a [T]
where
&'a T: Into<S>,
{
Expand All @@ -56,8 +56,8 @@ where
Rc::new(scalars)
}

fn id(&self) -> *const c_void {
self.as_ptr().cast::<c_void>()
fn id(&self) -> (*const c_void, usize) {
(self.as_ptr().cast::<c_void>(), self.len())
}
}

Expand All @@ -76,20 +76,20 @@ macro_rules! slice_like_mle_impl {
(&self[..]).to_sumcheck_term(num_vars)
}

fn id(&self) -> *const c_void {
fn id(&self) -> (*const c_void, usize) {
(&self[..]).id()
}
};
}

impl<'a, T: Sync, S: Scalar> MultilinearExtension<S> for &'a Vec<T>
impl<'a, T: Sync + Debug, S: Scalar> MultilinearExtension<S> for &'a Vec<T>
where
&'a T: Into<S>,
{
slice_like_mle_impl!();
}

impl<'a, T: Sync, const N: usize, S: Scalar> MultilinearExtension<S> for &'a [T; N]
impl<'a, T: Sync + Debug, const N: usize, S: Scalar> MultilinearExtension<S> for &'a [T; N]
where
&'a T: Into<S>,
{
Expand Down Expand Up @@ -139,7 +139,7 @@ impl<S: Scalar> MultilinearExtension<S> for &Column<'_, S> {
}
}

fn id(&self) -> *const c_void {
fn id(&self) -> (*const c_void, usize) {
match self {
Column::Boolean(c) => MultilinearExtension::<S>::id(c),
Column::Scalar(c) | Column::VarChar((_, c)) | Column::Decimal75(_, _, c) => {
Expand Down Expand Up @@ -167,7 +167,7 @@ impl<S: Scalar> MultilinearExtension<S> for Column<'_, S> {
(&self).to_sumcheck_term(num_vars)
}

fn id(&self) -> *const c_void {
fn id(&self) -> (*const c_void, usize) {
(&self).id()
}
}
Original file line number Diff line number Diff line change
@@ -1,5 +1,17 @@
use super::MultilinearExtension;
use crate::base::{database::Column, scalar::test_scalar::TestScalar};
use bumpalo::Bump;

#[test]
fn allocated_slices_must_have_different_ids_even_when_one_is_empty() {
let alloc = Bump::new();
let foo = alloc.alloc_slice_fill_default(5) as &[TestScalar];
let bar = alloc.alloc_slice_fill_default(0) as &[TestScalar];
assert_ne!(
MultilinearExtension::<TestScalar>::id(&foo),
MultilinearExtension::<TestScalar>::id(&bar)
);
}

#[test]
fn we_can_use_multilinear_extension_methods_for_i64_slice() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ use crate::base::polynomial::CompositePolynomial;
use crate::base::scalar::Scalar;
use alloc::vec::Vec;

#[derive(Debug)]
pub struct ProverState<S: Scalar> {
/// Stores the list of products that is meant to be added together. Each multiplicand is represented by
/// the index in `flattened_ml_extensions`
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ pub struct CompositePolynomialBuilder<S: Scalar> {
fr_multiplicands_rest: Vec<(S, Vec<Rc<Vec<S>>>)>,
zerosum_multiplicands: Vec<(S, Vec<Rc<Vec<S>>>)>,
fr: Rc<Vec<S>>,
mles: IndexMap<*const c_void, Rc<Vec<S>>>,
mles: IndexMap<(*const c_void, usize), Rc<Vec<S>>>,
}

impl<S: Scalar> CompositePolynomialBuilder<S> {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ pub type SumcheckSubpolynomialTerm<'a, S> = (S, Vec<Box<dyn MultilinearExtension
///
/// The subpolynomial is represented as a sum of terms, where each term is a
/// product of multilinear extensions and a constant.
#[derive(Debug)]
pub struct SumcheckSubpolynomial<'a, S: Scalar> {
terms: Vec<SumcheckSubpolynomialTerm<'a, S>>,
subpolynomial_type: SumcheckSubpolynomialType,
Expand Down

0 comments on commit 4a13d33

Please sign in to comment.