Skip to content

Commit

Permalink
Dirichlet multinomial (continued) (#4373)
Browse files Browse the repository at this point in the history
* Add implementation of DM distribution.

* Fix class name mistake.

* Add DM dist to exported multivariate distributions.

* Export DirichletMultinomial in pymc3.distributions

As suggested in
#3639 (comment)

Also see:
#3639 (comment)
but this seems to be part of a broader discussion.

* Attempt at matching Multinomial initialization.

* Add some simple tests for DM.

* Correctly deal with 1d n and 2d alpha.

* Fix typo in DM random.

* Fix faulty tests for DM.

* Drop redundant initialization test for DM.

* Add test that DM is normalized for n=1 case.

* Add DM test case based on BetaBinomial.

* Update pymc3/distributions/multivariate.py

* - Infer shape by default (copied code from Dirichlet Distribution)
- Add default shape in `test_distributions_random.py`

* - Use size information in random method
- Change random unittests

* - Restore merge accidental deletions

* - Underscore missing

* - More merge cleaning

* Bring DirichletMultinomial initialization into alignment with Multinomial.

* Align all DM tests with Multinomial.

* Align DirichletMultinomial random implementation with Multinomial.

* Match DM random method to Multinomial implementation.

* Change alpha -> a
Remove _repr_latex_

* Run pre-commit

* Keep standard order of methods random and logp

* Update docstrings for valid input types.
Progress on batch test.

* Add new test to ensure DM matches BetaBinom

* Change DM alpha -> a in docstrings.

* Test two additional parameterization shapes in `test_dirichlet_multinomial_random`.

* Revert debugging comments.

* Revert unrelated changes.

* Fix minor Black inconsistency.

* Drop no-longer-functional reshaping code.

* Assert shape of random samples is as expected.

* Explicitly test random sample shapes, including batch dimensions.

* Sort imports.

* Simplify _random

It should be okay to not explicitly change the input dtype as in the multinomial, because the input to the np.random.dirichlet should be safe (it's fine to have float32 to float64 overflow from 1.00 to 1.01..., underflow from 0.01, to 0.0 would still be problematic, but we don't know if this is an issue yet...). The output of the numpy.random.dirichlet to numpy.random.multinomial should be safe since it is already in float64 by then. We still need to convert to the previous dtype, since numpy changes it by default.

size_ argument was no longer being used.

* Reorder tests more logically

* Refactor tests

Merged mode tests since shape must be given explicitly anyway
Moved test_dirichlet_multinomial_random to test_distributions_random.py and renamed it to test_dirichlet_multinomial_shapes

* Require shape argument

Also allow more forgiveness if user passes lists instead of arrays (WIP/suggestion only)

* Remove unused import `to_tuple`

* Simplify logic to handle list as input for `a`

* Raise ShapeError in random()

* Finish batch and repr unittests

* Add note about mode

* Tiny rewording

* Change mode to _defaultval

* Revert comment for Multinomial mode

* Update shape check logic

* Add DM to release notes.

* Minor docstring revisions as suggested by @AlexAndorra.

* Revise the revision.

* Add comment clarifying bounds checking in logp()

* Address review suggestions

* Update `matches_beta_binomial` to take into consideration float precision

* Add DM to multivariate distributions docs.

Co-authored-by: Byron Smith <[email protected]>
Co-authored-by: Colin <[email protected]>
  • Loading branch information
3 people authored Jan 16, 2021
1 parent 1769258 commit 2a3d9a3
Show file tree
Hide file tree
Showing 6 changed files with 390 additions and 1 deletion.
1 change: 1 addition & 0 deletions RELEASE-NOTES.md
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ It also brings some dreadfully awaited fixes, so be sure to go through the chang
- Add `logcdf` method to all univariate discrete distributions (see [#4387](https://github.com/pymc-devs/pymc3/pull/4387)).
- Add `random` method to `MvGaussianRandomWalk` (see [#4388](https://github.com/pymc-devs/pymc3/pull/4388))
- `AsymmetricLaplace` distribution added (see [#4392](https://github.com/pymc-devs/pymc3/pull/4392)).
- `DirichletMultinomial` distribution added (see [#4373](https://github.com/pymc-devs/pymc3/pull/4373)).

### Maintenance
- Fixed bug whereby partial traces returns after keyboard interrupt during parallel sampling had fewer draws than would've been available [#4318](https://github.com/pymc-devs/pymc3/pull/4318)
Expand Down
1 change: 1 addition & 0 deletions docs/source/api/distributions/multivariate.rst
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ Multivariate
LKJCorr
Multinomial
Dirichlet
DirichletMultinomial

.. automodule:: pymc3.distributions.multivariate
:members:
2 changes: 2 additions & 0 deletions pymc3/distributions/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,7 @@
from pymc3.distributions.mixture import Mixture, MixtureSameFamily, NormalMixture
from pymc3.distributions.multivariate import (
Dirichlet,
DirichletMultinomial,
KroneckerNormal,
LKJCholeskyCov,
LKJCorr,
Expand Down Expand Up @@ -155,6 +156,7 @@
"MvStudentT",
"Dirichlet",
"Multinomial",
"DirichletMultinomial",
"Wishart",
"WishartBartlett",
"LKJCholeskyCov",
Expand Down
158 changes: 157 additions & 1 deletion pymc3/distributions/multivariate.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,15 +42,17 @@
)
from pymc3.distributions.shape_utils import broadcast_dist_samples_to, to_tuple
from pymc3.distributions.special import gammaln, multigammaln
from pymc3.exceptions import ShapeError
from pymc3.math import kron_diag, kron_dot, kron_solve_lower, kronecker
from pymc3.model import Deterministic
from pymc3.theanof import floatX
from pymc3.theanof import floatX, intX

__all__ = [
"MvNormal",
"MvStudentT",
"Dirichlet",
"Multinomial",
"DirichletMultinomial",
"Wishart",
"WishartBartlett",
"LKJCorr",
Expand Down Expand Up @@ -690,6 +692,160 @@ def logp(self, x):
)


class DirichletMultinomial(Discrete):
R"""Dirichlet Multinomial log-likelihood.
Dirichlet mixture of Multinomials distribution, with a marginalized PMF.
.. math::
f(x \mid n, a) = \frac{\Gamma(n + 1)\Gamma(\sum a_k)}
{\Gamma(\n + \sum a_k)}
\prod_{k=1}^K
\frac{\Gamma(x_k + a_k)}
{\Gamma(x_k + 1)\Gamma(a_k)}
========== ===========================================
Support :math:`x \in \{0, 1, \ldots, n\}` such that
:math:`\sum x_i = n`
Mean :math:`n \frac{a_i}{\sum{a_k}}`
========== ===========================================
Parameters
----------
n : int or array
Total counts in each replicate. If n is an array its shape must be (N,)
with N = a.shape[0]
a : one- or two-dimensional array
Dirichlet parameter. Elements must be strictly positive.
The number of categories is given by the length of the last axis.
shape : integer tuple
Describes shape of distribution. For example if n=array([5, 10]), and
a=array([1, 1, 1]), shape should be (2, 3).
"""

def __init__(self, n, a, shape, *args, **kwargs):

super().__init__(shape=shape, defaults=("_defaultval",), *args, **kwargs)

n = intX(n)
a = floatX(a)
if len(self.shape) > 1:
self.n = tt.shape_padright(n)
self.a = tt.as_tensor_variable(a) if a.ndim > 1 else tt.shape_padleft(a)
else:
# n is a scalar, p is a 1d array
self.n = tt.as_tensor_variable(n)
self.a = tt.as_tensor_variable(a)

p = self.a / self.a.sum(-1, keepdims=True)

self.mean = self.n * p
# Mode is only an approximation. Exact computation requires a complex
# iterative algorithm as described in https://doi.org/10.1016/j.spl.2009.09.013
mode = tt.cast(tt.round(self.mean), "int32")
diff = self.n - tt.sum(mode, axis=-1, keepdims=True)
inc_bool_arr = tt.abs_(diff) > 0
mode = tt.inc_subtensor(mode[inc_bool_arr.nonzero()], diff[inc_bool_arr.nonzero()])
self._defaultval = mode

def _random(self, n, a, size=None):
# numpy will cast dirichlet and multinomial samples to float64 by default
original_dtype = a.dtype

# Thanks to the default shape handling done in generate_values, the last
# axis of n is a dummy axis that allows it to broadcast well with `a`
n = np.broadcast_to(n, size)
a = np.broadcast_to(a, size)
n = n[..., 0]

# np.random.multinomial needs `n` to be a scalar int and `a` a
# sequence so we semi flatten them and iterate over them
n_ = n.reshape([-1])
a_ = a.reshape([-1, a.shape[-1]])
p_ = np.array([np.random.dirichlet(aa) for aa in a_])
samples = np.array([np.random.multinomial(nn, pp) for nn, pp in zip(n_, p_)])
samples = samples.reshape(a.shape)

# We cast back to the original dtype
return samples.astype(original_dtype)

def random(self, point=None, size=None):
"""
Draw random values from Dirichlet-Multinomial distribution.
Parameters
----------
point: dict, optional
Dict of variable values on which random values are to be
conditioned (uses default point if not specified).
size: int, optional
Desired size of random sample (returns one sample if not
specified).
Returns
-------
array
"""
n, a = draw_values([self.n, self.a], point=point, size=size)
samples = generate_samples(
self._random,
n,
a,
dist_shape=self.shape,
size=size,
)

# If distribution is initialized with .dist(), valid init shape is not asserted.
# Under normal use in a model context valid init shape is asserted at start.
expected_shape = to_tuple(size) + to_tuple(self.shape)
sample_shape = tuple(samples.shape)
if sample_shape != expected_shape:
raise ShapeError(
f"Expected sample shape was {expected_shape} but got {sample_shape}. "
"This may reflect an invalid initialization shape."
)

return samples

def logp(self, value):
"""
Calculate log-probability of DirichletMultinomial distribution
at specified value.
Parameters
----------
value: integer array
Value for which log-probability is calculated.
Returns
-------
TensorVariable
"""
a = self.a
n = self.n
sum_a = a.sum(axis=-1, keepdims=True)

const = (gammaln(n + 1) + gammaln(sum_a)) - gammaln(n + sum_a)
series = gammaln(value + a) - (gammaln(value + 1) + gammaln(a))
result = const + series.sum(axis=-1, keepdims=True)
# Bounds checking to confirm parameters and data meet all constraints
# and that each observation value_i sums to n_i.
return bound(
result,
tt.all(tt.ge(value, 0)),
tt.all(tt.gt(a, 0)),
tt.all(tt.ge(n, 0)),
tt.all(tt.eq(value.sum(axis=-1, keepdims=True), n)),
broadcast_conditions=False,
)

def _distr_parameters_for_repr(self):
return ["n", "a"]


def posdef(AA):
try:
linalg.cholesky(AA)
Expand Down
Loading

0 comments on commit 2a3d9a3

Please sign in to comment.