diff --git a/src/edwards.rs b/src/edwards.rs index 04c656510..3ab3e60b2 100644 --- a/src/edwards.rs +++ b/src/edwards.rs @@ -677,17 +677,17 @@ impl VartimeMultiscalarMul for EdwardsPoint { // This wraps the inner implementation in a facade type so that we can // decouple stability of the inner type from the stability of the // outer type. -#[cfg(all(feature = "alloc", feature = "yolocrypto"))] +#[cfg(feature = "alloc")] pub struct EdwardsPrecomputation(scalar_mul::precomputed_straus::PrecomputedStraus); /// Precomputation for variable-time multiscalar multiplication with `EdwardsPoint`s. // This wraps the inner implementation in a facade type so that we can // decouple stability of the inner type from the stability of the // outer type. -#[cfg(all(feature = "alloc", feature = "yolocrypto"))] +#[cfg(feature = "alloc")] pub struct VartimeEdwardsPrecomputation(scalar_mul::precomputed_straus::VartimePrecomputedStraus); -#[cfg(all(feature = "alloc", feature = "yolocrypto"))] +#[cfg(feature = "alloc")] impl PrecomputedMultiscalarMul for EdwardsPrecomputation { type Point = EdwardsPoint; @@ -720,7 +720,7 @@ impl PrecomputedMultiscalarMul for EdwardsPrecomputation { } } -#[cfg(all(feature = "alloc", feature = "yolocrypto"))] +#[cfg(feature = "alloc")] impl VartimePrecomputedMultiscalarMul for VartimeEdwardsPrecomputation { type Point = EdwardsPoint; @@ -1281,7 +1281,6 @@ mod test { } #[test] - #[cfg(feature = "yolocrypto")] fn precomputed_vs_nonprecomputed_multiscalar() { let mut rng = rand::thread_rng(); @@ -1325,7 +1324,6 @@ mod test { } #[test] - #[cfg(feature = "yolocrypto")] fn vartime_precomputed_vs_nonprecomputed_multiscalar() { let mut rng = rand::thread_rng(); diff --git a/src/ristretto.rs b/src/ristretto.rs index 2daabaf1c..34e3c332d 100644 --- a/src/ristretto.rs +++ b/src/ristretto.rs @@ -187,7 +187,21 @@ use scalar::Scalar; use traits::Identity; #[cfg(any(feature = "alloc", feature = "std"))] -use traits::{MultiscalarMul, VartimeMultiscalarMul}; +use traits::{ + MultiscalarMul, PrecomputedMultiscalarMul, VartimeMultiscalarMul, + VartimePrecomputedMultiscalarMul, +}; + +#[cfg(not(all( + feature = "simd_backend", + any(target_feature = "avx2", target_feature = "avx512ifma") +)))] +use backend::serial::scalar_mul; +#[cfg(all( + feature = "simd_backend", + any(target_feature = "avx2", target_feature = "avx512ifma") +))] +use backend::vector::scalar_mul; // ------------------------------------------------------------------------ // Compressed points @@ -891,8 +905,96 @@ impl VartimeMultiscalarMul for RistrettoPoint { { let extended_points = points.into_iter().map(|opt_P| opt_P.map(|P| P.borrow().0)); - EdwardsPoint::optional_multiscalar_mul(scalars, extended_points) - .map(|P| RistrettoPoint(P)) + EdwardsPoint::optional_multiscalar_mul(scalars, extended_points).map(|P| RistrettoPoint(P)) + } +} + +/// Precomputation for multiscalar multiplication with `RistrettoPoint`s. +// This wraps the inner implementation in a facade type so that we can +// decouple stability of the inner type from the stability of the +// outer type. +#[cfg(feature = "alloc")] +pub struct RistrettoPrecomputation(scalar_mul::precomputed_straus::PrecomputedStraus); + +/// Precomputation for variable-time multiscalar multiplication with `RistrettoPoint`s. +// This wraps the inner implementation in a facade type so that we can +// decouple stability of the inner type from the stability of the +// outer type. +#[cfg(feature = "alloc")] +pub struct VartimeRistrettoPrecomputation(scalar_mul::precomputed_straus::VartimePrecomputedStraus); + +#[cfg(feature = "alloc")] +impl PrecomputedMultiscalarMul for RistrettoPrecomputation { + type Point = RistrettoPoint; + + fn new(static_points: I) -> Self + where + I: IntoIterator, + I::Item: Borrow, + { + Self(scalar_mul::precomputed_straus::PrecomputedStraus::new( + static_points.into_iter().map(|P| P.borrow().0), + )) + } + + fn mixed_multiscalar_mul( + &self, + static_scalars: I, + dynamic_scalars: J, + dynamic_points: K, + ) -> Self::Point + where + I: IntoIterator, + I::Item: Borrow, + J: IntoIterator, + J::Item: Borrow, + K: IntoIterator, + K::Item: Borrow, + { + RistrettoPoint(self.0.mixed_multiscalar_mul( + static_scalars, + dynamic_scalars, + dynamic_points.into_iter().map(|P| P.borrow().0), + )) + } +} + +#[cfg(feature = "alloc")] +impl VartimePrecomputedMultiscalarMul for VartimeRistrettoPrecomputation { + type Point = RistrettoPoint; + + fn new(static_points: I) -> Self + where + I: IntoIterator, + I::Item: Borrow, + { + Self( + scalar_mul::precomputed_straus::VartimePrecomputedStraus::new( + static_points.into_iter().map(|P| P.borrow().0), + ), + ) + } + + fn optional_mixed_multiscalar_mul( + &self, + static_scalars: I, + dynamic_scalars: J, + dynamic_points: K, + ) -> Option + where + I: IntoIterator, + I::Item: Borrow, + J: IntoIterator, + J::Item: Borrow, + K: IntoIterator>, + { + self.0 + .optional_mixed_multiscalar_mul( + static_scalars, + dynamic_scalars, + dynamic_points.into_iter().map(|P_opt| P_opt.map(|P| P.0)), + ) + .map(|P_ed| RistrettoPoint(P_ed)) } } @@ -1263,4 +1365,90 @@ mod test { P.compress(); } } + + #[test] + fn precomputed_vs_nonprecomputed_multiscalar() { + let mut rng = rand::thread_rng(); + + let B = &::constants::RISTRETTO_BASEPOINT_TABLE; + + let static_scalars = (0..128) + .map(|_| Scalar::random(&mut rng)) + .collect::>(); + + let dynamic_scalars = (0..128) + .map(|_| Scalar::random(&mut rng)) + .collect::>(); + + let check_scalar: Scalar = static_scalars + .iter() + .chain(dynamic_scalars.iter()) + .map(|s| s * s) + .sum(); + + let static_points = static_scalars.iter().map(|s| s * B).collect::>(); + let dynamic_points = dynamic_scalars.iter().map(|s| s * B).collect::>(); + + let precomputation = RistrettoPrecomputation::new(static_points.iter()); + + let P = precomputation.mixed_multiscalar_mul( + &static_scalars, + &dynamic_scalars, + &dynamic_points, + ); + + use traits::MultiscalarMul; + let Q = RistrettoPoint::multiscalar_mul( + static_scalars.iter().chain(dynamic_scalars.iter()), + static_points.iter().chain(dynamic_points.iter()), + ); + + let R = &check_scalar * B; + + assert_eq!(P.compress(), R.compress()); + assert_eq!(Q.compress(), R.compress()); + } + + #[test] + fn vartime_precomputed_vs_nonprecomputed_multiscalar() { + let mut rng = rand::thread_rng(); + + let B = &::constants::RISTRETTO_BASEPOINT_TABLE; + + let static_scalars = (0..128) + .map(|_| Scalar::random(&mut rng)) + .collect::>(); + + let dynamic_scalars = (0..128) + .map(|_| Scalar::random(&mut rng)) + .collect::>(); + + let check_scalar: Scalar = static_scalars + .iter() + .chain(dynamic_scalars.iter()) + .map(|s| s * s) + .sum(); + + let static_points = static_scalars.iter().map(|s| s * B).collect::>(); + let dynamic_points = dynamic_scalars.iter().map(|s| s * B).collect::>(); + + let precomputation = VartimeRistrettoPrecomputation::new(static_points.iter()); + + let P = precomputation.vartime_mixed_multiscalar_mul( + &static_scalars, + &dynamic_scalars, + &dynamic_points, + ); + + use traits::VartimeMultiscalarMul; + let Q = RistrettoPoint::vartime_multiscalar_mul( + static_scalars.iter().chain(dynamic_scalars.iter()), + static_points.iter().chain(dynamic_points.iter()), + ); + + let R = &check_scalar * B; + + assert_eq!(P.compress(), R.compress()); + assert_eq!(Q.compress(), R.compress()); + } }