Skip to content

Commit

Permalink
Deprecate eps argument in math.invlogit
Browse files Browse the repository at this point in the history
  • Loading branch information
ricardoV94 committed Aug 17, 2021
1 parent 4dd0538 commit 3fb5807
Show file tree
Hide file tree
Showing 3 changed files with 24 additions and 3 deletions.
2 changes: 1 addition & 1 deletion pymc3/distributions/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,7 +170,7 @@ class LogOdds(ElemwiseTransform):
name = "logodds"

def backward(self, rv_var, rv_value):
return invlogit(rv_value, 0.0)
return invlogit(rv_value)

def forward(self, rv_var, rv_value):
return logit(rv_value)
Expand Down
10 changes: 8 additions & 2 deletions pymc3/math.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,9 +200,15 @@ def logdiffexp_numpy(a, b):
return a + log1mexp_numpy(b - a, negative_input=True)


def invlogit(x, eps=sys.float_info.epsilon):
def invlogit(x, eps=None):
"""The inverse of the logit function, 1 / (1 + exp(-x))."""
return (1.0 - 2.0 * eps) / (1.0 + at.exp(-x)) + eps
if eps is not None:
warnings.warn(
"pymc3.math.invlogit no longer supports the ``eps`` argument and it will be ignored.",
DeprecationWarning,
stacklevel=2,
)
return at.sigmoid(x)


def logbern(log_p):
Expand Down
15 changes: 15 additions & 0 deletions pymc3/tests/test_math.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
LogDet,
cartesian,
expand_packed_triangular,
invlogit,
invprobit,
kron_dot,
kron_solve_lower,
Expand Down Expand Up @@ -250,3 +251,17 @@ def test_expand_packed_triangular():
assert np.all(expand_upper.eval({packed: upper_packed}) == upper)
assert np.all(expand_diag_lower.eval({packed: lower_packed}) == floatX(np.diag(vals)))
assert np.all(expand_diag_upper.eval({packed: upper_packed}) == floatX(np.diag(vals)))


def test_invlogit_deprecation_warning():
with pytest.warns(
DeprecationWarning,
match="pymc3.math.invlogit no longer supports the",
):
res = invlogit(np.array(-750.0), 1e-5).eval()

with pytest.warns(None) as record:
res_zero_eps = invlogit(np.array(-750.0)).eval()
assert not record

assert np.isclose(res, res_zero_eps)

0 comments on commit 3fb5807

Please sign in to comment.