Skip to content

Commit

Permalink
Add tests for tt.switch related bugs (#4448)
Browse files Browse the repository at this point in the history
* Add tests for edge cases

* Add release-note
  • Loading branch information
ricardoV94 authored Jan 29, 2021
1 parent 03d7af5 commit 2d3ec8f
Show file tree
Hide file tree
Showing 2 changed files with 29 additions and 0 deletions.
1 change: 1 addition & 0 deletions RELEASE-NOTES.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

### Maintenance
- We upgraded to `Theano-PyMC v1.1.2` which [includes bugfixes](https://github.com/pymc-devs/aesara/compare/rel-1.1.0...rel-1.1.2) for warning floods and compiledir locking (see [#4444](https://github.com/pymc-devs/pymc3/pull/4444))
- `Theano-PyMC v1.1.2` also fixed an important issue in `tt.switch` that affected the behavior of several PyMC distributions, including at least the `Bernoulli` and `TruncatedNormal` (see[#4448](https://github.com/pymc-devs/pymc3/pull/4448))
- `math.log1mexp_numpy` no longer raises RuntimeWarning when given very small inputs. These were commonly observed during NUTS sampling (see [#4428](https://github.com/pymc-devs/pymc3/pull/4428)).

## PyMC3 3.11.0 (21 January 2021)
Expand Down
28 changes: 28 additions & 0 deletions pymc3/tests/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -366,6 +366,34 @@ def test_tensor_type_conversion(self):

assert m["x2_missing"].type == gf._extra_vars_shared["x2_missing"].type

def test_theano_switch_broadcast_edge_cases(self):
# Tests against two subtle issues related to a previous bug in Theano where tt.switch would not
# always broadcast tensors with single values https://github.com/pymc-devs/aesara/issues/270

# Known issue 1: https://github.com/pymc-devs/pymc3/issues/4389
data = np.zeros(10)
with pm.Model() as m:
p = pm.Beta("p", 1, 1)
obs = pm.Bernoulli("obs", p=p, observed=data)
# Assert logp is correct
npt.assert_allclose(
obs.logp(m.test_point),
np.log(0.5) * 10,
)

# Known issue 2: https://github.com/pymc-devs/pymc3/issues/4417
# fmt: off
data = np.array([
1.35202174, -0.83690274, 1.11175166, 1.29000367, 0.21282749,
0.84430966, 0.24841369, 0.81803141, 0.20550244, -0.45016253,
])
# fmt: on
with pm.Model() as m:
mu = pm.Normal("mu", 0, 5)
obs = pm.TruncatedNormal("obs", mu=mu, sigma=1, lower=-1, upper=2, observed=data)
# Assert dlogp is correct
npt.assert_allclose(m.dlogp([mu])({"mu": 0}), 2.499424682024436, rtol=1e-5)


def test_multiple_observed_rv():
"Test previously buggy MultiObservedRV comparison code."
Expand Down

0 comments on commit 2d3ec8f

Please sign in to comment.