Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Add WrappingMath trait #241

Merged
merged 5 commits into from
Jan 15, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
169 changes: 168 additions & 1 deletion src/math/src/lib.cairo
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,9 @@ mod wad_ray_math;
mod zellers_congruence;
use integer::{
u8_wide_mul, u16_wide_mul, u32_wide_mul, u64_wide_mul, u128_wide_mul, u256_overflow_mul,
BoundedInt
u8_wrapping_add, u16_wrapping_add, u32_wrapping_add, u64_wrapping_add, u128_wrapping_add,
u256_overflowing_add, u8_wrapping_sub, u16_wrapping_sub, u32_wrapping_sub, u64_wrapping_sub,
u128_wrapping_sub, u256_overflow_sub, BoundedInt
};

/// Raise a number to a power.
Expand Down Expand Up @@ -233,3 +235,168 @@ impl U256BitRotate of BitRotate<u256> {
remainder * pow(2, 256 - n) + quotient
}
}

trait WrappingMath<T> {
fn wrapping_add(self: T, rhs: T) -> T;
fn wrapping_sub(self: T, rhs: T) -> T;
fn wrapping_mul(self: T, rhs: T) -> T;
}

impl WrappingMathImpl<T, +WrappingAdd<T>, +WrappingSub<T>, +WrappingMul<T>> of WrappingMath<T> {
#[inline(always)]
fn wrapping_add(self: T, rhs: T) -> T {
WrappingAdd::<T>::wrapping_add(self, rhs)
}

#[inline(always)]
fn wrapping_sub(self: T, rhs: T) -> T {
WrappingSub::<T>::wrapping_sub(self, rhs)
}

#[inline(always)]
fn wrapping_mul(self: T, rhs: T) -> T {
WrappingMul::<T>::wrapping_mul(self, rhs)
}
}

trait WrappingAdd<T> {
fn wrapping_add(self: T, rhs: T) -> T;
}

trait WrappingSub<T> {
fn wrapping_sub(self: T, rhs: T) -> T;
}

trait WrappingMul<T> {
fn wrapping_mul(self: T, rhs: T) -> T;
}

impl U8WrappingAdd of WrappingAdd<u8> {
#[inline(always)]
fn wrapping_add(self: u8, rhs: u8) -> u8 {
u8_wrapping_add(self, rhs)
}
}

impl U8WrappingSub of WrappingSub<u8> {
#[inline(always)]
fn wrapping_sub(self: u8, rhs: u8) -> u8 {
u8_wrapping_sub(self, rhs)
}
}

impl U8WrappingMul of WrappingMul<u8> {
#[inline(always)]
fn wrapping_mul(self: u8, rhs: u8) -> u8 {
(u8_wide_mul(self, rhs) & BoundedInt::<u8>::max().into()).try_into().unwrap()
}
}

impl U16WrappingAdd of WrappingAdd<u16> {
#[inline(always)]
fn wrapping_add(self: u16, rhs: u16) -> u16 {
u16_wrapping_add(self, rhs)
}
}

impl U16WrappingSub of WrappingSub<u16> {
#[inline(always)]
fn wrapping_sub(self: u16, rhs: u16) -> u16 {
u16_wrapping_sub(self, rhs)
}
}

impl U16WrappingMul of WrappingMul<u16> {
#[inline(always)]
fn wrapping_mul(self: u16, rhs: u16) -> u16 {
(u16_wide_mul(self, rhs) & BoundedInt::<u16>::max().into()).try_into().unwrap()
}
}

impl U32WrappingAdd of WrappingAdd<u32> {
#[inline(always)]
fn wrapping_add(self: u32, rhs: u32) -> u32 {
u32_wrapping_add(self, rhs)
}
}

impl U32WrappingSub of WrappingSub<u32> {
#[inline(always)]
fn wrapping_sub(self: u32, rhs: u32) -> u32 {
u32_wrapping_sub(self, rhs)
}
}

impl U32WrappingMul of WrappingMul<u32> {
#[inline(always)]
fn wrapping_mul(self: u32, rhs: u32) -> u32 {
(u32_wide_mul(self, rhs) & BoundedInt::<u32>::max().into()).try_into().unwrap()
}
}

