From dd5b9d65ea3b419349a88cb84c571dda18b80aa5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Michele=20Orr=C3=B9?= Date: Mon, 1 Aug 2022 23:29:28 +0200 Subject: [PATCH] Improve ergonomics of scalar field multiplication. (#443) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Michele Orrù Co-authored-by: George Kadianakis Co-authored-by: Pratyush Mishra Co-authored-by: Weikeng Chen --- CHANGELOG.md | 7 +++++ ec/src/lib.rs | 12 +++++--- ec/src/models/short_weierstrass.rs | 30 ++++++++++++------- ec/src/models/twisted_edwards.rs | 32 ++++++++++++-------- test-curves/src/bls12_381/g2.rs | 4 +-- test-templates/src/curves.rs | 47 +++++++++++++++--------------- test-templates/src/msm.rs | 2 +- 7 files changed, 80 insertions(+), 54 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 1ad9ddb9b..f0bd6c184 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -46,6 +46,13 @@ - `TEModelParameters` → `TECurveConfig` - `MontgomeryModelParameters` → `MontCurveConfig` - [\#440](https://github.com/arkworks-rs/algebra/pull/440) (`ark-ff`) Add a method to construct a field element from an element of the underlying base prime field. +- [\#443](https://github.com/arkworks-rs/algebra/pull/443) (`ark-ec`) Improve ergonomics of scalar multiplication. + - Rename `ProjectiveCurve::mul(AsRef[u64])` to `ProjectiveCurve::mul_bigint(AsRef[u64])`. + - Bound `ProjectiveCurve` by + - `Mul`, + - `for<'a> Mul<&'a ScalarField>` + - `MulAssign`, + - `for<'a> MulAssign<&'a ScalarField>` ### Features diff --git a/ec/src/lib.rs b/ec/src/lib.rs index c85ae86b2..d57a41e70 100644 --- a/ec/src/lib.rs +++ b/ec/src/lib.rs @@ -27,7 +27,7 @@ use ark_serialize::{CanonicalDeserialize, CanonicalSerialize}; use ark_std::{ fmt::{Debug, Display}, hash::Hash, - ops::{Add, AddAssign, MulAssign, Neg, Sub, SubAssign}, + ops::{Add, AddAssign, Mul, MulAssign, Neg, Sub, SubAssign}, vec::Vec, }; use msm::VariableBaseMSM; @@ -149,13 +149,16 @@ pub trait ProjectiveCurve: + Neg + Add + Sub + + Mul<::ScalarField, Output = Self> + AddAssign + SubAssign + MulAssign<::ScalarField> + for<'a> Add<&'a Self, Output = Self> + for<'a> Sub<&'a Self, Output = Self> + + for<'a> Mul<&'a ::ScalarField, Output = Self> + for<'a> AddAssign<&'a Self> + for<'a> SubAssign<&'a Self> + + for<'a> MulAssign<&'a ::ScalarField> + core::iter::Sum + for<'a> core::iter::Sum<&'a Self> + From<::Affine> @@ -220,7 +223,7 @@ pub trait ProjectiveCurve: fn add_assign_mixed(&mut self, other: &Self::Affine); /// Performs scalar multiplication of this element. - fn mul>(self, other: S) -> Self; + fn mul_bigint>(self, other: S) -> Self; } /// Affine representation of an elliptic curve point guaranteed to be @@ -285,7 +288,7 @@ pub trait AffineCurve: /// Performs scalar multiplication of this element with mixed addition. #[must_use] - fn mul::BigInt>>(&self, by: S) -> Self::Projective; + fn mul_bigint>(&self, by: S) -> Self::Projective; /// Performs cofactor clearing. /// The default method is simply to multiply by the cofactor. @@ -308,7 +311,8 @@ pub trait AffineCurve: /// `Self::ScalarField`. #[must_use] fn mul_by_cofactor_inv(&self) -> Self { - self.mul(Self::Config::COFACTOR_INV).into() + self.mul_bigint(&Self::Config::COFACTOR_INV.into_bigint()) + .into() } } diff --git a/ec/src/models/short_weierstrass.rs b/ec/src/models/short_weierstrass.rs index ae5b90d9e..a7d7bca9d 100644 --- a/ec/src/models/short_weierstrass.rs +++ b/ec/src/models/short_weierstrass.rs @@ -3,17 +3,15 @@ use ark_serialize::{ CanonicalSerializeWithFlags, SWFlags, SerializationError, }; use ark_std::{ + borrow::Borrow, fmt::{Display, Formatter, Result as FmtResult}, hash::{Hash, Hasher}, io::{Read, Write}, - ops::{Add, AddAssign, MulAssign, Neg, Sub, SubAssign}, + ops::{Add, AddAssign, Mul, MulAssign, Neg, Sub, SubAssign}, vec::Vec, }; -use ark_ff::{ - fields::{Field, PrimeField}, - ToConstraintField, UniformRand, -}; +use ark_ff::{fields::Field, PrimeField, ToConstraintField, UniformRand}; use crate::{msm::VariableBaseMSM, AffineCurve, ProjectiveCurve}; @@ -326,8 +324,8 @@ impl AffineCurve for Affine

{ }) } - fn mul::BigInt>>(&self, by: S) -> Self::Projective { - P::mul_affine(self, by.into().as_ref()) + fn mul_bigint>(&self, by: S) -> Self::Projective { + P::mul_affine(self, by.as_ref()) } /// Multiplies this element by the cofactor and output the @@ -688,7 +686,7 @@ impl ProjectiveCurve for Projective

{ } #[inline] - fn mul>(self, other: S) -> Self { + fn mul_bigint>(self, other: S) -> Self { P::mul_projective(&self, other.as_ref()) } } @@ -796,9 +794,19 @@ impl<'a, P: SWCurveConfig> SubAssign<&'a Self> for Projective

{ } } -impl MulAssign for Projective

