Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refact: remove Into<[u64; 4]>, From<[u64; 4]>, and RefInto<[u64; 4]> bounds for trait Scalar #258

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions crates/proof-of-sql/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ zerocopy = { workspace = true }
[dev-dependencies]
alloy-sol-types = { workspace = true }
arrow-csv = { workspace = true }
blitzar = { workspace = true }
#blitzar = { workspace = true }
clap = { workspace = true, features = ["derive"] }
criterion = { workspace = true, features = ["html_reports"] }
merlin = { workspace = true }
Expand All @@ -74,7 +74,8 @@ default = ["arrow", "perf"]
arrow = ["dep:arrow", "std"]
blitzar = ["dep:blitzar", "dep:merlin", "std"]
test = ["dep:rand", "std"]
perf = ["blitzar", "rayon", "ark-ec/parallel", "ark-poly/parallel", "ark-ff/asm"]
#perf = ["blitzar", "rayon", "ark-ec/parallel", "ark-poly/parallel", "ark-ff/asm"]
perf = ["rayon", "ark-ec/parallel", "ark-poly/parallel", "ark-ff/asm"]
rayon = ["dep:rayon", "std"]
std = ["snafu/std"]

Expand Down
2 changes: 1 addition & 1 deletion crates/proof-of-sql/src/base/bit/abs_bit_mask.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ use crate::base::scalar::Scalar;

pub fn make_abs_bit_mask<S: Scalar>(x: S) -> [u64; 4] {
let (sign, x) = if S::MAX_SIGNED < x { (1, -x) } else { (0, x) };
let mut res: [u64; 4] = x.into();
let mut res: [u64; 4] = x.to_limbs();
res[3] |= sign << 63;
res
}
15 changes: 5 additions & 10 deletions crates/proof-of-sql/src/base/commitment/committable_column.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
use crate::base::{
database::{Column, ColumnType, OwnedColumn},
math::decimal::Precision,
ref_into::RefInto,
scalar::Scalar,
};
use alloc::vec::Vec;
Expand Down Expand Up @@ -111,12 +110,12 @@ impl<'a, S: Scalar> From<&Column<'a, S>> for CommittableColumn<'a> {
Column::BigInt(ints) => CommittableColumn::BigInt(ints),
Column::Int128(ints) => CommittableColumn::Int128(ints),
Column::Decimal75(precision, scale, decimals) => {
let as_limbs: Vec<_> = decimals.iter().map(RefInto::<[u64; 4]>::ref_into).collect();
let as_limbs: Vec<_> = decimals.iter().map(|s| s.to_limbs()).collect();
CommittableColumn::Decimal75(*precision, *scale, as_limbs)
}
Column::Scalar(scalars) => (scalars as &[_]).into(),
Column::VarChar((_, scalars)) => {
let as_limbs: Vec<_> = scalars.iter().map(RefInto::<[u64; 4]>::ref_into).collect();
let as_limbs: Vec<_> = scalars.iter().map(|s| s.to_limbs()).collect();
CommittableColumn::VarChar(as_limbs)
}
Column::TimestampTZ(tu, tz, times) => CommittableColumn::TimestampTZ(*tu, *tz, times),
Expand All @@ -142,18 +141,14 @@ impl<'a, S: Scalar> From<&'a OwnedColumn<S>> for CommittableColumn<'a> {
OwnedColumn::Decimal75(precision, scale, decimals) => CommittableColumn::Decimal75(
*precision,
*scale,
decimals
.iter()
.map(Into::<S>::into)
.map(Into::<[u64; 4]>::into)
.collect(),
decimals.iter().map(|s| s.to_limbs()).collect(),
),
OwnedColumn::Scalar(scalars) => (scalars as &[_]).into(),
OwnedColumn::VarChar(strings) => CommittableColumn::VarChar(
strings
.iter()
.map(Into::<S>::into)
.map(Into::<[u64; 4]>::into)
.map(|s| s.to_limbs())
.collect(),
),
OwnedColumn::TimestampTZ(tu, tz, times) => {
Expand Down Expand Up @@ -197,7 +192,7 @@ impl<'a> From<&'a [i128]> for CommittableColumn<'a> {
}
impl<'a, S: Scalar> From<&'a [S]> for CommittableColumn<'a> {
fn from(value: &'a [S]) -> Self {
CommittableColumn::Scalar(value.iter().map(RefInto::<[u64; 4]>::ref_into).collect())
CommittableColumn::Scalar(value.iter().map(|s| s.to_limbs()).collect())
}
}
impl<'a> From<&'a [bool]> for CommittableColumn<'a> {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ const MAX_SUPPORTED_I256: i256 = i256::from_parts(
pub fn convert_scalar_to_i256<S: Scalar>(val: &S) -> i256 {
let is_negative = val > &S::MAX_SIGNED;
let abs_scalar = if is_negative { -*val } else { *val };
let limbs: [u64; 4] = abs_scalar.into();
let limbs: [u64; 4] = abs_scalar.to_limbs();

let low = (limbs[0] as u128) | ((limbs[1] as u128) << 64);
let high = i128::from(limbs[2]) | (i128::from(limbs[3]) << 64);
Expand Down Expand Up @@ -46,7 +46,7 @@ pub fn convert_i256_to_scalar<S: Scalar>(value: &i256) -> Option<S> {
];

// Convert limbs to Scalar and adjust for sign
let scalar: S = limbs.into();
let scalar: S = S::from_limbs(limbs);
Some(if value.is_negative() { -scalar } else { scalar })
}
}
Expand Down
2 changes: 1 addition & 1 deletion crates/proof-of-sql/src/base/encode/u256.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
use crate::base::scalar::MontScalar;
use ark_ff::MontConfig;
use ark_ff::{MontConfig, PrimeField};

