diff --git a/pymc3/distributions/discrete.py b/pymc3/distributions/discrete.py index c172a2f29a..8ab3b06230 100644 --- a/pymc3/distributions/discrete.py +++ b/pymc3/distributions/discrete.py @@ -333,8 +333,15 @@ class Bernoulli(Discrete): @classmethod def dist(cls, p=None, logit_p=None, *args, **kwargs): + if p is not None and logit_p is not None: + raise ValueError("Incompatible parametrization. Can't specify both p and logit_p.") + elif p is None and logit_p is None: + raise ValueError("Incompatible parametrization. Must specify either p or logit_p.") + + if logit_p is not None: + p = at.sigmoid(logit_p) + p = at.as_tensor_variable(floatX(p)) - # mode = at.cast(tround(p), "int8") return super().dist([p], **kwargs) def logp(value, p): @@ -351,12 +358,9 @@ def logp(value, p): ------- TensorVariable """ - # if self._is_logit: - # lp = at.switch(value, self._logit_p, -self._logit_p) - # return -log1pexp(-lp) - # else: + return bound( - at.switch(value, at.log(p), at.log(1 - p)), + at.switch(value, at.log(p), at.log1p(-p)), value >= 0, value <= 1, p >= 0, diff --git a/pymc3/tests/test_distributions.py b/pymc3/tests/test_distributions.py index 5c46e14105..7de42a104c 100644 --- a/pymc3/tests/test_distributions.py +++ b/pymc3/tests/test_distributions.py @@ -1608,33 +1608,30 @@ def test_beta_binomial(self): {"alpha": Rplus, "beta": Rplus, "n": NatSmall}, ) - @pytest.mark.xfail(reason="Bernoulli logit_p not refactored yet") - def test_bernoulli_logit_p(self): + def test_bernoulli(self): self.check_logp( Bernoulli, Bool, - {"logit_p": R}, - lambda value, logit_p: sp.bernoulli.logpmf(value, scipy.special.expit(logit_p)), + {"p": Unit}, + lambda value, p: sp.bernoulli.logpmf(value, p), ) - self.check_logcdf( + self.check_logp( Bernoulli, Bool, {"logit_p": R}, - lambda value, logit_p: sp.bernoulli.logcdf(value, scipy.special.expit(logit_p)), + lambda value, logit_p: sp.bernoulli.logpmf(value, scipy.special.expit(logit_p)), ) - - def test_bernoulli(self): - self.check_logp( + self.check_logcdf( Bernoulli, Bool, {"p": Unit}, - lambda value, p: sp.bernoulli.logpmf(value, p), + lambda value, p: sp.bernoulli.logcdf(value, p), ) self.check_logcdf( Bernoulli, Bool, - {"p": Unit}, - lambda value, p: sp.bernoulli.logcdf(value, p), + {"logit_p": R}, + lambda value, logit_p: sp.bernoulli.logcdf(value, scipy.special.expit(logit_p)), ) self.check_selfconsistency_discrete_logcdf( Bernoulli, @@ -1642,6 +1639,19 @@ def test_bernoulli(self): {"p": Unit}, ) + def test_bernoulli_wrong_arguments(self): + m = pm.Model() + + msg = "Incompatible parametrization. Can't specify both p and logit_p" + with m: + with pytest.raises(ValueError, match=msg): + Bernoulli("x", p=0.5, logit_p=0) + + msg = "Incompatible parametrization. Must specify either p or logit_p" + with m: + with pytest.raises(ValueError, match=msg): + Bernoulli("x") + def test_discrete_weibull(self): self.check_logp( DiscreteWeibull, diff --git a/pymc3/tests/test_distributions_random.py b/pymc3/tests/test_distributions_random.py index fd6e95c430..51e3ccce52 100644 --- a/pymc3/tests/test_distributions_random.py +++ b/pymc3/tests/test_distributions_random.py @@ -1025,11 +1025,10 @@ class TestBernoulli(BaseTestDistribution): ] -@pytest.mark.skip("Still not implemented") class TestBernoulliLogitP(BaseTestDistribution): pymc_dist = pm.Bernoulli pymc_dist_params = {"logit_p": 1.0} - expected_rv_op_params = {"mean": 0, "sigma": 10.0} + expected_rv_op_params = {"p": expit(1.0)} tests_to_run = ["check_pymc_params_match_rv_op"] diff --git a/pymc3/tests/test_examples.py b/pymc3/tests/test_examples.py index 160712536f..f44add09ee 100644 --- a/pymc3/tests/test_examples.py +++ b/pymc3/tests/test_examples.py @@ -51,7 +51,6 @@ def get_city_data(): return data.merge(unique, "inner", on="fips") -@pytest.mark.xfail(reason="Bernoulli logitp distribution not refactored") class TestARM5_4(SeededTest): def build_model(self): data = pd.read_csv(