diff --git a/crates/utils/src/math.cairo b/crates/utils/src/math.cairo index ec1344b2..fab2f7dd 100644 --- a/crates/utils/src/math.cairo +++ b/crates/utils/src/math.cairo @@ -2,6 +2,7 @@ use core::integer::{u512}; use core::num::traits::{Zero, One, BitSize, OverflowingAdd, OverflowingMul, Bounded}; use core::panic_with_felt252; use core::traits::{BitAnd}; +use utils::constants::POW_2_256; // === Exponentiation === @@ -245,6 +246,14 @@ impl BitshiftImpl< if shift > BitSize::::bits() - One::one() { panic_with_felt252('mul Overflow'); } + // if the shift is within the bit size of u256 (<= 255 bits), + // use the POW_2 lookup table to get 2^shift for efficient multiplication + if shift <= BitSize::::bits() - One::::one() { + // In case the pow2 is greater than the max value of T, we have an overflow + // so we can panic + return self * (*POW_2_256.span().at(shift)).try_into().expect('mul Overflow'); + } + // for shifts greater than 255 bits, perform the shift manually let two = One::one() + One::one(); self * two.pow(shift.try_into().expect('mul Overflow')) } @@ -254,6 +263,10 @@ impl BitshiftImpl< if shift > BitSize::::bits() - One::one() { panic_with_felt252('mul Overflow'); } + // use the POW_2 lookup table when the bit size + if shift <= BitSize::::bits() - One::::one() { + return self / (*POW_2_256.span().at(shift)).try_into().expect('mul Overflow'); + } let two = One::one() + One::one(); self / two.pow(shift.try_into().expect('mul Overflow')) } @@ -309,12 +322,23 @@ pub impl WrappingBitshiftImpl< +Into > of WrappingBitshift { fn wrapping_shl(self: T, shift: usize) -> T { + if shift <= BitSize::::bits() - One::::one() { + let pow_2: u256 = (*POW_2_256.span().at(shift)); + let pow2_mod_t: u256 = pow_2 % Bounded::::MAX.into(); + let (result, _) = self.overflowing_mul(pow2_mod_t.try_into().unwrap()); + return result; + } let two = One::::one() + One::::one(); let (result, _) = self.overflowing_mul(two.wrapping_pow(shift.try_into().unwrap())); result } fn wrapping_shr(self: T, shift: usize) -> T { + if shift <= BitSize::::bits() - One::::one() { + let pow_2: u256 = (*POW_2_256.span().at(shift)); + let pow2_mod_t: u256 = pow_2 % Bounded::::MAX.into(); + return self / pow2_mod_t.try_into().unwrap(); + } let two = One::::one() + One::::one(); if shift > BitSize::::bits() - One::one() {