Skip to content

Commit

Permalink
feat: Add WrappingMath trait (#241)
Browse files Browse the repository at this point in the history
## Add WrappingMath trait

Please check the type of change your PR introduces:

- [ ] Bugfix
- [x] Feature
- [ ] Code style update (formatting, renaming)
- [ ] Refactoring (no functional changes, no API changes)
- [ ] Build-related changes
- [ ] Documentation content changes
- [ ] Other (please describe):

## What is the current behavior?

Individual functions for wrapping ops in uints.

Issue Number: #214

## What is the new behavior?

Adds a `WrappingMath` trait on math and implements it for uints. It
provides `wrapping_add`, `wrapping_sub` and `wrapping_mul`, removing the
need to use type-specific wrapping operations like `u64_wrapping_add`.

I added it directly to `math`, since I wasn't totally sure where would
be the best module to keep this. Feel free to move / ask me to move it
somewhere else.

(edit: Also adds individual traits for each operation: `WrappingAdd`,
`WrappingSub` and `WrappingMul`)

## Does this introduce a breaking change?

- [ ] Yes
- [x] No
  • Loading branch information
Hyodar authored Jan 15, 2024
1 parent 4941698 commit 98aadc7
Show file tree
Hide file tree
Showing 2 changed files with 349 additions and 2 deletions.
169 changes: 168 additions & 1 deletion src/math/src/lib.cairo
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,9 @@ mod wad_ray_math;
mod zellers_congruence;
use integer::{
u8_wide_mul, u16_wide_mul, u32_wide_mul, u64_wide_mul, u128_wide_mul, u256_overflow_mul,
BoundedInt
u8_wrapping_add, u16_wrapping_add, u32_wrapping_add, u64_wrapping_add, u128_wrapping_add,
u256_overflowing_add, u8_wrapping_sub, u16_wrapping_sub, u32_wrapping_sub, u64_wrapping_sub,
u128_wrapping_sub, u256_overflow_sub, BoundedInt
};

/// Raise a number to a power.
Expand Down Expand Up @@ -233,3 +235,168 @@ impl U256BitRotate of BitRotate<u256> {
remainder * pow(2, 256 - n) + quotient
}
}

trait WrappingMath<T> {
fn wrapping_add(self: T, rhs: T) -> T;
fn wrapping_sub(self: T, rhs: T) -> T;
fn wrapping_mul(self: T, rhs: T) -> T;
}

impl WrappingMathImpl<T, +WrappingAdd<T>, +WrappingSub<T>, +WrappingMul<T>> of WrappingMath<T> {
#[inline(always)]
fn wrapping_add(self: T, rhs: T) -> T {
WrappingAdd::<T>::wrapping_add(self, rhs)
}

#[inline(always)]
fn wrapping_sub(self: T, rhs: T) -> T {
WrappingSub::<T>::wrapping_sub(self, rhs)
}

#[inline(always)]
fn wrapping_mul(self: T, rhs: T) -> T {
WrappingMul::<T>::wrapping_mul(self, rhs)
}
}

trait WrappingAdd<T> {
fn wrapping_add(self: T, rhs: T) -> T;
}

trait WrappingSub<T> {
fn wrapping_sub(self: T, rhs: T) -> T;
}

trait WrappingMul<T> {
fn wrapping_mul(self: T, rhs: T) -> T;
}

impl U8WrappingAdd of WrappingAdd<u8> {
#[inline(always)]
fn wrapping_add(self: u8, rhs: u8) -> u8 {
u8_wrapping_add(self, rhs)
}
}

impl U8WrappingSub of WrappingSub<u8> {
#[inline(always)]
fn wrapping_sub(self: u8, rhs: u8) -> u8 {
u8_wrapping_sub(self, rhs)
}
}

impl U8WrappingMul of WrappingMul<u8> {
#[inline(always)]
fn wrapping_mul(self: u8, rhs: u8) -> u8 {
(u8_wide_mul(self, rhs) & BoundedInt::<u8>::max().into()).try_into().unwrap()
}
}

impl U16WrappingAdd of WrappingAdd<u16> {
#[inline(always)]
fn wrapping_add(self: u16, rhs: u16) -> u16 {
u16_wrapping_add(self, rhs)
}
}

