diff --git a/src/distribution/empirical.rs b/src/distribution/empirical.rs index a326591e..1d02ab65 100644 --- a/src/distribution/empirical.rs +++ b/src/distribution/empirical.rs @@ -57,7 +57,6 @@ mod non_nan { /// ``` #[derive(Clone, PartialEq, Debug)] pub struct Empirical { - mean_and_var: Option<(f64, f64)>, // keys are data points, values are number of data points with equal value data: BTreeMap, u64>, @@ -65,6 +64,8 @@ pub struct Empirical { /// Total amount of data points (== sum of all _values_ inside self.data). /// Must be 0 iff data.is_empty() sum: u64, + mean: f64, + var: f64, } impl Empirical { @@ -84,9 +85,10 @@ impl Empirical { #[allow(clippy::result_unit_err)] pub fn new() -> Result { Ok(Empirical { - sum: 0, - mean_and_var: None, data: BTreeMap::new(), + sum: 0, + mean: 0.0, + var: 0.0, }) } @@ -97,17 +99,10 @@ impl Empirical { }; self.sum += 1; - match self.mean_and_var { - Some((mean, var)) => { - let sum = self.sum as f64; - let var = var + (sum - 1.) * (data_point - mean) * (data_point - mean) / sum; - let mean = mean + (data_point - mean) / sum; - self.mean_and_var = Some((mean, var)); - } - None => { - self.mean_and_var = Some((data_point, 0.)); - } - } + let sum = self.sum as f64; + self.var += (sum - 1.) * (data_point - self.mean) * (data_point - self.mean) / sum; + self.mean += (data_point - self.mean) / sum; + *self.data.entry(map_key).or_insert(0) += 1; } @@ -117,21 +112,25 @@ impl Empirical { None => return, }; - if let (Some(val), Some((mean, var))) = (self.data.remove(&map_key), self.mean_and_var) { - if val == 1 && self.data.is_empty() { - self.mean_and_var = None; - self.sum = 0; - return; - }; - // reset mean and var - let sum = self.sum as f64; - let mean = (sum * mean - data_point) / (sum - 1.); - let var = var - (sum - 1.) * (data_point - mean) * (data_point - mean) / sum; - self.sum -= 1; - if val != 1 { - self.data.insert(map_key, val - 1); - }; - self.mean_and_var = Some((mean, var)); + let val = match self.data.remove(&map_key) { + Some(v) => v, + None => return, + }; + + if val == 1 && self.data.is_empty() { + self.sum = 0; + self.mean = 0.0; + self.var = 0.0; + return; + }; + + // reset mean and var + let sum = self.sum as f64; + self.mean = (sum * self.mean - data_point) / (sum - 1.); + self.var -= (sum - 1.) * (data_point - self.mean) * (data_point - self.mean) / sum; + self.sum -= 1; + if val != 1 { + self.data.insert(map_key, val - 1); } } @@ -232,12 +231,19 @@ impl Min for Empirical { impl Distribution for Empirical { fn mean(&self) -> Option { - self.mean_and_var.map(|(mean, _)| mean) + if self.data.is_empty() { + None + } else { + Some(self.mean) + } } fn variance(&self) -> Option { - self.mean_and_var - .map(|(_, var)| var / (self.sum as f64 - 1.)) + if self.data.is_empty() { + None + } else { + Some(self.var / (self.sum as f64 - 1.)) + } } } @@ -293,7 +299,7 @@ mod tests { #[test] fn test_remove_nonexisting() { let mut empirical = Empirical::new().unwrap(); - + empirical.add(5.2); // should not panic empirical.remove(10.0);