From 4648de93fbf013fc426c6f3de6b18416f6f4a1e7 Mon Sep 17 00:00:00 2001 From: Bhargav Annem Date: Mon, 26 Aug 2024 12:52:58 -0700 Subject: [PATCH 1/3] feat(wip): msm optimizations --- src/fp.rs | 115 +++++++++++---------- src/g1.rs | 289 +++++++++++++++++++++++++++++++++++++---------------- src/lib.rs | 2 +- 3 files changed, 264 insertions(+), 142 deletions(-) diff --git a/src/fp.rs b/src/fp.rs index 08c4803e..30394143 100644 --- a/src/fp.rs +++ b/src/fp.rs @@ -473,13 +473,18 @@ impl Fp { } #[inline] - #[cfg(target_os = "zkvm")] pub fn add_inp(&mut self, rhs: &Fp) { - unsafe { - syscall_bls12381_fp_addmod( - self.0.as_mut_ptr() as *mut u32, - rhs.0.as_ptr() as *const u32, - ); + cfg_if! { + if #[cfg(target_os = "zkvm")] { + unsafe { + syscall_bls12381_fp_addmod( + self.0.as_mut_ptr() as *mut u32, + rhs.0.as_ptr() as *const u32, + ); + } + } else { + *self = self.cpu_add(rhs); + } } } @@ -500,15 +505,19 @@ impl Fp { #[inline] pub fn add(&self, rhs: &Fp) -> Fp { + let mut out = *self; + out.add_inp(rhs); + out + } + + #[inline] + pub fn double_inp(&mut self) { cfg_if::cfg_if! { if #[cfg(target_os = "zkvm")] { - let mut out = self.clone(); - unsafe { - syscall_bls12381_fp_addmod(out.0.as_mut_ptr() as *mut u32, rhs.0.as_ptr() as *const u32); - } - out + let tmp = *self; + self.add_inp(&tmp); } else { - self.cpu_add(rhs) + *self = self.cpu_add(self); } } } @@ -554,31 +563,28 @@ impl Fp { } #[inline] - #[cfg(target_os = "zkvm")] pub fn sub_inp(&mut self, rhs: &Fp) { - unsafe { - syscall_bls12381_fp_submod( - self.0.as_mut_ptr() as *mut u32, - rhs.0.as_ptr() as *const u32, - ); - } - } - - #[inline] - pub fn sub(&self, rhs: &Fp) -> Fp { - cfg_if::cfg_if! { + cfg_if! { if #[cfg(target_os = "zkvm")] { - let mut out = self.clone(); unsafe { - syscall_bls12381_fp_submod(out.0.as_mut_ptr() as *mut u32, rhs.0.as_ptr() as *const u32); + syscall_bls12381_fp_submod( + self.0.as_mut_ptr() as *mut u32, + rhs.0.as_ptr() as *const u32, + ); } - out } else { - rhs.neg().add(self) + *self = self.cpu_sub(rhs); } } } + #[inline] + pub fn sub(&self, rhs: &Fp) -> Fp { + let mut out = *self; + out.sub_inp(rhs); + out + } + #[inline] /// CPU version of the subtraction operation. Necessary to prevent syscalls in unconstrained mode. pub(crate) fn cpu_sub(&self, rhs: &Fp) -> Fp { @@ -740,15 +746,20 @@ impl Fp { } #[inline] - #[cfg(target_os = "zkvm")] pub fn mul_inp(&mut self, rhs: &Fp) { - unsafe { - syscall_bls12381_fp_mulmod( - self.0.as_mut_ptr() as *mut u32, - rhs.0.as_ptr() as *const u32, - ); + cfg_if! { + if #[cfg(target_os = "zkvm")] { + unsafe { + syscall_bls12381_fp_mulmod( + self.0.as_mut_ptr() as *mut u32, + rhs.0.as_ptr() as *const u32, + ); + } + self.mul_r_inv_internal(); + } else { + *self = self.cpu_mul(rhs); + } } - self.mul_r_inv_internal(); } #[inline] @@ -801,18 +812,9 @@ impl Fp { #[inline] pub fn mul(&self, rhs: &Fp) -> Fp { - cfg_if::cfg_if! { - if #[cfg(target_os = "zkvm")] { - let mut out = self.clone(); - unsafe { - syscall_bls12381_fp_mulmod(out.0.as_mut_ptr() as *mut u32, rhs.0.as_ptr() as *const u32); - } - out.mul_r_inv_internal(); - out - } else { - self.cpu_mul(rhs) - } - } + let mut out = *self; + out.mul_inp(rhs); + out } /// Internal function to multiply the internal representation by `R_INV`, equivalent to transforming from @@ -841,15 +843,20 @@ impl Fp { } #[inline] - #[cfg(target_os = "zkvm")] pub fn square_inp(&mut self) { - unsafe { - syscall_bls12381_fp_mulmod( - self.0.as_mut_ptr() as *mut u32, - self.0.as_ptr() as *const u32, - ); + cfg_if! { + if #[cfg(target_os = "zkvm")] { + unsafe { + syscall_bls12381_fp_mulmod( + self.0.as_mut_ptr() as *mut u32, + self.0.as_ptr() as *const u32, + ); + } + self.mul_r_inv_internal(); + } else { + *self = self.cpu_square(); + } } - self.mul_r_inv_internal(); } /// CPU version of the squaring operation. Necessary to prevent syscalls in unconstrained mode. diff --git a/src/g1.rs b/src/g1.rs index 74e94f4e..ff6100dd 100644 --- a/src/g1.rs +++ b/src/g1.rs @@ -138,7 +138,15 @@ impl<'a, 'b> Add<&'b G1Projective> for &'a G1Affine { #[inline] fn add(self, rhs: &'b G1Projective) -> G1Projective { - rhs.add_mixed(self) + cfg_if::cfg_if! { + if #[cfg(target_os = "zkvm")] { + let affine_rhs = G1Affine::from(rhs); + let affine_sum = self + &affine_rhs; + G1Projective::from(affine_sum) + } else { + rhs.add_mixed(self) + } + } } } @@ -147,7 +155,24 @@ impl<'a, 'b> Add<&'b G1Affine> for &'a G1Projective { #[inline] fn add(self, rhs: &'b G1Affine) -> G1Projective { - self.add_mixed(rhs) + cfg_if::cfg_if! { + if #[cfg(target_os = "zkvm")] { + let affine_self = G1Affine::from(self); + let affine_sum = &affine_self + rhs; + G1Projective::from(affine_sum) + } else { + self.add_mixed(rhs) + } + } + } +} + +impl<'a, 'b> Add<&'b G1Affine> for &'a G1Affine { + type Output = G1Affine; + + #[inline] + fn add(self, rhs: &'b G1Affine) -> G1Affine { + self.add_affine(rhs) } } @@ -619,7 +644,16 @@ impl<'a, 'b> Add<&'b G1Projective> for &'a G1Projective { #[inline] fn add(self, rhs: &'b G1Projective) -> G1Projective { - self.add(rhs) + cfg_if::cfg_if! { + if #[cfg(target_os = "zkvm")] { + let affine_self = &G1Affine::from(*self); + let affine_rhs = &G1Affine::from(*rhs); + let affine_sum = affine_self + affine_rhs; + G1Projective::from(affine_sum) + } else { + self.add_proj(rhs) + } + } } } @@ -672,11 +706,10 @@ impl_binops_multiplicative_mixed!(G1Affine, Scalar, G1Projective); impl_binops_multiplicative_mixed!(Scalar, G1Affine, G1Projective); impl_binops_multiplicative_mixed!(Scalar, G1Projective, G1Projective); -#[inline(always)] -fn mul_by_3b(a: Fp) -> Fp { - let a = a + a; // 2 - let a = a + a; // 4 - a + a + a // 12 +fn mul_by_3b_inp(a: &mut Fp) { + a.double_inp(); + a.double_inp(); + a.double_inp(); } impl G1Projective { @@ -716,25 +749,44 @@ impl G1Projective { /// Computes the doubling of this point. pub fn double(&self) -> G1Projective { // Algorithm 9, https://eprint.iacr.org/2015/1060.pdf + let mut t0 = self.y; + t0.square_inp(); + + let mut z3 = t0; + z3.double_inp(); + z3.double_inp(); + z3.double_inp(); + + let mut t1 = self.y; + t1.mul_inp(&self.z); + + let mut t2 = self.z; + t2.square_inp(); + mul_by_3b_inp(&mut t2); + + let mut x3 = t2; + x3.mul_inp(&z3); - let t0 = self.y.square(); - let z3 = t0 + t0; - let z3 = z3 + z3; - let z3 = z3 + z3; - let t1 = self.y * self.z; - let t2 = self.z.square(); - let t2 = mul_by_3b(t2); - let x3 = t2 * z3; - let y3 = t0 + t2; - let z3 = t1 * z3; - let t1 = t2 + t2; - let t2 = t1 + t2; - let t0 = t0 - t2; - let y3 = t0 * y3; - let y3 = x3 + y3; - let t1 = self.x * self.y; - let x3 = t0 * t1; - let x3 = x3 + x3; + let mut y3 = t0; + y3.add_inp(&t2); + + z3.mul_inp(&t1); + + let mut t1 = t2; + t1.double_inp(); + + t2.add_inp(&t1); + t0.sub_inp(&t2); + + y3.mul_inp(&t0); + y3.add_inp(&x3); + + let mut t1 = self.x; + t1.mul_inp(&self.y); + + x3 = t0; + x3.mul_inp(&t1); + x3.double_inp(); let tmp = G1Projective { x: x3, @@ -746,42 +798,74 @@ impl G1Projective { } /// Adds this point to another point. - pub fn add(&self, rhs: &G1Projective) -> G1Projective { + pub fn add_proj(&self, rhs: &G1Projective) -> G1Projective { // Algorithm 7, https://eprint.iacr.org/2015/1060.pdf + let mut t0 = self.x; + t0.mul_inp(&rhs.x); + + let mut t1 = self.y; + t1.mul_inp(&rhs.y); + + let mut t2 = self.z; + t2.mul_inp(&rhs.z); - let t0 = self.x * rhs.x; - let t1 = self.y * rhs.y; - let t2 = self.z * rhs.z; - let t3 = self.x + self.y; - let t4 = rhs.x + rhs.y; - let t3 = t3 * t4; - let t4 = t0 + t1; - let t3 = t3 - t4; - let t4 = self.y + self.z; - let x3 = rhs.y + rhs.z; - let t4 = t4 * x3; - let x3 = t1 + t2; - let t4 = t4 - x3; - let x3 = self.x + self.z; - let y3 = rhs.x + rhs.z; - let x3 = x3 * y3; - let y3 = t0 + t2; - let y3 = x3 - y3; - let x3 = t0 + t0; - let t0 = x3 + t0; - let t2 = mul_by_3b(t2); - let z3 = t1 + t2; - let t1 = t1 - t2; - let y3 = mul_by_3b(y3); - let x3 = t4 * y3; - let t2 = t3 * t1; - let x3 = t2 - x3; - let y3 = y3 * t0; - let t1 = t1 * z3; - let y3 = t1 + y3; - let t0 = t0 * t3; - let z3 = z3 * t4; - let z3 = z3 + t0; + let mut t3 = self.x; + t3.add_inp(&self.y); + + let mut t4 = rhs.x; + t4.add_inp(&rhs.y); + t3.mul_inp(&t4); + + t4 = t0; + t4.add_inp(&t1); + t3.sub_inp(&t4); + + t4 = self.y; + t4.add_inp(&self.z); + + let mut x3 = rhs.y; + x3.add_inp(&rhs.z); + t4.mul_inp(&x3); + + x3 = t1; + x3.add_inp(&t2); + t4.sub_inp(&x3); + + x3 = self.x; + x3.add_inp(&self.z); + + let mut y3 = rhs.x; + y3.add_inp(&rhs.z); + x3.mul_inp(&y3); + + y3 = t0; + y3.add_inp(&t2); + x3.sub_inp(&y3); + + y3 = t0; + y3.double_inp(); + t0.add_inp(&y3); + + mul_by_3b_inp(&mut t2); + + let mut z3 = t1; + z3.add_inp(&t2); + t1.sub_inp(&t2); + + mul_by_3b_inp(&mut y3); + x3.mul_inp(&y3); + + t2 = t3; + t2.mul_inp(&t1); + x3 = t2; + + y3.mul_inp(&t0); + t1.mul_inp(&z3); + y3.add_inp(&t1); + + t0.mul_inp(&t3); + z3.mul_inp(&t4); + z3.add_inp(&t0); G1Projective { x: x3, @@ -793,33 +877,56 @@ impl G1Projective { /// Adds this point to another point in the affine model. pub fn add_mixed(&self, rhs: &G1Affine) -> G1Projective { // Algorithm 8, https://eprint.iacr.org/2015/1060.pdf + let mut t0 = self.x; + t0.mul_inp(&rhs.x); + + let mut t1 = self.y; + t1.mul_inp(&rhs.y); - let t0 = self.x * rhs.x; - let t1 = self.y * rhs.y; - let t3 = rhs.x + rhs.y; - let t4 = self.x + self.y; - let t3 = t3 * t4; - let t4 = t0 + t1; - let t3 = t3 - t4; - let t4 = rhs.y * self.z; - let t4 = t4 + self.y; - let y3 = rhs.x * self.z; - let y3 = y3 + self.x; - let x3 = t0 + t0; - let t0 = x3 + t0; - let t2 = mul_by_3b(self.z); - let z3 = t1 + t2; - let t1 = t1 - t2; - let y3 = mul_by_3b(y3); - let x3 = t4 * y3; - let t2 = t3 * t1; - let x3 = t2 - x3; - let y3 = y3 * t0; - let t1 = t1 * z3; - let y3 = t1 + y3; - let t0 = t0 * t3; - let z3 = z3 * t4; - let z3 = z3 + t0; + let mut t3 = rhs.x; + t3.add_inp(&rhs.y); + + let mut t4 = self.x; + t4.add_inp(&self.y); + t3.mul_inp(&t4); + + t4 = t0; + t4.add_inp(&t1); + t3.sub_inp(&t4); + + t4 = rhs.y; + t4.mul_inp(&self.z); + t4.add_inp(&self.y); + + let mut y3 = rhs.x; + y3.mul_inp(&self.z); + y3.add_inp(&self.x); + + let mut x3 = t0; + x3.double_inp(); + x3.add_inp(&t0); + + let mut t2 = self.z; + mul_by_3b_inp(&mut t2); + + let mut z3 = t1; + z3.add_inp(&t2); + t1.sub_inp(&t2); + + mul_by_3b_inp(&mut y3); + x3.mul_inp(&t4); + + t2 = t3; + t2.mul_inp(&t1); + x3.sub_inp(&t2); + + y3.mul_inp(&t0); + t1.mul_inp(&z3); + y3.add_inp(&t1); + + t0.mul_inp(&t3); + z3.mul_inp(&t4); + z3.add_inp(&t0); let tmp = G1Projective { x: x3, @@ -1148,7 +1255,15 @@ impl Group for G1Projective { #[must_use] fn double(&self) -> Self { - self.double() + cfg_if::cfg_if! { + if #[cfg(target_os = "zkvm")] { + let affine = G1Affine::from(*self); + let doubled_affine = &affine + &affine; + G1Projective::from(doubled_affine) + } else { + self.double() + } + } } } diff --git a/src/lib.rs b/src/lib.rs index abc351bd..fcb90c3e 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -8,7 +8,7 @@ //! * This implementation does not require the Rust standard library. //! * All operations are constant time unless explicitly noted. -#![no_std] +// #![no_std] #![cfg_attr(docsrs, feature(doc_cfg))] // Catch documentation errors caused by code changes. #![deny(rustdoc::broken_intra_doc_links)] From 485cb1aee7130ea6e7d1f7708e4a8bac9c58eaa3 Mon Sep 17 00:00:00 2001 From: Bhargav Annem Date: Mon, 26 Aug 2024 14:01:24 -0700 Subject: [PATCH 2/3] feat(wip): msm optimizations --- src/g1.rs | 18 ++++++------------ 1 file changed, 6 insertions(+), 12 deletions(-) diff --git a/src/g1.rs b/src/g1.rs index ff6100dd..e0f727ca 100644 --- a/src/g1.rs +++ b/src/g1.rs @@ -644,16 +644,7 @@ impl<'a, 'b> Add<&'b G1Projective> for &'a G1Projective { #[inline] fn add(self, rhs: &'b G1Projective) -> G1Projective { - cfg_if::cfg_if! { - if #[cfg(target_os = "zkvm")] { - let affine_self = &G1Affine::from(*self); - let affine_rhs = &G1Affine::from(*rhs); - let affine_sum = affine_self + affine_rhs; - G1Projective::from(affine_sum) - } else { - self.add_proj(rhs) - } - } + self.add_proj(rhs) } } @@ -1041,6 +1032,7 @@ impl G1Projective { /// Performs a Variable Base Multiscalar Multiplication. pub fn msm_variable_base(points: &[G1Projective], scalars: &[Scalar]) -> G1Projective { + println!("cycle-tracker-start: msm_variable_base"); let c = if scalars.len() < 32 { 3 } else { @@ -1114,7 +1106,7 @@ impl G1Projective { // We store the sum for the lowest window. let lowest = *window_sums.first().unwrap(); // We're traversing windows from high to low. - window_sums[1..] + let out = window_sums[1..] .iter() .rev() .fold(zero, |mut total, sum_i| { @@ -1124,7 +1116,9 @@ impl G1Projective { } total }) - + lowest + + lowest; + println!("cycle-tracker-end: msm_variable_base"); + out } } From 93e69065295df18bb8f6428baa4c3797789bd9f9 Mon Sep 17 00:00:00 2001 From: Bhargav Annem Date: Mon, 26 Aug 2024 14:15:29 -0700 Subject: [PATCH 3/3] chore(wip): debug msm --- src/g1.rs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/g1.rs b/src/g1.rs index e0f727ca..70baa4fa 100644 --- a/src/g1.rs +++ b/src/g1.rs @@ -1032,7 +1032,7 @@ impl G1Projective { /// Performs a Variable Base Multiscalar Multiplication. pub fn msm_variable_base(points: &[G1Projective], scalars: &[Scalar]) -> G1Projective { - println!("cycle-tracker-start: msm_variable_base"); + println!("cycle-tracker-report-start: msm_variable_base"); let c = if scalars.len() < 32 { 3 } else { @@ -1117,7 +1117,7 @@ impl G1Projective { total }) + lowest; - println!("cycle-tracker-end: msm_variable_base"); + println!("cycle-tracker-report-end: msm_variable_base"); out } }