From 5dd94fe7c70c98960c57782ef9f53fad3eb4218d Mon Sep 17 00:00:00 2001 From: Jeremy Craig Sawruk Date: Thu, 11 Mar 2021 19:18:17 -0500 Subject: [PATCH] Add Weibull distribution functions --- .../src/main/sphinx/functions/math.rst | 11 +++++++ .../presto/operator/scalar/MathFunctions.java | 30 +++++++++++++++++++ .../operator/scalar/TestMathFunctions.java | 25 ++++++++++++++++ 3 files changed, 66 insertions(+) diff --git a/presto-docs/src/main/sphinx/functions/math.rst b/presto-docs/src/main/sphinx/functions/math.rst index 106faa636f9d..4199a68d4381 100644 --- a/presto-docs/src/main/sphinx/functions/math.rst +++ b/presto-docs/src/main/sphinx/functions/math.rst @@ -109,6 +109,12 @@ Mathematical Functions The lambda parameter must be a positive real number (of type DOUBLE). The probability p must lie on the interval [0, 1). +.. function:: inverse_weibull_cdf(a, b, p) -> double + + Compute the inverse of the Weibull cdf with given parameters ``a``, ``b`` for the probability ``p``. + The ``a``, ``b`` parameters must be positive double values. The probability ``p`` must be a double + on the interval [0, 1]. + .. function:: normal_cdf(mean, sd, v) -> double Compute the Normal cdf with given mean and standard deviation (sd): P(N < v; mean, sd). @@ -220,6 +226,11 @@ Mathematical Functions ``truncate(REAL '12.333', 0)`` -> result is 12.0 ``truncate(REAL '12.333', 1)`` -> result is 12.3 +.. function:: weibull_cdf(a, b, value) -> double + + Compute the Weibull cdf with given parameters a, b: P(N <= value). The ``a`` + and ``b`` parameters must be positive doubles and ``value`` must also be a double. + .. function:: width_bucket(x, bound1, bound2, n) -> bigint Returns the bin number of ``x`` in an equi-width histogram with the diff --git a/presto-main/src/main/java/com/facebook/presto/operator/scalar/MathFunctions.java b/presto-main/src/main/java/com/facebook/presto/operator/scalar/MathFunctions.java index d6203539681a..2efeb0e1a4c1 100644 --- a/presto-main/src/main/java/com/facebook/presto/operator/scalar/MathFunctions.java +++ b/presto-main/src/main/java/com/facebook/presto/operator/scalar/MathFunctions.java @@ -35,6 +35,7 @@ import org.apache.commons.math3.distribution.CauchyDistribution; import org.apache.commons.math3.distribution.ChiSquaredDistribution; import org.apache.commons.math3.distribution.PoissonDistribution; +import org.apache.commons.math3.distribution.WeibullDistribution; import org.apache.commons.math3.special.Erf; import java.math.BigDecimal; @@ -816,6 +817,35 @@ public static double poissonCdf( return distribution.cumulativeProbability((int) value); } + @Description("Inverse of Weibull cdf given a, b parameters and probability") + @ScalarFunction + @SqlType(StandardTypes.DOUBLE) + public static double inverseWeibullCdf( + @SqlType(StandardTypes.DOUBLE) double a, + @SqlType(StandardTypes.DOUBLE) double b, + @SqlType(StandardTypes.DOUBLE) double p) + { + checkCondition(p >= 0 && p <= 1, INVALID_FUNCTION_ARGUMENT, "p must be in the interval [0, 1]"); + checkCondition(a > 0, INVALID_FUNCTION_ARGUMENT, "a must be greater than 0"); + checkCondition(b > 0, INVALID_FUNCTION_ARGUMENT, "b must be greater than 0"); + WeibullDistribution distribution = new WeibullDistribution(null, a, b, WeibullDistribution.DEFAULT_INVERSE_ABSOLUTE_ACCURACY); + return distribution.inverseCumulativeProbability(p); + } + + @Description("Weibull cdf given the a, b parameters and value") + @ScalarFunction + @SqlType(StandardTypes.DOUBLE) + public static double weibullCdf( + @SqlType(StandardTypes.DOUBLE) double a, + @SqlType(StandardTypes.DOUBLE) double b, + @SqlType(StandardTypes.DOUBLE) double value) + { + checkCondition(a > 0, INVALID_FUNCTION_ARGUMENT, "a must be greater than 0"); + checkCondition(b > 0, INVALID_FUNCTION_ARGUMENT, "b must be greater than 0"); + WeibullDistribution distribution = new WeibullDistribution(null, a, b, WeibullDistribution.DEFAULT_INVERSE_ABSOLUTE_ACCURACY); + return distribution.cumulativeProbability(value); + } + @Description("round to nearest integer") @ScalarFunction("round") @SqlType(StandardTypes.TINYINT) diff --git a/presto-main/src/test/java/com/facebook/presto/operator/scalar/TestMathFunctions.java b/presto-main/src/test/java/com/facebook/presto/operator/scalar/TestMathFunctions.java index 1aebf8f5fefb..dd9f073d8153 100644 --- a/presto-main/src/test/java/com/facebook/presto/operator/scalar/TestMathFunctions.java +++ b/presto-main/src/test/java/com/facebook/presto/operator/scalar/TestMathFunctions.java @@ -1475,6 +1475,31 @@ public void testPoissonCdf() assertInvalidFunction("poisson_cdf(3, -10)", "value must be a non-negative integer"); } + public void testInverseWeibullCdf() + { + assertFunction("inverse_weibull_cdf(1.0, 1.0, 0.0)", DOUBLE, 0.0); + assertFunction("round(inverse_weibull_cdf(1.0, 1.0, 0.632), 2)", DOUBLE, 1.00); + assertFunction("round(inverse_weibull_cdf(1.0, 0.6, 0.91), 2)", DOUBLE, 1.44); + + assertInvalidFunction("inverse_weibull_cdf(0, 3, 0.5)", "a must be greater than 0"); + assertInvalidFunction("inverse_weibull_cdf(3, 0, 0.5)", "b must be greater than 0"); + assertInvalidFunction("inverse_weibull_cdf(3, 5, -0.1)", "p must be in the interval [0, 1]"); + assertInvalidFunction("inverse_weibull_cdf(3, 5, 1.1)", "p must be in the interval [0, 1]"); + } + + @Test + public void testWeibullCdf() + throws Exception + { + assertFunction("weibull_cdf(1.0, 1.0, 0.0)", DOUBLE, 0.0); + assertFunction("weibull_cdf(1.0, 1.0, 40.0)", DOUBLE, 1.0); + assertFunction("round(weibull_cdf(1.0, 0.6, 3.0), 2)", DOUBLE, 0.99); + assertFunction("round(weibull_cdf(1.0, 0.9, 2.0), 2)", DOUBLE, 0.89); + + assertInvalidFunction("weibull_cdf(0, 3, 0.5)", "a must be greater than 0"); + assertInvalidFunction("weibull_cdf(3, 0, 0.5)", "b must be greater than 0"); + } + @Test public void testWilsonInterval() {