Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Bernstein yang modular multiplicative inverter #2

Merged
merged 11 commits into from
Aug 29, 2023
351 changes: 351 additions & 0 deletions src/bernsteinyang.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,351 @@
use core::cmp::PartialEq;
use std::ops::{Add, Mul, Sub, Neg};

/// Big signed (B * L)-bit integer type, whose variables store
/// numbers in the two's complement code as arrays of B-bit chunks.
/// The ordering of the chunks in these arrays is little-endian.
/// The arithmetic operations for this type are wrapping ones.
#[derive(Clone)]
struct ChunkInt<const B:usize, const L:usize>(pub [u64; L]);
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it's usually called unsaturated integers see:

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It seems that this term is rarely used. On the other hand, "saturated integer" is usually used to describe the type, for which the saturation arithmetic is implemented: https://en.wikipedia.org/wiki/Saturation_arithmetic A least, my first thought was about the saturation arithmetic, when I saw this term somewhere before. Thus, I will leave this comment "as is" to avoid the ambiguation.


impl<const B:usize, const L:usize> ChunkInt<B,L> {
/// Mask, in which the B lowest bits are 1 and only they
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

"and only they"

Seems like the sentence is cut.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No, it means "they and only they are 1".

pub const MASK: u64 = u64::MAX >> (64 - B);

/// Representation of 0
pub const ZERO: Self = Self([0; L]);

/// Representation of 1
pub const ONE: Self = {
let mut data = [0; L];
data[0] = 1; Self(data)
};

/// Representation of -1
pub const MINUS_ONE: Self = Self([Self::MASK; L]);

/// Returns the result of applying B-bit right
/// arithmetical shift to the current number
pub fn shift(&self) -> Self {
let mut data = [0; L];
for i in 1..L {
data[i - 1] = self.0[i];
}
if self.is_negative() {
data[L - 1] = Self::MASK;
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

does the mask in data[L - 2] need to be cancelled?

A comment is needed, at least to attract attention for a potential bug during review/audit and also test input design.

Edit: Ah I understand the representation now. All limbs are 2-complement and negated, not just the most significant word.

Copy link
Author

@AlekseiVambol AlekseiVambol Aug 23, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To dispel any doubts, I perform this:
image

}
Self(data)
}

/// Returns the lowest B bits of the current number
pub fn lowest(&self) -> u64 { self.0[0] }

/// Returns "true" iff the current number is negative
pub fn is_negative(&self) -> bool {
self.0[L - 1] > (Self::MASK >> 1)
}
}

impl <const B:usize, const L:usize> PartialEq for ChunkInt<B,L> {
fn eq(&self, other: &Self) -> bool { self.0 == other.0 }
fn ne(&self, other: &Self) -> bool { self.0 != other.0 }
}

impl<const B:usize, const L:usize> Add for &ChunkInt<B,L> {
type Output = ChunkInt<B,L>;
fn add(self, other: Self) -> Self::Output {
let (mut data, mut carry) = ([0; L], 0);
for i in 0..L {
let sum = self.0[i] + other.0[i] + carry;
data[i] = sum & ChunkInt::<B,L>::MASK;
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What if the sum is negative? You remove the negative tag there.

Copy link
Author

@AlekseiVambol AlekseiVambol Aug 23, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The value of "sum" is never negative, since actually I operate on 2 non-negative integers in [0 .. 2 ^ (B * L) - 1] splited into B-bit non-negative chunks, but any x in [2 ^ (B * L - 1) .. 2 ^ (B * L) - 1] is considered to be the representation of -|2 ^ (B * L) - x|. The Mul, Add and Sub algorithms for the two's complement code do not care about the sign, and this is the main reason for them being used by the majority of processors. Since ChunkInt has a fixed size, it may afford to use this convenient code.

carry = sum >> B;
}
Self::Output { 0: data }
}
}

impl<const B:usize, const L:usize> Add<&ChunkInt<B,L>> for ChunkInt<B,L> {
type Output = ChunkInt<B,L>;
fn add(self, other: &Self) -> Self::Output {
&self + other
}
}

