Skip to content

Commit

Permalink
span math trait f32x32
Browse files Browse the repository at this point in the history
  • Loading branch information
chachaleo committed Jul 8, 2024
1 parent 439eb9a commit c2cbfa1
Show file tree
Hide file tree
Showing 6 changed files with 177 additions and 50 deletions.
7 changes: 3 additions & 4 deletions packages/orion-algo/src/algo/linear_fit.cairo
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use orion_numbers::f16x16::core::{f16x16, FixedTrait};
use orion_algo::span_math::core::SpanMathTrait;
use orion_numbers::f16x16::core_trait::I32Div;
use orion_numbers::{f16x16::core::{f16x16}, FixedTrait};
use orion_algo::span_math::SpanMathTrait;
use orion_numbers::core_trait::I32Div;

pub fn linear_fit(x: Span<f16x16>, y: Span<f16x16>) -> (f16x16, f16x16) {
if x.len() != y.len() || x.len() == 0 {
Expand All @@ -24,7 +24,6 @@ pub fn linear_fit(x: Span<f16x16>, y: Span<f16x16>) -> (f16x16, f16x16) {
(a, b)
}


#[cfg(test)]
mod tests {
use super::linear_fit;
Expand Down
2 changes: 1 addition & 1 deletion packages/orion-algo/src/lib.cairo
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
pub mod span_math;
pub mod algo;
pub mod algo;
16 changes: 14 additions & 2 deletions packages/orion-algo/src/span_math.cairo
Original file line number Diff line number Diff line change
@@ -1,2 +1,14 @@
pub mod core;
mod math;
pub mod span_f32x32;
pub mod span_f16x16;

use span_f16x16::F16x16SpanMath;
use span_f32x32::F32x32SpanMath;

pub trait SpanMathTrait<T> {
fn arange(n: u32) -> Span<T>;
fn dot(self: Span<T>, other: Span<T>) -> T;
fn max(self: Span<T>) -> T;
fn min(self: Span<T>) -> T;
fn prod(self: Span<T>) -> T;
fn sum(self: Span<T>) -> T;
}
38 changes: 0 additions & 38 deletions packages/orion-algo/src/span_math/core.cairo

This file was deleted.

Original file line number Diff line number Diff line change
@@ -1,8 +1,36 @@
use core::array::ArrayTrait;
use core::option::OptionTrait;
use core::traits::TryInto;
use orion_numbers::f16x16::core::{f16x16, FixedTrait, ONE};
use orion_numbers::f16x16::core_trait::{I32Rem, I32Div};
use orion_numbers::{f16x16::core::{f16x16, ONE}, FixedTrait};
use orion_numbers::core_trait::{I32Rem, I32Div};

use orion_algo::span_math::SpanMathTrait;


pub impl F16x16SpanMath of SpanMathTrait<f16x16> {
fn arange(n: u32) -> Span<f16x16> {
arange(n)
}

fn dot(self: Span<f16x16>, other: Span<f16x16>) -> f16x16 {
dot(self, other)
}

fn max(self: Span<f16x16>) -> f16x16 {
max(self)
}

fn min(self: Span<f16x16>) -> f16x16 {
min(self)
}

fn prod(self: Span<f16x16>) -> f16x16 {
prod(self)
}

fn sum(self: Span<f16x16>) -> f16x16 {
sum(self)
}

}


pub(crate) fn arange(n: u32) -> Span<f16x16> {
let mut i = 0;
Expand Down
126 changes: 126 additions & 0 deletions packages/orion-algo/src/span_math/span_f32x32.cairo
Original file line number Diff line number Diff line change
@@ -0,0 +1,126 @@
use orion_numbers::{core_trait::{I64Rem, I64Div}, FixedTrait};
use orion_numbers::f32x32::core::{f32x32, ONE};

use orion_algo::span_math::SpanMathTrait;


pub impl F32x32SpanMath of SpanMathTrait<f32x32> {
fn arange(n: u32) -> Span<f32x32> {
arange(n)
}

fn dot(self: Span<f32x32>, other: Span<f32x32>) -> f32x32 {
dot(self, other)
}

fn max(self: Span<f32x32>) -> f32x32 {
max(self)
}

fn min(self: Span<f32x32>) -> f32x32 {
min(self)
}

fn prod(self: Span<f32x32>) -> f32x32 {
prod(self)
}

fn sum(self: Span<f32x32>) -> f32x32 {
sum(self)
}
}

fn arange(n: u32) -> Span<f32x32> {
let mut i = 0;
let mut arr = array![];
while i < n {
arr.append(i.try_into().unwrap() * ONE);
i += 1;
};

arr.span()
}

fn dot(a: Span<f32x32>, b: Span<f32x32>) -> f32x32 {
let mut i = 0;
let mut acc = 0;
while i != a.len() {
acc += FixedTrait::mul(*a.at(i), *b.at(i));
i += 1;
};

acc
}

fn max(mut a: Span<f32x32>) -> f32x32 {
assert(a.len() > 0, 'span cannot be empty');

let mut max = FixedTrait::MIN();

loop {
match a.pop_front() {
Option::Some(item) => { if *item > max {
max = *item;
} },
Option::None => { break max; },
}
}
}

fn min(mut a: Span<f32x32>) -> f32x32 {
assert(a.len() > 0, 'span cannot be empty');

let mut min = FixedTrait::MAX();

loop {
match a.pop_front() {
Option::Some(item) => { if *item < min {
min = *item;
} },
Option::None => { break min; },
}
}
}

fn prod(mut a: Span<f32x32>) -> f32x32 {
let mut prod = 1;
loop {
match a.pop_front() {
Option::Some(v) => { prod = prod.mul(*v); },
Option::None => { break prod; }
};
}
}

fn sum(mut a: Span<f32x32>) -> f32x32 {
let mut prod = 1;
loop {
match a.pop_front() {
Option::Some(v) => { prod = prod + *v; },
Option::None => { break prod; }
};
}
}


pub fn linear_fit(x: Span<f32x32>, y: Span<f32x32>) -> (f32x32, f32x32) {
if x.len() != y.len() || x.len() == 0 {
panic!("x and y should be of the same lenght")
}

let n: f32x32 = x.len().try_into().unwrap();
let sum_x = x.sum();
let sum_y = y.sum();
let sum_xx = x.dot(x);
let sum_xy = x.dot(y);

let denominator = n * sum_xx - (sum_x.mul(sum_x));
if denominator == 0 {
panic!("division by zero exception")
}

let a = ((n * sum_xy) - sum_x.mul(sum_y)).div(denominator);
let b = (sum_y - a.mul(sum_x)) / n;

(a, b)
}

0 comments on commit c2cbfa1

Please sign in to comment.