Skip to content

Commit

Permalink
write cdf tests + update orion-numbers
Browse files Browse the repository at this point in the history
  • Loading branch information
raphaelDkhn committed Jul 25, 2024
1 parent 00ba76a commit aacfc29
Show file tree
Hide file tree
Showing 14 changed files with 368 additions and 242 deletions.
92 changes: 82 additions & 10 deletions packages/orion-algo/src/algo/cdf.cairo
Original file line number Diff line number Diff line change
Expand Up @@ -3,22 +3,59 @@ use core::array::{SpanTrait, SpanIter};
use orion_algo::span_math::SpanMathTrait;
use orion_numbers::FixedTrait;

/// Computes the cumulative distribution function (CDF) for a given set of values using the
/// standard normal distribution formula. This implementation allows for optional location (`loc`)
/// and scale (`scale`) parameters, which default to 0.0 and 1.0 respectively if not provided.
///
/// # Arguments
/// * `x` - A `Span<T>` containing the data points for which the CDF is to be computed.
/// * `loc` - An optional `Span<T>` representing the location parameter (mean) for each data point.
/// If `Some(Span<T>)` is provided, it must either contain a single value or have the same
/// length as `x`. If `None` is provided, defaults to a Span of a single 0.0 value.
/// * `scale` - An optional `Span<T>` representing the scale parameter (standard deviation) for each
/// data point. If `Some(Span<T>)` is provided, it must either contain a single value
/// or have the same length as `x`. If `None` is provided, defaults to a Span of a single 1.0 value.
///
/// # Returns
/// A `Span<T>` representing the CDF values corresponding to each entry in `x`.
///
/// # Panics
/// * The function panics if the lengths of `loc` or `scale` Spans are more than one and not equal to
/// the length of `x`.
///
/// # Examples
/// Basic usage:
///
/// ```
/// let x = array![FixedTrait::new_unscaled(2), FixedTrait::new_unscaled(1), FixedTrait::new_unscaled(0)].span();
/// let result = cdf(x, None, None);
/// // Expected output: CDF values for a standard normal distribution
/// ```
///
/// With location and scale parameters:
///
/// ```
/// let x = array![FixedTrait::new_unscaled(2), FixedTrait::new_unscaled(1), FixedTrait::new_unscaled(0)].span();
/// let loc = array![FixedTrait::new_unscaled(1), FixedTrait::new_unscaled(1), FixedTrait::new_unscaled(1)].span();
/// let scale = array![FixedTrait::new_unscaled(1), FixedTrait::new_unscaled(1), FixedTrait::new_unscaled(1)].span();
/// let result = cdf(x, Some(loc), Some(scale));
/// // Expected output: Adjusted CDF values using specified loc and scale
/// ```
pub fn cdf<
T, +FixedTrait<T>, +SpanMathTrait<T>, +Sub<T>, +Div<T>, +Mul<T>, +Drop<T>, +Add<T>, +Copy<T>
>(
x: Span<T>, loc: Option<Span<T>>, scale: Option<Span<T>>
) //-> Span<T>
{
) -> Span<T> {
// Default loc to 0.0 if not provided
let mut loc = match loc {
Option::Some(val) => val,
Option::None => array![FixedTrait::ZERO()].span()
Option::None => array![FixedTrait::ZERO].span()
};

// Default scale to 1.0 if not provided
let mut scale = match scale {
Option::Some(val) => val,
Option::None => array![FixedTrait::ONE()].span()
Option::None => array![FixedTrait::ONE].span()
};

// single value or same length as x
Expand Down Expand Up @@ -50,16 +87,51 @@ pub fn cdf<
scale_first_val
};

// Calculate: 0.5 * (1.0 + erf((x_val - loc_val) / (scale_val * (2.0f64).sqrt())))
let calc = FixedTrait::HALF()
* (FixedTrait::ONE()
+ ((*x_val - loc_val)
/ (scale_val * (FixedTrait::ONE() + FixedTrait::ONE()).sqrt()))
.erf());
// Calculate: 0.5 * (1.0 + erf((x_val - loc_val) / (scale_val * sqrt(2.0))))
let sqrt_2 = FixedTrait::sqrt(FixedTrait::TWO);
let x_minus_loc = FixedTrait::sub(*x_val, loc_val);
let scale_times_sqrt_2 = FixedTrait::mul(scale_val, sqrt_2);
let division_result = FixedTrait::div(x_minus_loc, scale_times_sqrt_2);
let erf_result = FixedTrait::erf(division_result);
let one_plus_erf = FixedTrait::add(FixedTrait::ONE, erf_result);
let calc = FixedTrait::mul(FixedTrait::HALF, one_plus_erf);

res_data.append(calc);
},
Option::None => { break; }
}
};

res_data.span()
}

