Skip to content

Commit

Permalink
Feat: Adds bit rotation to math (#227)
Browse files Browse the repository at this point in the history
<!--- Please provide a general summary of your changes in the title
above -->

## Pull Request type

<!-- Please try to limit your pull request to one type; submit multiple
pull requests if needed. -->

Please check the type of change your PR introduces:

- [ ] Bugfix
- [x] Feature
- [ ] Code style update (formatting, renaming)
- [ ] Refactoring (no functional changes, no API changes)
- [ ] Build-related changes
- [ ] Documentation content changes
- [ ] Other (please describe):

## What is the current behavior?

<!-- Please describe the current behavior that you are modifying, or
link to a relevant issue. -->

Issue Number: N/A

## What is the new behavior?

<!-- Please describe the behavior or changes that are being added by
this PR. -->

-
-
-

## Does this introduce a breaking change?

- [ ] Yes
- [ ] No

<!-- If this does introduce a breaking change, please describe the
impact and migration path for existing applications below. -->

## Other information

<!-- Any other information that is important to this PR, such as
screenshots of how the component looks before and after the change. -->
  • Loading branch information
sveamarcus authored Dec 12, 2023
1 parent 47dbd48 commit ef5e11d
Show file tree
Hide file tree
Showing 2 changed files with 161 additions and 1 deletion.
106 changes: 106 additions & 0 deletions src/math/src/lib.cairo
Original file line number Diff line number Diff line change
Expand Up @@ -126,3 +126,109 @@ impl U256BitShift of BitShift<u256> {
x / pow(2, n)
}
}

/// Rotate the bits of an unsigned integer of type T
trait BitRotate<T> {
/// Take the bits of an unsigned integer and rotate in the left direction
/// # Arguments
/// * `x` - rotate its bit representation in the leftward direction
/// * `n` - number of steps to rotate
/// # Returns
/// * `T` - the result of rotating the bits of number `x` left, `n` number of steps
fn rotate_left(x: T, n: T) -> T;
/// Take the bits of an unsigned integer and rotate in the right direction
/// # Arguments
/// * `x` - rotate its bit representation in the rightward direction
/// * `n` - number of steps to rotate
/// # Returns
/// * `T` - the result of rotating the bits of number `x` right, `n` number of steps
fn rotate_right(x: T, n: T) -> T;
}

impl U8BitRotate of BitRotate<u8> {
fn rotate_left(x: u8, n: u8) -> u8 {
let word = u8_wide_mul(x, pow(2, n));
let (quotient, remainder) = DivRem::div_rem(word, 0x100_u16.try_into().unwrap());
(quotient + remainder).try_into().unwrap()
}

fn rotate_right(x: u8, n: u8) -> u8 {
let step = pow(2, n);
let (quotient, remainder) = DivRem::div_rem(x, step.try_into().unwrap());
remainder * pow(2, 8 - n) + quotient
}
}

impl U16BitRotate of BitRotate<u16> {
fn rotate_left(x: u16, n: u16) -> u16 {
let word = u16_wide_mul(x, pow(2, n));
let (quotient, remainder) = DivRem::div_rem(word, 0x10000_u32.try_into().unwrap());
(quotient + remainder).try_into().unwrap()
}

fn rotate_right(x: u16, n: u16) -> u16 {
let step = pow(2, n);
let (quotient, remainder) = DivRem::div_rem(x, step.try_into().unwrap());
remainder * pow(2, 16 - n) + quotient
}
}

impl U32BitRotate of BitRotate<u32> {
fn rotate_left(x: u32, n: u32) -> u32 {
let word = u32_wide_mul(x, pow(2, n));
let (quotient, remainder) = DivRem::div_rem(word, 0x100000000_u64.try_into().unwrap());
(quotient + remainder).try_into().unwrap()
}

fn rotate_right(x: u32, n: u32) -> u32 {
let step = pow(2, n);
let (quotient, remainder) = DivRem::div_rem(x, step.try_into().unwrap());
remainder * pow(2, 32 - n) + quotient
}
}

impl U64BitRotate of BitRotate<u64> {
fn rotate_left(x: u64, n: u64) -> u64 {
let word = u64_wide_mul(x, pow(2, n));
let (quotient, remainder) = DivRem::div_rem(
word, 0x10000000000000000_u128.try_into().unwrap()
);
(quotient + remainder).try_into().unwrap()
}

fn rotate_right(x: u64, n: u64) -> u64 {
let step = pow(2, n);
let (quotient, remainder) = DivRem::div_rem(x, step.try_into().unwrap());
remainder * pow(2, 64 - n) + quotient
}
}