impl U64WrappingAdd of WrappingAdd<u64> {
#[inline(always)]
fn wrapping_add(self: u64, rhs: u64) -> u64 {
u64_wrapping_add(self, rhs)
}
}

impl U64WrappingSub of WrappingSub<u64> {
#[inline(always)]
fn wrapping_sub(self: u64, rhs: u64) -> u64 {
u64_wrapping_sub(self, rhs)
}
}

impl U64WrappingMul of WrappingMul<u64> {
#[inline(always)]
fn wrapping_mul(self: u64, rhs: u64) -> u64 {
(u64_wide_mul(self, rhs) & BoundedInt::<u64>::max().into()).try_into().unwrap()
}
}

impl U128WrappingAdd of WrappingAdd<u128> {
#[inline(always)]
fn wrapping_add(self: u128, rhs: u128) -> u128 {
u128_wrapping_add(self, rhs)
}
}

impl U128WrappingSub of WrappingSub<u128> {
#[inline(always)]
fn wrapping_sub(self: u128, rhs: u128) -> u128 {
u128_wrapping_sub(self, rhs)
}
}

impl U128WrappingMul of WrappingMul<u128> {
#[inline(always)]
fn wrapping_mul(self: u128, rhs: u128) -> u128 {
let (_, low) = u128_wide_mul(self, rhs);
low
}
}

impl U256WrappingAdd of WrappingAdd<u256> {
#[inline(always)]
fn wrapping_add(self: u256, rhs: u256) -> u256 {
let (val, _) = u256_overflowing_add(self, rhs);
val
}
}

impl U256WrappingSub of WrappingSub<u256> {
#[inline(always)]
fn wrapping_sub(self: u256, rhs: u256) -> u256 {
let (val, _) = u256_overflow_sub(self, rhs);
val
}
}

impl U256WrappingMul of WrappingMul<u256> {
#[inline(always)]
fn wrapping_mul(self: u256, rhs: u256) -> u256 {
let (val, _) = u256_overflow_mul(self, rhs);
val
}
}
182 changes: 181 additions & 1 deletion src/math/src/tests/math_test.cairo
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use alexandria_math::{count_digits_of_base, pow, BitShift, BitRotate};
use alexandria_math::{count_digits_of_base, pow, BitShift, BitRotate, WrappingMath};
use integer::BoundedInt;

// Test power function
Expand Down Expand Up @@ -224,3 +224,183 @@ fn test_rotr_max() {
assert(BitRotate::rotate_right(0b101_u128, 127) == 0b1010, 'invalid result');
assert(BitRotate::rotate_right(0b101_u256, 255) == 0b1010, 'invalid result');
}