impl U16WrappingSub of WrappingSub<u16> {
#[inline(always)]
fn wrapping_sub(self: u16, rhs: u16) -> u16 {
u16_wrapping_sub(self, rhs)
}
}

impl U16WrappingMul of WrappingMul<u16> {
#[inline(always)]
fn wrapping_mul(self: u16, rhs: u16) -> u16 {
(u16_wide_mul(self, rhs) & BoundedInt::<u16>::max().into()).try_into().unwrap()
}
}

impl U32WrappingAdd of WrappingAdd<u32> {
#[inline(always)]
fn wrapping_add(self: u32, rhs: u32) -> u32 {
u32_wrapping_add(self, rhs)
}
}

impl U32WrappingSub of WrappingSub<u32> {
#[inline(always)]
fn wrapping_sub(self: u32, rhs: u32) -> u32 {
u32_wrapping_sub(self, rhs)
}
}

impl U32WrappingMul of WrappingMul<u32> {
#[inline(always)]
fn wrapping_mul(self: u32, rhs: u32) -> u32 {
(u32_wide_mul(self, rhs) & BoundedInt::<u32>::max().into()).try_into().unwrap()
}
}

impl U64WrappingAdd of WrappingAdd<u64> {
#[inline(always)]
fn wrapping_add(self: u64, rhs: u64) -> u64 {
u64_wrapping_add(self, rhs)
}
}

impl U64WrappingSub of WrappingSub<u64> {
#[inline(always)]
fn wrapping_sub(self: u64, rhs: u64) -> u64 {
u64_wrapping_sub(self, rhs)
}
}

impl U64WrappingMul of WrappingMul<u64> {
#[inline(always)]
fn wrapping_mul(self: u64, rhs: u64) -> u64 {
(u64_wide_mul(self, rhs) & BoundedInt::<u64>::max().into()).try_into().unwrap()
}
}

impl U128WrappingAdd of WrappingAdd<u128> {
#[inline(always)]
fn wrapping_add(self: u128, rhs: u128) -> u128 {
u128_wrapping_add(self, rhs)
}
}

impl U128WrappingSub of WrappingSub<u128> {
#[inline(always)]
fn wrapping_sub(self: u128, rhs: u128) -> u128 {
u128_wrapping_sub(self, rhs)
}
}

impl U128WrappingMul of WrappingMul<u128> {
#[inline(always)]
fn wrapping_mul(self: u128, rhs: u128) -> u128 {
let (_, low) = u128_wide_mul(self, rhs);
low
}
}

impl U256WrappingAdd of WrappingAdd<u256> {
#[inline(always)]
fn wrapping_add(self: u256, rhs: u256) -> u256 {
let (val, _) = u256_overflowing_add(self, rhs);
val
}
}

impl U256WrappingSub of WrappingSub<u256> {
#[inline(always)]
fn wrapping_sub(self: u256, rhs: u256) -> u256 {
let (val, _) = u256_overflow_sub(self, rhs);
val
}
}

impl U256WrappingMul of WrappingMul<u256> {
#[inline(always)]
fn wrapping_mul(self: u256, rhs: u256) -> u256 {
let (val, _) = u256_overflow_mul(self, rhs);
val
}
}
182 changes: 181 additions & 1 deletion src/math/src/tests/math_test.cairo
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use alexandria_math::{count_digits_of_base, pow, BitShift, BitRotate};
use alexandria_math::{count_digits_of_base, pow, BitShift, BitRotate, WrappingMath};
use integer::BoundedInt;

// Test power function
Expand Down Expand Up @@ -224,3 +224,183 @@ fn test_rotr_max() {
assert(BitRotate::rotate_right(0b101_u128, 127) == 0b1010, 'invalid result');
assert(BitRotate::rotate_right(0b101_u256, 255) == 0b1010, 'invalid result');
}

