Skip to content

Commit

Permalink
dev: optimized bitshifts by using a lookup table for powers of two
Browse files Browse the repository at this point in the history
  • Loading branch information
enitrat committed Oct 2, 2024
1 parent 7048c0d commit 96dea68
Showing 1 changed file with 24 additions and 0 deletions.
24 changes: 24 additions & 0 deletions crates/utils/src/math.cairo
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ use core::integer::{u512};
use core::num::traits::{Zero, One, BitSize, OverflowingAdd, OverflowingMul, Bounded};
use core::panic_with_felt252;
use core::traits::{BitAnd};
use utils::constants::POW_2_256;

// === Exponentiation ===

Expand Down Expand Up @@ -245,6 +246,14 @@ impl BitshiftImpl<
if shift > BitSize::<T>::bits() - One::one() {
panic_with_felt252('mul Overflow');
}
// if the shift is within the bit size of u256 (<= 255 bits),
// use the POW_2 lookup table to get 2^shift for efficient multiplication
if shift <= BitSize::<u256>::bits() - One::<u32>::one() {
// In case the pow2 is greater than the max value of T, we have an overflow
// so we can panic
return self * (*POW_2_256.span().at(shift)).try_into().expect('mul Overflow');
}
// for shifts greater than 255 bits, perform the shift manually
let two = One::one() + One::one();
self * two.pow(shift.try_into().expect('mul Overflow'))
}
Expand All @@ -254,6 +263,10 @@ impl BitshiftImpl<
if shift > BitSize::<T>::bits() - One::one() {
panic_with_felt252('mul Overflow');
}
// use the POW_2 lookup table when the bit size
if shift <= BitSize::<u256>::bits() - One::<u32>::one() {
return self / (*POW_2_256.span().at(shift)).try_into().expect('mul Overflow');
}
let two = One::one() + One::one();
self / two.pow(shift.try_into().expect('mul Overflow'))
}
Expand Down Expand Up @@ -309,12 +322,23 @@ pub impl WrappingBitshiftImpl<
+Into<T, u256>
> of WrappingBitshift<T> {
fn wrapping_shl(self: T, shift: usize) -> T {
if shift <= BitSize::<u256>::bits() - One::<u32>::one() {
let pow_2: u256 = (*POW_2_256.span().at(shift));
let pow2_mod_t: u256 = pow_2 % Bounded::<T>::MAX.into();
let (result, _) = self.overflowing_mul(pow2_mod_t.try_into().unwrap());
return result;
}
let two = One::<T>::one() + One::<T>::one();
let (result, _) = self.overflowing_mul(two.wrapping_pow(shift.try_into().unwrap()));
result
}

fn wrapping_shr(self: T, shift: usize) -> T {
if shift <= BitSize::<u256>::bits() - One::<u32>::one() {
let pow_2: u256 = (*POW_2_256.span().at(shift));
let pow2_mod_t: u256 = pow_2 % Bounded::<T>::MAX.into();
return self / pow2_mod_t.try_into().unwrap();
}
let two = One::<T>::one() + One::<T>::one();

if shift > BitSize::<T>::bits() - One::one() {
Expand Down

0 comments on commit 96dea68

Please sign in to comment.