From 993a4b5732452f4f12250c60f07defc3e1480849 Mon Sep 17 00:00:00 2001 From: FreezyLemon Date: Mon, 23 Sep 2024 09:58:32 +0200 Subject: [PATCH] =?UTF-8?q?refactor!:=20forbid=20=CE=B1=3Dinf=20||=20?= =?UTF-8?q?=CE=B2=3Dinf=20in=20Beta=20distr.?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit This fixes `skewness()` returning an incorrect value for numeric limits and simplifies some functions. Also clarifies and fixes up some docs. --- src/distribution/beta.rs | 211 +++++++++-------------------------- src/distribution/internal.rs | 4 +- 2 files changed, 53 insertions(+), 162 deletions(-) diff --git a/src/distribution/beta.rs b/src/distribution/beta.rs index 5ea768de..81856468 100644 --- a/src/distribution/beta.rs +++ b/src/distribution/beta.rs @@ -26,23 +26,19 @@ pub struct Beta { #[derive(Copy, Clone, PartialEq, Eq, Debug, Hash)] #[non_exhaustive] pub enum BetaError { - /// Shape A is NaN, zero or negative. + /// Shape A is NaN, infinite, zero or negative. ShapeAInvalid, - /// Shape B is NaN, zero or negative. + /// Shape B is NaN, infinite, zero or negative. ShapeBInvalid, - - /// Shape A and Shape B are infinite. - BothShapesInfinite, } impl std::fmt::Display for BetaError { #[cfg_attr(coverage_nightly, coverage(off))] fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { match self { - BetaError::ShapeAInvalid => write!(f, "Shape A is NaN, zero or negative"), - BetaError::ShapeBInvalid => write!(f, "Shape B is NaN, zero or negative"), - BetaError::BothShapesInfinite => write!(f, "Shape A and shape B are infinite"), + BetaError::ShapeAInvalid => write!(f, "Shape A is NaN, infinite, zero or negative"), + BetaError::ShapeBInvalid => write!(f, "Shape B is NaN, infinite, zero or negative"), } } } @@ -55,7 +51,7 @@ impl Beta { /// /// # Errors /// - /// Returns an error if `shape_a` or `shape_b` are `NaN`. + /// Returns an error if `shape_a` or `shape_b` are `NaN` or infinite. /// Also returns an error if `shape_a <= 0.0` or `shape_b <= 0.0` /// /// # Examples @@ -70,18 +66,14 @@ impl Beta { /// assert!(result.is_err()); /// ``` pub fn new(shape_a: f64, shape_b: f64) -> Result { - if shape_a.is_nan() || shape_a <= 0.0 { + if shape_a.is_nan() || shape_a.is_infinite() || shape_a <= 0.0 { return Err(BetaError::ShapeAInvalid); } - if shape_b.is_nan() || shape_b <= 0.0 { + if shape_b.is_nan() || shape_b.is_infinite() || shape_b <= 0.0 { return Err(BetaError::ShapeBInvalid); } - if shape_a.is_infinite() && shape_b.is_infinite() { - return Err(BetaError::BothShapesInfinite); - } - Ok(Beta { shape_a, shape_b }) } @@ -92,8 +84,8 @@ impl Beta { /// ``` /// use statrs::distribution::Beta; /// - /// let n = Beta::new(2.0, 2.0).unwrap(); - /// assert_eq!(n.shape_a(), 2.0); + /// let n = Beta::new(1.0, 2.0).unwrap(); + /// assert_eq!(n.shape_a(), 1.0); /// ``` pub fn shape_a(&self) -> f64 { self.shape_a @@ -106,7 +98,7 @@ impl Beta { /// ``` /// use statrs::distribution::Beta; /// - /// let n = Beta::new(2.0, 2.0).unwrap(); + /// let n = Beta::new(1.0, 2.0).unwrap(); /// assert_eq!(n.shape_b(), 2.0); /// ``` pub fn shape_b(&self) -> f64 { @@ -133,8 +125,7 @@ impl ::rand::distributions::Distribution for Beta { impl ContinuousCDF for Beta { /// Calculates the cumulative distribution function for the beta - /// distribution - /// at `x` + /// distribution at `x`. /// /// # Formula /// @@ -143,20 +134,12 @@ impl ContinuousCDF for Beta { /// ``` /// /// where `α` is shapeA, `β` is shapeB, and `I_x` is the regularized - /// lower incomplete beta function + /// lower incomplete beta function. fn cdf(&self, x: f64) -> f64 { if x < 0.0 { 0.0 } else if x >= 1.0 { 1.0 - } else if self.shape_a.is_infinite() { - if x < 1.0 { - 0.0 - } else { - 1.0 - } - } else if self.shape_b.is_infinite() { - 1.0 } else if ulps_eq!(self.shape_a, 1.0) && ulps_eq!(self.shape_b, 1.0) { x } else { @@ -164,8 +147,7 @@ impl ContinuousCDF for Beta { } } - /// Calculates the survival function for the beta - /// distribution at `x` + /// Calculates the survival function for the beta distribution at `x`. /// /// # Formula /// @@ -174,20 +156,12 @@ impl ContinuousCDF for Beta { /// ``` /// /// where `α` is shapeA, `β` is shapeB, and `I_x` is the regularized - /// lower incomplete beta function + /// lower incomplete beta function. fn sf(&self, x: f64) -> f64 { if x < 0.0 { 1.0 } else if x >= 1.0 { 0.0 - } else if self.shape_a.is_infinite() { - if x < 1.0 { - 1.0 - } else { - 0.0 - } - } else if self.shape_b.is_infinite() { - 0.0 } else if ulps_eq!(self.shape_a, 1.0) && ulps_eq!(self.shape_b, 1.0) { 1. - x } else { @@ -196,8 +170,11 @@ impl ContinuousCDF for Beta { } /// Calculates the inverse cumulative distribution function for the beta - /// distribution - /// at `x` + /// distribution at `x`. + /// + /// # Panics + /// + /// If x is not in `[0, 1]`. /// /// # Formula /// @@ -206,7 +183,7 @@ impl ContinuousCDF for Beta { /// ``` /// /// where `α` is shapeA, `β` is shapeB, and `I_x` is the inverse of the - /// regularized lower incomplete beta function + /// regularized lower incomplete beta function. fn inverse_cdf(&self, x: f64) -> f64 { if !(0.0..=1.0).contains(&x) { panic!("x must be in [0, 1]"); @@ -217,9 +194,8 @@ impl ContinuousCDF for Beta { } impl Min for Beta { - /// Returns the minimum value in the domain of the - /// beta distribution representable by a double precision - /// float + /// Returns the minimum value in the domain of the beta distribution + /// representable by a double precision float. /// /// # Formula /// @@ -232,9 +208,8 @@ impl Min for Beta { } impl Max for Beta { - /// Returns the maximum value in the domain of the - /// beta distribution representable by a double precision - /// float + /// Returns the maximum value in the domain of the beta distribution + /// representable by a double precision float. /// /// # Formula /// @@ -247,7 +222,7 @@ impl Max for Beta { } impl Distribution for Beta { - /// Returns the mean of the beta distribution + /// Returns the mean of the beta distribution. /// /// # Formula /// @@ -255,19 +230,12 @@ impl Distribution for Beta { /// α / (α + β) /// ``` /// - /// where `α` is shapeA and `β` is shapeB + /// where `α` is shapeA and `β` is shapeB. fn mean(&self) -> Option { - let mean = if self.shape_a.is_infinite() { - 1.0 - } else { - self.shape_a / (self.shape_a + self.shape_b) - }; - Some(mean) + Some(self.shape_a / (self.shape_a + self.shape_b)) } - /// Returns the variance of the beta distribution - /// - /// # Remarks + /// Returns the variance of the beta distribution. /// /// # Formula /// @@ -275,20 +243,17 @@ impl Distribution for Beta { /// (α * β) / ((α + β)^2 * (α + β + 1)) /// ``` /// - /// where `α` is shapeA and `β` is shapeB + /// where `α` is shapeA and `β` is shapeB. fn variance(&self) -> Option { - let var = if self.shape_a.is_infinite() || self.shape_b.is_infinite() { - 0.0 - } else { + Some( self.shape_a * self.shape_b / ((self.shape_a + self.shape_b) * (self.shape_a + self.shape_b) - * (self.shape_a + self.shape_b + 1.0)) - }; - Some(var) + * (self.shape_a + self.shape_b + 1.0)), + ) } - /// Returns the entropy of the beta distribution + /// Returns the entropy of the beta distribution. /// /// # Formula /// @@ -296,21 +261,17 @@ impl Distribution for Beta { /// ln(B(α, β)) - (α - 1)ψ(α) - (β - 1)ψ(β) + (α + β - 2)ψ(α + β) /// ``` /// - /// where `α` is shapeA, `β` is shapeB and `ψ` is the digamma function + /// where `α` is shapeA, `β` is shapeB and `ψ` is the digamma function. fn entropy(&self) -> Option { - let entr = if self.shape_a.is_infinite() || self.shape_b.is_infinite() { - // unsupported limit - return None; - } else { + Some( beta::ln_beta(self.shape_a, self.shape_b) - (self.shape_a - 1.0) * gamma::digamma(self.shape_a) - (self.shape_b - 1.0) * gamma::digamma(self.shape_b) - + (self.shape_a + self.shape_b - 2.0) * gamma::digamma(self.shape_a + self.shape_b) - }; - Some(entr) + + (self.shape_a + self.shape_b - 2.0) * gamma::digamma(self.shape_a + self.shape_b), + ) } - /// Returns the skewness of the Beta distribution + /// Returns the skewness of the Beta distribution. /// /// # Formula /// @@ -318,33 +279,24 @@ impl Distribution for Beta { /// 2(β - α) * sqrt(α + β + 1) / ((α + β + 2) * sqrt(αβ)) /// ``` /// - /// where `α` is shapeA and `β` is shapeB + /// where `α` is shapeA and `β` is shapeB. fn skewness(&self) -> Option { - let skew = if self.shape_a.is_infinite() { - -2.0 - } else if self.shape_b.is_infinite() { - 2.0 - } else { + Some( 2.0 * (self.shape_b - self.shape_a) * (self.shape_a + self.shape_b + 1.0).sqrt() - / ((self.shape_a + self.shape_b + 2.0) * (self.shape_a * self.shape_b).sqrt()) - }; - Some(skew) + / ((self.shape_a + self.shape_b + 2.0) * (self.shape_a * self.shape_b).sqrt()), + ) } } impl Mode> for Beta { - /// Returns the mode of the Beta distribution. + /// Returns the mode of the Beta distribution. Returns `None` if `α <= 1` + /// or `β <= 1`. /// /// # Remarks /// - /// Since the mode is technically only calculate for `α > 1, β > 1`, those + /// Since the mode is technically only calculated for `α > 1, β > 1`, those /// are the only values we allow. We may consider relaxing this constraint - /// in - /// the future. - /// - /// # Panics - /// - /// If `α <= 1` or `β <= 1` + /// in the future. /// /// # Formula /// @@ -358,8 +310,6 @@ impl Mode> for Beta { // of 'anti-mode; if self.shape_a <= 1.0 || self.shape_b <= 1.0 { None - } else if self.shape_a.is_infinite() { - Some(1.0) } else { Some((self.shape_a - 1.0) / (self.shape_a + self.shape_b - 2.0)) } @@ -382,18 +332,6 @@ impl Continuous for Beta { fn pdf(&self, x: f64) -> f64 { if !(0.0..=1.0).contains(&x) { 0.0 - } else if self.shape_a.is_infinite() { - if ulps_eq!(x, 1.0) { - f64::INFINITY - } else { - 0.0 - } - } else if self.shape_b.is_infinite() { - if x == 0.0 { - f64::INFINITY - } else { - 0.0 - } } else if ulps_eq!(self.shape_a, 1.0) && ulps_eq!(self.shape_b, 1.0) { 1.0 } else if self.shape_a > 80.0 || self.shape_b > 80.0 { @@ -416,22 +354,10 @@ impl Continuous for Beta { /// ln(x^(α - 1) * (1 - x)^(β - 1) / B(α, β)) /// ``` /// - /// where `α` is shapeA, `β` is shapeB, and `Γ` is the gamma function + /// where `α` is shapeA, `β` is shapeB, and `Γ` is the gamma function. fn ln_pdf(&self, x: f64) -> f64 { if !(0.0..=1.0).contains(&x) { f64::NEG_INFINITY - } else if self.shape_a.is_infinite() { - if ulps_eq!(x, 1.0) { - f64::INFINITY - } else { - f64::NEG_INFINITY - } - } else if self.shape_b.is_infinite() { - if x == 0.0 { - f64::INFINITY - } else { - f64::NEG_INFINITY - } } else if ulps_eq!(self.shape_a, 1.0) && ulps_eq!(self.shape_b, 1.0) { 0.0 } else { @@ -468,7 +394,7 @@ mod tests { #[test] fn test_create() { - let valid = [(1.0, 1.0), (9.0, 1.0), (5.0, 100.0), (1.0, f64::INFINITY), (f64::INFINITY, 1.0)]; + let valid = [(1.0, 1.0), (9.0, 1.0), (5.0, 100.0)]; for (a, b) in valid { create_ok(a, b); } @@ -480,8 +406,8 @@ mod tests { (0.0, 0.0), (0.0, 0.1), (1.0, 0.0), - (0.0, f64::INFINITY), - (f64::INFINITY, 0.0), + (0.5, f64::INFINITY), + (f64::INFINITY, 0.5), (f64::NAN, 1.0), (1.0, f64::NAN), (f64::NAN, f64::NAN), @@ -502,8 +428,6 @@ mod tests { ((1.0, 1.0), 0.5), ((9.0, 1.0), 0.9), ((5.0, 100.0), 0.047619047619047619047616), - ((1.0, f64::INFINITY), 0.0), - ((f64::INFINITY, 1.0), 1.0), ]; for ((a, b), res) in test { test_relative(a, b, res, f); @@ -517,8 +441,6 @@ mod tests { ((1.0, 1.0), 1.0 / 12.0), ((9.0, 1.0), 9.0 / 1100.0), ((5.0, 100.0), 500.0 / 1168650.0), - ((1.0, f64::INFINITY), 0.0), - ((f64::INFINITY, 1.0), 0.0), ]; for ((a, b), res) in test { test_relative(a, b, res, f); @@ -536,9 +458,6 @@ mod tests { test_relative(a, b, res, f); } test_absolute(1.0, 1.0, 0.0, 1e-14, f); - let entropy = |x: Beta| x.entropy(); - test_none(1.0, f64::INFINITY, entropy); - test_none(f64::INFINITY, 1.0, entropy); } #[test] @@ -547,16 +466,12 @@ mod tests { test_relative(1.0, 1.0, 0.0, skewness); test_relative(9.0, 1.0, -1.4740554623801777107177478829, skewness); test_relative(5.0, 100.0, 0.817594109275534303545831591, skewness); - test_relative(1.0, f64::INFINITY, 2.0, skewness); - test_relative(f64::INFINITY, 1.0, -2.0, skewness); } #[test] fn test_mode() { let mode = |x: Beta| x.mode().unwrap(); test_relative(5.0, 100.0, 0.038834951456310676243255386, mode); - test_relative(92.0, f64::INFINITY, 0.0, mode); - test_relative(f64::INFINITY, 2.0, 1.0, mode); } #[test] @@ -590,13 +505,7 @@ mod tests { ((5.0, 100.0), 0.0, 0.0), ((5.0, 100.0), 0.5, 4.534102298350337661e-23), ((5.0, 100.0), 1.0, 0.0), - ((5.0, 100.0), 1.0, 0.0), - ((1.0, f64::INFINITY), 0.0, f64::INFINITY), - ((1.0, f64::INFINITY), 0.5, 0.0), - ((1.0, f64::INFINITY), 1.0, 0.0), - ((f64::INFINITY, 1.0), 0.0, 0.0), - ((f64::INFINITY, 1.0), 0.5, 0.0), - ((f64::INFINITY, 1.0), 1.0, f64::INFINITY), + ((5.0, 100.0), 1.0, 0.0) ]; for ((a, b), x, expect) in test { test_relative(a, b, expect, f(x)); @@ -628,12 +537,6 @@ mod tests { ((5.0, 100.0), 0.0, f64::NEG_INFINITY), ((5.0, 100.0), 0.5, -51.447830024537682154565870), ((5.0, 100.0), 1.0, f64::NEG_INFINITY), - ((1.0, f64::INFINITY), 0.0, f64::INFINITY), - ((1.0, f64::INFINITY), 0.5, f64::NEG_INFINITY), - ((1.0, f64::INFINITY), 1.0, f64::NEG_INFINITY), - ((f64::INFINITY, 1.0), 0.0, f64::NEG_INFINITY), - ((f64::INFINITY, 1.0), 0.5, f64::NEG_INFINITY), - ((f64::INFINITY, 1.0), 1.0, f64::INFINITY), ]; for ((a, b), x, expect) in test { test_relative(a, b, expect, f(x)); @@ -665,12 +568,6 @@ mod tests { ((5.0, 100.0), 0.0, 0.0), ((5.0, 100.0), 0.5, 1.0), ((5.0, 100.0), 1.0, 1.0), - ((1.0, f64::INFINITY), 0.0, 1.0), - ((1.0, f64::INFINITY), 0.5, 1.0), - ((1.0, f64::INFINITY), 1.0, 1.0), - ((f64::INFINITY, 1.0), 0.0, 0.0), - ((f64::INFINITY, 1.0), 0.5, 0.0), - ((f64::INFINITY, 1.0), 1.0, 1.0), ]; for ((a, b), x, expect) in test { test_relative(a, b, expect, cdf(x)); @@ -690,12 +587,6 @@ mod tests { ((5.0, 100.0), 0.0, 1.0), ((5.0, 100.0), 0.5, 0.0), ((5.0, 100.0), 1.0, 0.0), - ((1.0, f64::INFINITY), 0.0, 0.0), - ((1.0, f64::INFINITY), 0.5, 0.0), - ((1.0, f64::INFINITY), 1.0, 0.0), - ((f64::INFINITY, 1.0), 0.0, 1.0), - ((f64::INFINITY, 1.0), 0.5, 1.0), - ((f64::INFINITY, 1.0), 1.0, 0.0), ]; for ((a, b), x, expect) in test { test_relative(a, b, expect, sf(x)); diff --git a/src/distribution/internal.rs b/src/distribution/internal.rs index ba6d1090..9075669e 100644 --- a/src/distribution/internal.rs +++ b/src/distribution/internal.rs @@ -323,7 +323,7 @@ pub mod test { #[test] #[should_panic] fn test_create_err_failure() { - test_create_err(0.0, 0.5, BetaError::BothShapesInfinite); + test_create_err(0.0, 0.5, BetaError::ShapeBInvalid); } #[test] @@ -340,7 +340,7 @@ pub mod test { #[test] fn test_is_none_success() { - test_none(f64::INFINITY, 1.2, |dist| dist.entropy()); + test_none(0.5, 1.2, |dist| dist.mode()); } #[test]