From 2f6c33d4bbc002bfe76232b85346b4f6276357e9 Mon Sep 17 00:00:00 2001 From: Ricardo Date: Wed, 5 Jan 2022 12:43:32 +0100 Subject: [PATCH] Fix limit of `log1mexp` gradient at zero and improve numerical precision --- aesara/scalar/math.py | 8 +++++++- tests/tensor/test_math.py | 11 +++++++++++ 2 files changed, 18 insertions(+), 1 deletion(-) diff --git a/aesara/scalar/math.py b/aesara/scalar/math.py index 5697140914..20f45b7be5 100644 --- a/aesara/scalar/math.py +++ b/aesara/scalar/math.py @@ -20,10 +20,13 @@ complex_types, discrete_types, exp, + expm1, float64, float_types, + isinf, log, log1p, + switch, true_div, upcast, upgrade_to_float, @@ -1201,7 +1204,10 @@ def impl(self, x): def grad(self, inp, grads): (x,) = inp (gz,) = grads - return [gz * true_div(1.0, 1.0 - exp(-x))] + res = true_div(-1.0, expm1(-x)) + # Correct gradient at 0.0 to be -inf + res = switch(isinf(res), -np.inf, res) + return [gz * res] def c_code(self, node, name, inp, out, sub): (x,) = inp diff --git a/tests/tensor/test_math.py b/tests/tensor/test_math.py index 40d4bfa435..e00ad7d983 100644 --- a/tests/tensor/test_math.py +++ b/tests/tensor/test_math.py @@ -74,6 +74,7 @@ isnan, isnan_, log, + log1mexp, log1p, log2, log10, @@ -3343,3 +3344,13 @@ def test_pprint(): x = vector("x") y = aet_sum(x, axis=0) assert pprint(y) == "sum(x, axis=(0,))" + + +def test_log1mexp_grad_lim(): + x = dscalar("x") + grad_x = grad(log1mexp(x), [x])[0] + grad_x_fn = function([x], grad_x) + assert grad_x_fn(0.0) == -np.inf + assert grad_x_fn(-0.0) == -np.inf + assert grad_x_fn(-1e-309) == -np.inf + assert grad_x_fn(-1e-308) != -np.inf