impl<const B:usize, const L:usize> Add for ChunkInt<B,L> {
type Output = ChunkInt<B,L>;
fn add(self, other: Self) -> Self::Output {
&self + &other
}
}

impl<const B:usize, const L:usize> Sub for &ChunkInt<B,L> {
type Output = ChunkInt<B,L>;
fn sub(self, other: Self) -> Self::Output {
let (mut data, mut carry) = ([0; L], 1);
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This needs a comment to explain that

-x = flip the bits of x and add 1.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agree. This part of code has not been commented yet.

for i in 0..L {
let sum = self.0[i] + (other.0[i] ^ ChunkInt::<B,L>::MASK) + carry;
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The positive/negative interference is non-trivial here

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Possibly yes. It is my own approach, I do not know how it is done in other implementations of fixed-length signed big integers. Of course, I planned to comment this.

data[i] = sum & ChunkInt::<B,L>::MASK;
carry = sum >> B;
}
Self::Output { 0: data }
}
}

impl<const B:usize, const L:usize> Sub<&ChunkInt<B,L>> for ChunkInt<B,L> {
type Output = ChunkInt<B,L>;
fn sub(self, other: &Self) -> Self::Output {
&self - other
}
}

impl<const B:usize, const L:usize> Sub for ChunkInt<B,L> {
type Output = ChunkInt<B,L>;
fn sub(self, other: Self) -> Self::Output {
&self - &other
}
}

impl<const B:usize, const L:usize> Neg for &ChunkInt<B,L> {
type Output = ChunkInt<B,L>;
fn neg(self) -> Self::Output {
let (mut data, mut carry) = ([0; L], 1);
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This needs a comment to explain that

-x = flip the bits of x and add 1.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agree. This part of code has not been commented yet.

for i in 0..L {
let sum = (self.0[i] ^ ChunkInt::<B,L>::MASK) + carry;
data[i] = sum & ChunkInt::<B,L>::MASK;
carry = sum >> B;
}
Self::Output { 0: data }
}
}

impl<const B:usize, const L:usize> Neg for ChunkInt<B,L> {
type Output = ChunkInt<B,L>;
fn neg(self) -> Self::Output {
-&self
}
}

impl<const B:usize, const L:usize> Mul for &ChunkInt<B,L> {
type Output = ChunkInt<B,L>;
fn mul(self, other: Self) -> Self::Output {
let mut data = [0; L];
for i in 0..L {
let mut carry = 0;
for k in 0..(L - i) {
let sum = (data[i + k] as u128) + (carry as u128) +
(self.0[i] as u128) * (other.0[k] as u128);
data[i + k] = sum as u64 & ChunkInt::<B,L>::MASK;
carry = (sum >> B) as u64;
}
}
Self::Output { 0: data }
}
}

impl<const B:usize, const L:usize> Mul<&ChunkInt<B,L>> for ChunkInt<B,L> {
type Output = ChunkInt<B,L>;
fn mul(self, other: &Self) -> Self::Output {
&self * other
}
}

impl<const B:usize, const L:usize> Mul for ChunkInt<B,L> {
type Output = ChunkInt<B,L>;
fn mul(self, other: Self) -> Self::Output {
&self * &other
}
}

impl<const B:usize, const L:usize> Mul<i64> for &ChunkInt<B,L> {
type Output = ChunkInt<B,L>;
fn mul(self, other: i64) -> Self::Output {
let mut data = [0; L];
let (other, mut carry, mask) = if other < 0 {
(-other, -other as u64, ChunkInt::<B,L>::MASK)
} else { (other, 0, 0) };
for i in 0..L {
let sum = (carry as u128) + ((self.0[i] ^ mask) as u128) * (other as u128);
data[i] = sum as u64 & ChunkInt::<B,L>::MASK;
carry = (sum >> B) as u64;
}
Self::Output { 0: data }
}
}

impl<const B:usize, const L:usize> Mul<i64> for ChunkInt<B,L> {
type Output = ChunkInt<B,L>;
fn mul(self, other: i64) -> Self::Output {
&self * other
}
}

impl<const B:usize, const L:usize> Mul<&ChunkInt<B,L>> for i64 {
type Output = ChunkInt<B,L>;
fn mul(self, other: &ChunkInt<B,L>) -> Self::Output {
other * self
}
}

