diff --git a/ec/src/msm/fixed_base.rs b/ec/src/msm/fixed_base.rs index 6d0d3f6b6..7f10a9ade 100644 --- a/ec/src/msm/fixed_base.rs +++ b/ec/src/msm/fixed_base.rs @@ -1,6 +1,7 @@ -use crate::ProjectiveCurve; +use crate::{AffineCurve, ProjectiveCurve}; use ark_ff::{BigInteger, FpParameters, PrimeField}; use ark_std::vec::Vec; +use ark_std::{cfg_iter, cfg_iter_mut}; #[cfg(feature = "parallel")] use rayon::prelude::*; @@ -16,12 +17,11 @@ impl FixedBaseMSM { } } - // TODO: parallelize this by extracting out `g_outer` computation. pub fn get_window_table( scalar_size: usize, window: usize, g: T, - ) -> Vec> { + ) -> Vec> { let in_window = 1 << window; let outerc = (scalar_size + window - 1) / window; let last_in_window = 1 << (scalar_size - (outerc - 1) * window); @@ -29,7 +29,18 @@ impl FixedBaseMSM { let mut multiples_of_g = vec![vec![T::zero(); in_window]; outerc]; let mut g_outer = g; - for (outer, multiples_of_g) in multiples_of_g.iter_mut().enumerate().take(outerc) { + let mut g_outers = Vec::with_capacity(outerc); + for _ in 0..outerc { + g_outers.push(g_outer); + for _ in 0..window { + g_outer.double_in_place(); + } + } + for ((outer, multiples_of_g), g_outer) in cfg_iter_mut!(multiples_of_g) + .enumerate() + .take(outerc) + .zip(g_outers) + { let mut g_inner = T::zero(); let cur_in_window = if outer == outerc - 1 { last_in_window @@ -40,23 +51,22 @@ impl FixedBaseMSM { *inner = g_inner; g_inner += &g_outer; } - for _ in 0..window { - g_outer.double_in_place(); - } } - multiples_of_g + cfg_iter!(multiples_of_g) + .map(|s| T::batch_normalization_into_affine(&s)) + .collect() } pub fn windowed_mul( outerc: usize, window: usize, - multiples_of_g: &[Vec], + multiples_of_g: &[Vec], scalar: &T::ScalarField, ) -> T { let mut scalar_val = scalar.into_repr().to_bits(); scalar_val.reverse(); - let mut res = multiples_of_g[0][0]; + let mut res = multiples_of_g[0][0].into_projective(); for outer in 0..outerc { let mut inner = 0usize; for i in 0..window { @@ -67,7 +77,7 @@ impl FixedBaseMSM { inner |= 1 << i; } } - res += &multiples_of_g[outer][inner]; + res.add_assign_mixed(&multiples_of_g[outer][inner]); } res } @@ -75,13 +85,13 @@ impl FixedBaseMSM { pub fn multi_scalar_mul( scalar_size: usize, window: usize, - table: &[Vec], + table: &[Vec], v: &[T::ScalarField], ) -> Vec { let outerc = (scalar_size + window - 1) / window; assert!(outerc <= table.len()); - ark_std::cfg_iter!(v) + cfg_iter!(v) .map(|e| Self::windowed_mul::(outerc, window, table, e)) .collect::>() }