/// U256 represents an unsigned 256-bits integer number
///
Expand Down
6 changes: 3 additions & 3 deletions crates/proof-of-sql/src/base/proof/transcript_core.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
use super::Transcript;
use crate::base::{ref_into::RefInto, scalar::Scalar};
use crate::base::scalar::Scalar;
use zerocopy::{AsBytes, FromBytes};

/// A trait used to facilitate implementation of [Transcript](super::Transcript).
Expand Down Expand Up @@ -48,10 +48,10 @@ impl<T: TranscriptCore> Transcript for T {
&mut self,
messages: impl IntoIterator<Item = &'a S>,
) {
self.extend_as_be::<[u64; 4]>(messages.into_iter().map(RefInto::ref_into));
self.extend_as_be::<[u64; 4]>(messages.into_iter().map(|s| s.to_limbs()));
}
fn scalar_challenge_as_be<S: Scalar>(&mut self) -> S {
receive_challenge_as_be::<[u64; 4]>(self).into()
S::from_limbs(receive_challenge_as_be::<[u64; 4]>(self))
}
fn challenge_as_le(&mut self) -> [u8; 32] {
self.raw_challenge()
Expand Down
8 changes: 5 additions & 3 deletions crates/proof-of-sql/src/base/scalar/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -47,8 +47,6 @@ pub trait Scalar:
+ core::convert::TryInto <i32>
+ core::convert::TryInto <i64>
+ core::convert::TryInto <i128>
+ core::convert::Into<[u64; 4]>
+ core::convert::From<[u64; 4]>
+ core::cmp::Ord
+ core::ops::Neg<Output = Self>
+ num_traits::Zero
Expand All @@ -57,7 +55,6 @@ pub trait Scalar:
+ ark_std::UniformRand //This enables us to get `Scalar`s as challenges from the transcript
+ num_traits::Inv<Output = Option<Self>> // Note: `inv` should return `None` exactly when the element is zero.
+ core::ops::SubAssign
+ super::ref_into::RefInto<[u64; 4]>
+ for<'a> core::convert::From<&'a String>
+ super::encode::VarInt
+ core::convert::From<String>
Expand Down Expand Up @@ -87,4 +84,9 @@ pub trait Scalar:
_ => Ordering::Greater,
}
}

fn from_limbs(val: [u64; 4]) -> Self;

fn to_limbs(&self) -> [u64; 4];

}
52 changes: 28 additions & 24 deletions crates/proof-of-sql/src/base/scalar/mont_scalar.rs
Original file line number Diff line number Diff line change
Expand Up @@ -229,11 +229,13 @@ impl<T: MontConfig<4>> TryFrom<BigInt> for MontScalar<T> {

// Check if the number of digits exceeds the maximum precision allowed
if digits.len() > MAX_SUPPORTED_PRECISION.into() {
return Err(ScalarConversionError::Overflow{ error: format!(
"Attempted to parse a number with {} digits, which exceeds the max supported precision of {}",
digits.len(),
MAX_SUPPORTED_PRECISION
)});
return Err(ScalarConversionError::Overflow {
error: format!(
"Attempted to parse a number with {} digits, which exceeds the max supported precision of {}",
digits.len(),
MAX_SUPPORTED_PRECISION
)
});
}

// Continue with the previous logic
Expand Down Expand Up @@ -349,12 +351,6 @@ impl From<&Curve25519Scalar> for curve25519_dalek::scalar::Scalar {
}
}

