diff --git a/src/math/src/mod_arithmetics.cairo b/src/math/src/mod_arithmetics.cairo index 8f8cd4fa..22ec60d5 100644 --- a/src/math/src/mod_arithmetics.cairo +++ b/src/math/src/mod_arithmetics.cairo @@ -7,6 +7,7 @@ use integer::u512; /// * `modulo` - modulo. /// # Returns /// * `u256` - result of modular addition +#[inline(always)] fn add_mod(a: u256, b: u256, modulo: u256) -> u256 { let mod_non_zero: NonZero = integer::u256_try_as_non_zero(modulo).unwrap(); let low: u256 = a.low.into() + b.low.into(); @@ -25,10 +26,12 @@ fn add_mod(a: u256, b: u256, modulo: u256) -> u256 { /// * `modulo` - modulo. /// # Returns /// * `u256` - modular multiplicative inverse +#[inline(always)] fn mult_inverse(b: u256, modulo: u256) -> u256 { - // From Fermat's little theorem, a ^ (p - 1) = 1 when p is prime and a != 0. Since a ^ (p - 1) = a ยท a ^ (p - 2) we have that - // a ^ (p - 2) is the multiplicative inverse of a modulo p. - pow_mod(b, modulo - 2, modulo) + match math::u256_guarantee_inv_mod_n(b, modulo.try_into().expect('inverse non zero')) { + Result::Ok((inv_a, _, _, _, _, _, _, _, _)) => inv_a.into(), + Result::Err(_) => 0 + } } /// Function that return the modular additive inverse. @@ -37,6 +40,7 @@ fn mult_inverse(b: u256, modulo: u256) -> u256 { /// * `modulo` - modulo. /// # Returns /// * `u256` - modular additive inverse +#[inline(always)] fn add_inverse_mod(b: u256, modulo: u256) -> u256 { modulo - b } @@ -48,6 +52,7 @@ fn add_inverse_mod(b: u256, modulo: u256) -> u256 { /// * `modulo` - modulo. /// # Returns /// * `u256` - result of modular substraction +#[inline(always)] fn sub_mod(mut a: u256, mut b: u256, modulo: u256) -> u256 { // reduce values a = a % modulo; @@ -65,6 +70,7 @@ fn sub_mod(mut a: u256, mut b: u256, modulo: u256) -> u256 { /// * `modulo` - modulo. /// # Returns /// * `u256` - result of modular multiplication +#[inline(always)] fn mult_mod(a: u256, b: u256, modulo: u256) -> u256 { let mult: u512 = integer::u256_wide_mul(a, b); let mod_non_zero: NonZero = integer::u256_try_as_non_zero(modulo).unwrap(); @@ -79,8 +85,11 @@ fn mult_mod(a: u256, b: u256, modulo: u256) -> u256 { /// * `modulo` - modulo. /// # Returns /// * `u256` - result of modular division +#[inline(always)] fn div_mod(a: u256, b: u256, modulo: u256) -> u256 { - mult_mod(a, mult_inverse(b, modulo), modulo) + let modulo_nz = modulo.try_into().expect('0 modulo'); + let inv = math::u256_inv_mod(b, modulo_nz).unwrap().into(); + math::u256_mul_mod_n(a, inv, modulo_nz) } /// Function that performs modular exponentiation.