#[test]
fn test_wrapping_math_non_wrapping() {
assert_eq!(10_u8.wrapping_add(10_u8), 20_u8);
assert_eq!(0_u8.wrapping_add(10_u8), 10_u8);
assert_eq!(10_u8.wrapping_add(0_u8), 10_u8);
assert_eq!(0_u8.wrapping_add(0_u8), 0_u8);
assert_eq!(20_u8.wrapping_sub(10_u8), 10_u8);
assert_eq!(10_u8.wrapping_sub(0_u8), 10_u8);
assert_eq!(0_u8.wrapping_sub(0_u8), 0_u8);
assert_eq!(10_u8.wrapping_mul(10_u8), 100_u8);
assert_eq!(10_u8.wrapping_mul(0_u8), 0_u8);
assert_eq!(0_u8.wrapping_mul(10_u8), 0_u8);
assert_eq!(0_u8.wrapping_mul(0_u8), 0_u8);

assert_eq!(10_u16.wrapping_add(10_u16), 20_u16);
assert_eq!(0_u16.wrapping_add(10_u16), 10_u16);
assert_eq!(10_u16.wrapping_add(0_u16), 10_u16);
assert_eq!(0_u16.wrapping_add(0_u16), 0_u16);
assert_eq!(20_u16.wrapping_sub(10_u16), 10_u16);
assert_eq!(10_u16.wrapping_sub(0_u16), 10_u16);
assert_eq!(0_u16.wrapping_sub(0_u16), 0_u16);
assert_eq!(10_u16.wrapping_mul(10_u16), 100_u16);
assert_eq!(10_u16.wrapping_mul(0_u16), 0_u16);
assert_eq!(0_u16.wrapping_mul(10_u16), 0_u16);
assert_eq!(0_u16.wrapping_mul(0_u16), 0_u16);

assert_eq!(10_u32.wrapping_add(10_u32), 20_u32);
assert_eq!(0_u32.wrapping_add(10_u32), 10_u32);
assert_eq!(10_u32.wrapping_add(0_u32), 10_u32);
assert_eq!(0_u32.wrapping_add(0_u32), 0_u32);
assert_eq!(20_u32.wrapping_sub(10_u32), 10_u32);
assert_eq!(10_u32.wrapping_sub(0_u32), 10_u32);
assert_eq!(0_u32.wrapping_sub(0_u32), 0_u32);
assert_eq!(10_u32.wrapping_mul(10_u32), 100_u32);
assert_eq!(10_u32.wrapping_mul(0_u32), 0_u32);
assert_eq!(0_u32.wrapping_mul(10_u32), 0_u32);
assert_eq!(0_u32.wrapping_mul(0_u32), 0_u32);

assert_eq!(10_u64.wrapping_add(10_u64), 20_u64);
assert_eq!(0_u64.wrapping_add(10_u64), 10_u64);
assert_eq!(10_u64.wrapping_add(0_u64), 10_u64);
assert_eq!(0_u64.wrapping_add(0_u64), 0_u64);
assert_eq!(20_u64.wrapping_sub(10_u64), 10_u64);
assert_eq!(10_u64.wrapping_sub(0_u64), 10_u64);
assert_eq!(0_u64.wrapping_sub(0_u64), 0_u64);
assert_eq!(10_u64.wrapping_mul(10_u64), 100_u64);
assert_eq!(10_u64.wrapping_mul(0_u64), 0_u64);
assert_eq!(0_u64.wrapping_mul(10_u64), 0_u64);
assert_eq!(0_u64.wrapping_mul(0_u64), 0_u64);

assert_eq!(10_u128.wrapping_add(10_u128), 20_u128);
assert_eq!(0_u128.wrapping_add(10_u128), 10_u128);
assert_eq!(10_u128.wrapping_add(0_u128), 10_u128);
assert_eq!(0_u128.wrapping_add(0_u128), 0_u128);
assert_eq!(20_u128.wrapping_sub(10_u128), 10_u128);
assert_eq!(10_u128.wrapping_sub(0_u128), 10_u128);
assert_eq!(0_u128.wrapping_sub(0_u128), 0_u128);
assert_eq!(10_u128.wrapping_mul(10_u128), 100_u128);
assert_eq!(10_u128.wrapping_mul(0_u128), 0_u128);
assert_eq!(0_u128.wrapping_mul(10_u128), 0_u128);
assert_eq!(0_u128.wrapping_mul(0_u128), 0_u128);

assert_eq!(10_u256.wrapping_add(10_u256), 20_u256);
assert_eq!(0_u256.wrapping_add(10_u256), 10_u256);
assert_eq!(10_u256.wrapping_add(0_u256), 10_u256);
assert_eq!(0_u256.wrapping_add(0_u256), 0_u256);
assert_eq!(20_u256.wrapping_sub(10_u256), 10_u256);
assert_eq!(10_u256.wrapping_sub(0_u256), 10_u256);
assert_eq!(0_u256.wrapping_sub(0_u256), 0_u256);
assert_eq!(10_u256.wrapping_mul(10_u256), 100_u256);
assert_eq!(10_u256.wrapping_mul(0_u256), 0_u256);
assert_eq!(0_u256.wrapping_mul(10_u256), 0_u256);
assert_eq!(0_u256.wrapping_mul(0_u256), 0_u256);
}

