From ae4c5567a20f5a4457bd7ad3137c217cf87072af Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Fran=C3=A7ois=20Garillot?= Date: Sat, 23 Dec 2023 11:12:27 -0500 Subject: [PATCH] refactor: Refactor `DlogGroup` trait and optimize batch operations - the DlogGroup trait is now group-crate aware, and requires traits in those terms, - the requirements will be further streamlined when https://github.com/zkcrypto/group/pull/48 merges - simplified declarations boilerplate in halo2curves & pasta macros - removed boilerplate macro duplication for grumpkin_msm. --- Cargo.toml | 1 + src/provider/bn256_grumpkin.rs | 202 ++------------------------- src/provider/kzg_commitment.rs | 2 +- src/provider/mlkzg.rs | 16 +-- src/provider/mod.rs | 2 +- src/provider/non_hiding_kzg.rs | 4 +- src/provider/non_hiding_zeromorph.rs | 12 +- src/provider/pasta.rs | 92 +++--------- src/provider/pedersen.rs | 55 ++++---- src/provider/secp_secq.rs | 19 +-- src/provider/traits.rs | 136 ++++++------------ src/provider/util/msm.rs | 4 +- 12 files changed, 134 insertions(+), 411 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index 060df8ddd..011d586a3 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -45,6 +45,7 @@ once_cell = "1.18.0" itertools = "0.12.0" rand = "0.8.5" ref-cast = "1.0.20" +derive_more = "0.99.17" [target.'cfg(any(target_arch = "x86_64", target_arch = "aarch64"))'.dependencies] pasta-msm = { git = "https://github.com/lurk-lab/pasta-msm", branch = "dev", version = "0.1.4" } diff --git a/src/provider/bn256_grumpkin.rs b/src/provider/bn256_grumpkin.rs index 8b7ac9348..dbd9a421c 100644 --- a/src/provider/bn256_grumpkin.rs +++ b/src/provider/bn256_grumpkin.rs @@ -1,14 +1,13 @@ //! This module implements the Nova traits for `bn256::Point`, `bn256::Scalar`, `grumpkin::Point`, `grumpkin::Scalar`. use crate::{ - provider::{ - traits::{CompressedGroup, DlogGroup}, - util::msm::cpu_best_msm, - }, + impl_traits, + provider::{traits::DlogGroup, util::msm::cpu_best_msm}, traits::{Group, PrimeFieldExt, TranscriptReprTrait}, }; use digest::{ExtendableOutput, Update}; use ff::{FromUniformBytes, PrimeField}; -use group::{cofactor::CofactorCurveAffine, Curve, Group as AnotherGroup, GroupEncoding}; +use group::{cofactor::CofactorCurveAffine, Curve, Group as AnotherGroup}; +use grumpkin_msm::{bn256 as bn256_msm, grumpkin as grumpkin_msm}; use num_bigint::BigInt; use num_traits::Num; // Remove this when https://github.com/zcash/pasta_curves/issues/41 resolves @@ -17,203 +16,32 @@ use rayon::prelude::*; use sha3::Shake256; use std::io::Read; -use halo2curves::bn256::{ - G1Affine as Bn256Affine, G1Compressed as Bn256Compressed, G1 as Bn256Point, -}; -use halo2curves::grumpkin::{ - G1Affine as GrumpkinAffine, G1Compressed as GrumpkinCompressed, G1 as GrumpkinPoint, -}; - /// Re-exports that give access to the standard aliases used in the code base, for bn256 pub mod bn256 { - pub use halo2curves::bn256::{Fq as Base, Fr as Scalar, G1Affine as Affine, G1 as Point}; + pub use halo2curves::bn256::{ + Fq as Base, Fr as Scalar, G1Affine as Affine, G1Compressed as Compressed, G1 as Point, + }; } /// Re-exports that give access to the standard aliases used in the code base, for grumpkin pub mod grumpkin { - pub use halo2curves::grumpkin::{Fq as Base, Fr as Scalar, G1Affine as Affine, G1 as Point}; -} - -macro_rules! impl_traits { - ( - $name:ident, - $name_compressed:ident, - $name_curve:ident, - $name_curve_affine:ident, - $order_str:literal, - $base_str:literal - ) => { - impl Group for $name::Point { - type Base = $name::Base; - type Scalar = $name::Scalar; - - fn group_params() -> (Self::Base, Self::Base, BigInt, BigInt) { - let A = $name::Point::a(); - let B = $name::Point::b(); - let order = BigInt::from_str_radix($order_str, 16).unwrap(); - let base = BigInt::from_str_radix($base_str, 16).unwrap(); - - (A, B, order, base) - } - } - - impl DlogGroup for $name::Point { - type CompressedGroupElement = $name_compressed; - type PreprocessedGroupElement = $name::Affine; - - #[tracing::instrument( - skip_all, - level = "trace", - name = "<_ as Group>::vartime_multiscalar_mul" - )] - fn vartime_multiscalar_mul( - scalars: &[Self::Scalar], - bases: &[Self::PreprocessedGroupElement], - ) -> Self { - #[cfg(any(target_arch = "x86_64", target_arch = "aarch64"))] - if scalars.len() >= 128 { - grumpkin_msm::$name(bases, scalars) - } else { - cpu_best_msm(scalars, bases) - } - #[cfg(not(any(target_arch = "x86_64", target_arch = "aarch64")))] - cpu_best_msm(scalars, bases) - } - fn preprocessed(&self) -> Self::PreprocessedGroupElement { - self.to_affine() - } - - fn compress(&self) -> Self::CompressedGroupElement { - self.to_bytes() - } - - fn from_label(label: &'static [u8], n: usize) -> Vec { - let mut shake = Shake256::default(); - shake.update(label); - let mut reader = shake.finalize_xof(); - let mut uniform_bytes_vec = Vec::new(); - for _ in 0..n { - let mut uniform_bytes = [0u8; 32]; - reader.read_exact(&mut uniform_bytes).unwrap(); - uniform_bytes_vec.push(uniform_bytes); - } - let gens_proj: Vec<$name_curve> = (0..n) - .into_par_iter() - .map(|i| { - let hash = $name_curve::hash_to_curve("from_uniform_bytes"); - hash(&uniform_bytes_vec[i]) - }) - .collect(); - - let num_threads = rayon::current_num_threads(); - if gens_proj.len() > num_threads { - let chunk = (gens_proj.len() as f64 / num_threads as f64).ceil() as usize; - (0..num_threads) - .into_par_iter() - .flat_map(|i| { - let start = i * chunk; - let end = if i == num_threads - 1 { - gens_proj.len() - } else { - core::cmp::min((i + 1) * chunk, gens_proj.len()) - }; - if end > start { - let mut gens = vec![$name_curve_affine::identity(); end - start]; - ::batch_normalize(&gens_proj[start..end], &mut gens); - gens - } else { - vec![] - } - }) - .collect() - } else { - let mut gens = vec![$name_curve_affine::identity(); n]; - ::batch_normalize(&gens_proj, &mut gens); - gens - } - } - - fn zero() -> Self { - $name::Point::identity() - } - - fn to_coordinates(&self) -> (Self::Base, Self::Base, bool) { - let coordinates = self.to_affine().coordinates(); - if coordinates.is_some().unwrap_u8() == 1 - && ($name_curve_affine::identity() != self.to_affine()) - { - (*coordinates.unwrap().x(), *coordinates.unwrap().y(), false) - } else { - (Self::Base::zero(), Self::Base::zero(), true) - } - } - } - - impl PrimeFieldExt for $name::Scalar { - fn from_uniform(bytes: &[u8]) -> Self { - let bytes_arr: [u8; 64] = bytes.try_into().unwrap(); - $name::Scalar::from_uniform_bytes(&bytes_arr) - } - } - - impl TranscriptReprTrait for $name_compressed { - fn to_transcript_bytes(&self) -> Vec { - self.as_ref().to_vec() - } - } - - impl CompressedGroup for $name_compressed { - type GroupElement = $name::Point; - - fn decompress(&self) -> Option<$name::Point> { - Some($name_curve::from_bytes(&self).unwrap()) - } - } - - impl TranscriptReprTrait for $name::Scalar { - fn to_transcript_bytes(&self) -> Vec { - self.to_repr().to_vec() - } - } - - impl TranscriptReprTrait for $name::Affine { - fn to_transcript_bytes(&self) -> Vec { - let (x, y, is_infinity_byte) = { - let coordinates = self.coordinates(); - if coordinates.is_some().unwrap_u8() == 1 && ($name_curve_affine::identity() != *self) { - let c = coordinates.unwrap(); - (*c.x(), *c.y(), u8::from(false)) - } else { - ($name::Base::zero(), $name::Base::zero(), u8::from(false)) - } - }; - - x.to_repr() - .into_iter() - .chain(y.to_repr().into_iter()) - .chain(std::iter::once(is_infinity_byte)) - .collect() - } - } + pub use halo2curves::grumpkin::{ + Fq as Base, Fr as Scalar, G1Affine as Affine, G1Compressed as Compressed, G1 as Point, }; } impl_traits!( bn256, - Bn256Compressed, - Bn256Point, - Bn256Affine, "30644e72e131a029b85045b68181585d2833e84879b9709143e1f593f0000001", - "30644e72e131a029b85045b68181585d97816a916871ca8d3c208c16d87cfd47" + "30644e72e131a029b85045b68181585d97816a916871ca8d3c208c16d87cfd47", + bn256_msm ); impl_traits!( grumpkin, - GrumpkinCompressed, - GrumpkinPoint, - GrumpkinAffine, "30644e72e131a029b85045b68181585d97816a916871ca8d3c208c16d87cfd47", - "30644e72e131a029b85045b68181585d2833e84879b9709143e1f593f0000001" + "30644e72e131a029b85045b68181585d2833e84879b9709143e1f593f0000001", + grumpkin_msm ); #[cfg(test)] @@ -237,7 +65,7 @@ mod tests { .map(|_| bn256::Scalar::random(&mut rng)) .collect::>(); - let cpu_msm = cpu_best_msm(&scalars, &points); + let cpu_msm = cpu_best_msm(&points, &scalars); let gpu_msm = bn256::Point::vartime_multiscalar_mul(&scalars, &points); assert_eq!(cpu_msm, gpu_msm); @@ -253,7 +81,7 @@ mod tests { .map(|_| grumpkin::Scalar::random(&mut rng)) .collect::>(); - let cpu_msm = cpu_best_msm(&scalars, &points); + let cpu_msm = cpu_best_msm(&points, &scalars); let gpu_msm = grumpkin::Point::vartime_multiscalar_mul(&scalars, &points); assert_eq!(cpu_msm, gpu_msm); diff --git a/src/provider/kzg_commitment.rs b/src/provider/kzg_commitment.rs index 3663a379b..6d17cbfd6 100644 --- a/src/provider/kzg_commitment.rs +++ b/src/provider/kzg_commitment.rs @@ -30,7 +30,7 @@ pub struct KZGCommitmentEngine { impl> CommitmentEngineTrait for KZGCommitmentEngine where - E::G1: DlogGroup, + E::G1: DlogGroup, E::G1Affine: Serialize + for<'de> Deserialize<'de>, E::G2Affine: Serialize + for<'de> Deserialize<'de>, E::Fr: PrimeFieldBits, // TODO due to use of gen_srs_for_testing, make optional diff --git a/src/provider/mlkzg.rs b/src/provider/mlkzg.rs index 3f3485cb1..164fbae5e 100644 --- a/src/provider/mlkzg.rs +++ b/src/provider/mlkzg.rs @@ -49,7 +49,7 @@ impl EvaluationEngine where E: Engine, NE: NovaEngine, - E::G1: DlogGroup, + E::G1: DlogGroup, E::Fr: TranscriptReprTrait, E::G1Affine: TranscriptReprTrait, // TODO: this bound on DlogGroup is really unusable! { @@ -104,7 +104,7 @@ where E::Fr: Serialize + DeserializeOwned, E::G1Affine: Serialize + DeserializeOwned, E::G2Affine: Serialize + DeserializeOwned, - E::G1: DlogGroup, + E::G1: DlogGroup, ::Base: TranscriptReprTrait, // Note: due to the move of the bound TranscriptReprTrait on G::Base from Group to Engine E::Fr: PrimeFieldBits, // TODO due to use of gen_srs_for_testing, make optional E::Fr: TranscriptReprTrait, @@ -161,7 +161,7 @@ where >::commit(ck, &h) .comm - .preprocessed() + .to_affine() }; let kzg_open_batch = |C: &[E::G1Affine], @@ -253,7 +253,7 @@ where .map(|i| { >::commit(ck, &polys[i]) .comm - .preprocessed() + .to_affine() }) .collect(); @@ -265,7 +265,7 @@ where // Phase 3 -- create response let mut com_all = comms.clone(); - com_all.insert(0, C.comm.preprocessed()); + com_all.insert(0, C.comm.to_affine()); let (w, evals) = kzg_open_batch(&com_all, &polys, &u, transcript); Ok(EvaluationArgument { comms, w, evals }) @@ -300,7 +300,7 @@ where // Compute the commitment to the batched polynomial B(X) let c_0: E::G1 = C[0].into(); - let C_B = (c_0 + NE::GE::vartime_multiscalar_mul(&q_powers[1..k], &C[1..k])).preprocessed(); + let C_B = (c_0 + NE::GE::vartime_multiscalar_mul(&q_powers[1..k], &C[1..k])).to_affine(); // Compute the batched openings // compute B(u_i) = v[i][0] + q*v[i][1] + ... + q^(t-1) * v[i][t-1] @@ -356,10 +356,10 @@ where // obtained from the transcript let r = Self::compute_challenge(&com, transcript); - if r == E::Fr::ZERO || C.comm == E::G1::zero() { + if r == E::Fr::ZERO || C.comm == E::G1::identity() { return Err(NovaError::ProofVerifyError); } - com.insert(0, C.comm.preprocessed()); // set com_0 = C, shifts other commitments to the right + com.insert(0, C.comm.to_affine()); // set com_0 = C, shifts other commitments to the right let u = vec![r, -r, r * r]; diff --git a/src/provider/mod.rs b/src/provider/mod.rs index 6bbf9bec5..2e05b92ea 100644 --- a/src/provider/mod.rs +++ b/src/provider/mod.rs @@ -205,7 +205,7 @@ mod tests { acc + *base * coeff }); - assert_eq!(naive, cpu_best_msm(&coeffs, &bases)) + assert_eq!(naive, cpu_best_msm(&bases, &coeffs)) } #[test] diff --git a/src/provider/non_hiding_kzg.rs b/src/provider/non_hiding_kzg.rs index f9b7af9a2..e8fdadc0a 100644 --- a/src/provider/non_hiding_kzg.rs +++ b/src/provider/non_hiding_kzg.rs @@ -234,7 +234,7 @@ pub struct UVKZGPCS { impl UVKZGPCS where - E::G1: DlogGroup, + E::G1: DlogGroup, { /// Generate a commitment for a polynomial /// Note that the scheme is not hidding @@ -327,7 +327,7 @@ mod tests { fn end_to_end_test_template() -> Result<(), NovaError> where E: MultiMillerLoop, - E::G1: DlogGroup, + E::G1: DlogGroup, E::Fr: PrimeFieldBits, { for _ in 0..100 { diff --git a/src/provider/non_hiding_zeromorph.rs b/src/provider/non_hiding_zeromorph.rs index d0d30b745..cd80860f4 100644 --- a/src/provider/non_hiding_zeromorph.rs +++ b/src/provider/non_hiding_zeromorph.rs @@ -28,7 +28,7 @@ use rayon::{ prelude::{IndexedParallelIterator, IntoParallelRefMutIterator, ParallelIterator}, }; use ref_cast::RefCast; -use serde::{de::DeserializeOwned, Deserialize, Serialize}; +use serde::{Deserialize, Serialize}; use std::{borrow::Borrow, iter, marker::PhantomData}; use crate::provider::kzg_commitment::KZGCommitmentEngine; @@ -145,7 +145,7 @@ pub struct ZMPCS { impl> ZMPCS where - E::G1: DlogGroup, + E::G1: DlogGroup, // Note: due to the move of the bound TranscriptReprTrait on G::Base from Group to Engine ::Base: TranscriptReprTrait, { @@ -463,9 +463,9 @@ fn eval_and_quotient_scalars(y: F, x: F, z: F, point: &[F]) -> (F, (Ve impl>> EvaluationEngineTrait for ZMPCS where - E::G1: DlogGroup, - E::G1Affine: Serialize + DeserializeOwned, - E::G2Affine: Serialize + DeserializeOwned, + E::G1: DlogGroup, + E::G1Affine: Serialize + for<'de> Deserialize<'de>, + E::G2Affine: Serialize + for<'de> Deserialize<'de>, ::Base: TranscriptReprTrait, // Note: due to the move of the bound TranscriptReprTrait on G::Base from Group to Engine E::Fr: PrimeFieldBits, // TODO due to use of gen_srs_for_testing, make optional { @@ -542,7 +542,7 @@ mod test { fn commit_open_verify_with>() where - E::G1: DlogGroup, + E::G1: DlogGroup, ::Base: TranscriptReprTrait, // Note: due to the move of the bound TranscriptReprTrait on G::Base from Group to Engine E::Fr: PrimeFieldBits, { diff --git a/src/provider/pasta.rs b/src/provider/pasta.rs index db5ad3afa..f24a894dc 100644 --- a/src/provider/pasta.rs +++ b/src/provider/pasta.rs @@ -1,20 +1,18 @@ //! This module implements the Nova traits for `pallas::Point`, `pallas::Scalar`, `vesta::Point`, `vesta::Scalar`. use crate::{ - provider::{ - traits::{CompressedGroup, DlogGroup}, - util::msm::cpu_best_msm, - }, + provider::{traits::DlogGroup, util::msm::cpu_best_msm}, traits::{Group, PrimeFieldExt, TranscriptReprTrait}, }; +use derive_more::{From, Into}; use digest::{ExtendableOutput, Update}; use ff::{FromUniformBytes, PrimeField}; +use group::{prime::PrimeCurveAffine, Curve}; use num_bigint::BigInt; use num_traits::Num; use pasta_curves::{ self, arithmetic::{CurveAffine, CurveExt}, - group::{cofactor::CofactorCurveAffine, Curve, Group as AnotherGroup, GroupEncoding}, - pallas, vesta, Ep, EpAffine, Eq, EqAffine, + pallas, vesta, }; use rayon::prelude::*; use serde::{Deserialize, Serialize}; @@ -22,37 +20,17 @@ use sha3::Shake256; use std::io::Read; /// A wrapper for compressed group elements of pallas -#[derive(Clone, Copy, Debug, Eq, PartialEq, Serialize, Deserialize)] -pub struct PallasCompressedElementWrapper { - repr: [u8; 32], -} - -impl PallasCompressedElementWrapper { - /// Wraps repr into the wrapper - pub const fn new(repr: [u8; 32]) -> Self { - Self { repr } - } -} +#[derive(Clone, Copy, Debug, Eq, From, Into, PartialEq, Serialize, Deserialize)] +pub struct PallasCompressedElementWrapper([u8; 32]); /// A wrapper for compressed group elements of vesta -#[derive(Clone, Copy, Debug, Eq, PartialEq, Serialize, Deserialize)] -pub struct VestaCompressedElementWrapper { - repr: [u8; 32], -} - -impl VestaCompressedElementWrapper { - /// Wraps repr into the wrapper - pub const fn new(repr: [u8; 32]) -> Self { - Self { repr } - } -} +#[derive(Clone, Copy, Debug, Eq, From, Into, PartialEq, Serialize, Deserialize)] +pub struct VestaCompressedElementWrapper([u8; 32]); macro_rules! impl_traits { ( $name:ident, $name_compressed:ident, - $name_curve:ident, - $name_curve_affine:ident, $order_str:literal, $base_str:literal ) => { @@ -71,37 +49,27 @@ macro_rules! impl_traits { } impl DlogGroup for $name::Point { - type CompressedGroupElement = $name_compressed; - type PreprocessedGroupElement = $name::Affine; + type ScalarExt = $name::Scalar; + type AffineExt = $name::Affine; + type Compressed = $name_compressed; #[tracing::instrument( skip_all, level = "trace", name = "<_ as Group>::vartime_multiscalar_mul" )] - fn vartime_multiscalar_mul( - scalars: &[Self::Scalar], - bases: &[Self::PreprocessedGroupElement], - ) -> Self { + fn vartime_multiscalar_mul(scalars: &[Self::ScalarExt], bases: &[Self::Affine]) -> Self { #[cfg(any(target_arch = "x86_64", target_arch = "aarch64"))] if scalars.len() >= 128 { pasta_msm::$name(bases, scalars) } else { - cpu_best_msm(scalars, bases) + cpu_best_msm(bases, scalars) } #[cfg(not(any(target_arch = "x86_64", target_arch = "aarch64")))] - cpu_best_msm(scalars, bases) - } - - fn preprocessed(&self) -> Self::PreprocessedGroupElement { - self.to_affine() - } - - fn compress(&self) -> Self::CompressedGroupElement { - $name_compressed::new(self.to_bytes()) + cpu_best_msm(bases, scalars) } - fn from_label(label: &'static [u8], n: usize) -> Vec { + fn from_label(label: &'static [u8], n: usize) -> Vec { let mut shake = Shake256::default(); shake.update(label); let mut reader = shake.finalize_xof(); @@ -111,10 +79,10 @@ macro_rules! impl_traits { reader.read_exact(&mut uniform_bytes).unwrap(); uniform_bytes_vec.push(uniform_bytes); } - let ck_proj: Vec<$name_curve> = (0..n) + let ck_proj: Vec<$name::Point> = (0..n) .into_par_iter() .map(|i| { - let hash = $name_curve::hash_to_curve("from_uniform_bytes"); + let hash = $name::Point::hash_to_curve("from_uniform_bytes"); hash(&uniform_bytes_vec[i]) }) .collect(); @@ -132,7 +100,7 @@ macro_rules! impl_traits { core::cmp::min((i + 1) * chunk, ck_proj.len()) }; if end > start { - let mut ck = vec![$name_curve_affine::identity(); end - start]; + let mut ck = vec![$name::Affine::identity(); end - start]; ::batch_normalize(&ck_proj[start..end], &mut ck); ck } else { @@ -141,16 +109,12 @@ macro_rules! impl_traits { }) .collect() } else { - let mut ck = vec![$name_curve_affine::identity(); n]; + let mut ck = vec![$name::Affine::identity(); n]; ::batch_normalize(&ck_proj, &mut ck); ck } } - fn zero() -> Self { - $name::Point::identity() - } - fn to_coordinates(&self) -> (Self::Base, Self::Base, bool) { let coordinates = self.to_affine().coordinates(); if coordinates.is_some().unwrap_u8() == 1 { @@ -168,17 +132,9 @@ macro_rules! impl_traits { } } - impl CompressedGroup for $name_compressed { - type GroupElement = $name::Point; - - fn decompress(&self) -> Option<$name::Point> { - Some($name_curve::from_bytes(&self.repr).unwrap()) - } - } - impl TranscriptReprTrait for $name_compressed { fn to_transcript_bytes(&self) -> Vec { - self.repr.to_vec() + self.0.to_vec() } } @@ -216,8 +172,6 @@ macro_rules! impl_traits { impl_traits!( pallas, PallasCompressedElementWrapper, - Ep, - EpAffine, "40000000000000000000000000000000224698fc0994a8dd8c46eb2100000001", "40000000000000000000000000000000224698fc094cf91b992d30ed00000001" ); @@ -225,8 +179,6 @@ impl_traits!( impl_traits!( vesta, VestaCompressedElementWrapper, - Eq, - EqAffine, "40000000000000000000000000000000224698fc094cf91b992d30ed00000001", "40000000000000000000000000000000224698fc0994a8dd8c46eb2100000001" ); @@ -249,7 +201,7 @@ mod tests { .map(|_| pallas::Scalar::random(&mut rng)) .collect::>(); - let cpu_msm = cpu_best_msm(&scalars, &points); + let cpu_msm = cpu_best_msm(&points, &scalars); let gpu_msm = pallas::Point::vartime_multiscalar_mul(&scalars, &points); assert_eq!(cpu_msm, gpu_msm); @@ -265,7 +217,7 @@ mod tests { .map(|_| vesta::Scalar::random(&mut rng)) .collect::>(); - let cpu_msm = cpu_best_msm(&scalars, &points); + let cpu_msm = cpu_best_msm(&points, &scalars); let gpu_msm = vesta::Point::vartime_multiscalar_mul(&scalars, &points); assert_eq!(cpu_msm, gpu_msm); diff --git a/src/provider/pedersen.rs b/src/provider/pedersen.rs index daea2594c..6b631c65d 100644 --- a/src/provider/pedersen.rs +++ b/src/provider/pedersen.rs @@ -1,7 +1,7 @@ //! This module provides an implementation of a commitment engine use crate::{ errors::NovaError, - provider::traits::{CompressedGroup, DlogGroup}, + provider::traits::DlogGroup, traits::{ commitment::{CommitmentEngineTrait, CommitmentTrait, Len}, AbsorbInROTrait, Engine, ROTrait, TranscriptReprTrait, @@ -14,6 +14,7 @@ use core::{ ops::{Add, Mul, MulAssign}, }; use ff::Field; +use group::{prime::PrimeCurve, Curve, Group, GroupEncoding}; use rayon::prelude::*; use serde::{Deserialize, Serialize}; @@ -23,21 +24,21 @@ use serde::{Deserialize, Serialize}; pub struct CommitmentKey where E: Engine, - E::GE: DlogGroup, + E::GE: DlogGroup, { #[abomonate_with(Vec<[u64; 8]>)] // this is a hack; we just assume the size of the element. - ck: Vec<::PreprocessedGroupElement>, + ck: Vec<::Affine>, } /// [CommitmentKey]s are often large, and this helps with cloning bottlenecks impl Clone for CommitmentKey where E: Engine, - E::GE: DlogGroup, + E::GE: DlogGroup, { fn clone(&self) -> Self { Self { - ck: self.ck.par_iter().cloned().collect(), + ck: self.ck[..].par_iter().cloned().collect(), } } } @@ -45,7 +46,7 @@ where impl Len for CommitmentKey where E: Engine, - E::GE: DlogGroup, + E::GE: DlogGroup, { fn length(&self) -> usize { self.ck.len() @@ -62,26 +63,26 @@ pub struct Commitment { } /// A type that holds a compressed commitment -#[derive(Clone, Copy, Debug, PartialEq, Eq, Serialize, Deserialize)] +#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)] #[serde(bound = "")] pub struct CompressedCommitment where E: Engine, - E::GE: DlogGroup, + E::GE: DlogGroup, { - comm: ::CompressedGroupElement, + comm: ::Compressed, } impl CommitmentTrait for Commitment where E: Engine, - E::GE: DlogGroup, + E::GE: DlogGroup, { type CompressedCommitment = CompressedCommitment; fn compress(&self) -> Self::CompressedCommitment { CompressedCommitment { - comm: self.comm.compress(), + comm: ::to_bytes(&self.comm).into(), } } @@ -90,13 +91,11 @@ where } fn decompress(c: &Self::CompressedCommitment) -> Result { - let comm = <::GE as DlogGroup>::CompressedGroupElement::decompress(&c.comm); - if comm.is_none() { + let opt_comm = <::GE as GroupEncoding>::from_bytes(&c.comm.clone().into()); + let Some(comm) = Option::from(opt_comm) else { return Err(NovaError::DecompressionError); - } - Ok(Commitment { - comm: comm.unwrap(), - }) + }; + Ok(Commitment { comm }) } } @@ -107,7 +106,7 @@ where { fn default() -> Self { Commitment { - comm: E::GE::zero(), + comm: E::GE::identity(), } } } @@ -149,7 +148,7 @@ where impl TranscriptReprTrait for CompressedCommitment where E: Engine, - E::GE: DlogGroup, + E::GE: DlogGroup, { fn to_transcript_bytes(&self) -> Vec { self.comm.to_transcript_bytes() @@ -159,7 +158,7 @@ where impl MulAssign for Commitment where E: Engine, - E::GE: DlogGroup, + E::GE: DlogGroup, { fn mul_assign(&mut self, scalar: E::Scalar) { *self = Commitment { @@ -171,7 +170,7 @@ where impl<'a, 'b, E> Mul<&'b E::Scalar> for &'a Commitment where E: Engine, - E::GE: DlogGroup, + E::GE: DlogGroup, { type Output = Commitment; fn mul(self, scalar: &'b E::Scalar) -> Commitment { @@ -184,7 +183,7 @@ where impl Mul for Commitment where E: Engine, - E::GE: DlogGroup, + E::GE: DlogGroup, { type Output = Commitment; @@ -198,7 +197,7 @@ where impl Add for Commitment where E: Engine, - E::GE: DlogGroup, + E::GE: DlogGroup, { type Output = Commitment; @@ -218,7 +217,7 @@ pub struct CommitmentEngine { impl CommitmentEngineTrait for CommitmentEngine where E: Engine, - E::GE: DlogGroup, + E::GE: DlogGroup, { type CommitmentKey = CommitmentKey; type Commitment = Commitment; @@ -268,7 +267,7 @@ where impl CommitmentKeyExtTrait for CommitmentKey where E: Engine>, - E::GE: DlogGroup, + E::GE: DlogGroup, { fn split_at(&self, n: usize) -> (CommitmentKey, CommitmentKey) { ( @@ -299,7 +298,7 @@ where .into_par_iter() .map(|i| { let bases = [L.ck[i].clone(), R.ck[i].clone()].to_vec(); - E::GE::vartime_multiscalar_mul(&w, &bases).preprocessed() + E::GE::vartime_multiscalar_mul(&w, &bases).to_affine() }) .collect(); @@ -312,7 +311,7 @@ where .ck .clone() .into_par_iter() - .map(|g| E::GE::vartime_multiscalar_mul(&[*r], &[g]).preprocessed()) + .map(|g| E::GE::vartime_multiscalar_mul(&[*r], &[g]).to_affine()) .collect(); CommitmentKey { ck: ck_scaled } @@ -326,7 +325,7 @@ where .collect::>, NovaError>>()?; let ck = (0..d.len()) .into_par_iter() - .map(|i| d[i].comm.preprocessed()) + .map(|i| d[i].comm.to_affine()) .collect(); Ok(CommitmentKey { ck }) } diff --git a/src/provider/secp_secq.rs b/src/provider/secp_secq.rs index 5be35bd6b..d8b0f6cf4 100644 --- a/src/provider/secp_secq.rs +++ b/src/provider/secp_secq.rs @@ -1,29 +1,23 @@ //! This module implements the Nova traits for `secp::Point`, `secp::Scalar`, `secq::Point`, `secq::Scalar`. use crate::{ impl_traits, - provider::{ - traits::{CompressedGroup, DlogGroup}, - util::msm::cpu_best_msm, - }, + provider::{traits::DlogGroup, util::msm::cpu_best_msm}, traits::{Group, PrimeFieldExt, TranscriptReprTrait}, }; use digest::{ExtendableOutput, Update}; use ff::{FromUniformBytes, PrimeField}; -use group::{cofactor::CofactorCurveAffine, Curve, Group as AnotherGroup, GroupEncoding}; +use group::{cofactor::CofactorCurveAffine, Curve, Group as AnotherGroup}; use num_bigint::BigInt; use num_traits::Num; use pasta_curves::arithmetic::{CurveAffine, CurveExt}; use rayon::prelude::*; use sha3::Shake256; use std::io::Read; - -use halo2curves::secp256k1::{Secp256k1, Secp256k1Affine, Secp256k1Compressed}; -use halo2curves::secq256k1::{Secq256k1, Secq256k1Affine, Secq256k1Compressed}; - /// Re-exports that give access to the standard aliases used in the code base, for secp pub mod secp256k1 { pub use halo2curves::secp256k1::{ Fp as Base, Fq as Scalar, Secp256k1 as Point, Secp256k1Affine as Affine, + Secp256k1Compressed as Compressed, }; } @@ -31,23 +25,18 @@ pub mod secp256k1 { pub mod secq256k1 { pub use halo2curves::secq256k1::{ Fp as Base, Fq as Scalar, Secq256k1 as Point, Secq256k1Affine as Affine, + Secq256k1Compressed as Compressed, }; } impl_traits!( secp256k1, - Secp256k1Compressed, - Secp256k1, - Secp256k1Affine, "fffffffffffffffffffffffffffffffebaaedce6af48a03bbfd25e8cd0364141", "fffffffffffffffffffffffffffffffffffffffffffffffffffffffefffffc2f" ); impl_traits!( secq256k1, - Secq256k1Compressed, - Secq256k1, - Secq256k1Affine, "fffffffffffffffffffffffffffffffffffffffffffffffffffffffefffffc2f", "fffffffffffffffffffffffffffffffebaaedce6af48a03bbfd25e8cd0364141" ); diff --git a/src/provider/traits.rs b/src/provider/traits.rs index 0ed61f420..caea6a747 100644 --- a/src/provider/traits.rs +++ b/src/provider/traits.rs @@ -1,69 +1,33 @@ -use crate::traits::{commitment::ScalarMul, Group, TranscriptReprTrait}; -use core::fmt::Debug; -use group::{GroupOps, GroupOpsOwned, ScalarMulOwned}; +use crate::traits::{Group, TranscriptReprTrait}; +use group::{prime::PrimeCurve, GroupEncoding}; use serde::{Deserialize, Serialize}; - -/// Represents a compressed version of a group element -pub trait CompressedGroup: - Clone - + Copy - + Debug - + Eq - + Send - + Sync - + TranscriptReprTrait - + Serialize - + for<'de> Deserialize<'de> - + 'static -{ - /// A type that holds the decompressed version of the compressed group element - type GroupElement: DlogGroup; - - /// Decompresses the compressed group element - fn decompress(&self) -> Option; -} +use std::fmt::Debug; /// A trait that defines extensions to the Group trait pub trait DlogGroup: - Group + Group::ScalarExt> + Serialize + for<'de> Deserialize<'de> - + GroupOps - + GroupOpsOwned - + ScalarMul<::Scalar> - + ScalarMulOwned<::Scalar> + + PrimeCurve::ScalarExt, Affine = ::AffineExt> { - /// A type representing the compressed version of the group element - type CompressedGroupElement: CompressedGroup; - - /// A type representing preprocessed group element - type PreprocessedGroupElement: Clone + type ScalarExt; + type AffineExt: Clone + Debug + Eq + Serialize + for<'de> Deserialize<'de> + Sync + Send; + type Compressed: Clone + Debug - + PartialEq + Eq - + Send - + Sync + + From<::Repr> + + Into<::Repr> + Serialize + for<'de> Deserialize<'de> + + Sync + + Send + TranscriptReprTrait; /// A method to compute a multiexponentation - fn vartime_multiscalar_mul( - scalars: &[Self::Scalar], - bases: &[Self::PreprocessedGroupElement], - ) -> Self; + fn vartime_multiscalar_mul(scalars: &[Self::ScalarExt], bases: &[Self::AffineExt]) -> Self; /// Produce a vector of group elements using a static label - fn from_label(label: &'static [u8], n: usize) -> Vec; - - /// Compresses the group element - fn compress(&self) -> Self::CompressedGroupElement; - - /// Produces a preprocessed element - fn preprocessed(&self) -> Self::PreprocessedGroupElement; - - /// Returns an element that is the additive identity of the group - fn zero() -> Self; + fn from_label(label: &'static [u8], n: usize) -> Vec; /// Returns the affine coordinates (x, y, infinty) for the point fn to_coordinates(&self) -> (::Base, ::Base, bool); @@ -77,11 +41,16 @@ pub trait DlogGroup: macro_rules! impl_traits { ( $name:ident, - $name_compressed:ident, - $name_curve:ident, - $name_curve_affine:ident, $order_str:literal, $base_str:literal + ) => { + $crate::impl_traits!($name, $order_str, $base_str, cpu_best_msm); + }; + ( + $name:ident, + $order_str:literal, + $base_str:literal, + $large_msm_method: ident ) => { impl Group for $name::Point { type Base = $name::Base; @@ -98,25 +67,24 @@ macro_rules! impl_traits { } impl DlogGroup for $name::Point { - type CompressedGroupElement = $name_compressed; - type PreprocessedGroupElement = $name::Affine; - - fn vartime_multiscalar_mul( - scalars: &[Self::Scalar], - bases: &[Self::PreprocessedGroupElement], - ) -> Self { - cpu_best_msm(scalars, bases) - } - - fn preprocessed(&self) -> Self::PreprocessedGroupElement { - self.to_affine() - } - - fn compress(&self) -> Self::CompressedGroupElement { - self.to_bytes() + type ScalarExt = $name::Scalar; + type AffineExt = $name::Affine; + // note: for halo2curves implementations, $name::Compressed == <$name::Point as GroupEncoding>::Repr + // so the blanket impl From for T and impl Into apply. + type Compressed = $name::Compressed; + + fn vartime_multiscalar_mul(scalars: &[Self::ScalarExt], bases: &[Self::AffineExt]) -> Self { + #[cfg(any(target_arch = "x86_64", target_arch = "aarch64"))] + if scalars.len() >= 128 { + $large_msm_method(bases, scalars) + } else { + cpu_best_msm(bases, scalars) + } + #[cfg(not(any(target_arch = "x86_64", target_arch = "aarch64")))] + cpu_best_msm(bases, scalars) } - fn from_label(label: &'static [u8], n: usize) -> Vec { + fn from_label(label: &'static [u8], n: usize) -> Vec { let mut shake = Shake256::default(); shake.update(label); let mut reader = shake.finalize_xof(); @@ -126,10 +94,10 @@ macro_rules! impl_traits { reader.read_exact(&mut uniform_bytes).unwrap(); uniform_bytes_vec.push(uniform_bytes); } - let gens_proj: Vec<$name_curve> = (0..n) + let gens_proj: Vec<$name::Point> = (0..n) .into_par_iter() .map(|i| { - let hash = $name_curve::hash_to_curve("from_uniform_bytes"); + let hash = $name::Point::hash_to_curve("from_uniform_bytes"); hash(&uniform_bytes_vec[i]) }) .collect(); @@ -147,7 +115,7 @@ macro_rules! impl_traits { core::cmp::min((i + 1) * chunk, gens_proj.len()) }; if end > start { - let mut gens = vec![$name_curve_affine::identity(); end - start]; + let mut gens = vec![$name::Affine::identity(); end - start]; ::batch_normalize(&gens_proj[start..end], &mut gens); gens } else { @@ -156,21 +124,15 @@ macro_rules! impl_traits { }) .collect() } else { - let mut gens = vec![$name_curve_affine::identity(); n]; + let mut gens = vec![$name::Affine::identity(); n]; ::batch_normalize(&gens_proj, &mut gens); gens } } - fn zero() -> Self { - $name::Point::identity() - } - fn to_coordinates(&self) -> (Self::Base, Self::Base, bool) { let coordinates = self.to_affine().coordinates(); - if coordinates.is_some().unwrap_u8() == 1 - && ($name_curve_affine::identity() != self.to_affine()) - { + if coordinates.is_some().unwrap_u8() == 1 && ($name::Point::identity() != *self) { (*coordinates.unwrap().x(), *coordinates.unwrap().y(), false) } else { (Self::Base::zero(), Self::Base::zero(), true) @@ -185,20 +147,12 @@ macro_rules! impl_traits { } } - impl TranscriptReprTrait for $name_compressed { + impl TranscriptReprTrait for $name::Compressed { fn to_transcript_bytes(&self) -> Vec { self.as_ref().to_vec() } } - impl CompressedGroup for $name_compressed { - type GroupElement = $name::Point; - - fn decompress(&self) -> Option<$name::Point> { - Some($name_curve::from_bytes(&self).unwrap()) - } - } - impl TranscriptReprTrait for $name::Scalar { fn to_transcript_bytes(&self) -> Vec { self.to_repr().to_vec() @@ -209,7 +163,7 @@ macro_rules! impl_traits { fn to_transcript_bytes(&self) -> Vec { let (x, y, is_infinity_byte) = { let coordinates = self.coordinates(); - if coordinates.is_some().unwrap_u8() == 1 && ($name_curve_affine::identity() != *self) { + if coordinates.is_some().unwrap_u8() == 1 && ($name::Affine::identity() != *self) { let c = coordinates.unwrap(); (*c.x(), *c.y(), u8::from(false)) } else { diff --git a/src/provider/util/msm.rs b/src/provider/util/msm.rs index fa2432688..e206575ea 100644 --- a/src/provider/util/msm.rs +++ b/src/provider/util/msm.rs @@ -95,7 +95,7 @@ fn cpu_msm_serial(coeffs: &[C::Scalar], bases: &[C]) -> C::Curve /// /// This will use multithreading if beneficial. /// Adapted from zcash/halo2 -pub(crate) fn cpu_best_msm(coeffs: &[C::Scalar], bases: &[C]) -> C::Curve { +pub(crate) fn cpu_best_msm(bases: &[C], coeffs: &[C::Scalar]) -> C::Curve { assert_eq!(coeffs.len(), bases.len()); let num_threads = current_num_threads(); @@ -137,7 +137,7 @@ mod tests { .fold(A::CurveExt::identity(), |acc, (coeff, base)| { acc + *base * coeff }); - let msm = cpu_best_msm(&coeffs, &bases); + let msm = cpu_best_msm(&bases, &coeffs); assert_eq!(naive, msm) }