From 64802677d21ced4410e5fd42cb6a2626b9b182c2 Mon Sep 17 00:00:00 2001 From: Orion Yeung <11580988+orionyeung001@users.noreply.github.com> Date: Fri, 2 Aug 2024 21:17:26 -0500 Subject: [PATCH] refactor(perf): Multinomial samples from Binomial --- src/distribution/multinomial.rs | 57 +++++++++++++++++++++++++++------ 1 file changed, 48 insertions(+), 9 deletions(-) diff --git a/src/distribution/multinomial.rs b/src/distribution/multinomial.rs index a4af290f..9cb97111 100644 --- a/src/distribution/multinomial.rs +++ b/src/distribution/multinomial.rs @@ -1,9 +1,11 @@ +use std::ops::Div; + use crate::distribution::Discrete; use crate::function::factorial; use crate::statistics::*; use crate::{Result, StatsError}; use nalgebra::{ - base::allocator::Allocator, Const, DMatrix, DVector, DefaultAllocator, Dim, DimMin, Dyn, + base::allocator::Allocator, Const, DMatrix, DVector, DefaultAllocator, Dim, Dyn, Matrix, OMatrix, OVector, }; use rand::Rng; @@ -128,12 +130,21 @@ where nalgebra::DefaultAllocator: nalgebra::allocator::Allocator, { fn sample(&self, rng: &mut R) -> OVector { - let p_cdf = super::categorical::prob_mass_to_cdf(self.p().as_slice()); + use crate::distribution::Binomial; + // TODO: find the right way to allocate zeros when dimension is known statically OR dynamically let mut res = OVector::zeros_generic(self.p.shape_generic().0, Const::<1>); - for _ in 0..self.n { - let i = super::categorical::sample_unchecked(rng, &p_cdf); - let el = res.get_mut(i as usize).unwrap(); - *el += 1.0; + let mut probs_taken = 0.0; + let mut samples_taken = 0; + for (w, s) in self.p().iter().zip(res.iter_mut()) { + if samples_taken >= self.n { + break; + } + let p = (w.div(1.0 - probs_taken)).min(1.0); + *s = Binomial::new(p, self.n - samples_taken) + .expect("probability already on [0,1]") + .sample(rng); + samples_taken += *s as u64; + probs_taken += p; } res } @@ -294,7 +305,7 @@ where } } -#[rustfmt::skip] +// #[rustfmt::skip] #[cfg(test)] mod tests { use crate::{ @@ -303,9 +314,9 @@ mod tests { }; use approx::UlpsEq; use nalgebra::{ - dmatrix, dvector, matrix, vector, Const, DimMin, Dyn, Matrix, OMatrix, OVector, - VecStorage, + dmatrix, dvector, matrix, vector, Const, DimMin, Dyn, Matrix, OMatrix, OVector, VecStorage, }; + use rand::{distributions::Distribution, thread_rng}; use std::fmt::{Debug, Display}; fn try_create(p: OVector, n: u64) -> Multinomial @@ -438,6 +449,7 @@ mod tests { #[test] fn test_pmf() { let pmf = |arg: OVector| move |x: Multinomial<_>| x.pmf(&arg); + let pmf_3d = |arg: OVector| move |x: Multinomial<_>| x.pmf(&arg); test_almost( dvector![0.3, 0.7], 10, @@ -452,6 +464,13 @@ mod tests { 1e-15, pmf(dvector![1, 3, 6]), ); + test_almost( + vector![0.1, 0.3, 0.6], + 10, + 0.105815808, + 1e-15, + pmf_3d(vector![1, 3, 6]), + ); test_almost( dvector![0.15, 0.35, 0.3, 0.2], 10, @@ -504,4 +523,24 @@ mod tests { // let n = Multinomial::new(&[0.3, 0.7], 10).unwrap(); // n.ln_pmf(&[1, 3]); // } + #[test] + fn test_almost_zero_sample() { + let n = 1000; + let weights = vec![0.0, 0.0, 0.0, 0.000000001]; + let multinomial = Multinomial::new(weights, n).unwrap(); + let sample = multinomial.sample(&mut thread_rng()); + assert_relative_eq!(sample[3], n as f64); + } + #[test] + fn test_uniform_samples() { + let n: f64 = 1000.0; + let weights = vec![1.0, 1.0]; + let multinomial = Multinomial::new(weights, n as u64).unwrap(); + let sample = multinomial.sample(&mut thread_rng()); + assert_abs_diff_eq!( + sample[0], + n / 2.0, + epsilon = 3.0 * multinomial.variance().unwrap()[0] / n.sqrt() + ); + } }