#[cfg(test)]
mod tests {
use super::cdf;
use orion_numbers::{f16x16::{core::f16x16, helpers::assert_relative_span}, FixedTrait};

#[test]
fn test_cdf_loc_scale_are_none() {
let x: Span<f16x16> = array![FixedTrait::ONE, FixedTrait::HALF, FixedTrait::ZERO].span();

let res = cdf(x, Option::None, Option::None);
let expected = array![55138, 45316, 32768].span();

assert_relative_span(res, expected, 'res != expected', Option::None);
}

#[test]
fn test_cdf_loc_scale_are_some() {
let x: Span<f16x16> = array![FixedTrait::ONE, FixedTrait::HALF, FixedTrait::ZERO].span();

let loc: Span<f16x16> = array![FixedTrait::HALF, FixedTrait::HALF, FixedTrait::HALF].span();

let scale: Span<f16x16> = array![FixedTrait::HALF, FixedTrait::HALF, FixedTrait::HALF]
.span();

let res = cdf(x, Option::Some(loc), Option::Some(scale));
let expected = array![55138, 32768, 10398].span();

assert_relative_span(res, expected, 'res != expected', Option::None);
}
}
2 changes: 1 addition & 1 deletion packages/orion-algo/src/algo/linear_fit.cairo
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ pub fn linear_fit<
let sum_xy = x.dot(y);