{ - fn mul_assign(&mut self, other: P::ScalarField) { - *self = self.mul(other.into_bigint()) +impl> MulAssign for Projective

{ + fn mul_assign(&mut self, other: T) { + *self = self.mul_bigint(other.borrow().into_bigint()) + } +} + +impl<'a, P: SWCurveConfig, T: Borrow> Mul for Projective

{ + type Output = Self; + + #[inline] + fn mul(mut self, other: T) -> Self { + self *= other; + self } } diff --git a/ec/src/models/twisted_edwards.rs b/ec/src/models/twisted_edwards.rs index 5ccd536fb..0a56037c3 100644 --- a/ec/src/models/twisted_edwards.rs +++ b/ec/src/models/twisted_edwards.rs @@ -4,10 +4,11 @@ use ark_serialize::{ CanonicalSerializeWithFlags, EdwardsFlags, SerializationError, }; use ark_std::{ + borrow::Borrow, fmt::{Display, Formatter, Result as FmtResult}, hash::{Hash, Hasher}, io::{Read, Write}, - ops::{Add, AddAssign, MulAssign, Neg, Sub, SubAssign}, + ops::{Add, AddAssign, Mul, MulAssign, Neg, Sub, SubAssign}, rand::{ distributions::{Distribution, Standard}, Rng, @@ -17,10 +18,7 @@ use ark_std::{ use num_traits::{One, Zero}; use zeroize::Zeroize; -use ark_ff::{ - fields::{Field, PrimeField}, - ToConstraintField, UniformRand, -}; +use ark_ff::{fields::Field, PrimeField, ToConstraintField, UniformRand}; #[cfg(feature = "parallel")] use rayon::prelude::*; @@ -234,8 +232,8 @@ impl AffineCurve for Affine

{ }) } - fn mul::BigInt>>(&self, by: S) -> Self::Projective { - P::mul_affine(self, by.into().as_ref()) + fn mul_bigint>(&self, by: S) -> Self::Projective { + P::mul_affine(self, by.as_ref()) } /// Multiplies this element by the cofactor and output the @@ -315,7 +313,7 @@ impl<'a, P: TECurveConfig> SubAssign<&'a Self> for Affine

{ impl MulAssign for Affine

{ fn mul_assign(&mut self, other: P::ScalarField) { - *self = self.mul(other.into_bigint()).into() + *self = self.mul_bigint(&other.into_bigint()).into() } } @@ -573,7 +571,7 @@ impl ProjectiveCurve for Projective

{ } #[inline] - fn mul>(self, other: S) -> Self { + fn mul_bigint>(self, other: S) -> Self { P::mul_projective(&self, other.as_ref()) } } @@ -655,9 +653,19 @@ impl<'a, P: TECurveConfig> SubAssign<&'a Self> for Projective

{ } } -impl MulAssign for Projective

{ - fn mul_assign(&mut self, other: P::ScalarField) { - *self = self.mul(other.into_bigint()) +impl> MulAssign for Projective

{ + fn mul_assign(&mut self, other: T) { + *self = self.mul_bigint(other.borrow().into_bigint()) + } +} + +impl> Mul for Projective

{ + type Output = Self; + + #[inline] + fn mul(mut self, other: T) -> Self { + self *= other; + self } } diff --git a/test-curves/src/bls12_381/g2.rs b/test-curves/src/bls12_381/g2.rs index 07e2f0ccb..af04a3fad 100644 --- a/test-curves/src/bls12_381/g2.rs +++ b/test-curves/src/bls12_381/g2.rs @@ -64,7 +64,7 @@ impl short_weierstrass::SWCurveConfig for Parameters { // Checks that [p]P = [X]P let mut x_times_point = - point.mul(BigInt::new([crate::bls12_381::Parameters::X[0], 0, 0, 0])); + point.mul_bigint(BigInt::new([crate::bls12_381::Parameters::X[0], 0, 0, 0])); if crate::bls12_381::Parameters::X_IS_NEGATIVE { x_times_point = -x_times_point; } @@ -99,7 +99,7 @@ impl short_weierstrass::SWCurveConfig for Parameters { // tmp2 = [x^2]P + [x]ψ(P) let mut tmp2: Projective = tmp; - tmp2 = tmp2.mul(x).neg(); + tmp2 = tmp2.mul_bigint(x).neg(); // add up all the terms psi2_p2 += tmp2; diff --git a/test-templates/src/curves.rs b/test-templates/src/curves.rs index 1a4ca5d37..07965bfa3 100644 --- a/test-templates/src/curves.rs +++ b/test-templates/src/curves.rs @@ -127,11 +127,14 @@ fn random_multiplication_test() { tmp2.add_assign(&b); // Affine multiplication - let mut tmp3 = a_affine.mul(s.into_bigint()); - tmp3.add_assign(&b_affine.mul(s.into_bigint())); - + let mut tmp3 = a_affine.mul_bigint(&s.into_bigint()); + tmp3.add_assign(&b_affine.mul_bigint(&s.into_bigint())); assert_eq!(tmp1, tmp2); assert_eq!(tmp1, tmp3); + + let expected = a_affine.mul_bigint(s.into_bigint()); + let got = a_affine.mul_bigint(&s.into_bigint()); + assert_eq!(expected, got); } } @@ -304,12 +307,12 @@ pub fn curve_tests() { assert_eq!(zero, zero); assert_eq!(zero.is_zero(), true); - assert_eq!(a.mul(&fr_one.into_bigint()), a); - assert_eq!(a.mul(&fr_two.into_bigint()), a + &a); - assert_eq!(a.mul(&fr_zero.into_bigint()), zero); - assert_eq!(a.mul(&fr_zero.into_bigint()) - &a, -a); - assert_eq!(a.mul(&fr_one.into_bigint()) - &a, zero); - assert_eq!(a.mul(&fr_two.into_bigint()) - &a, a); + assert_eq!(a.mul(&fr_one), a); + assert_eq!(a.mul(&fr_two), a + &a); + assert_eq!(a.mul(&fr_zero), zero); + assert_eq!(a.mul(&fr_zero) - &a, -a); + assert_eq!(a.mul(&fr_one) - &a, zero); + assert_eq!(a.mul(&fr_two) - &a, a); // a == a assert_eq!(a, a); @@ -341,31 +344,27 @@ pub fn curve_tests() { let fr_rand1 = G::ScalarField::rand(&mut rng); let fr_rand2 = G::ScalarField::rand(&mut rng); - let a_rand1 = a.mul(&fr_rand1.into_bigint()); - let a_rand2 = a.mul(&fr_rand2.into_bigint()); + let a_rand1 = a.mul(&fr_rand1); + let a_rand2 = a.mul(&fr_rand2); let fr_three = fr_two + &fr_rand1; - let a_two = a.mul(&fr_two.into_bigint()); + let a_two = a.mul(&fr_two); assert_eq!(a_two, a.double(), "(a * 2) != a.double()"); - let a_six = a.mul(&(fr_three * &fr_two).into_bigint()); - assert_eq!( - a_two.mul(&fr_three.into_bigint()), - a_six, - "(a * 2) * 3 != a * (2 * 3)" - ); + let a_six = a.mul(&(fr_three * &fr_two)); + assert_eq!(a_two.mul(&fr_three), a_six, "(a * 2) * 3 != a * (2 * 3)"); assert_eq!( - a_rand1.mul(&fr_rand2.into_bigint()), - a_rand2.mul(&fr_rand1.into_bigint()), + a_rand1.mul(&fr_rand2), + a_rand2.mul(&fr_rand1), "(a * r1) * r2 != (a * r2) * r1" ); assert_eq!( - a_rand2.mul(&fr_rand1.into_bigint()), - a.mul(&(fr_rand1 * &fr_rand2).into_bigint()), + a_rand2.mul(&fr_rand1), + a.mul(&(fr_rand1 * &fr_rand2)), "(a * r2) * r1 != a * (r1 * r2)" ); assert_eq!( - a_rand1.mul(&fr_rand2.into_bigint()), - a.mul(&(fr_rand1 * &fr_rand2).into_bigint()), + a_rand1.mul(&fr_rand2), + a.mul(&(fr_rand1 * &fr_rand2)), "(a * r1) * r2 != a * (r1 * r2)" ); } diff --git a/test-templates/src/msm.rs b/test-templates/src/msm.rs index fe75bab12..246b26334 100644 --- a/test-templates/src/msm.rs +++ b/test-templates/src/msm.rs @@ -8,7 +8,7 @@ fn naive_var_base_msm(bases: &[G], scalars: &[G::ScalarField]) - let mut acc = G::Projective::zero(); for (base, scalar) in bases.iter().zip(scalars.iter()) { - acc += &base.mul(scalar.into_bigint()); + acc += base.mul_bigint(&scalar.into_bigint()); } acc }