Skip to content

Commit

Permalink
refactor: Categorical now stores normalized values
Browse files Browse the repository at this point in the history
norm_pmf (probabilities) was already normalized before
storing, but cdf and sf weren't. Instead, they were normalized
on every API call.

The refactor also reduces the amount of vec/slice iterations in `new`
from 4 to 2.
  • Loading branch information
FreezyLemon committed Sep 23, 2024
1 parent 1cbcc49 commit 37f8fdf
Showing 1 changed file with 48 additions and 108 deletions.
156 changes: 48 additions & 108 deletions src/distribution/categorical.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,8 @@ use std::f64;
#[derive(Clone, PartialEq, Debug)]
pub struct Categorical {
norm_pmf: Vec<f64>,
cdf: Vec<f64>,
sf: Vec<f64>,
norm_cdf: Vec<f64>,
norm_sf: Vec<f64>,
}

/// Represents the errors that can occur when creating a [`Categorical`].
Expand Down Expand Up @@ -98,22 +98,25 @@ impl Categorical {
return Err(CategoricalError::ProbMassSumZero);
}

// extract un-normalized cdf
let cdf = prob_mass_to_cdf(prob_mass);
// extract un-normalized sf
let sf = cdf_to_sf(&cdf);
// extract normalized probability mass
let sum = cdf[cdf.len() - 1];
let mut norm_pmf = vec![0.0; prob_mass.len()];
norm_pmf
.iter_mut()
.zip(prob_mass.iter())
.for_each(|(np, pm)| *np = *pm / sum);
Ok(Categorical { norm_pmf, cdf, sf })
}
let mut cdf_sum = 0.0;

let mut norm_cdf = Vec::with_capacity(prob_mass.len());
let mut norm_sf = Vec::with_capacity(prob_mass.len());
let mut norm_pmf = Vec::with_capacity(prob_mass.len());

fn cdf_max(&self) -> f64 {
*self.cdf.last().unwrap()
for &prob in prob_mass {
cdf_sum += prob;

norm_cdf.push(cdf_sum / prob_sum);
norm_sf.push((prob_sum - cdf_sum) / prob_sum);
norm_pmf.push(prob / prob_sum);
}

Ok(Categorical {
norm_pmf,
norm_cdf,
norm_sf,
})
}
}

Expand All @@ -123,27 +126,31 @@ impl std::fmt::Display for Categorical {
}
}

#[cfg(feature = "rand")]
use rand::distributions::Distribution as RandDistribution;

