Skip to content

Commit

Permalink
Reintroduce logit_p argument in Bernoulli
Browse files Browse the repository at this point in the history
  • Loading branch information
ricardoV94 committed Aug 17, 2021
1 parent 69a4e60 commit 4dd0538
Show file tree
Hide file tree
Showing 4 changed files with 33 additions and 21 deletions.
16 changes: 10 additions & 6 deletions pymc3/distributions/discrete.py
Original file line number Diff line number Diff line change
Expand Up @@ -333,8 +333,15 @@ class Bernoulli(Discrete):

@classmethod
def dist(cls, p=None, logit_p=None, *args, **kwargs):
if p is not None and logit_p is not None:
raise ValueError("Incompatible parametrization. Can't specify both p and logit_p.")
elif p is None and logit_p is None:
raise ValueError("Incompatible parametrization. Must specify either p or logit_p.")

if logit_p is not None:
p = at.sigmoid(logit_p)

p = at.as_tensor_variable(floatX(p))
# mode = at.cast(tround(p), "int8")
return super().dist([p], **kwargs)

def logp(value, p):
Expand All @@ -351,12 +358,9 @@ def logp(value, p):
-------
TensorVariable
"""
# if self._is_logit:
# lp = at.switch(value, self._logit_p, -self._logit_p)
# return -log1pexp(-lp)
# else:

return bound(
at.switch(value, at.log(p), at.log(1 - p)),
at.switch(value, at.log(p), at.log1p(-p)),
value >= 0,
value <= 1,
p >= 0,
Expand Down
34 changes: 22 additions & 12 deletions pymc3/tests/test_distributions.py
Original file line number Diff line number Diff line change
Expand Up @@ -1608,40 +1608,50 @@ def test_beta_binomial(self):
{"alpha": Rplus, "beta": Rplus, "n": NatSmall},
)

@pytest.mark.xfail(reason="Bernoulli logit_p not refactored yet")
def test_bernoulli_logit_p(self):
def test_bernoulli(self):
self.check_logp(
Bernoulli,
Bool,
{"logit_p": R},
lambda value, logit_p: sp.bernoulli.logpmf(value, scipy.special.expit(logit_p)),
{"p": Unit},
lambda value, p: sp.bernoulli.logpmf(value, p),
)
self.check_logcdf(
self.check_logp(
Bernoulli,
Bool,
{"logit_p": R},
lambda value, logit_p: sp.bernoulli.logcdf(value, scipy.special.expit(logit_p)),
lambda value, logit_p: sp.bernoulli.logpmf(value, scipy.special.expit(logit_p)),
)

def test_bernoulli(self):
self.check_logp(
self.check_logcdf(
Bernoulli,
Bool,
{"p": Unit},
lambda value, p: sp.bernoulli.logpmf(value, p),
lambda value, p: sp.bernoulli.logcdf(value, p),
)
self.check_logcdf(
Bernoulli,
Bool,
{"p": Unit},
lambda value, p: sp.bernoulli.logcdf(value, p),
{"logit_p": R},
lambda value, logit_p: sp.bernoulli.logcdf(value, scipy.special.expit(logit_p)),
)
self.check_selfconsistency_discrete_logcdf(
Bernoulli,
Bool,
{"p": Unit},
)

def test_bernoulli_wrong_arguments(self):
m = pm.Model()

msg = "Incompatible parametrization. Can't specify both p and logit_p"
with m:
with pytest.raises(ValueError, match=msg):
Bernoulli("x", p=0.5, logit_p=0)

msg = "Incompatible parametrization. Must specify either p or logit_p"
with m:
with pytest.raises(ValueError, match=msg):
Bernoulli("x")

def test_discrete_weibull(self):
self.check_logp(
DiscreteWeibull,
Expand Down
3 changes: 1 addition & 2 deletions pymc3/tests/test_distributions_random.py
Original file line number Diff line number Diff line change
Expand Up @@ -1025,11 +1025,10 @@ class TestBernoulli(BaseTestDistribution):
]


@pytest.mark.skip("Still not implemented")
class TestBernoulliLogitP(BaseTestDistribution):
pymc_dist = pm.Bernoulli
pymc_dist_params = {"logit_p": 1.0}
expected_rv_op_params = {"mean": 0, "sigma": 10.0}
expected_rv_op_params = {"p": expit(1.0)}
tests_to_run = ["check_pymc_params_match_rv_op"]


Expand Down
1 change: 0 additions & 1 deletion pymc3/tests/test_examples.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,6 @@ def get_city_data():
return data.merge(unique, "inner", on="fips")


@pytest.mark.xfail(reason="Bernoulli logitp distribution not refactored")
class TestARM5_4(SeededTest):
def build_model(self):
data = pd.read_csv(
Expand Down

0 comments on commit 4dd0538

Please sign in to comment.