impl<T: MontConfig<4>> From<MontScalar<T>> for [u64; 4] {
fn from(value: MontScalar<T>) -> Self {
(&value).into()
}
}

impl<T: MontConfig<4>> From<&MontScalar<T>> for [u64; 4] {
fn from(value: &MontScalar<T>) -> Self {
value.0.into_bigint().0
Expand Down Expand Up @@ -433,6 +429,14 @@ impl super::Scalar for Curve25519Scalar {
const ZERO: Self = Self(ark_ff::MontFp!("0"));
const ONE: Self = Self(ark_ff::MontFp!("1"));
const TWO: Self = Self(ark_ff::MontFp!("2"));

fn from_limbs(val: [u64; 4]) -> Self {
Self(Fp::new(ark_ff::BigInt(val)))
}

fn to_limbs(&self) -> [u64; 4] {
self.0.into_bigint().0
}
}

impl<T> TryFrom<MontScalar<T>> for bool
Expand All @@ -443,9 +447,9 @@ where
type Error = ScalarConversionError;
fn try_from(value: MontScalar<T>) -> Result<Self, Self::Error> {
let (sign, abs): (i128, [u64; 4]) = if value > <MontScalar<T>>::MAX_SIGNED {
(-1, (-value).into())
(-1, (-value).to_limbs())
} else {
(1, value.into())
(1, value.to_limbs())
};
if abs[1] != 0 || abs[2] != 0 || abs[3] != 0 {
return Err(ScalarConversionError::Overflow {
Expand All @@ -471,9 +475,9 @@ where
type Error = ScalarConversionError;
fn try_from(value: MontScalar<T>) -> Result<Self, Self::Error> {
let (sign, abs): (i128, [u64; 4]) = if value > <MontScalar<T>>::MAX_SIGNED {
(-1, (-value).into())
(-1, (-value).to_limbs())
} else {
(1, value.into())
(1, value.to_limbs())
};
if abs[1] != 0 || abs[2] != 0 || abs[3] != 0 {
return Err(ScalarConversionError::Overflow {
Expand All @@ -495,9 +499,9 @@ where
type Error = ScalarConversionError;
fn try_from(value: MontScalar<T>) -> Result<Self, Self::Error> {
let (sign, abs): (i128, [u64; 4]) = if value > <MontScalar<T>>::MAX_SIGNED {
(-1, (-value).into())
(-1, (-value).to_limbs())
} else {
(1, value.into())
(1, value.to_limbs())
};
if abs[1] != 0 || abs[2] != 0 || abs[3] != 0 {
return Err(ScalarConversionError::Overflow {
Expand All @@ -519,9 +523,9 @@ where
type Error = ScalarConversionError;
fn try_from(value: MontScalar<T>) -> Result<Self, Self::Error> {
let (sign, abs): (i128, [u64; 4]) = if value > <MontScalar<T>>::MAX_SIGNED {
(-1, (-value).into())
(-1, (-value).to_limbs())
} else {
(1, value.into())
(1, value.to_limbs())
};
if abs[1] != 0 || abs[2] != 0 || abs[3] != 0 {
return Err(ScalarConversionError::Overflow {
Expand All @@ -543,9 +547,9 @@ where
type Error = ScalarConversionError;
fn try_from(value: MontScalar<T>) -> Result<Self, Self::Error> {
let (sign, abs): (i128, [u64; 4]) = if value > <MontScalar<T>>::MAX_SIGNED {
(-1, (-value).into())
(-1, (-value).to_limbs())
} else {
(1, value.into())
(1, value.to_limbs())
};
if abs[1] != 0 || abs[2] != 0 || abs[3] != 0 {
return Err(ScalarConversionError::Overflow {
Expand All @@ -567,9 +571,9 @@ where
type Error = ScalarConversionError;
fn try_from(value: MontScalar<T>) -> Result<Self, Self::Error> {
let (sign, abs): (i128, [u64; 4]) = if value > <MontScalar<T>>::MAX_SIGNED {
(-1, (-value).into())
(-1, (-value).to_limbs())
} else {
(1, value.into())
(1, value.to_limbs())
};
if abs[2] != 0 || abs[3] != 0 {
return Err(ScalarConversionError::Overflow {
Expand Down Expand Up @@ -601,7 +605,7 @@ where
} else {
num_bigint::Sign::Plus
};
let value_abs: [u64; 4] = (if is_negative { -value } else { value }).into();
let value_abs: [u64; 4] = (if is_negative { -value } else { value }).to_limbs();
let bits: &[u8] = bytemuck::cast_slice(&value_abs);
BigInt::from_bytes_le(sign, bits)
}
Expand Down
10 changes: 9 additions & 1 deletion crates/proof-of-sql/src/base/scalar/test_scalar.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
use super::{MontScalar, Scalar};
use ark_ff::{Fp, MontBackend, MontConfig};
use ark_ff::{Fp, MontBackend, MontConfig, PrimeField};

/// An implementation of `Scalar` intended for use in testing when a concrete implementation is required.
///
Expand All @@ -13,6 +13,14 @@ impl Scalar for TestScalar {
const ZERO: Self = Self(ark_ff::MontFp!("0"));
const ONE: Self = Self(ark_ff::MontFp!("1"));
const TWO: Self = Self(ark_ff::MontFp!("2"));

fn from_limbs(val: [u64; 4]) -> Self {
Self(Fp::new(ark_ff::BigInt(val)))
}

fn to_limbs(&self) -> [u64; 4] {
self.0.into_bigint().0
}
}

pub struct TestMontConfig(pub ark_curve25519::FrConfig);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ use crate::base::{
};
use alloc::vec::Vec;
use ark_ec::pairing::PairingOutput;
use ark_ff::{Fp, PrimeField};
use ark_serialize::{CanonicalDeserialize, CanonicalSerialize};
use core::ops::Mul;
use derive_more::{AddAssign, Neg, Sub, SubAssign};
Expand All @@ -45,6 +46,14 @@ impl Scalar for DoryScalar {
const ZERO: Self = Self(ark_ff::MontFp!("0"));
const ONE: Self = Self(ark_ff::MontFp!("1"));
const TWO: Self = Self(ark_ff::MontFp!("2"));

fn from_limbs(val: [u64; 4]) -> Self {
Self(Fp::new(ark_ff::BigInt(val)))
}

fn to_limbs(&self) -> [u64; 4] {
self.0.into_bigint().0
}
}

#[derive(
Expand Down
Original file line number Diff line number Diff line change
@@ -1,13 +1,12 @@
#[cfg(not(feature = "blitzar"))]
use super::G1Projective;
use super::{transpose, G1Affine, ProverSetup, F};
use super::{G1Affine, ProverSetup, F};
use crate::base::polynomial::compute_evaluation_vector;
#[cfg(feature = "blitzar")]
use crate::base::slice_ops::slice_cast;
use alloc::{vec, vec::Vec};
#[cfg(not(feature = "blitzar"))]
use ark_ec::{AffineRepr, VariableBaseMSM};
use ark_ff::{BigInt, MontBackend};
#[cfg(feature = "blitzar")]
use blitzar::compute::ElementP2;
#[cfg(feature = "blitzar")]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ where
&'a T: Into<DoryScalar>,
T: Sync,
{
let a = column;
let Gamma_1 = setup.Gamma_1.last().unwrap();
let Gamma_2 = setup.Gamma_2.last().unwrap();
let (first_row, _) = row_and_column_from_index(offset);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,12 +38,12 @@ pub fn verify_constant_sign_decomposition<S: Scalar>(
&& !dist.has_varying_sign_bit()
);
let lhs = if dist.sign_bit() { -eval } else { eval };
let mut rhs = S::from(dist.constant_part()) * one_eval;
let mut rhs = S::from_limbs(dist.constant_part()) * one_eval;
let mut vary_index = 0;
dist.for_each_abs_varying_bit(|int_index: usize, bit_index: usize| {
let mut mult = [0u64; 4];
mult[int_index] = 1u64 << bit_index;
rhs += S::from(mult) * bit_evals[vary_index];
rhs += S::from_limbs(mult) * bit_evals[vary_index];
vary_index += 1;
});
if lhs == rhs {
Expand Down Expand Up @@ -72,7 +72,7 @@ pub fn verify_constant_abs_decomposition<S: Scalar>(
&& dist.has_varying_sign_bit()
);
let t = one_eval - S::TWO * sign_eval;
if S::from(dist.constant_part()) * t == eval {
if S::from_limbs(dist.constant_part()) * t == eval {
Ok(())
} else {
Err(ProofError::VerificationError {
Expand Down
2 changes: 1 addition & 1 deletion crates/proof-of-sql/src/sql/proof_exprs/range_check.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ fn decompose_scalar_to_words<'a, S: Scalar + 'a>(
byte_counts: &mut [u64],
) {
for (i, scalar) in scalars.iter().enumerate() {
let scalar_array: [u64; 4] = (*scalar).into(); // Convert scalar to u64 array
let scalar_array: [u64; 4] = (*scalar).to_limbs(); // Convert scalar to u64 array
let scalar_bytes_full = cast_slice::<u64, u8>(&scalar_array); // Cast u64 array to u8 slice
let scalar_bytes = &scalar_bytes_full[..31];

Expand Down
Loading
Loading