#[test]
fn test_wrapping_math_wrapping() {
assert_eq!(BoundedInt::<u8>::max().wrapping_add(1_u8), 0_u8);
assert_eq!(1_u8.wrapping_add(BoundedInt::<u8>::max()), 0_u8);
assert_eq!(BoundedInt::<u8>::max().wrapping_add(2_u8), 1_u8);
assert_eq!(2_u8.wrapping_add(BoundedInt::<u8>::max()), 1_u8);
assert_eq!(
BoundedInt::<u8>::max().wrapping_add(BoundedInt::<u8>::max()),
BoundedInt::<u8>::max() - 1_u8
);
assert_eq!(BoundedInt::<u8>::min().wrapping_sub(1_u8), BoundedInt::<u8>::max());
assert_eq!(BoundedInt::<u8>::min().wrapping_sub(2_u8), BoundedInt::<u8>::max() - 1_u8);
assert_eq!(1_u8.wrapping_sub(BoundedInt::<u8>::max()), 2_u8);
assert_eq!(0_u8.wrapping_sub(BoundedInt::<u8>::max()), 1_u8);
assert_eq!(BoundedInt::<u8>::max().wrapping_mul(BoundedInt::<u8>::max()), 1_u8);
assert_eq!((BoundedInt::<u8>::max() - 1_u8).wrapping_mul(2_u8), BoundedInt::<u8>::max() - 3_u8);

assert_eq!(BoundedInt::<u16>::max().wrapping_add(1_u16), 0_u16);
assert_eq!(1_u16.wrapping_add(BoundedInt::<u16>::max()), 0_u16);
assert_eq!(BoundedInt::<u16>::max().wrapping_add(2_u16), 1_u16);
assert_eq!(2_u16.wrapping_add(BoundedInt::<u16>::max()), 1_u16);
assert_eq!(
BoundedInt::<u16>::max().wrapping_add(BoundedInt::<u16>::max()),
BoundedInt::<u16>::max() - 1_u16
);
assert_eq!(BoundedInt::<u16>::min().wrapping_sub(1_u16), BoundedInt::<u16>::max());
assert_eq!(BoundedInt::<u16>::min().wrapping_sub(2_u16), BoundedInt::<u16>::max() - 1_u16);
assert_eq!(1_u16.wrapping_sub(BoundedInt::<u16>::max()), 2_u16);
assert_eq!(0_u16.wrapping_sub(BoundedInt::<u16>::max()), 1_u16);
assert_eq!(BoundedInt::<u16>::max().wrapping_mul(BoundedInt::<u16>::max()), 1_u16);
assert_eq!(
(BoundedInt::<u16>::max() - 1_u16).wrapping_mul(2_u16), BoundedInt::<u16>::max() - 3_u16
);

assert_eq!(BoundedInt::<u32>::max().wrapping_add(1_u32), 0_u32);
assert_eq!(1_u32.wrapping_add(BoundedInt::<u32>::max()), 0_u32);
assert_eq!(BoundedInt::<u32>::max().wrapping_add(2_u32), 1_u32);
assert_eq!(2_u32.wrapping_add(BoundedInt::<u32>::max()), 1_u32);
assert_eq!(
BoundedInt::<u32>::max().wrapping_add(BoundedInt::<u32>::max()),
BoundedInt::<u32>::max() - 1_u32
);
assert_eq!(BoundedInt::<u32>::min().wrapping_sub(1_u32), BoundedInt::<u32>::max());
assert_eq!(BoundedInt::<u32>::min().wrapping_sub(2_u32), BoundedInt::<u32>::max() - 1_u32);
assert_eq!(1_u32.wrapping_sub(BoundedInt::<u32>::max()), 2_u32);
assert_eq!(0_u32.wrapping_sub(BoundedInt::<u32>::max()), 1_u32);
assert_eq!(BoundedInt::<u32>::max().wrapping_mul(BoundedInt::<u32>::max()), 1_u32);
assert_eq!(
(BoundedInt::<u32>::max() - 1_u32).wrapping_mul(2_u32), BoundedInt::<u32>::max() - 3_u32
);

assert_eq!(BoundedInt::<u64>::max().wrapping_add(1_u64), 0_u64);
assert_eq!(1_u64.wrapping_add(BoundedInt::<u64>::max()), 0_u64);
assert_eq!(BoundedInt::<u64>::max().wrapping_add(2_u64), 1_u64);
assert_eq!(2_u64.wrapping_add(BoundedInt::<u64>::max()), 1_u64);
assert_eq!(
BoundedInt::<u64>::max().wrapping_add(BoundedInt::<u64>::max()),
BoundedInt::<u64>::max() - 1_u64
);
assert_eq!(BoundedInt::<u64>::min().wrapping_sub(1_u64), BoundedInt::<u64>::max());
assert_eq!(BoundedInt::<u64>::min().wrapping_sub(2_u64), BoundedInt::<u64>::max() - 1_u64);
assert_eq!(1_u64.wrapping_sub(BoundedInt::<u64>::max()), 2_u64);
assert_eq!(0_u64.wrapping_sub(BoundedInt::<u64>::max()), 1_u64);
assert_eq!(BoundedInt::<u64>::max().wrapping_mul(BoundedInt::<u64>::max()), 1_u64);
assert_eq!(
(BoundedInt::<u64>::max() - 1_u64).wrapping_mul(2_u64), BoundedInt::<u64>::max() - 3_u64
);

assert_eq!(BoundedInt::<u128>::max().wrapping_add(1_u128), 0_u128);
assert_eq!(1_u128.wrapping_add(BoundedInt::<u128>::max()), 0_u128);
assert_eq!(BoundedInt::<u128>::max().wrapping_add(2_u128), 1_u128);
assert_eq!(2_u128.wrapping_add(BoundedInt::<u128>::max()), 1_u128);
assert_eq!(
BoundedInt::<u128>::max().wrapping_add(BoundedInt::<u128>::max()),
BoundedInt::<u128>::max() - 1_u128
);
assert_eq!(BoundedInt::<u128>::min().wrapping_sub(1_u128), BoundedInt::<u128>::max());
assert_eq!(BoundedInt::<u128>::min().wrapping_sub(2_u128), BoundedInt::<u128>::max() - 1_u128);
assert_eq!(1_u128.wrapping_sub(BoundedInt::<u128>::max()), 2_u128);
assert_eq!(0_u128.wrapping_sub(BoundedInt::<u128>::max()), 1_u128);
assert_eq!(BoundedInt::<u128>::max().wrapping_mul(BoundedInt::<u128>::max()), 1_u128);
assert_eq!(
(BoundedInt::<u128>::max() - 1_u128).wrapping_mul(2_u128),
BoundedInt::<u128>::max() - 3_u128
);

assert_eq!(BoundedInt::<u256>::max().wrapping_add(1_u256), 0_u256);
assert_eq!(1_u256.wrapping_add(BoundedInt::<u256>::max()), 0_u256);
assert_eq!(BoundedInt::<u256>::max().wrapping_add(2_u256), 1_u256);
assert_eq!(2_u256.wrapping_add(BoundedInt::<u256>::max()), 1_u256);
assert_eq!(
BoundedInt::<u256>::max().wrapping_add(BoundedInt::<u256>::max()),
BoundedInt::<u256>::max() - 1_u256
);
assert_eq!(BoundedInt::<u256>::min().wrapping_sub(1_u256), BoundedInt::<u256>::max());
assert_eq!(BoundedInt::<u256>::min().wrapping_sub(2_u256), BoundedInt::<u256>::max() - 1_u256);
assert_eq!(1_u256.wrapping_sub(BoundedInt::<u256>::max()), 2_u256);
assert_eq!(0_u256.wrapping_sub(BoundedInt::<u256>::max()), 1_u256);
assert_eq!(BoundedInt::<u256>::max().wrapping_mul(BoundedInt::<u256>::max()), 1_u256);
assert_eq!(
(BoundedInt::<u256>::max() - 1_u256).wrapping_mul(2_u256),
BoundedInt::<u256>::max() - 3_u256
);
}

0 comments on commit 98aadc7

Please sign in to comment.