Skip to content

Commit

Permalink
Implement betainc and derivatives
Browse files Browse the repository at this point in the history
  • Loading branch information
Ricardo authored and brandonwillard committed Jul 1, 2021
1 parent b5313f1 commit 1ece906
Show file tree
Hide file tree
Showing 6 changed files with 346 additions and 1 deletion.
4 changes: 4 additions & 0 deletions aesara/scalar/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -1281,6 +1281,10 @@ class BinaryScalarOp(ScalarOp):
nin = 2


class TernaryScalarOp(ScalarOp):
nin = 3


class LogicalComparison(BinaryScalarOp):
def __init__(self, *args, **kwargs):
BinaryScalarOp.__init__(self, *args, **kwargs)
Expand Down
229 changes: 229 additions & 0 deletions aesara/scalar/math.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
"""

import os
import warnings

import numpy as np
import scipy.special
Expand All @@ -14,12 +15,15 @@
from aesara.gradient import grad_not_implemented
from aesara.scalar.basic import (
BinaryScalarOp,
TernaryScalarOp,
UnaryScalarOp,
complex_types,
discrete_types,
exp,
float64,
float_types,
log,
log1p,
true_div,
upcast,
upgrade_to_float,
Expand Down Expand Up @@ -1044,3 +1048,228 @@ def c_code(self, node, name, inp, out, sub):


log1mexp = Log1mexp(upgrade_to_float, name="scalar_log1mexp")


class BetaInc(TernaryScalarOp):
"""
Regularized incomplete beta function
"""

nfunc_spec = ("scipy.special.betainc", 3, 1)

def impl(self, a, b, x):
return scipy.special.betainc(a, b, x)

def grad(self, inp, grads):
a, b, x = inp
(gz,) = grads

return [
gz * betainc_dda_scalar(a, b, x),
gz * betainc_ddb_scalar(a, b, x),
gz
* exp(
log1p(-x) * (b - 1)
+ log(x) * (a - 1)
- (gammaln(a) + gammaln(b) - gammaln(a + b))
),
]


betainc = BetaInc(upgrade_to_float_no_complex, name="betainc")


class BetaIncDda(TernaryScalarOp):
"""
Gradient of the regularized incomplete beta function wrt to the first argument (a)
"""

def impl(self, a, b, x):
return _betainc_derivative(a, b, x, wrtp=True)


betainc_dda_scalar = BetaIncDda(upgrade_to_float_no_complex, name="betainc_dda")


class BetaIncDdb(TernaryScalarOp):
"""
Gradient of the regularized incomplete beta function wrt to the second argument (b)
"""

def impl(self, a, b, x):
return _betainc_derivative(a, b, x, wrtp=False)


betainc_ddb_scalar = BetaIncDdb(upgrade_to_float_no_complex, name="betainc_ddb")


def _betainc_derivative(p, q, x, wrtp=True):
"""
Compute the derivative of regularized incomplete beta function wrt to p (alpha) or q (beta)
Reference: Boik, R. J., & Robison-Cox, J. F. (1998). Derivatives of the incomplete beta function.
Journal of Statistical Software, 3(1), 1-20.
"""

def _betainc_a_n(f, p, q, n):
"""
Numerator (a_n) of the nth approximant of the continued fraction
representation of the regularized incomplete beta function
"""

if n == 1:
return p * f * (q - 1) / (q * (p + 1))

p2n = p + 2 * n
F1 = p ** 2 * f ** 2 * (n - 1) / (q ** 2)
F2 = (
(p + q + n - 2)
* (p + n - 1)
* (q - n)
/ ((p2n - 3) * (p2n - 2) ** 2 * (p2n - 1))
)

return F1 * F2

def _betainc_b_n(f, p, q, n):
"""
Offset (b_n) of the nth approximant of the continued fraction
representation of the regularized incomplete beta function
"""
pf = p * f
p2n = p + 2 * n

N1 = 2 * (pf + 2 * q) * n * (n + p - 1) + p * q * (p - 2 - pf)
D1 = q * (p2n - 2) * p2n

return N1 / D1

def _betainc_da_n_dp(f, p, q, n):
"""
Derivative of a_n wrt p
"""

if n == 1:
return -p * f * (q - 1) / (q * (p + 1) ** 2)

pp = p ** 2
ppp = pp * p
p2n = p + 2 * n

N1 = -(n - 1) * f ** 2 * pp * (q - n)
N2a = (-8 + 8 * p + 8 * q) * n ** 3
N2b = (16 * pp + (-44 + 20 * q) * p + 26 - 24 * q) * n ** 2
N2c = (10 * ppp + (14 * q - 46) * pp + (-40 * q + 66) * p - 28 + 24 * q) * n
N2d = 2 * pp ** 2 + (-13 + 3 * q) * ppp + (-14 * q + 30) * pp
N2e = (-29 + 19 * q) * p + 10 - 8 * q

D1 = q ** 2 * (p2n - 3) ** 2
D2 = (p2n - 2) ** 3 * (p2n - 1) ** 2

return (N1 / D1) * (N2a + N2b + N2c + N2d + N2e) / D2

def _betainc_da_n_dq(f, p, q, n):
"""
Derivative of a_n wrt q
"""
if n == 1:
return p * f / (q * (p + 1))

p2n = p + 2 * n
F1 = (p ** 2 * f ** 2 / (q ** 2)) * (n - 1) * (p + n - 1) * (2 * q + p - 2)
D1 = (p2n - 3) * (p2n - 2) ** 2 * (p2n - 1)

return F1 / D1

def _betainc_db_n_dp(f, p, q, n):
"""
Derivative of b_n wrt p
"""
p2n = p + 2 * n
pp = p ** 2
q4 = 4 * q
p4 = 4 * p

F1 = (p * f / q) * (
(-p4 - q4 + 4) * n ** 2 + (p4 - 4 + q4 - 2 * pp) * n + pp * q
)
D1 = (p2n - 2) ** 2 * p2n ** 2

return F1 / D1

def _betainc_db_n_dq(f, p, q, n):
"""
Derivative of b_n wrt to q
"""
p2n = p + 2 * n
return -(p ** 2 * f) / (q * (p2n - 2) * p2n)

# Input validation
if not (0 <= x <= 1) or p < 0 or q < 0:
return np.nan

if x > (p / (p + q)):
return -_betainc_derivative(q, p, 1 - x, not wrtp)

min_iters = 3
max_iters = 200
err_threshold = 1e-12

derivative_old = 0

Am2, Am1 = 1, 1
Bm2, Bm1 = 0, 1
dAm2, dAm1 = 0, 0
dBm2, dBm1 = 0, 0

f = (q * x) / (p * (1 - x))
K = np.exp(
p * np.log(x) + (q - 1) * np.log1p(-x) - np.log(p) - scipy.special.betaln(p, q)
)
if wrtp:
dK = np.log(x) - 1 / p + scipy.special.digamma(p + q) - scipy.special.digamma(p)
else:
dK = np.log1p(-x) + scipy.special.digamma(p + q) - scipy.special.digamma(q)

for n in range(1, max_iters + 1):
a_n_ = _betainc_a_n(f, p, q, n)
b_n_ = _betainc_b_n(f, p, q, n)
if wrtp:
da_n = _betainc_da_n_dp(f, p, q, n)
db_n = _betainc_db_n_dp(f, p, q, n)
else:
da_n = _betainc_da_n_dq(f, p, q, n)
db_n = _betainc_db_n_dq(f, p, q, n)

A = a_n_ * Am2 + b_n_ * Am1
B = a_n_ * Bm2 + b_n_ * Bm1
dA = da_n * Am2 + a_n_ * dAm2 + db_n * Am1 + b_n_ * dAm1
dB = da_n * Bm2 + a_n_ * dBm2 + db_n * Bm1 + b_n_ * dBm1

Am2, Am1 = Am1, A
Bm2, Bm1 = Bm1, B
dAm2, dAm1 = dAm1, dA
dBm2, dBm1 = dBm1, dB

if n < min_iters - 1:
continue

F1 = A / B
F2 = (dA - F1 * dB) / B
derivative = K * (F1 * dK + F2)

errapx = abs(derivative_old - derivative)
d_errapx = errapx / max(err_threshold, abs(derivative))
derivative_old = derivative

if d_errapx <= err_threshold:
break

if n >= max_iters:
warnings.warn(
f"_betainc_derivative did not converge after {n} iterations",
RuntimeWarning,
)
return np.nan

return derivative
5 changes: 5 additions & 0 deletions aesara/tensor/inplace.py
Original file line number Diff line number Diff line change
Expand Up @@ -323,6 +323,11 @@ def log1mexp_inplace(x):
"""Compute log(1 - exp(x)), also known as log1mexp"""


@scalar_elemwise
def betainc_inplace(a, b, x):
"""Regularized incomplete beta function"""


@scalar_elemwise
def second_inplace(a):
"""Fill `a` with `b`"""
Expand Down
6 changes: 6 additions & 0 deletions aesara/tensor/math.py
Original file line number Diff line number Diff line change
Expand Up @@ -1429,6 +1429,11 @@ def log1mexp(x):
"""Compute log(1 - exp(x)), also known as log1mexp"""


@scalar_elemwise
def betainc(a, b, x):
"""Regularized incomplete beta function"""


@scalar_elemwise
def real(z):
"""Return real component of complex-valued tensor `z`"""
Expand Down Expand Up @@ -2909,6 +2914,7 @@ def logsumexp(x, axis=None, keepdims=False):
"softplus",
"log1pexp",
"log1mexp",
"betainc",
"real",
"imag",
"angle",
Expand Down
79 changes: 78 additions & 1 deletion tests/scalar/test_math.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
import numpy as np
from numpy.testing import assert_allclose, assert_almost_equal

import aesara.tensor as aet
from aesara import config, function
from aesara.graph.fg import FunctionGraph
from aesara.link.c.basic import CLinker
from aesara.scalar.math import gammainc, gammaincc, gammal, gammau
from aesara.scalar.math import betainc, gammainc, gammaincc, gammal, gammau


def test_gammainc_nan():
Expand Down Expand Up @@ -44,3 +46,78 @@ def test_gammau_nan():
assert np.isnan(test_func(-1, 1))
assert np.isnan(test_func(1, -1))
assert np.isnan(test_func(-1, -1))


class TestBetaIncGrad:
def test_stan_grad_combined(self):
# This test replicates the following STAN test:
# https://github.com/stan-dev/math/blob/master/test/unit/math/prim/fun/grad_reg_inc_beta_test.cpp
a, b, z = aet.scalars("a", "b", "z")
betainc_out = betainc(a, b, z)
betainc_grad = aet.grad(betainc_out, [a, b], null_gradients="return")
f_grad = function([a, b, z], betainc_grad)

for test_a, test_b, test_z, expected_dda, expected_ddb in (
(1.0, 1.0, 1.0, 0, np.nan),
(1.0, 1.0, 0.4, -0.36651629, 0.30649537),
):
assert_allclose(
f_grad(test_a, test_b, test_z), [expected_dda, expected_ddb]
)

def test_stan_grad_partial(self):
# This test combines the following STAN tests:
# https://github.com/stan-dev/math/blob/master/test/unit/math/prim/fun/inc_beta_dda_test.cpp
# https://github.com/stan-dev/math/blob/master/test/unit/math/prim/fun/inc_beta_ddb_test.cpp
# https://github.com/stan-dev/math/blob/master/test/unit/math/prim/fun/inc_beta_ddz_test.cpp
a, b, z = aet.scalars("a", "b", "z")
betainc_out = betainc(a, b, z)
betainc_grad = aet.grad(betainc_out, [a, b, z])
f_grad = function([a, b, z], betainc_grad)

decimal_precision = 7 if config.floatX == "float64" else 3

for test_a, test_b, test_z, expected_dda, expected_ddb, expected_ddz in (
(1.5, 1.25, 0.001, -0.00028665637, 4.41357328e-05, 0.063300692),
(1.5, 1.25, 0.5, -0.26038693947, 0.29301795, 1.1905416),
(1.5, 1.25, 0.6, -0.23806757, 0.32279575, 1.23341068),
(1.5, 1.25, 0.999, -0.00022264493, 0.0018969609, 0.35587692),
(15000, 1.25, 0.001, 0, 0, 0),
(15000, 1.25, 0.5, 0, 0, 0),
(15000, 1.25, 0.6, 0, 0, 0),
(15000, 1.25, 0.999, -6.59543226e-10, 2.00849793e-06, 0.009898182),
(1.5, 12500, 0.001, -3.93756641e-05, 1.47821755e-09, 0.1848717),
(1.5, 12500, 0.5, 0, 0, 0),
(1.5, 12500, 0.6, 0, 0, 0),
(1.5, 12500, 0.999, 0, 0, 0),
(15000, 12500, 0.001, 0, 0, 0),
(15000, 12500, 0.5, -8.72102443e-53, 9.55282792e-53, 5.01131256e-48),
(15000, 12500, 0.6, -4.085621e-14, -5.5067062e-14, 1.15135267e-71),
(15000, 12500, 0.999, 0, 0, 0),
):

assert_almost_equal(
f_grad(test_a, test_b, test_z),
[expected_dda, expected_ddb, expected_ddz],
decimal=decimal_precision,
)

def test_boik_robison_cox(self):
# This test compares against the tabulated values in:
# Boik, R. J., & Robison-Cox, J. F. (1998). Derivatives of the incomplete beta function.
# Journal of Statistical Software, 3(1), 1-20.
a, b, z = aet.scalars("a", "b", "z")
betainc_out = betainc(a, b, z)
betainc_grad = aet.grad(betainc_out, [a, b])
f_grad = function([a, b, z], betainc_grad)

for test_a, test_b, test_z, expected_dda, expected_ddb in (
(1.5, 11.0, 0.001, -4.5720356e-03, 1.1845673e-04),
(1.5, 11.0, 0.5, -2.5501997e-03, 9.0824388e-04),
(1000.0, 1000.0, 0.5, -8.9224793e-03, 8.9224793e-03),
(1000.0, 1000.0, 0.55, -3.6713108e-07, 4.0584118e-07),
):
assert_almost_equal(
f_grad(test_a, test_b, test_z),
[expected_dda, expected_ddb],
)
Loading

0 comments on commit 1ece906

Please sign in to comment.