let denominator = n * sum_xx - (sum_x.mul(sum_x));
if denominator == FixedTrait::ZERO() {
if denominator == FixedTrait::ZERO {
panic!("division by zero exception")
}

Expand Down
15 changes: 8 additions & 7 deletions packages/orion-algo/src/span_math/span_f16x16.cairo
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use orion_numbers::{f16x16::core::{f16x16, ONE}, FixedTrait};
use orion_numbers::{f16x16::core::f16x16, FixedTrait};

use orion_algo::span_math::SpanMathTrait;

Expand Down Expand Up @@ -34,7 +34,7 @@ fn arange(n: u32) -> Span<f16x16> {
let mut i = 0;
let mut arr = array![];
while i < n {
arr.append(i.try_into().unwrap() * ONE);
arr.append(i.try_into().unwrap() * FixedTrait::ONE);
i += 1;
};

Expand All @@ -55,7 +55,7 @@ fn dot(a: Span<f16x16>, b: Span<f16x16>) -> f16x16 {
fn max(mut a: Span<f16x16>) -> f16x16 {
assert(a.len() > 0, 'span cannot be empty');

let mut max = FixedTrait::MIN();
let mut max = FixedTrait::MIN;

loop {
match a.pop_front() {
Expand All @@ -70,7 +70,7 @@ fn max(mut a: Span<f16x16>) -> f16x16 {
fn min(mut a: Span<f16x16>) -> f16x16 {
assert(a.len() > 0, 'span cannot be empty');

let mut min = FixedTrait::MAX();
let mut min = FixedTrait::MAX;

loop {
match a.pop_front() {
Expand Down Expand Up @@ -105,8 +105,9 @@ fn sum(mut a: Span<f16x16>) -> f16x16 {

#[cfg(test)]
mod tests {
use super::{arange, dot, max, min, prod, sum, ONE};
use super::{arange, dot, max, min, prod, sum};
use orion_numbers::f16x16::helpers::assert_precise;
use orion_numbers::F16x16Impl;

#[test]
fn test_arange() {
Expand All @@ -129,7 +130,7 @@ mod tests {
let y = array![0, 131072, 262144, 393216, 524288, 655360].span(); // 0, 2, 4, 6, 8, 10
let result = dot(x, y);

assert_precise(result, (110 * ONE).into(), 'should be equal', Option::None);
assert_precise(result, (110 * F16x16Impl::ONE).into(), 'should be equal', Option::None);
}

#[test]
Expand Down Expand Up @@ -165,6 +166,6 @@ mod tests {

let result = sum(x);

assert_precise(result, (15 * ONE).into(), 'should be equal', Option::None);
assert_precise(result, (15 * F16x16Impl::ONE).into(), 'should be equal', Option::None);
}
}
17 changes: 10 additions & 7 deletions packages/orion-algo/src/span_math/span_f32x32.cairo
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
use orion_numbers::{FixedTrait};
use orion_numbers::f32x32::core::{f32x32, ONE};
use orion_numbers::f32x32::core::f32x32;

use orion_algo::span_math::SpanMathTrait;

Expand Down Expand Up @@ -34,7 +34,7 @@ fn arange(n: u32) -> Span<f32x32> {
let mut i = 0;
let mut arr = array![];
while i < n {
arr.append(i.try_into().unwrap() * ONE);
arr.append(i.try_into().unwrap() * FixedTrait::ONE);
i += 1;
};

Expand All @@ -55,7 +55,7 @@ fn dot(a: Span<f32x32>, b: Span<f32x32>) -> f32x32 {
fn max(mut a: Span<f32x32>) -> f32x32 {
assert(a.len() > 0, 'span cannot be empty');

let mut max = FixedTrait::MIN();
let mut max = FixedTrait::MIN;

loop {
match a.pop_front() {
Expand All @@ -70,7 +70,7 @@ fn max(mut a: Span<f32x32>) -> f32x32 {
fn min(mut a: Span<f32x32>) -> f32x32 {
assert(a.len() > 0, 'span cannot be empty');

let mut min = FixedTrait::MAX();
let mut min = FixedTrait::MAX;

loop {
match a.pop_front() {
Expand Down Expand Up @@ -105,8 +105,9 @@ fn sum(mut a: Span<f32x32>) -> f32x32 {

#[cfg(test)]
mod tests {
use super::{arange, dot, max, min, prod, sum, ONE};
use super::{arange, dot, max, min, prod, sum};
use orion_numbers::f32x32::helpers::assert_precise;
use orion_numbers::F32x32Impl;

#[test]
fn test_arange() {
Expand All @@ -131,7 +132,7 @@ mod tests {
.span(); // 0, 2, 4, 6, 8, 10
let result = dot(x, y);

assert_precise(result, (7208960 * ONE).into(), 'should be equal', Option::None);
assert_precise(result, (7208960 * F32x32Impl::ONE).into(), 'should be equal', Option::None);
}

#[test]
Expand Down Expand Up @@ -171,6 +172,8 @@ mod tests {

let result = sum(x);

assert_precise(result, (98304000 * ONE).into(), 'should be equal', Option::None);
assert_precise(
result, (98304000 * F32x32Impl::ONE).into(), 'should be equal', Option::None
);
}
}
48 changes: 15 additions & 33 deletions packages/orion-numbers/src/f16x16/core.cairo
Original file line number Diff line number Diff line change
Expand Up @@ -3,37 +3,19 @@ use orion_numbers::FixedTrait;

pub type f16x16 = i32;

// CONSTANTS
pub const TWO: f16x16 = 131072; // 2 ** 17
pub const ONE: f16x16 = 65536; // 2 ** 16
pub const HALF: f16x16 = 32768; // 2 ** 15
pub const MAX: f16x16 = 2147483647; // 2 ** 31 -1
pub const MIN: f16x16 = -2147483648; // 2 ** 31


pub impl F16x16Impl of FixedTrait<f16x16> {
fn ZERO() -> f16x16 {
0
}
// CONSTANTS
const ZERO: f16x16 = 0;
const HALF: f16x16 = 32768; // 2 ** 15
const ONE: f16x16 = 65536; // 2 ** 16
const TWO: f16x16 = 131072; // 2 ** 17
const MAX: f16x16 = 2147483647; // 2 ** 31 -1
const MIN: f16x16 = -2147483648; // 2 ** 31

fn HALF() -> f16x16 {
HALF
}

fn ONE() -> f16x16 {
ONE
}

fn MAX() -> f16x16 {
MAX
}

fn MIN() -> f16x16 {
MIN
}

fn new_unscaled(x: i32) -> f16x16 {
x * ONE
x * Self::ONE
}

fn new(x: i32) -> f16x16 {
Expand All @@ -45,7 +27,7 @@ pub impl F16x16Impl of FixedTrait<f16x16> {
}

fn from_unscaled_felt(x: felt252) -> f16x16 {
return FixedTrait::from_felt(x * ONE.into());
return FixedTrait::from_felt(x * Self::ONE.into());
}

fn abs(self: f16x16) -> f16x16 {
Expand Down Expand Up @@ -182,27 +164,27 @@ pub impl F16x16Impl of FixedTrait<f16x16> {
}

fn INF() -> f16x16 {
MAX
Self::MAX
}

fn POS_INF() -> f16x16 {
MAX
Self::MAX
}

fn NEG_INF() -> f16x16 {
MIN
Self::MIN
}

fn is_inf(self: f16x16) -> bool {
self == MAX
self == Self::MAX
}

fn is_pos_inf(self: f16x16) -> bool {
self == MAX
self == Self::MAX
}

fn is_neg_inf(self: f16x16) -> bool {
self == MIN
self == Self::MIN
}

fn erf(self: f16x16) -> f16x16 {
Expand Down
8 changes: 4 additions & 4 deletions packages/orion-numbers/src/f16x16/erf.cairo
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
use orion_numbers::f16x16::{core::{f16x16, ONE}, lut};
use orion_numbers::FixedTrait;
use orion_numbers::f16x16::{core::{f16x16}, lut};
use orion_numbers::{FixedTrait};

const ERF_COMPUTATIONAL_ACCURACY: i32 = 100;
const ROUND_CHECK_NUMBER: i32 = 10;
Expand All @@ -18,7 +18,7 @@ pub fn erf(x: f16x16) -> f16x16 {
if x.abs() < MAX_ERF_NUMBER {
erf_value = lut::erf_lut(x.abs());
} else {
erf_value = ONE;
erf_value = FixedTrait::ONE;
}

FixedTrait::mul(erf_value, x.sign())
Expand All @@ -27,7 +27,7 @@ pub fn erf(x: f16x16) -> f16x16 {

// Tests
//
//
//
// --------------------------------------------------------------------------------------------------------------

#[cfg(test)]
Expand Down
Loading

0 comments on commit aacfc29

Please sign in to comment.