From a83509660676f252d8534684fcc718603e72b1ee Mon Sep 17 00:00:00 2001 From: Mamy Ratsimbazafy Date: Thu, 19 Oct 2023 03:26:57 +0200 Subject: [PATCH] Fast modular inverse - 9.4x acceleration (#83) * Bernstein yang modular multiplicative inverter (#2) * rename similar to https://github.com/privacy-scaling-explorations/halo2curves/pull/95 --------- Co-authored-by: Aleksei Vambol <77882392+AlekseiVambol@users.noreply.github.com> --- benches/bn256_field.rs | 3 + src/bn256/fq.rs | 13 +- src/bn256/fr.rs | 13 +- src/derive/field.rs | 16 ++ src/ff_inverse.rs | 424 +++++++++++++++++++++++++++++++++++++++++ src/lib.rs | 1 + src/secp256k1/fp.rs | 13 +- src/secp256k1/fq.rs | 13 +- src/secp256r1/fp.rs | 13 +- src/secp256r1/fq.rs | 13 +- 10 files changed, 462 insertions(+), 60 deletions(-) create mode 100644 src/ff_inverse.rs diff --git a/benches/bn256_field.rs b/benches/bn256_field.rs index 8a17aef3..814d13b4 100644 --- a/benches/bn256_field.rs +++ b/benches/bn256_field.rs @@ -40,6 +40,9 @@ pub fn bench_bn256_field(c: &mut Criterion) { group.bench_function("bn256_fq_square", |bencher| { bencher.iter(|| black_box(&a).square()) }); + group.bench_function("bn256_fq_invert", |bencher| { + bencher.iter(|| black_box(&a).invert()) + }); } criterion_group!(benches, bench_bn256_field); diff --git a/src/bn256/fq.rs b/src/bn256/fq.rs index fec8d863..56be690f 100644 --- a/src/bn256/fq.rs +++ b/src/bn256/fq.rs @@ -198,17 +198,10 @@ impl ff::Field for Fq { ff::helpers::sqrt_ratio_generic(num, div) } - /// Computes the multiplicative inverse of this element, - /// failing if the element is zero. + /// Returns the multiplicative inverse of the + /// element. If it is zero, the method fails. fn invert(&self) -> CtOption { - let tmp = self.pow([ - 0x3c208c16d87cfd45, - 0x97816a916871ca8d, - 0xb85045b68181585d, - 0x30644e72e131a029, - ]); - - CtOption::new(tmp, !self.ct_eq(&Self::zero())) + self.invert() } } diff --git a/src/bn256/fr.rs b/src/bn256/fr.rs index 7e3b5ae8..9e02fcdc 100644 --- a/src/bn256/fr.rs +++ b/src/bn256/fr.rs @@ -231,17 +231,10 @@ impl ff::Field for Fr { self.square() } - /// Computes the multiplicative inverse of this element, - /// failing if the element is zero. + /// Returns the multiplicative inverse of the + /// element. If it is zero, the method fails. fn invert(&self) -> CtOption { - let tmp = self.pow([ - 0x43e1f593efffffff, - 0x2833e84879b97091, - 0xb85045b68181585d, - 0x30644e72e131a029, - ]); - - CtOption::new(tmp, !self.ct_eq(&Self::zero())) + self.invert() } fn sqrt(&self) -> CtOption { diff --git a/src/derive/field.rs b/src/derive/field.rs index da962457..a9f57c4a 100644 --- a/src/derive/field.rs +++ b/src/derive/field.rs @@ -24,6 +24,11 @@ macro_rules! field_common { $r2:ident, $r3:ident ) => { + /// Bernstein-Yang modular multiplicative inverter created for the modulus equal to + /// the characteristic of the field to invert positive integers in the Montgomery form. + const BYINVERTOR: $crate::ff_inverse::BYInverter<6> = + $crate::ff_inverse::BYInverter::<6>::new(&$modulus.0, &$r2.0); + impl $field { /// Returns zero, the additive identity. #[inline] @@ -37,6 +42,16 @@ macro_rules! field_common { $r } + /// Returns the multiplicative inverse of the + /// element. If it is zero, the method fails. + pub fn invert(&self) -> CtOption { + if let Some(inverse) = BYINVERTOR.invert(&self.0) { + CtOption::new(Self(inverse), Choice::from(1)) + } else { + CtOption::new(Self::zero(), Choice::from(0)) + } + } + fn from_u512(limbs: [u64; 8]) -> $field { // We reduce an arbitrary 512-bit number by decomposing it into two 256-bit digits // with the higher bits multiplied by 2^256. Thus, we perform two reductions @@ -345,6 +360,7 @@ macro_rules! field_common { macro_rules! field_arithmetic { ($field:ident, $modulus:ident, $inv:ident, $field_type:ident) => { field_specific!($field, $modulus, $inv, $field_type); + impl $field { /// Doubles this field element. #[inline] diff --git a/src/ff_inverse.rs b/src/ff_inverse.rs new file mode 100644 index 00000000..53285e6c --- /dev/null +++ b/src/ff_inverse.rs @@ -0,0 +1,424 @@ +use core::cmp::PartialEq; +use std::ops::{Add, Mul, Neg, Sub}; + +/// Big signed (B * L)-bit integer type, whose variables store +/// numbers in the two's complement code as arrays of B-bit chunks. +/// The ordering of the chunks in these arrays is little-endian. +/// The arithmetic operations for this type are wrapping ones. +#[derive(Clone)] +struct CInt(pub [u64; L]); + +impl CInt { + /// Mask, in which the B lowest bits are 1 and only they + pub const MASK: u64 = u64::MAX >> (64 - B); + + /// Representation of -1 + pub const MINUS_ONE: Self = Self([Self::MASK; L]); + + /// Representation of 0 + pub const ZERO: Self = Self([0; L]); + + /// Representation of 1 + pub const ONE: Self = { + let mut data = [0; L]; + data[0] = 1; + Self(data) + }; + + /// Returns the result of applying B-bit right + /// arithmetical shift to the current number + pub fn shift(&self) -> Self { + let mut data = [0; L]; + if self.is_negative() { + data[L - 1] = Self::MASK; + } + data[..L - 1].copy_from_slice(&self.0[1..]); + Self(data) + } + + /// Returns the lowest B bits of the current number + pub fn lowest(&self) -> u64 { + self.0[0] + } + + /// Returns "true" iff the current number is negative + pub fn is_negative(&self) -> bool { + self.0[L - 1] > (Self::MASK >> 1) + } +} + +impl PartialEq for CInt { + fn eq(&self, other: &Self) -> bool { + self.0 == other.0 + } +} + +impl Add for &CInt { + type Output = CInt; + fn add(self, other: Self) -> Self::Output { + let (mut data, mut carry) = ([0; L], 0); + for i in 0..L { + let sum = self.0[i] + other.0[i] + carry; + data[i] = sum & CInt::::MASK; + carry = sum >> B; + } + Self::Output { 0: data } + } +} + +impl Add<&CInt> for CInt { + type Output = CInt; + fn add(self, other: &Self) -> Self::Output { + &self + other + } +} + +impl Add for CInt { + type Output = CInt; + fn add(self, other: Self) -> Self::Output { + &self + &other + } +} + +impl Sub for &CInt { + type Output = CInt; + fn sub(self, other: Self) -> Self::Output { + // For the two's complement code the additive negation is the result of + // adding 1 to the bitwise inverted argument's representation. Thus, for + // any encoded integers x and y we have x - y = x + !y + 1, where "!" is + // the bitwise inversion and addition is done according to the rules of + // the code. The algorithm below uses this formula and is the modified + // addition algorithm, where the carry flag is initialized with 1 and + // the chunks of the second argument are bitwise inverted + let (mut data, mut carry) = ([0; L], 1); + for i in 0..L { + let sum = self.0[i] + (other.0[i] ^ CInt::::MASK) + carry; + data[i] = sum & CInt::::MASK; + carry = sum >> B; + } + Self::Output { 0: data } + } +} + +impl Sub<&CInt> for CInt { + type Output = CInt; + fn sub(self, other: &Self) -> Self::Output { + &self - other + } +} + +impl Sub for CInt { + type Output = CInt; + fn sub(self, other: Self) -> Self::Output { + &self - &other + } +} + +impl Neg for &CInt { + type Output = CInt; + fn neg(self) -> Self::Output { + // For the two's complement code the additive negation is the result + // of adding 1 to the bitwise inverted argument's representation + let (mut data, mut carry) = ([0; L], 1); + for i in 0..L { + let sum = (self.0[i] ^ CInt::::MASK) + carry; + data[i] = sum & CInt::::MASK; + carry = sum >> B; + } + Self::Output { 0: data } + } +} + +impl Neg for CInt { + type Output = CInt; + fn neg(self) -> Self::Output { + -&self + } +} + +impl Mul for &CInt { + type Output = CInt; + fn mul(self, other: Self) -> Self::Output { + let mut data = [0; L]; + for i in 0..L { + let mut carry = 0; + for k in 0..(L - i) { + let sum = (data[i + k] as u128) + + (carry as u128) + + (self.0[i] as u128) * (other.0[k] as u128); + data[i + k] = sum as u64 & CInt::::MASK; + carry = (sum >> B) as u64; + } + } + Self::Output { 0: data } + } +} + +impl Mul<&CInt> for CInt { + type Output = CInt; + fn mul(self, other: &Self) -> Self::Output { + &self * other + } +} + +impl Mul for CInt { + type Output = CInt; + fn mul(self, other: Self) -> Self::Output { + &self * &other + } +} + +impl Mul for &CInt { + type Output = CInt; + fn mul(self, other: i64) -> Self::Output { + let mut data = [0; L]; + // If the short multiplicand is non-negative, the standard multiplication + // algorithm is performed. Otherwise, the product of the additively negated + // multiplicands is found as follows. Since for the two's complement code + // the additive negation is the result of adding 1 to the bitwise inverted + // argument's representation, for any encoded integers x and y we have + // x * y = (-x) * (-y) = (!x + 1) * (-y) = !x * (-y) + (-y), where "!" is + // the bitwise inversion and arithmetic operations are performed according + // to the rules of the code. If the short multiplicand is negative, the + // algorithm below uses this formula by substituting the short multiplicand + // for y and turns into the modified standard multiplication algorithm, + // where the carry flag is initialized with the additively negated short + // multiplicand and the chunks of the long multiplicand are bitwise inverted + let (other, mut carry, mask) = if other < 0 { + (-other, -other as u64, CInt::::MASK) + } else { + (other, 0, 0) + }; + for i in 0..L { + let sum = (carry as u128) + ((self.0[i] ^ mask) as u128) * (other as u128); + data[i] = sum as u64 & CInt::::MASK; + carry = (sum >> B) as u64; + } + Self::Output { 0: data } + } +} + +impl Mul for CInt { + type Output = CInt; + fn mul(self, other: i64) -> Self::Output { + &self * other + } +} + +impl Mul<&CInt> for i64 { + type Output = CInt; + fn mul(self, other: &CInt) -> Self::Output { + other * self + } +} + +impl Mul> for i64 { + type Output = CInt; + fn mul(self, other: CInt) -> Self::Output { + other * self + } +} + +/// Type of the modular multiplicative inverter based on the Bernstein-Yang method. +/// The inverter can be created for a specified modulus M and adjusting parameter A +/// to compute the adjusted multiplicative inverses of positive integers, i.e. for +/// computing (1 / x) * A (mod M) for a positive integer x. +/// +/// The adjusting parameter allows computing the multiplicative inverses in the case of +/// using the Montgomery representation for the input or the expected output. If R is +/// the Montgomery factor, the multiplicative inverses in the appropriate representation +/// can be computed provided that the value of A is chosen as follows: +/// - A = 1, if both the input and the expected output are in the standard form +/// - A = R^2 mod M, if both the input and the expected output are in the Montgomery form +/// - A = R mod M, if either the input or the expected output is in the Montgomery form, +/// but not both of them +/// +/// The public methods of this type receive and return unsigned big integers as arrays of +/// 64-bit chunks, the ordering of which is little-endian. Both the modulus and the integer +/// to be inverted should not exceed 2 ^ (62 * L - 64) +/// +/// For better understanding the implementation, the following resources are recommended: +/// - D. Bernstein, B.-Y. Yang, "Fast constant-time gcd computation and modular inversion", +/// https://gcd.cr.yp.to/safegcd-20190413.pdf +/// - P. Wuille, "The safegcd implementation in libsecp256k1 explained", +/// https://github.com/bitcoin-core/secp256k1/blob/master/doc/safegcd_implementation.md +pub struct BYInverter { + /// Modulus + modulus: CInt<62, L>, + + /// Adjusting parameter + adjuster: CInt<62, L>, + + /// Multiplicative inverse of the modulus modulo 2^62 + inverse: i64, +} + +/// Type of the Bernstein-Yang transition matrix multiplied by 2^62 +type Matrix = [[i64; 2]; 2]; + +impl BYInverter { + /// Returns the Bernstein-Yang transition matrix multiplied by 2^62 and the new value + /// of the delta variable for the 62 basic steps of the Bernstein-Yang method, which + /// are to be performed sequentially for specified initial values of f, g and delta + fn jump(f: &CInt<62, L>, g: &CInt<62, L>, mut delta: i64) -> (i64, Matrix) { + let (mut steps, mut f, mut g) = (62, f.lowest() as i64, g.lowest() as i128); + let mut t: Matrix = [[1, 0], [0, 1]]; + + loop { + let zeros = steps.min(g.trailing_zeros() as i64); + (steps, delta, g) = (steps - zeros, delta + zeros, g >> zeros); + t[0] = [t[0][0] << zeros, t[0][1] << zeros]; + + if steps == 0 { + break; + } + if delta > 0 { + (delta, f, g) = (-delta, g as i64, -f as i128); + (t[0], t[1]) = (t[1], [-t[0][0], -t[0][1]]); + } + + // The formula (3 * x) xor 28 = -1 / x (mod 32) for an odd integer x + // in the two's complement code has been derived from the formula + // (3 * x) xor 2 = 1 / x (mod 32) attributed to Peter Montgomery + let mask = (1 << steps.min(1 - delta).min(5)) - 1; + let w = (g as i64).wrapping_mul(f.wrapping_mul(3) ^ 28) & mask; + + t[1] = [t[0][0] * w + t[1][0], t[0][1] * w + t[1][1]]; + g += w as i128 * f as i128; + } + + (delta, t) + } + + /// Returns the updated values of the variables f and g for specified initial ones and Bernstein-Yang transition + /// matrix multiplied by 2^62. The returned vector is "matrix * (f, g)' / 2^62", where "'" is the transpose operator + fn fg(f: CInt<62, L>, g: CInt<62, L>, t: Matrix) -> (CInt<62, L>, CInt<62, L>) { + ( + (t[0][0] * &f + t[0][1] * &g).shift(), + (t[1][0] * &f + t[1][1] * &g).shift(), + ) + } + + /// Returns the updated values of the variables d and e for specified initial ones and Bernstein-Yang transition + /// matrix multiplied by 2^62. The returned vector is congruent modulo M to "matrix * (d, e)' / 2^62 (mod M)", + /// where M is the modulus the inverter was created for and "'" stands for the transpose operator. Both the input + /// and output values lie in the interval (-2 * M, M) + fn de(&self, d: CInt<62, L>, e: CInt<62, L>, t: Matrix) -> (CInt<62, L>, CInt<62, L>) { + let mask = CInt::<62, L>::MASK as i64; + let mut md = t[0][0] * d.is_negative() as i64 + t[0][1] * e.is_negative() as i64; + let mut me = t[1][0] * d.is_negative() as i64 + t[1][1] * e.is_negative() as i64; + + let cd = t[0][0] + .wrapping_mul(d.lowest() as i64) + .wrapping_add(t[0][1].wrapping_mul(e.lowest() as i64)) + & mask; + let ce = t[1][0] + .wrapping_mul(d.lowest() as i64) + .wrapping_add(t[1][1].wrapping_mul(e.lowest() as i64)) + & mask; + + md -= (self.inverse.wrapping_mul(cd).wrapping_add(md)) & mask; + me -= (self.inverse.wrapping_mul(ce).wrapping_add(me)) & mask; + + let cd = t[0][0] * &d + t[0][1] * &e + md * &self.modulus; + let ce = t[1][0] * &d + t[1][1] * &e + me * &self.modulus; + + (cd.shift(), ce.shift()) + } + + /// Returns either "value (mod M)" or "-value (mod M)", where M is the modulus the + /// inverter was created for, depending on "negate", which determines the presence + /// of "-" in the used formula. The input integer lies in the interval (-2 * M, M) + fn norm(&self, mut value: CInt<62, L>, negate: bool) -> CInt<62, L> { + if value.is_negative() { + value = value + &self.modulus; + } + + if negate { + value = -value; + } + + if value.is_negative() { + value = value + &self.modulus; + } + + value + } + + /// Returns a big unsigned integer as an array of O-bit chunks, which is equal modulo + /// 2 ^ (O * S) to the input big unsigned integer stored as an array of I-bit chunks. + /// The ordering of the chunks in these arrays is little-endian + const fn convert(input: &[u64]) -> [u64; S] { + // This function is defined because the method "min" of the usize type is not constant + const fn min(a: usize, b: usize) -> usize { + if a > b { + b + } else { + a + } + } + + let (total, mut output, mut bits) = (min(input.len() * I, S * O), [0; S], 0); + + while bits < total { + let (i, o) = (bits % I, bits % O); + output[bits / O] |= (input[bits / I] >> i) << o; + bits += min(I - i, O - o); + } + + let mask = u64::MAX >> (64 - O); + let mut filled = total / O + if total % O > 0 { 1 } else { 0 }; + + while filled > 0 { + filled -= 1; + output[filled] &= mask; + } + + output + } + + /// Returns the multiplicative inverse of the argument modulo 2^62. The implementation is based + /// on the Hurchalla's method for computing the multiplicative inverse modulo a power of two. + /// For better understanding the implementation, the following paper is recommended: + /// J. Hurchalla, "An Improved Integer Multiplicative Inverse (modulo 2^w)", + /// https://arxiv.org/pdf/2204.04342.pdf + const fn inv(value: u64) -> i64 { + let x = value.wrapping_mul(3) ^ 2; + let y = 1u64.wrapping_sub(x.wrapping_mul(value)); + let (x, y) = (x.wrapping_mul(y.wrapping_add(1)), y.wrapping_mul(y)); + let (x, y) = (x.wrapping_mul(y.wrapping_add(1)), y.wrapping_mul(y)); + let (x, y) = (x.wrapping_mul(y.wrapping_add(1)), y.wrapping_mul(y)); + (x.wrapping_mul(y.wrapping_add(1)) & CInt::<62, L>::MASK) as i64 + } + + /// Creates the inverter for specified modulus and adjusting parameter + pub const fn new(modulus: &[u64], adjuster: &[u64]) -> Self { + Self { + modulus: CInt::<62, L>(Self::convert::<64, 62, L>(modulus)), + adjuster: CInt::<62, L>(Self::convert::<64, 62, L>(adjuster)), + inverse: Self::inv(modulus[0]), + } + } + + /// Returns either the adjusted modular multiplicative inverse for the argument or None + /// depending on invertibility of the argument, i.e. its coprimality with the modulus + pub fn invert(&self, value: &[u64]) -> Option<[u64; S]> { + let (mut d, mut e) = (CInt::ZERO, self.adjuster.clone()); + let mut g = CInt::<62, L>(Self::convert::<64, 62, L>(value)); + let (mut delta, mut f) = (1, self.modulus.clone()); + let mut matrix; + while g != CInt::ZERO { + (delta, matrix) = Self::jump(&f, &g, delta); + (f, g) = Self::fg(f, g, matrix); + (d, e) = self.de(d, e, matrix); + } + // At this point the absolute value of "f" equals the greatest common divisor + // of the integer to be inverted and the modulus the inverter was created for. + // Thus, if "f" is neither 1 nor -1, then the sought inverse does not exist + let antiunit = f == CInt::MINUS_ONE; + if (f != CInt::ONE) && !antiunit { + return None; + } + Some(Self::convert::<62, 64, S>(&self.norm(d, antiunit).0)) + } +} diff --git a/src/lib.rs b/src/lib.rs index 5bd7d506..afc95fdb 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,4 +1,5 @@ mod arithmetic; +mod ff_inverse; pub mod fft; pub mod hash_to_curve; pub mod msm; diff --git a/src/secp256k1/fp.rs b/src/secp256k1/fp.rs index c346dc6c..db496559 100644 --- a/src/secp256k1/fp.rs +++ b/src/secp256k1/fp.rs @@ -178,17 +178,10 @@ impl ff::Field for Fp { CtOption::new(tmp, tmp.square().ct_eq(self)) } - /// Computes the multiplicative inverse of this element, - /// failing if the element is zero. + /// Returns the multiplicative inverse of the + /// element. If it is zero, the method fails. fn invert(&self) -> CtOption { - let tmp = self.pow_vartime([ - 0xfffffffefffffc2d, - 0xffffffffffffffff, - 0xffffffffffffffff, - 0xffffffffffffffff, - ]); - - CtOption::new(tmp, !self.ct_eq(&Self::zero())) + self.invert() } fn pow_vartime>(&self, exp: S) -> Self { diff --git a/src/secp256k1/fq.rs b/src/secp256k1/fq.rs index 189daaba..c6cb8e06 100644 --- a/src/secp256k1/fq.rs +++ b/src/secp256k1/fq.rs @@ -178,17 +178,10 @@ impl ff::Field for Fq { self.square() } - /// Computes the multiplicative inverse of this element, - /// failing if the element is zero. + /// Returns the multiplicative inverse of the + /// element. If it is zero, the method fails. fn invert(&self) -> CtOption { - let tmp = self.pow_vartime([ - 0xbfd25e8cd036413f, - 0xbaaedce6af48a03b, - 0xfffffffffffffffe, - 0xffffffffffffffff, - ]); - - CtOption::new(tmp, !self.ct_eq(&Self::zero())) + self.invert() } fn pow_vartime>(&self, exp: S) -> Self { diff --git a/src/secp256r1/fp.rs b/src/secp256r1/fp.rs index bf86e157..331545c3 100644 --- a/src/secp256r1/fp.rs +++ b/src/secp256r1/fp.rs @@ -196,17 +196,10 @@ impl ff::Field for Fp { CtOption::new(tmp, tmp.square().ct_eq(self)) } - /// Computes the multiplicative inverse of this element, - /// failing if the element is zero. + /// Returns the multiplicative inverse of the + /// element. If it is zero, the method fails. fn invert(&self) -> CtOption { - let tmp = self.pow_vartime([ - 0xfffffffffffffffd, - 0x00000000ffffffff, - 0x0000000000000000, - 0xffffffff00000001, - ]); - - CtOption::new(tmp, !self.ct_eq(&Self::zero())) + self.invert() } fn pow_vartime>(&self, exp: S) -> Self { diff --git a/src/secp256r1/fq.rs b/src/secp256r1/fq.rs index d1a7b809..077ec331 100644 --- a/src/secp256r1/fq.rs +++ b/src/secp256r1/fq.rs @@ -172,17 +172,10 @@ impl ff::Field for Fq { self.square() } - /// Computes the multiplicative inverse of this element, - /// failing if the element is zero. + /// Returns the multiplicative inverse of the + /// element. If it is zero, the method fails. fn invert(&self) -> CtOption { - let tmp = self.pow_vartime([ - 0xf3b9cac2fc63254f, - 0xbce6faada7179e84, - 0xffffffffffffffff, - 0xffffffff00000000, - ]); - - CtOption::new(tmp, !self.ct_eq(&Self::zero())) + self.invert() } fn pow_vartime>(&self, exp: S) -> Self {