impl U128BitRotate of BitRotate<u128> {
fn rotate_left(x: u128, n: u128) -> u128 {
let (high, low) = u128_wide_mul(x, pow(2, n));
let word = u256 { low, high };
let (quotient, remainder) = DivRem::div_rem(
word, u256 { low: 0, high: 1 }.try_into().unwrap()
);
(quotient + remainder).try_into().unwrap()
}

fn rotate_right(x: u128, n: u128) -> u128 {
let step = pow(2, n);
let (quotient, remainder) = DivRem::div_rem(x, step.try_into().unwrap());
remainder * pow(2, 128 - n) + quotient
}
}

impl U256BitRotate of BitRotate<u256> {
fn rotate_left(x: u256, n: u256) -> u256 {
// Alternative solution since we cannot divide u512 yet
BitShift::shl(x, n) + BitShift::shr(x, 256 - n)
}

fn rotate_right(x: u256, n: u256) -> u256 {
let step = pow(2, n);
let (quotient, remainder) = DivRem::div_rem(x, step.try_into().unwrap());
remainder * pow(2, 256 - n) + quotient
}
}
56 changes: 55 additions & 1 deletion src/math/src/tests/math_test.cairo
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use alexandria_math::{pow, BitShift, count_digits_of_base};
use alexandria_math::{count_digits_of_base, pow, BitShift, BitRotate};
use integer::BoundedInt;

// Test power function
Expand Down Expand Up @@ -170,3 +170,57 @@ fn shl_should_not_overflow() {
assert(BitShift::shl(pow::<u128>(2, 127), 1) == 0, 'invalid result');
assert(BitShift::shl(pow::<u256>(2, 255), 1) == 0, 'invalid result');
}

#[test]
#[available_gas(3000000)]
fn test_rotl_min() {
assert(BitRotate::rotate_left(pow::<u8>(2, 7) + 1, 1) == 3, 'invalid result');
assert(BitRotate::rotate_left(pow::<u16>(2, 15) + 1, 1) == 3, 'invalid result');
assert(BitRotate::rotate_left(pow::<u32>(2, 31) + 1, 1) == 3, 'invalid result');
assert(BitRotate::rotate_left(pow::<u64>(2, 63) + 1, 1) == 3, 'invalid result');
assert(BitRotate::rotate_left(pow::<u128>(2, 127) + 1, 1) == 3, 'invalid result');
assert(BitRotate::rotate_left(pow::<u256>(2, 255) + 1, 1) == 3, 'invalid result');
}

#[test]
#[available_gas(3000000)]
fn test_rotl_max() {
assert(BitRotate::rotate_left(0b101, 7) == pow::<u8>(2, 7) + 0b10, 'invalid result');
assert(BitRotate::rotate_left(0b101, 15) == pow::<u16>(2, 15) + 0b10, 'invalid result');
assert(BitRotate::rotate_left(0b101, 31) == pow::<u32>(2, 31) + 0b10, 'invalid result');
assert(BitRotate::rotate_left(0b101, 63) == pow::<u64>(2, 63) + 0b10, 'invalid result');
assert(BitRotate::rotate_left(0b101, 127) == pow::<u128>(2, 127) + 0b10, 'invalid result');
assert(BitRotate::rotate_left(0b101, 255) == pow::<u256>(2, 255) + 0b10, 'invalid result');
}

#[test]
#[available_gas(4000000)]
fn test_rotr_min() {
assert(BitRotate::rotate_right(pow::<u8>(2, 7) + 1, 1) == 0b11 * pow(2, 6), 'invalid result');
assert(
BitRotate::rotate_right(pow::<u16>(2, 15) + 1, 1) == 0b11 * pow(2, 14), 'invalid result'
);
assert(
BitRotate::rotate_right(pow::<u32>(2, 31) + 1, 1) == 0b11 * pow(2, 30), 'invalid result'
);
assert(
BitRotate::rotate_right(pow::<u64>(2, 63) + 1, 1) == 0b11 * pow(2, 62), 'invalid result'
);
assert(
BitRotate::rotate_right(pow::<u128>(2, 127) + 1, 1) == 0b11 * pow(2, 126), 'invalid result'
);
assert(
BitRotate::rotate_right(pow::<u256>(2, 255) + 1, 1) == 0b11 * pow(2, 254), 'invalid result'
);
}

#[test]
#[available_gas(2000000)]
fn test_rotr_max() {
assert(BitRotate::rotate_right(0b101_u8, 7) == 0b1010, 'invalid result');
assert(BitRotate::rotate_right(0b101_u16, 15) == 0b1010, 'invalid result');
assert(BitRotate::rotate_right(0b101_u32, 31) == 0b1010, 'invalid result');
assert(BitRotate::rotate_right(0b101_u64, 63) == 0b1010, 'invalid result');
assert(BitRotate::rotate_right(0b101_u128, 127) == 0b1010, 'invalid result');
assert(BitRotate::rotate_right(0b101_u256, 255) == 0b1010, 'invalid result');
}

0 comments on commit ef5e11d

Please sign in to comment.