#[test]
fn test_wrapping_math_non_wrapping() {
assert_eq!(10_u8.wrapping_add(10_u8), 20_u8);
assert_eq!(0_u8.wrapping_add(10_u8), 10_u8);
assert_eq!(10_u8.wrapping_add(0_u8), 10_u8);
assert_eq!(0_u8.wrapping_add(0_u8), 0_u8);
assert_eq!(20_u8.wrapping_sub(10_u8), 10_u8);
assert_eq!(10_u8.wrapping_sub(0_u8), 10_u8);
assert_eq!(0_u8.wrapping_sub(0_u8), 0_u8);
assert_eq!(10_u8.wrapping_mul(10_u8), 100_u8);
assert_eq!(10_u8.wrapping_mul(0_u8), 0_u8);
assert_eq!(0_u8.wrapping_mul(10_u8), 0_u8);
assert_eq!(0_u8.wrapping_mul(0_u8), 0_u8);

assert_eq!(10_u16.wrapping_add(10_u16), 20_u16);
assert_eq!(0_u16.wrapping_add(10_u16), 10_u16);
assert_eq!(10_u16.wrapping_add(0_u16), 10_u16);
assert_eq!(0_u16.wrapping_add(0_u16), 0_u16);
assert_eq!(20_u16.wrapping_sub(10_u16), 10_u16);
assert_eq!(10_u16.wrapping_sub(0_u16), 10_u16);
assert_eq!(0_u16.wrapping_sub(0_u16), 0_u16);
assert_eq!(10_u16.wrapping_mul(10_u16), 100_u16);
assert_eq!(10_u16.wrapping_mul(0_u16), 0_u16);
assert_eq!(0_u16.wrapping_mul(10_u16), 0_u16);
assert_eq!(0_u16.wrapping_mul(0_u16), 0_u16);

assert_eq!(10_u32.wrapping_add(10_u32), 20_u32);
assert_eq!(0_u32.wrapping_add(10_u32), 10_u32);
assert_eq!(10_u32.wrapping_add(0_u32), 10_u32);
assert_eq!(0_u32.wrapping_add(0_u32), 0_u32);
assert_eq!(20_u32.wrapping_sub(10_u32), 10_u32);
assert_eq!(10_u32.wrapping_sub(0_u32), 10_u32);
assert_eq!(0_u32.wrapping_sub(0_u32), 0_u32);
assert_eq!(10_u32.wrapping_mul(10_u32), 100_u32);
assert_eq!(10_u32.wrapping_mul(0_u32), 0_u32);
assert_eq!(0_u32.wrapping_mul(10_u32), 0_u32);
assert_eq!(0_u32.wrapping_mul(0_u32), 0_u32);

assert_eq!(10_u64.wrapping_add(10_u64), 20_u64);
assert_eq!(0_u64.wrapping_add(10_u64), 10_u64);
assert_eq!(10_u64.wrapping_add(0_u64), 10_u64);
assert_eq!(0_u64.wrapping_add(0_u64), 0_u64);
assert_eq!(20_u64.wrapping_sub(10_u64), 10_u64);
assert_eq!(10_u64.wrapping_sub(0_u64), 10_u64);
assert_eq!(0_u64.wrapping_sub(0_u64), 0_u64);
assert_eq!(10_u64.wrapping_mul(10_u64), 100_u64);
assert_eq!(10_u64.wrapping_mul(0_u64), 0_u64);
assert_eq!(0_u64.wrapping_mul(10_u64), 0_u64);
assert_eq!(0_u64.wrapping_mul(0_u64), 0_u64);

assert_eq!(10_u128.wrapping_add(10_u128), 20_u128);
assert_eq!(0_u128.wrapping_add(10_u128), 10_u128);
assert_eq!(10_u128.wrapping_add(0_u128), 10_u128);
assert_eq!(0_u128.wrapping_add(0_u128), 0_u128);
assert_eq!(20_u128.wrapping_sub(10_u128), 10_u128);
assert_eq!(10_u128.wrapping_sub(0_u128), 10_u128);
assert_eq!(0_u128.wrapping_sub(0_u128), 0_u128);
assert_eq!(10_u128.wrapping_mul(10_u128), 100_u128);
assert_eq!(10_u128.wrapping_mul(0_u128), 0_u128);
assert_eq!(0_u128.wrapping_mul(10_u128), 0_u128);
assert_eq!(0_u128.wrapping_mul(0_u128), 0_u128);

assert_eq!(10_u256.wrapping_add(10_u256), 20_u256);
assert_eq!(0_u256.wrapping_add(10_u256), 10_u256);
assert_eq!(10_u256.wrapping_add(0_u256), 10_u256);
assert_eq!(0_u256.wrapping_add(0_u256), 0_u256);
assert_eq!(20_u256.wrapping_sub(10_u256), 10_u256);
assert_eq!(10_u256.wrapping_sub(0_u256), 10_u256);
assert_eq!(0_u256.wrapping_sub(0_u256), 0_u256);
assert_eq!(10_u256.wrapping_mul(10_u256), 100_u256);
assert_eq!(10_u256.wrapping_mul(0_u256), 0_u256);
assert_eq!(0_u256.wrapping_mul(10_u256), 0_u256);
assert_eq!(0_u256.wrapping_mul(0_u256), 0_u256);
}

