Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add gammainc(c) derivatives #513

Merged
merged 3 commits into from
Sep 19, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
188 changes: 177 additions & 11 deletions aesara/scalar/math.py
Original file line number Diff line number Diff line change
Expand Up @@ -556,6 +556,14 @@ def st_impl(k, x):
def impl(self, k, x):
return GammaInc.st_impl(k, x)

def grad(self, inputs, grads):
(k, x) = inputs
(gz,) = grads
return [
gz * gammainc_der(k, x),
gz * exp(-x + (k - 1) * log(x) - gammaln(k)),
]

def c_support_code(self, **kwargs):
with open(os.path.join(os.path.dirname(__file__), "c_code", "gamma.c")) as f:
raw = f.read()
Expand Down Expand Up @@ -592,11 +600,19 @@ class GammaIncC(BinaryScalarOp):

@staticmethod
def st_impl(k, x):
return scipy.special.gammaincc(x, k)
return scipy.special.gammaincc(k, x)

def impl(self, k, x):
return GammaIncC.st_impl(k, x)

def grad(self, inputs, grads):
(k, x) = inputs
(gz,) = grads
return [
gz * gammaincc_der(k, x),
gz * -exp(-x + (k - 1) * log(x) - gammaln(k)),
]

def c_support_code(self, **kwargs):
with open(os.path.join(os.path.dirname(__file__), "c_code", "gamma.c")) as f:
raw = f.read()
Expand Down Expand Up @@ -624,6 +640,159 @@ def __hash__(self):
gammaincc = GammaIncC(upgrade_to_float, name="gammaincc")


class GammaIncDer(BinaryScalarOp):
"""
Gradient of the the regularized lower gamma function (P) wrt to the first
argument (k, a.k.a. alpha). Adapted from STAN `grad_reg_lower_inc_gamma.hpp`

Reference: Gautschi, W. (1979). A computational procedure for incomplete gamma functions.
ACM Transactions on Mathematical Software (TOMS), 5(4), 466-481.
"""

def impl(self, k, x):

if x == 0:
return 0

sqrt_exp = -756 - x ** 2 + 60 * x
if (
(k < 0.8 and x > 15)
or (k < 12 and x > 30)
or (sqrt_exp > 0 and k < np.sqrt(sqrt_exp))
):
return -GammaIncCDer.st_impl(k, x)

precision = 1e-10
max_iters = int(1e5)

log_x = np.log(x)
log_gamma_k_plus_1 = scipy.special.gammaln(k + 1)

k_plus_n = k
log_gamma_k_plus_n_plus_1 = log_gamma_k_plus_1
sum_a = 0.0
for n in range(0, max_iters + 1):
term = np.exp(k_plus_n * log_x - log_gamma_k_plus_n_plus_1)
sum_a += term

if term <= precision:
break

log_gamma_k_plus_n_plus_1 += np.log1p(k_plus_n)
k_plus_n += 1

if n >= max_iters:
warnings.warn(
f"gammainc_der did not converge after {n} iterations",
RuntimeWarning,
)
return np.nan

k_plus_n = k
log_gamma_k_plus_n_plus_1 = log_gamma_k_plus_1
sum_b = 0.0
for n in range(0, max_iters + 1):
term = np.exp(
k_plus_n * log_x - log_gamma_k_plus_n_plus_1
) * scipy.special.digamma(k_plus_n + 1)
sum_b += term

if term <= precision and n >= 1: # Require at least two iterations
return np.exp(-x) * (log_x * sum_a - sum_b)

log_gamma_k_plus_n_plus_1 += np.log1p(k_plus_n)
k_plus_n += 1

warnings.warn(
f"gammainc_der did not converge after {n} iterations",
RuntimeWarning,
)
return np.nan


gammainc_der = GammaIncDer(upgrade_to_float, name="gammainc_der")


class GammaIncCDer(BinaryScalarOp):
"""
Gradient of the the regularized upper gamma function (Q) wrt to the first
argument (k, a.k.a. alpha). Adapted from STAN `grad_reg_inc_gamma.hpp`
"""

@staticmethod
def st_impl(k, x):
gamma_k = scipy.special.gamma(k)
digamma_k = scipy.special.digamma(k)
log_x = np.log(x)

