Skip to content

Commit

Permalink
refactor: impl MeanN over generic dimension
Browse files Browse the repository at this point in the history
  • Loading branch information
YeungOnion committed Dec 3, 2024
1 parent e7a995e commit fcce9f6
Showing 1 changed file with 4 additions and 6 deletions.
10 changes: 4 additions & 6 deletions src/distribution/multinomial.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
use crate::distribution::Discrete;
use crate::function::factorial;
use crate::statistics::*;
use nalgebra::{DVector, Dim, Dyn, OMatrix, OVector};
use nalgebra::{Dim, Dyn, OMatrix, OVector};

/// Implements the
/// [Multinomial](https://en.wikipedia.org/wiki/Multinomial_distribution)
Expand Down Expand Up @@ -203,7 +203,7 @@ where
res
}

impl<D> MeanN<DVector<f64>> for Multinomial<D>
impl<D> MeanN<OVector<f64, D>> for Multinomial<D>
where
D: Dim,
nalgebra::DefaultAllocator: nalgebra::allocator::Allocator<D>,
Expand All @@ -218,10 +218,8 @@ where
///
/// where `n` is the number of trials, `p_i` is the `i`th probability,
/// and `k` is the total number of probabilities
fn mean(&self) -> Option<DVector<f64>> {
Some(DVector::from_vec(
self.p.iter().map(|x| x * self.n as f64).collect(),
))
fn mean(&self) -> Option<OVector<f64, D>> {
Some(self.p.map(|x| x * self.n as f64))
}
}

Expand Down

0 comments on commit fcce9f6

Please sign in to comment.