impl<const B:usize, const L:usize> Mul<ChunkInt<B,L>> for i64 {
type Output = ChunkInt<B,L>;
fn mul(self, other: ChunkInt<B,L>) -> Self::Output {
other * self
}
}

/// Type of the modular multiplicative inverter based on the Bernstein-Yang method.
/// The inverter can be created for a specified modulus M and adjusting parameter A
/// to compute the adjusted multiplicative inverses of positive integers, i.e. for
/// computing (1 / x) * A mod M for a positive integer x.
///
/// The adjusting parameter allows computing the multiplicative inverses in the case of
/// using the Montgomery representation for the input or the expected output. If R is
/// the Montgomery factor, the multiplicative inverses in the appropriate representation
/// can be computed provided that the value of A is chosen as follows:
/// - A = 1, if both the input and the expected output are in the trivial form
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

trivial => canonical

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agree. "Canonical" or "standard" suits better.

/// - A = R^2 mod M, if both the input and the expected output are in the Montgomery form
/// - A = R mod M, if either the input or the expected output is in the Montgomery form,
/// but not both of them
///
/// The public methods of this type receive and return unsigned big integers as arrays of
/// 64-bit chunks, the ordering of which is little-endian. Both the modulus and the integer
/// to be inverted should not exceed 2 ^ (B * (L - 1) - 2)
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

2 ^ (B * (L - 1) - 2) "bits".

Also mention that B should be 62 on 64-bit platforms.
And for BN254 and secp256k1, L should be 5

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it's important to put references here:

Copy link
Author

@AlekseiVambol AlekseiVambol Aug 23, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also mention that B should be 62 on 64-bit platforms.

Agree.

And for BN254 and secp256k1, L should be 5

No, details are provided here #2 (comment)

Update: 62 is hardcoded instead of B.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

2 ^ (B * (L - 1) - 2) "bits".

No, for bits it is going to be rather B * (L - 1) - 2, but this bound will be a bit lower.

pub struct BYInverter<const B:usize, const L:usize> {
/// Modulus
modulus: ChunkInt<B,L>,

/// Adjusting parameter
adjuster: ChunkInt<B,L>,

/// Multiplicative inverse of the modulus modulo 2^B
inverse: i64
}

/// Type of the Bernstein-Yang transition matrix multiplied by 2^B
type Matrix = [[i64; 2]; 2];