# asymptotic expansion http://dlmf.nist.gov/8.11#E2
if (x >= k) and (x >= 8):
S = 0
k_minus_one_minus_n = k - 1
fac = k_minus_one_minus_n
dfac = 1
xpow = x
delta = dfac / xpow

for n in range(1, 10):
k_minus_one_minus_n -= 1
S += delta
xpow *= x
dfac = k_minus_one_minus_n * dfac + fac
fac *= k_minus_one_minus_n
delta = dfac / xpow
if np.isinf(delta):
warnings.warn(
"gammaincc_der did not converge",
RuntimeWarning,
)
return np.nan

return (
scipy.special.gammaincc(k, x) * (log_x - digamma_k)
+ np.exp(-x + (k - 1) * log_x) * S / gamma_k
)

# gradient of series expansion http://dlmf.nist.gov/8.7#E3
else:
log_precision = np.log(1e-6)
max_iters = int(1e5)
S = 0
log_s = 0.0
s_sign = 1
log_delta = log_s - 2 * np.log(k)
for n in range(1, max_iters + 1):
S += np.exp(log_delta) if s_sign > 0 else -np.exp(log_delta)
s_sign = -s_sign
log_s += log_x - np.log(n)
log_delta = log_s - 2 * np.log(n + k)

if np.isinf(log_delta):
warnings.warn(
"gammaincc_der did not converge",
RuntimeWarning,
)
return np.nan

if log_delta <= log_precision:
return (
scipy.special.gammainc(k, x) * (digamma_k - log_x)
+ np.exp(k * log_x) * S / gamma_k
)

warnings.warn(
f"gammaincc_der did not converge after {n} iterations",
RuntimeWarning,
)
return np.nan

def impl(self, k, x):
return self.st_impl(k, x)


gammaincc_der = GammaIncCDer(upgrade_to_float, name="gammaincc_der")


class GammaU(BinaryScalarOp):
"""
compute the upper incomplete gamma function.
Expand Down Expand Up @@ -1083,7 +1252,7 @@ def grad(self, inp, grads):
class BetaIncDer(ScalarOp):
"""
Gradient of the regularized incomplete beta function wrt to the first
argument (alpha) or the second argument (bbeta), depending on whether the
argument (alpha) or the second argument (beta), depending on whether the
fourth argument to betainc_der is `True` or `False`, respectively.

Reference: Boik, R. J., & Robison-Cox, J. F. (1998). Derivatives of the incomplete beta function.
Expand Down Expand Up @@ -1253,16 +1422,13 @@ def _betainc_db_n_dq(f, p, q, n):
derivative_old = derivative

if d_errapx <= err_threshold:
break
return derivative

if n >= max_iters:
warnings.warn(
f"_betainc_derivative did not converge after {n} iterations",
RuntimeWarning,
)
return np.nan

return derivative
warnings.warn(
f"betainc_der did not converge after {n} iterations",
RuntimeWarning,
)
return np.nan


betainc_der = BetaIncDer(upgrade_to_float_no_complex, name="betainc_der")
24 changes: 20 additions & 4 deletions tests/scalar/test_math.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,15 @@
from aesara.scalar.math import betainc, betainc_der, gammainc, gammaincc, gammal, gammau


def test_gammainc_nan():
def test_gammainc_python():
x1 = aet.dscalar()
x2 = aet.dscalar()
y = gammainc(x1, x2)
test_func = function([x1, x2], y, mode=Mode("py"))
assert np.isclose(test_func(1, 2), sp.gammainc(1, 2))


def test_gammainc_nan_c():
x1 = aet.dscalar()
x2 = aet.dscalar()
y = gammainc(x1, x2)
Expand All @@ -19,7 +27,15 @@ def test_gammainc_nan():
assert np.isnan(test_func(-1, -1))


def test_gammaincc_nan():
def test_gammaincc_python():
x1 = aet.dscalar()
x2 = aet.dscalar()
y = gammaincc(x1, x2)
test_func = function([x1, x2], y, mode=Mode("py"))
assert np.isclose(test_func(1, 2), sp.gammaincc(1, 2))


def test_gammaincc_nan_c():
x1 = aet.dscalar()
x2 = aet.dscalar()
y = gammaincc(x1, x2)
Expand All @@ -29,7 +45,7 @@ def test_gammaincc_nan():
assert np.isnan(test_func(-1, -1))


