Skip to content

Commit

Permalink
refactor(perf): Multinomial samples from Binomial
Browse files Browse the repository at this point in the history
  • Loading branch information
YeungOnion committed Aug 3, 2024
1 parent cd16782 commit 6480267
Showing 1 changed file with 48 additions and 9 deletions.
57 changes: 48 additions & 9 deletions src/distribution/multinomial.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
use std::ops::Div;

use crate::distribution::Discrete;
use crate::function::factorial;
use crate::statistics::*;
use crate::{Result, StatsError};
use nalgebra::{
base::allocator::Allocator, Const, DMatrix, DVector, DefaultAllocator, Dim, DimMin, Dyn,
base::allocator::Allocator, Const, DMatrix, DVector, DefaultAllocator, Dim, Dyn, Matrix,
OMatrix, OVector,
};
use rand::Rng;
Expand Down Expand Up @@ -128,12 +130,21 @@ where
nalgebra::DefaultAllocator: nalgebra::allocator::Allocator<f64, D>,
{
fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> OVector<f64, D> {
let p_cdf = super::categorical::prob_mass_to_cdf(self.p().as_slice());
use crate::distribution::Binomial;
// TODO: find the right way to allocate zeros when dimension is known statically OR dynamically
let mut res = OVector::zeros_generic(self.p.shape_generic().0, Const::<1>);
for _ in 0..self.n {
let i = super::categorical::sample_unchecked(rng, &p_cdf);
let el = res.get_mut(i as usize).unwrap();
*el += 1.0;
let mut probs_taken = 0.0;
let mut samples_taken = 0;
for (w, s) in self.p().iter().zip(res.iter_mut()) {
if samples_taken >= self.n {
break;
}
let p = (w.div(1.0 - probs_taken)).min(1.0);
*s = Binomial::new(p, self.n - samples_taken)
.expect("probability already on [0,1]")
.sample(rng);
samples_taken += *s as u64;
probs_taken += p;
}
res
}
Expand Down Expand Up @@ -294,7 +305,7 @@ where
}
}

#[rustfmt::skip]
// #[rustfmt::skip]
#[cfg(test)]
mod tests {
use crate::{
Expand All @@ -303,9 +314,9 @@ mod tests {
};
use approx::UlpsEq;
use nalgebra::{
dmatrix, dvector, matrix, vector, Const, DimMin, Dyn, Matrix, OMatrix, OVector,
VecStorage,
dmatrix, dvector, matrix, vector, Const, DimMin, Dyn, Matrix, OMatrix, OVector, VecStorage,
};
use rand::{distributions::Distribution, thread_rng};
use std::fmt::{Debug, Display};

fn try_create<D>(p: OVector<f64, D>, n: u64) -> Multinomial<D>
Expand Down Expand Up @@ -438,6 +449,7 @@ mod tests {
#[test]
fn test_pmf() {
let pmf = |arg: OVector<u64, Dyn>| move |x: Multinomial<_>| x.pmf(&arg);
let pmf_3d = |arg: OVector<u64, _>| move |x: Multinomial<_>| x.pmf(&arg);
test_almost(
dvector![0.3, 0.7],
10,
Expand All @@ -452,6 +464,13 @@ mod tests {
1e-15,
pmf(dvector![1, 3, 6]),
);
test_almost(
vector![0.1, 0.3, 0.6],
10,
0.105815808,
1e-15,
pmf_3d(vector![1, 3, 6]),
);
test_almost(
dvector![0.15, 0.35, 0.3, 0.2],
10,
Expand Down Expand Up @@ -504,4 +523,24 @@ mod tests {
// let n = Multinomial::new(&[0.3, 0.7], 10).unwrap();
// n.ln_pmf(&[1, 3]);
// }
#[test]
fn test_almost_zero_sample() {
let n = 1000;
let weights = vec![0.0, 0.0, 0.0, 0.000000001];
let multinomial = Multinomial::new(weights, n).unwrap();
let sample = multinomial.sample(&mut thread_rng());
assert_relative_eq!(sample[3], n as f64);
}
#[test]
fn test_uniform_samples() {
let n: f64 = 1000.0;
let weights = vec![1.0, 1.0];
let multinomial = Multinomial::new(weights, n as u64).unwrap();
let sample = multinomial.sample(&mut thread_rng());
assert_abs_diff_eq!(
sample[0],
n / 2.0,
epsilon = 3.0 * multinomial.variance().unwrap()[0] / n.sqrt()
);
}
}

0 comments on commit 6480267

Please sign in to comment.