Skip to content

Commit

Permalink
Tested exact Dirichlet.logp values againt scipy implementation
Browse files Browse the repository at this point in the history
Given a mention in RELEASE-NOTES.md
  • Loading branch information
Sayam753 committed Feb 6, 2021
1 parent 4d7c192 commit 1333c23
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 2 deletions.
1 change: 1 addition & 0 deletions RELEASE-NOTES.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
- `math.log1mexp_numpy` no longer raises RuntimeWarning when given very small inputs. These were commonly observed during NUTS sampling (see [#4428](https://github.com/pymc-devs/pymc3/pull/4428)).
- `ScalarSharedVariable` can now be used as an input to other RVs directly (see [#4445](https://github.com/pymc-devs/pymc3/pull/4445)).
- `pm.sample` and `pm.find_MAP` no longer change the `start` argument (see [#4458](https://github.com/pymc-devs/pymc3/pull/4458)).
- Fixed `Dirichlet.logp` method to work with unit batch or event shapes (see [#4454](https://github.com/pymc-devs/pymc3/pull/4454)).

## PyMC3 3.11.0 (21 January 2021)

Expand Down
6 changes: 4 additions & 2 deletions pymc3/tests/test_distributions.py
Original file line number Diff line number Diff line change
Expand Up @@ -1702,8 +1702,10 @@ def test_dirichlet_with_batch_shapes(self, dist_shape):
with pm.Model() as model:
d = pm.Dirichlet("a", a=a)

value = d.tag.test_value
assert_almost_equal(dirichlet_logpdf(value, a), d.distribution.logp(value).eval().sum())
pymc3_res = d.distribution.logp(d.tag.test_value).eval()
for idx in np.ndindex(a.shape[:-1]):
scipy_res = scipy.stats.dirichlet(a[idx]).logpdf(d.tag.test_value[idx])
assert_almost_equal(pymc3_res[idx], scipy_res)

def test_dirichlet_shape(self):
a = tt.as_tensor_variable(np.r_[1, 2])
Expand Down

0 comments on commit 1333c23

Please sign in to comment.