Skip to content

Commit

Permalink
f32x32 test
Browse files Browse the repository at this point in the history
  • Loading branch information
chachaleo committed Jul 8, 2024
1 parent c2cbfa1 commit c68df1a
Show file tree
Hide file tree
Showing 5 changed files with 155 additions and 19 deletions.
2 changes: 1 addition & 1 deletion packages/orion-algo/src/algo.cairo
Original file line number Diff line number Diff line change
@@ -1 +1 @@
pub mod linear_fit;
pub mod linear_fit;
91 changes: 78 additions & 13 deletions packages/orion-algo/src/algo/linear_fit.cairo
Original file line number Diff line number Diff line change
@@ -1,20 +1,34 @@
use orion_numbers::{f16x16::core::{f16x16}, FixedTrait};
use orion_algo::span_math::SpanMathTrait;
use orion_numbers::core_trait::I32Div;

pub fn linear_fit(x: Span<f16x16>, y: Span<f16x16>) -> (f16x16, f16x16) {
use orion_numbers::core_trait::{I32Div, I64Div};


pub fn linear_fit<
T,
+SpanMathTrait<T>,
+TryInto<u32, T>,
+FixedTrait<T>,
+Mul<T>,
+Sub<T>,
+PartialEq<T>,
+Div<T>,
+Drop<T>,
+Copy<T>
>(
x: Span<T>, y: Span<T>
) -> (T, T) {
if x.len() != y.len() || x.len() == 0 {
panic!("x and y should be of the same lenght")
}

let n: f16x16 = x.len().try_into().unwrap();
let n: T = x.len().try_into().unwrap();
let sum_x = x.sum();
let sum_y = y.sum();
let sum_xx = x.dot(x);
let sum_xy = x.dot(y);

let denominator = n * sum_xx - (sum_x.mul(sum_x));
if denominator == 0 {
if denominator == FixedTrait::ZERO() {
panic!("division by zero exception")
}

Expand All @@ -27,7 +41,9 @@ pub fn linear_fit(x: Span<f16x16>, y: Span<f16x16>) -> (f16x16, f16x16) {
#[cfg(test)]
mod tests {
use super::linear_fit;
use orion_numbers::f16x16::helpers::{assert_precise, assert_relative};
use orion_numbers::f16x16;
use orion_numbers::f32x32;
use orion_numbers::core_trait::{I32Div, I64Div};

#[test]
fn linear_fit_line_test() {
Expand All @@ -37,10 +53,10 @@ mod tests {
let (slope_expected, intercept_expected) = (131072, 0);
let (slope_actual, intercept_actual) = linear_fit(x, y);

assert_precise(
f16x16::helpers::assert_precise(
slope_actual, slope_expected, 'slopes should be equal', Option::None(())
);
assert_precise(
f16x16::helpers::assert_precise(
intercept_actual, intercept_expected, 'intercepts should be equal', Option::None(())
);
}
Expand All @@ -53,13 +69,12 @@ mod tests {
let (slope_expected, intercept_expected) = (130698, 5305);
let (slope_actual, intercept_actual) = linear_fit(x, y);

assert_precise(
f16x16::helpers::assert_precise(
slope_actual, slope_expected, 'slopes should be equal', Option::None(())
);
assert_precise(
f16x16::helpers::assert_precise(
intercept_actual, intercept_expected, 'intercepts should be equal', Option::None(())
);

}

#[test]
Expand All @@ -70,12 +85,62 @@ mod tests {
let (slope_expected, intercept_expected) = (98866, 119837);
let (slope_actual, intercept_actual) = linear_fit(x, y);

assert_precise(
f16x16::helpers::assert_precise(
slope_actual, slope_expected, 'slopes should be equal', Option::None(())
);
f16x16::helpers::assert_precise(
intercept_actual, intercept_expected, 'intercepts should be equal', Option::None(())
);
}


#[test]
fn linear_fit_line_test_f32x32() {
let x = array![0, 4294967296, 8589934592, 12884901888, 17179869184, 21474836480].span();
let y = array![0, 8589934592, 17179869184, 25769803776, 34359738368, 42949672960].span();

let (slope_expected, intercept_expected) = (8589934592, 0);
let (slope_actual, intercept_actual) = linear_fit(x, y);

f32x32::helpers::assert_precise(
slope_actual, slope_expected, 'slopes should be equal', Option::None(())
);
f32x32::helpers::assert_precise(
intercept_actual, intercept_expected, 'intercepts should be equal', Option::None(())
);
}

#[test]
fn linear_fit_line_with_noise_test_f32x32() {
let x = array![0, 4294967296, 8589934592, 12884901888, 17179869184, 21474836480].span();
let y = array![430014464, 9448922319, 16754909120, 26198093840, 33983924224, 43786887168]
.span();

let (slope_expected, intercept_expected) = (8566644398, 350514194);
let (slope_actual, intercept_actual) = linear_fit(x, y);

f32x32::helpers::assert_precise(
slope_actual, slope_expected, 'slopes should be equal', Option::None(())
);
assert_precise(
f32x32::helpers::assert_precise(
intercept_actual, intercept_expected, 'intercepts should be equal', Option::None(())
);
}

#[test]
fn linear_fit_test_f32x32() {
let x = array![0, 4294967296, 8589934592, 12884901888, 17179869184, 21474836480].span();
let y = array![12458487808, 12884901888, 20111880192, 21474836480, 30031216640, 47185920000]
.span();

let (slope_expected, intercept_expected) = (6479333376, 7850820812);
let (slope_actual, intercept_actual) = linear_fit(x, y);

f32x32::helpers::assert_precise(
slope_actual, slope_expected, 'slopes should be equal', Option::None(())
);
f32x32::helpers::assert_precise(
intercept_actual, intercept_expected, 'intercepts should be equal', Option::None(())
);
}
}
2 changes: 1 addition & 1 deletion packages/orion-algo/src/span_math.cairo
Original file line number Diff line number Diff line change
Expand Up @@ -11,4 +11,4 @@ pub trait SpanMathTrait<T> {
fn min(self: Span<T>) -> T;
fn prod(self: Span<T>) -> T;
fn sum(self: Span<T>) -> T;
}
}
6 changes: 2 additions & 4 deletions packages/orion-algo/src/span_math/span_f16x16.cairo
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,6 @@ pub impl F16x16SpanMath of SpanMathTrait<f16x16> {
fn sum(self: Span<f16x16>) -> f16x16 {
sum(self)
}

}


Expand Down Expand Up @@ -123,15 +122,14 @@ mod tests {
assert_precise(*res.at(3), *x.at(3), 'should be equal', Option::None);
assert_precise(*res.at(4), *x.at(4), 'should be equal', Option::None);
assert_precise(*res.at(5), *x.at(5), 'should be equal', Option::None);

}

#[test]
fn test_dot() {
let x = array![0, 65536, 131072, 196608, 262144, 327680].span(); // 0, 1, 2, 3, 4, 5
let y = array![0, 131072, 262144, 393216, 524288, 655360].span(); // 0, 2, 4, 6, 8, 10
let result = dot(x, y);

assert_precise(result, (110 * ONE).into(), 'should be equal', Option::None);
}

Expand Down Expand Up @@ -170,4 +168,4 @@ mod tests {

assert_precise(result, (15 * ONE).into(), 'should be equal', Option::None);
}
}
}
73 changes: 73 additions & 0 deletions packages/orion-algo/src/span_math/span_f32x32.cairo
Original file line number Diff line number Diff line change
Expand Up @@ -124,3 +124,76 @@ pub fn linear_fit(x: Span<f32x32>, y: Span<f32x32>) -> (f32x32, f32x32) {

(a, b)
}


#[cfg(test)]
mod tests {
use super::{arange, dot, max, min, prod, sum, ONE};
use orion_numbers::f32x32::helpers::assert_precise;

#[test]
fn test_arange() {
let n = 6;
let res = arange(n);

let x = array![0, 4294967296, 8589934592, 12884901888, 17179869184, 21474836480].span();

assert_precise(*res.at(0), *x.at(0), 'should be equal', Option::None);
assert_precise(*res.at(1), *x.at(1), 'should be equal', Option::None);
assert_precise(*res.at(2), *x.at(2), 'should be equal', Option::None);
assert_precise(*res.at(3), *x.at(3), 'should be equal', Option::None);
assert_precise(*res.at(4), *x.at(4), 'should be equal', Option::None);
assert_precise(*res.at(5), *x.at(5), 'should be equal', Option::None);
}

#[test]
fn test_dot() {
let x = array![0, 4294967296, 8589934592, 12884901888, 17179869184, 21474836480]
.span(); // 0, 1, 2, 3, 4, 5
let y = array![0, 8589934592, 17179869184, 25769803776, 34359738368, 42949672960]
.span(); // 0, 2, 4, 6, 8, 10
let result = dot(x, y);

assert_precise(result, (7208960 * ONE).into(), 'should be equal', Option::None);
}

#[test]
fn test_max() {
let x = array![0, 4294967296, 8589934592, 12884901888, 17179869184, 21474836480]
.span(); // 0, 1, 2, 3, 4, 5

let result = max(x);

assert_precise(result, 21474836480, 'should be equal', Option::None);
}

#[test]
fn test_min() {
let x = array![0, 4294967296, 8589934592, 12884901888, 17179869184, 21474836480]
.span(); // 0, 1, 2, 3, 4, 5

let result = min(x);

assert_precise(result, 0, 'should be equal', Option::None);
}

#[test]
fn test_prod() {
let x = array![0, 4294967296, 8589934592, 12884901888, 17179869184, 21474836480]
.span(); // 0, 1, 2, 3, 4, 5

let result = prod(x);

assert_precise(result, 0, 'should be equal', Option::None);
}

#[test]
fn test_sum() {
let x = array![0, 4294967296, 8589934592, 12884901888, 17179869184, 21474836480]
.span(); // 0, 1, 2, 3, 4, 5

let result = sum(x);

assert_precise(result, (98304000 * ONE).into(), 'should be equal', Option::None);
}
}

0 comments on commit c68df1a

Please sign in to comment.