impl<const B:usize, const L:usize> BYInverter<B,L> {
fn step(f: &ChunkInt<B,L>, g: &ChunkInt<B,L>, mut delta: i64) -> (i64, Matrix) {
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In BY paper there are divsteps and jumpdivsteps.

The jumpdivstep is fusing many divsteps at once and summarizing the effect in a transition matrix.

Using just step will be confusing for reviewers/auditors comparing with the paper.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok, I will choose some other short name like "jump".

let (mut steps, mut f, mut g) = (B as i64, f.lowest() as i64, g.lowest() as i128);
let mut matrix: Matrix = [[1, 0], [0, 1]];

loop {
let zeros = steps.min(g.trailing_zeros() as i64);
(steps, delta, g) = (steps - zeros, delta + zeros, g >> zeros);
matrix[0] = [matrix[0][0] << zeros, matrix[0][1] << zeros];

if steps == 0 { break; }

if delta > 0 {
(delta, f, g) = (-delta, g as i64, -f as i128);
(matrix[0], matrix[1]) = (matrix[1], [-matrix[0][0], -matrix[0][1]]);
}

let mask = (1 << steps.min(1 - delta).min(4)) - 1;
let w = (g as i64).wrapping_mul(f.wrapping_mul(3) ^ 12) & mask;
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This needs a comment in the vein of

Find the multiple of f to add to cancel the bottom min(steps, 4) bits of g

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agree. This part of code has not been commented yet; "f.wrapping_mul(3) ^ 12" will also been explained a bit.


matrix[1] = [matrix[0][0] * w + matrix[1][0], matrix[0][1] * w + matrix[1][1]];
g += w as i128 * f as i128;
}

(delta, matrix)
}

fn fg(f: ChunkInt<B,L>, g: ChunkInt<B,L>, matrix: Matrix) -> (ChunkInt<B,L>, ChunkInt<B,L>) {
((matrix[0][0] * &f + matrix[0][1] * &g).shift(), (matrix[1][0] * &f + matrix[1][1] * &g).shift())
}

fn de(&self, d: ChunkInt<B,L>, e: ChunkInt<B,L>, matrix: Matrix) -> (ChunkInt<B,L>, ChunkInt<B,L>) {
let mask = ChunkInt::<B,L>::MASK as i64;
let mut md = matrix[0][0] * d.is_negative() as i64 + matrix[0][1] * e.is_negative() as i64;
let mut me = matrix[1][0] * d.is_negative() as i64 + matrix[1][1] * e.is_negative() as i64;

let cd = matrix[0][0].wrapping_mul(d.lowest() as i64).wrapping_add(matrix[0][1].wrapping_mul(e.lowest() as i64)) & mask;
let ce = matrix[1][0].wrapping_mul(d.lowest() as i64).wrapping_add(matrix[1][1].wrapping_mul(e.lowest() as i64)) & mask;

md -= (self.inverse.wrapping_mul(cd).wrapping_add(md)) & mask;
me -= (self.inverse.wrapping_mul(ce).wrapping_add(me)) & mask;

let cd = matrix[0][0] * &d + matrix[0][1] * &e + md * &self.modulus;
let ce = matrix[1][0] * &d + matrix[1][1] * &e + me * &self.modulus;

(cd.shift(), ce.shift())
}

fn norm(&self, mut value: ChunkInt<B,L>, negate: bool) -> ChunkInt<B,L> {
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should indicate the input range and output range

AFAIK:

Compute a = sign*a (mod M)

with a in range (-2*M, M)
result in range [0, M)

also why is the value mutated and returned as well?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agree. I was going to comment this part of the code.

also why is the value mutated and returned as well?

It is mutable in order to mutate it within the function without introducing a new temporary variable.
It is returned, because the function takes the ownership of the argument instead of borrowing it mutably.
It is returned instead of being changed using a mutable reference because I want all function to be pure: https://en.wikipedia.org/wiki/Pure_function

if value.is_negative() {
value = value + &self.modulus;
}

if negate {
value = -value;
}

if value.is_negative() {
value = value + &self.modulus;
}

value
}

/// Returns a big unsigned integer as an array of O-bit chunks, which is equal modulo
/// 2 ^ (O * S) to the input big unsigned integer stored as an array of I-bit chunks.
/// The ordering of the chunks in these arrays is little-endian
const fn convert<const I:usize, const O:usize, const S:usize>(input: &[u64]) -> [u64; S] {
const fn min(a: usize, b: usize) -> usize { if a > b { b } else { a } }
let (total, mut output, mut bits) = (min(input.len() * I, S * O), [0; S], 0);

while bits < total {
let (i, o) = (bits % I, bits % O);
output[bits / O] |= (input[bits / I] >> i) << o;
bits += min(I - i, O - o);
}

let mask = u64::MAX >> (64 - O);
let mut filled = total / O + if total % O > 0 { 1 } else { 0 };

while filled > 0 {
filled -= 1;
output[filled] &= mask;
}

output
}

/// Returns the multiplicative inverse of the argument modulo 2^B. The implementation is based
/// on the Hurchalla's method for computing the multiplicative inverse modulo a power of two
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Missing link to paper

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Usually I try to avoid overwhelming my code with direct links to papers easy to find, since I try to achieve better readability and aesthetics) However, I can add this one.

const fn inv(value: u64) -> i64 {
let x = value.wrapping_mul(3) ^ 2;
let y = 1u64.wrapping_sub(x.wrapping_mul(value));
let (x, y) = (x.wrapping_mul(y.wrapping_add(1)), y.wrapping_mul(y));
let (x, y) = (x.wrapping_mul(y.wrapping_add(1)), y.wrapping_mul(y));
let (x, y) = (x.wrapping_mul(y.wrapping_add(1)), y.wrapping_mul(y));
(x.wrapping_mul(y.wrapping_add(1)) & ChunkInt::<B,L>::MASK) as i64
}