#[cfg(feature = "rand")]
#[cfg_attr(docsrs, doc(cfg(feature = "rand")))]
impl ::rand::distributions::Distribution<usize> for Categorical {
impl RandDistribution<usize> for Categorical {
fn sample<R: ::rand::Rng + ?Sized>(&self, rng: &mut R) -> usize {
sample_unchecked(rng, &self.cdf)
let draw = rng.gen::<f64>();
self.norm_cdf.iter().position(|val| *val >= draw).unwrap()
}
}

#[cfg(feature = "rand")]
#[cfg_attr(docsrs, doc(cfg(feature = "rand")))]
impl ::rand::distributions::Distribution<u64> for Categorical {
impl RandDistribution<u64> for Categorical {
fn sample<R: ::rand::Rng + ?Sized>(&self, rng: &mut R) -> u64 {
sample_unchecked(rng, &self.cdf) as u64
<Self as RandDistribution<usize>>::sample(&self, rng) as u64
}
}

#[cfg(feature = "rand")]
#[cfg_attr(docsrs, doc(cfg(feature = "rand")))]
impl ::rand::distributions::Distribution<f64> for Categorical {
impl RandDistribution<f64> for Categorical {
fn sample<R: ::rand::Rng + ?Sized>(&self, rng: &mut R) -> f64 {
sample_unchecked(rng, &self.cdf) as f64
<Self as RandDistribution<usize>>::sample(&self, rng) as f64
}
}

Expand All @@ -159,11 +166,7 @@ impl DiscreteCDF<u64, f64> for Categorical {
///
/// where `p_j` is the probability mass for the `j`th category
fn cdf(&self, x: u64) -> f64 {
if x >= self.cdf.len() as u64 {
1.0
} else {
self.cdf.get(x as usize).unwrap() / self.cdf_max()
}
*self.norm_cdf.get(x as usize).unwrap_or(&1.0)
}

/// Calculates the survival function for the categorical distribution
Expand All @@ -175,11 +178,7 @@ impl DiscreteCDF<u64, f64> for Categorical {
/// [ sum(p_j) from x..end ]
/// ```
fn sf(&self, x: u64) -> f64 {
if x >= self.sf.len() as u64 {
0.0
} else {
self.sf.get(x as usize).unwrap() / self.cdf_max()
}
*self.norm_sf.get(x as usize).unwrap_or(&0.0)
}

/// Calculates the inverse cumulative distribution function for the
Expand All @@ -203,8 +202,17 @@ impl DiscreteCDF<u64, f64> for Categorical {
if x >= 1.0 || x <= 0.0 {
panic!("x must be in [0, 1]")
}
let denorm_prob = x * self.cdf_max();
binary_index(&self.cdf, denorm_prob) as u64

// `Vec::binary_search` will either return the index of a value equal to x
// or an index where x could be inserted into the sorted Vec.
// Both fit the description, so return either one.
match self
.norm_cdf
.binary_search_by(|v| v.partial_cmp(&x).unwrap())
{
Ok(idx) => idx as u64,
Err(idx) => idx as u64,
}
}
}

Expand Down Expand Up @@ -234,7 +242,7 @@ impl Max<u64> for Categorical {
/// n
/// ```
fn max(&self) -> u64 {
self.cdf.len() as u64 - 1
self.norm_cdf.len() as u64 - 1
}
}

Expand Down Expand Up @@ -337,74 +345,6 @@ impl Discrete<u64, f64> for Categorical {
}
}

/// Draws a sample from the categorical distribution described by `cdf`
/// without doing any bounds checking
#[cfg(feature = "rand")]
#[cfg_attr(docsrs, doc(cfg(feature = "rand")))]
pub fn sample_unchecked<R: ::rand::Rng + ?Sized>(rng: &mut R, cdf: &[f64]) -> usize {
let draw = rng.gen::<f64>() * cdf.last().unwrap();
cdf.iter().position(|val| *val >= draw).unwrap()
}

/// Computes the cdf from the given probability masses. Performs
/// no parameter or bounds checking.
pub fn prob_mass_to_cdf(prob_mass: &[f64]) -> Vec<f64> {
let mut cdf = Vec::with_capacity(prob_mass.len());
prob_mass.iter().fold(0.0, |s, p| {
let sum = s + p;
cdf.push(sum);
sum
});
cdf
}

/// Computes the sf from the given cumulative densities.
/// Performs no parameter or bounds checking.
pub fn cdf_to_sf(cdf: &[f64]) -> Vec<f64> {
let max = *cdf.last().unwrap();
cdf.iter().map(|x| max - x).collect()
}

// Returns the index of val if placed into the sorted search array.
// If val is greater than all elements, it therefore would return
// the length of the array (N). If val is less than all elements, it would
// return 0. Otherwise val returns the index of the first element larger than
// it within the search array.
fn binary_index(search: &[f64], val: f64) -> usize {
use std::cmp;

let mut low = 0_isize;
let mut high = search.len() as isize - 1;
while low <= high {
let mid = low + ((high - low) / 2);
let el = *search.get(mid as usize).unwrap();
if el > val {
high = mid - 1;
} else if el < val {
low = mid.saturating_add(1);
} else {
return mid as usize;
}
}
cmp::min(search.len(), cmp::max(low, 0) as usize)
}

#[test]
fn test_prob_mass_to_cdf() {
let arr = [0.0, 0.5, 0.5, 3.0, 1.1];
let res = prob_mass_to_cdf(&arr);
assert_eq!(res, [0.0, 0.5, 1.0, 4.0, 5.1]);
}

#[test]
fn test_binary_index() {
let arr = [0.0, 3.0, 5.0, 9.0, 10.0];
assert_eq!(0, binary_index(&arr, -1.0));
assert_eq!(2, binary_index(&arr, 5.0));
assert_eq!(3, binary_index(&arr, 5.2));
assert_eq!(5, binary_index(&arr, 10.1));
}

#[rustfmt::skip]
#[cfg(test)]
mod tests {
Expand Down Expand Up @@ -541,10 +481,10 @@ mod tests {
fn test_cdf_sf_mirror() {
let mass = [4.0, 2.5, 2.5, 1.0];
let cat = Categorical::new(&mass).unwrap();
assert_eq!(cat.cdf(0), 1.-cat.sf(0));
assert_eq!(cat.cdf(1), 1.-cat.sf(1));
assert_eq!(cat.cdf(2), 1.-cat.sf(2));
assert_eq!(cat.cdf(3), 1.-cat.sf(3));
assert_eq!(cat.cdf(0), 1. - cat.sf(0));
assert_eq!(cat.cdf(1), 1. - cat.sf(1));
assert_eq!(cat.cdf(2), 1. - cat.sf(2));
assert_eq!(cat.cdf(3), 1. - cat.sf(3));
}

#[test]
Expand Down

0 comments on commit 37f8fdf

Please sign in to comment.