#[test]
fn test_wrapping_math_wrapping() {
assert_eq!(BoundedInt::<u8>::max().wrapping_add(1_u8), 0_u8);
assert_eq!(1_u8.wrapping_add(BoundedInt::<u8>::max()), 0_u8);
assert_eq!(BoundedInt::<u8>::max().wrapping_add(2_u8), 1_u8);
assert_eq!(2_u8.wrapping_add(BoundedInt::<u8>::max()), 1_u8);
assert_eq!(
BoundedInt::<u8>::max().wrapping_add(BoundedInt::<u8>::max()),
BoundedInt::<u8>::max() - 1_u8
);
assert_eq!(BoundedInt::<u8>::min().wrapping_sub(1_u8), BoundedInt::<u8>::max());
assert_eq!(BoundedInt::<u8>::min().wrapping_sub(2_u8), BoundedInt::<u8>::max() - 1_u8);
assert_eq!(1_u8.wrapping_sub(BoundedInt::<u8>::max()), 2_u8);
assert_eq!(0_u8.wrapping_sub(BoundedInt::<u8>::max()), 1_u8);
assert_eq!(BoundedInt::<u8>::max().wrapping_mul(BoundedInt::<u8>::max()), 1_u8);
assert_eq!((BoundedInt::<u8>::max() - 1_u8).wrapping_mul(2_u8), BoundedInt::<u8>::max() - 3_u8);

assert_eq!(BoundedInt::<u16>::max().wrapping_add(1_u16), 0_u16);
assert_eq!(1_u16.wrapping_add(BoundedInt::<u16>::max()), 0_u16);
assert_eq!(BoundedInt::<u16>::max().wrapping_add(2_u16), 1_u16);
assert_eq!(2_u16.wrapping_add(BoundedInt::<u16>::max()), 1_u16);
assert_eq!(
BoundedInt::<u16>::max().wrapping_add(BoundedInt::<u16>::max()),
BoundedInt::<u16>::max() - 1_u16
);
assert_eq!(BoundedInt::<u16>::min().wrapping_sub(1_u16), BoundedInt::<u16>::max());
assert_eq!(BoundedInt::<u16>::min().wrapping_sub(2_u16), BoundedInt::<u16>::max() - 1_u16);
assert_eq!(1_u16.wrapping_sub(BoundedInt::<u16>::max()), 2_u16);
assert_eq!(0_u16.wrapping_sub(BoundedInt::<u16>::max()), 1_u16);
assert_eq!(BoundedInt::<u16>::max().wrapping_mul(BoundedInt::<u16>::max()), 1_u16);
assert_eq!(
(BoundedInt::<u16>::max() - 1_u16).wrapping_mul(2_u16), BoundedInt::<u16>::max() - 3_u16
);

assert_eq!(BoundedInt::<u32>::max().wrapping_add(1_u32), 0_u32);
assert_eq!(1_u32.wrapping_add(BoundedInt::<u32>::max()), 0_u32);
assert_eq!(BoundedInt::<u32>::max().wrapping_add(2_u32), 1_u32);
assert_eq!(2_u32.wrapping_add(BoundedInt::<u32>::max()), 1_u32);
assert_eq!(
BoundedInt::<u32>::max().wrapping_add(BoundedInt::<u32>::max()),
BoundedInt::<u32>::max() - 1_u32
);
assert_eq!(BoundedInt::<u32>::min().wrapping_sub(1_u32), BoundedInt::<u32>::max());
assert_eq!(BoundedInt::<u32>::min().wrapping_sub(2_u32), BoundedInt::<u32>::max() - 1_u32);
assert_eq!(1_u32.wrapping_sub(BoundedInt::<u32>::max()), 2_u32);
assert_eq!(0_u32.wrapping_sub(BoundedInt::<u32>::max()), 1_u32);
assert_eq!(BoundedInt::<u32>::max().wrapping_mul(BoundedInt::<u32>::max()), 1_u32);
assert_eq!(
(BoundedInt::<u32>::max() - 1_u32).wrapping_mul(2_u32), BoundedInt::<u32>::max() - 3_u32
);

assert_eq!(BoundedInt::<u64>::max().wrapping_add(1_u64), 0_u64);
assert_eq!(1_u64.wrapping_add(BoundedInt::<u64>::max()), 0_u64);
assert_eq!(BoundedInt::<u64>::max().wrapping_add(2_u64), 1_u64);
assert_eq!(2_u64.wrapping_add(BoundedInt::<u64>::max()), 1_u64);
assert_eq!(
BoundedInt::<u64>::max().wrapping_add(BoundedInt::<u64>::max()),
BoundedInt::<u64>::max() - 1_u64
);
assert_eq!(BoundedInt::<u64>::min().wrapping_sub(1_u64), BoundedInt::<u64>::max());
assert_eq!(BoundedInt::<u64>::min().wrapping_sub(2_u64), BoundedInt::<u64>::max() - 1_u64);
assert_eq!(1_u64.wrapping_sub(BoundedInt::<u64>::max()), 2_u64);
assert_eq!(0_u64.wrapping_sub(BoundedInt::<u64>::max()), 1_u64);
assert_eq!(BoundedInt::<u64>::max().wrapping_mul(BoundedInt::<u64>::max()), 1_u64);
assert_eq!(
(BoundedInt::<u64>::max() - 1_u64).wrapping_mul(2_u64), BoundedInt::<u64>::max() - 3_u64
);

assert_eq!(BoundedInt::<u128>::max().wrapping_add(1_u128), 0_u128);
assert_eq!(1_u128.wrapping_add(BoundedInt::<u128>::max()), 0_u128);
assert_eq!(BoundedInt::<u128>::max().wrapping_add(2_u128), 1_u128);
assert_eq!(2_u128.wrapping_add(BoundedInt::<u128>::max()), 1_u128);
assert_eq!(
BoundedInt::<u128>::max().wrapping_add(BoundedInt::<u128>::max()),
BoundedInt::<u128>::max() - 1_u128
);
assert_eq!(BoundedInt::<u128>::min().wrapping_sub(1_u128), BoundedInt::<u128>::max());
assert_eq!(BoundedInt::<u128>::min().wrapping_sub(2_u128), BoundedInt::<u128>::max() - 1_u128);
assert_eq!(1_u128.wrapping_sub(BoundedInt::<u128>::max()), 2_u128);
assert_eq!(0_u128.wrapping_sub(BoundedInt::<u128>::max()), 1_u128);
assert_eq!(BoundedInt::<u128>::max().wrapping_mul(BoundedInt::<u128>::max()), 1_u128);
assert_eq!(
(BoundedInt::<u128>::max() - 1_u128).wrapping_mul(2_u128),
BoundedInt::<u128>::max() - 3_u128
);

assert_eq!(BoundedInt::<u256>::max().wrapping_add(1_u256), 0_u256);
assert_eq!(1_u256.wrapping_add(BoundedInt::<u256>::max()), 0_u256);
assert_eq!(BoundedInt::<u256>::max().wrapping_add(2_u256), 1_u256);
assert_eq!(2_u256.wrapping_add(BoundedInt::<u256>::max()), 1_u256);
assert_eq!(
BoundedInt::<u256>::max().wrapping_add(BoundedInt::<u256>::max()),
BoundedInt::<u256>::max() - 1_u256
);
assert_eq!(BoundedInt::<u256>::min().wrapping_sub(1_u256), BoundedInt::<u256>::max());
assert_eq!(BoundedInt::<u256>::min().wrapping_sub(2_u256), BoundedInt::<u256>::max() - 1_u256);
assert_eq!(1_u256.wrapping_sub(BoundedInt::<u256>::max()), 2_u256);
assert_eq!(0_u256.wrapping_sub(BoundedInt::<u256>::max()), 1_u256);
assert_eq!(BoundedInt::<u256>::max().wrapping_mul(BoundedInt::<u256>::max()), 1_u256);
assert_eq!(
(BoundedInt::<u256>::max() - 1_u256).wrapping_mul(2_u256),
BoundedInt::<u256>::max() - 3_u256
);
}