/// Creates the inverter for specified modulus and adjusting parameter
pub const fn new(modulus: &[u64], adjuster: &[u64]) -> Self {
Self {
modulus: ChunkInt::<B, L>(Self::convert::<64, B, L>(modulus)),
adjuster: ChunkInt::<B, L>(Self::convert::<64, B, L>(adjuster)),
inverse: Self::inv(modulus[0])
}
}

/// Returns either the adjusted modular multiplicative inverse for the argument or None
/// depending on invertibility of the argument, i.e. its coprimality with the modulus
pub fn invert<const S:usize>(&self, value: &[u64]) -> Option<[u64; S]> {
let (mut d, mut e) = (ChunkInt::ZERO, self.adjuster.clone());
let mut g = ChunkInt::<B, L>(Self::convert::<64, B, L>(value));
let (mut delta, mut f) = (1, self.modulus.clone());
let mut matrix;
while g != ChunkInt::ZERO {
(delta, matrix) = Self::step(&f, &g, delta);
(f, g) = Self::fg(f, g, matrix);
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Comment: fg updates can take a parameter "limbsLeft" to avoid iterating on the full L all the time.

This might be left as a future optimization.

See: https://github.com/mratsim/constantine/blob/f57d071f1192a4039979a3baf6c835b89841bcfa/constantine/math/arithmetic/limbs_exgcd.nim#L836-L839

Copy link
Author

@AlekseiVambol AlekseiVambol Aug 23, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I thought about this during planing the implementation, but decided not to do this, because:

  • It will require some extra time to determine the number of unused limbs after each arithmetic operation;
  • Only the f,g-update can benefit from it;
  • For the two's complement code working with fixed-length register-like structure is much more convenient.

For much larger modulus this optimization makes sense, but I am not sure that it is needed here.

"Programmers waste enormous amounts of time thinking about, or worrying about, the speed of noncritical parts of their programs, and these attempts at efficiency actually have a strong negative impact when debugging and maintenance are considered. We should forget about small efficiencies, say about 97% of the time: premature optimization is the root of all evil. Yet we should not pass up our opportunities in that critical 3%." - Donald Knuth

(d, e) = self.de(d, e, matrix);
}
let antiunit = f == ChunkInt::MINUS_ONE;
if (f != ChunkInt::ONE) && !antiunit { return None; }
Some(Self::convert::<B, 64, S>(&self.norm(d, antiunit).0))
}
}
15 changes: 3 additions & 12 deletions src/bn256/fq.rs
Original file line number Diff line number Diff line change
Expand Up @@ -217,18 +217,9 @@ impl ff::Field for Fq {
ff::helpers::sqrt_ratio_generic(num, div)
}

/// Computes the multiplicative inverse of this element,
/// failing if the element is zero.
fn invert(&self) -> CtOption<Self> {
let tmp = self.pow([
0x3c208c16d87cfd45,
0x97816a916871ca8d,
0xb85045b68181585d,
0x30644e72e131a029,
]);

CtOption::new(tmp, !self.ct_eq(&Self::zero()))
}
/// Returns the multiplicative inverse of the
/// element. If it is zero, the method fails.
fn invert(&self) -> CtOption<Self> { self.invert() }
}

impl ff::PrimeField for Fq {
Expand Down
15 changes: 3 additions & 12 deletions src/bn256/fr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -231,18 +231,9 @@ impl ff::Field for Fr {
self.square()
}

/// Computes the multiplicative inverse of this element,
/// failing if the element is zero.
fn invert(&self) -> CtOption<Self> {
let tmp = self.pow([
0x43e1f593efffffff,
0x2833e84879b97091,
0xb85045b68181585d,
0x30644e72e131a029,
]);

CtOption::new(tmp, !self.ct_eq(&Self::zero()))
}
/// Returns the multiplicative inverse of the
/// element. If it is zero, the method fails.
fn invert(&self) -> CtOption<Self> { self.invert() }

fn sqrt(&self) -> CtOption<Self> {
/// `(t - 1) // 2` where t * 2^s + 1 = p with t odd.
Expand Down
Loading