def test_gammal_nan():
def test_gammal_nan_c():
x1 = aet.dscalar()
x2 = aet.dscalar()
y = gammal(x1, x2)
Expand All @@ -39,7 +55,7 @@ def test_gammal_nan():
assert np.isnan(test_func(-1, -1))


def test_gammau_nan():
def test_gammau_nan_c():
x1 = aet.dscalar()
x2 = aet.dscalar()
y = gammau(x1, x2)
Expand Down
57 changes: 57 additions & 0 deletions tests/tensor/test_math_scipy.py
Original file line number Diff line number Diff line change
Expand Up @@ -272,10 +272,19 @@ def scipy_special_gammal(k, x):
),
)

_good_broadcast_binary_gamma_grad = dict(
normal=_good_broadcast_binary_gamma["normal"],
specific_branches=(
np.array([0.7, 11.0, 19.0]),
np.array([16.0, 31.0, 3.0]),
),
)

TestGammaIncBroadcast = makeBroadcastTester(
op=aet.gammainc,
expected=expected_gammainc,
good=_good_broadcast_binary_gamma,
grad=_good_broadcast_binary_gamma_grad,
eps=2e-8,
mode=mode_no_scipy,
)
Expand All @@ -293,6 +302,7 @@ def scipy_special_gammal(k, x):
op=aet.gammaincc,
expected=expected_gammaincc,
good=_good_broadcast_binary_gamma,
grad=_good_broadcast_binary_gamma_grad,
eps=2e-8,
mode=mode_no_scipy,
)
Expand All @@ -306,6 +316,53 @@ def scipy_special_gammal(k, x):
inplace=True,
)


def test_gammainc_ddk_tabulated_values():
# This test replicates part of the old STAN test:
# https://github.com/stan-dev/math/blob/21333bb70b669a1bd54d444ecbe1258078d33153/test/unit/math/prim/scal/fun/grad_reg_lower_inc_gamma_test.cpp
k, x = aet.scalars("k", "x")
gammainc_out = aet.gammainc(k, x)
gammaincc_ddk = aet.grad(gammainc_out, k)
f_grad = function([k, x], gammaincc_ddk)

for test_k, test_x, expected_ddk in (
(0.0001, 0, 0), # Limit condition
(0.0001, 0.0001, -8.62594024578651),
(0.0001, 6.2501, -0.0002705821702813008),
(0.0001, 12.5001, -2.775406821933887e-7),
(0.0001, 18.7501, -3.653379783274905e-10),
(0.0001, 25.0001, -5.352425240798134e-13),
(0.0001, 29.7501, -3.912723010174313e-15),
(4.7501, 0.0001, 0),
(4.7501, 6.2501, -0.1330287013623819),
(4.7501, 12.5001, -0.004712176128251421),
(4.7501, 18.7501, -0.00004898939126595217),
(4.7501, 25.0001, -3.098781566343336e-7),
(4.7501, 29.7501, -5.478399030091586e-9),
(9.5001, 0.0001, -5.869126325643798e-15),
(9.5001, 6.2501, -0.07717967485372858),
(9.5001, 12.5001, -0.07661095137424883),
(9.5001, 18.7501, -0.005594043337407605),
(9.5001, 25.0001, -0.0001410123206233104),
(9.5001, 29.7501, -5.75023943432906e-6),
(14.2501, 0.0001, -7.24495484418588e-15),
(14.2501, 6.2501, -0.003689474744087815),
(14.2501, 12.5001, -0.1008796179460247),
(14.2501, 18.7501, -0.05124664255610913),
(14.2501, 25.0001, -0.005115177188580634),
(14.2501, 29.7501, -0.0004793406401524598),
(19.0001, 0.0001, -8.26027539153394e-15),
(19.0001, 6.2501, -0.00003509660448733015),
(19.0001, 12.5001, -0.02624562607393565),
(19.0001, 18.7501, -0.0923829735092193),
(19.0001, 25.0001, -0.03641281853907181),
(19.0001, 29.7501, -0.007828749832965796),
):
np.testing.assert_allclose(
f_grad(test_k, test_x), expected_ddk, rtol=1e-5, atol=1e-14
)


TestGammaUBroadcast = makeBroadcastTester(
op=aet.gammau,
expected=expected_gammau,
Expand Down