Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Dirichlet multinomial (continued) #4373

Merged
merged 57 commits into from
Jan 16, 2021
Merged
Show file tree
Hide file tree
Changes from 56 commits
Commits
Show all changes
57 commits
Select commit Hold shift + click to select a range
b7492d2
Add implementation of DM distribution.
bsmith89 Oct 1, 2019
2106f7c
Fix class name mistake.
bsmith89 Oct 2, 2019
487fc8a
Add DM dist to exported multivariate distributions.
bsmith89 Oct 2, 2019
24d7ec8
Export DirichletMultinomial in pymc3.distributions
bsmith89 Dec 7, 2019
4fbd1d9
Attempt at matching Multinomial initialization.
bsmith89 Dec 16, 2019
685a428
Add some simple tests for DM.
bsmith89 Dec 16, 2019
ad8e77e
Correctly deal with 1d n and 2d alpha.
bsmith89 Dec 16, 2019
8fa717a
Fix typo in DM random.
bsmith89 Dec 16, 2019
4db6b1c
Fix faulty tests for DM.
bsmith89 Dec 16, 2019
01d359b
Drop redundant initialization test for DM.
bsmith89 Dec 16, 2019
4892355
Add test that DM is normalized for n=1 case.
bsmith89 Dec 16, 2019
bc5f3bf
Add DM test case based on BetaBinomial.
bsmith89 Dec 16, 2019
ffa705c
Update pymc3/distributions/multivariate.py
ColCarroll Sep 19, 2020
c801ef1
- Infer shape by default (copied code from Dirichlet Distribution)
ricardoV94 Dec 22, 2020
c8921ee
- Use size information in random method
ricardoV94 Dec 22, 2020
e801568
- Restore merge accidental deletions
ricardoV94 Dec 22, 2020
3483ab5
- Underscore missing
ricardoV94 Dec 22, 2020
23ba2e4
- More merge cleaning
ricardoV94 Dec 22, 2020
fe018ec
Bring DirichletMultinomial initialization into alignment with Multino…
bsmith89 Dec 29, 2020
25fd41a
Align all DM tests with Multinomial.
bsmith89 Jan 1, 2021
28b0a62
Align DirichletMultinomial random implementation with Multinomial.
bsmith89 Jan 1, 2021
d363f96
Match DM random method to Multinomial implementation.
bsmith89 Jan 3, 2021
9b6828c
Change alpha -> a
ricardoV94 Jan 4, 2021
d438dfc
Run pre-commit
ricardoV94 Jan 4, 2021
dde5c45
Keep standard order of methods random and logp
ricardoV94 Jan 4, 2021
49b432d
Update docstrings for valid input types.
ricardoV94 Jan 4, 2021
83fbda6
Add new test to ensure DM matches BetaBinom
ricardoV94 Jan 4, 2021
9748a9d
Change DM alpha -> a in docstrings.
bsmith89 Jan 4, 2021
7b20680
Test two additional parameterization shapes in `test_dirichlet_multin…
bsmith89 Jan 4, 2021
66c83b0
Revert debugging comments.
bsmith89 Jan 4, 2021
672ef56
Revert unrelated changes.
bsmith89 Jan 4, 2021
2d5d20e
Fix minor Black inconsistency.
bsmith89 Jan 4, 2021
922515b
Drop no-longer-functional reshaping code.
bsmith89 Jan 5, 2021
aa89d0a
Assert shape of random samples is as expected.
bsmith89 Jan 5, 2021
2343004
Explicitly test random sample shapes, including batch dimensions.
bsmith89 Jan 5, 2021
a08bc51
Sort imports.
bsmith89 Jan 5, 2021
22beead
Simplify _random
ricardoV94 Jan 6, 2021
7bad831
Reorder tests more logically
ricardoV94 Jan 6, 2021
9bbddba
Refactor tests
ricardoV94 Jan 6, 2021
086459f
Require shape argument
ricardoV94 Jan 6, 2021
f8499d3
Remove unused import `to_tuple`
ricardoV94 Jan 6, 2021
1cd2a9f
Simplify logic to handle list as input for `a`
ricardoV94 Jan 6, 2021
ef00fe1
Raise ShapeError in random()
ricardoV94 Jan 10, 2021
f2ac8e9
Finish batch and repr unittests
ricardoV94 Jan 10, 2021
f5dcdc3
Add note about mode
ricardoV94 Jan 10, 2021
c4e017a
Tiny rewording
ricardoV94 Jan 10, 2021
d46dd50
Change mode to _defaultval
ricardoV94 Jan 12, 2021
3ab518d
Revert comment for Multinomial mode
ricardoV94 Jan 12, 2021
cdd6d27
Update shape check logic
ricardoV94 Jan 12, 2021
24447a4
Add DM to release notes.
bsmith89 Jan 12, 2021
c5e9b67
Merge branch 'master' into dirichlet_multinomial_fork
bsmith89 Jan 12, 2021
0bd6c3d
Minor docstring revisions as suggested by @AlexAndorra.
bsmith89 Jan 14, 2021
f919456
Revise the revision.
bsmith89 Jan 14, 2021
c082f00
Add comment clarifying bounds checking in logp()
bsmith89 Jan 14, 2021
ea0ae59
Address review suggestions
ricardoV94 Jan 15, 2021
b451967
Update `matches_beta_binomial` to take into consideration float preci…
ricardoV94 Jan 15, 2021
128d5cf
Add DM to multivariate distributions docs.
bsmith89 Jan 16, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions RELEASE-NOTES.md
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ It also brings some dreadfully awaited fixes, so be sure to go through the chang
- Add `logcdf` method to all univariate discrete distributions (see [#4387](https://github.com/pymc-devs/pymc3/pull/4387)).
- Add `random` method to `MvGaussianRandomWalk` (see [#4388](https://github.com/pymc-devs/pymc3/pull/4388))
- `AsymmetricLaplace` distribution added (see [#4392](https://github.com/pymc-devs/pymc3/pull/4392)).
- `DirichletMultinomial` distribution added (see [#4373](https://github.com/pymc-devs/pymc3/pull/4373)).

### Maintenance
- Fixed bug whereby partial traces returns after keyboard interrupt during parallel sampling had fewer draws than would've been available [#4318](https://github.com/pymc-devs/pymc3/pull/4318)
Expand Down
2 changes: 2 additions & 0 deletions pymc3/distributions/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,7 @@
from pymc3.distributions.mixture import Mixture, MixtureSameFamily, NormalMixture
from pymc3.distributions.multivariate import (
Dirichlet,
DirichletMultinomial,
KroneckerNormal,
LKJCholeskyCov,
LKJCorr,
Expand Down Expand Up @@ -155,6 +156,7 @@
"MvStudentT",
"Dirichlet",
"Multinomial",
"DirichletMultinomial",
"Wishart",
"WishartBartlett",
"LKJCholeskyCov",
Expand Down
158 changes: 157 additions & 1 deletion pymc3/distributions/multivariate.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,15 +42,17 @@
)
from pymc3.distributions.shape_utils import broadcast_dist_samples_to, to_tuple
from pymc3.distributions.special import gammaln, multigammaln
from pymc3.exceptions import ShapeError
from pymc3.math import kron_diag, kron_dot, kron_solve_lower, kronecker
from pymc3.model import Deterministic
from pymc3.theanof import floatX
from pymc3.theanof import floatX, intX

__all__ = [
"MvNormal",
"MvStudentT",
"Dirichlet",
"Multinomial",
"DirichletMultinomial",
"Wishart",
"WishartBartlett",
"LKJCorr",
Expand Down Expand Up @@ -690,6 +692,160 @@ def logp(self, x):
)


class DirichletMultinomial(Discrete):
R"""Dirichlet Multinomial log-likelihood.

Dirichlet mixture of Multinomials distribution, with a marginalized PMF.

AlexAndorra marked this conversation as resolved.
Show resolved Hide resolved
.. math::

f(x \mid n, a) = \frac{\Gamma(n + 1)\Gamma(\sum a_k)}
{\Gamma(\n + \sum a_k)}
\prod_{k=1}^K
\frac{\Gamma(x_k + a_k)}
{\Gamma(x_k + 1)\Gamma(a_k)}

========== ===========================================
Support :math:`x \in \{0, 1, \ldots, n\}` such that
:math:`\sum x_i = n`
Mean :math:`n \frac{a_i}{\sum{a_k}}`
========== ===========================================

Parameters
----------
n : int or array
Total counts in each replicate. If n is an array its shape must be (N,)
with N = a.shape[0]

a : one- or two-dimensional array
Dirichlet parameter. Elements must be strictly positive.
The number of categories is given by the length of the last axis.

shape : integer tuple
Sayam753 marked this conversation as resolved.
Show resolved Hide resolved
Describes shape of distribution. For example if n=array([5, 10]), and
a=array([1, 1, 1]), shape should be (2, 3).
"""

def __init__(self, n, a, shape, *args, **kwargs):

super().__init__(shape=shape, defaults=("_defaultval",), *args, **kwargs)
Comment on lines +730 to +731
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Dirichlet distribution makes use of get_test_value function to compute its distribution shape. Can we use get_test_value to determine shape here as well? Doing so, will even us help in #4379.

Ping @brandonwillard to ask how does get_test_value function work?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think it is as simple, since the shape can be influenced by the n parameter as well as the a, whereas in the Dirichlet all information is necessarily contained in the a (when shape is not specified)

Copy link
Member Author

@ricardoV94 ricardoV94 Jan 15, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also the Dirichlet functionality is wrapped in a DeprecationWarning (even though I don't seem to be able to trigger it), which suggests that they wanted to abandon that approach at some point.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ricardoV94 , just a follow up, it indeed makes sense to avoid the use of get_test_value function as also discussed here #4000 (comment)


n = intX(n)
a = floatX(a)
if len(self.shape) > 1:
self.n = tt.shape_padright(n)
self.a = tt.as_tensor_variable(a) if a.ndim > 1 else tt.shape_padleft(a)
else:
# n is a scalar, p is a 1d array
self.n = tt.as_tensor_variable(n)
self.a = tt.as_tensor_variable(a)
ricardoV94 marked this conversation as resolved.
Show resolved Hide resolved

p = self.a / self.a.sum(-1, keepdims=True)

self.mean = self.n * p
# Mode is only an approximation. Exact computation requires a complex
# iterative algorithm as described in https://doi.org/10.1016/j.spl.2009.09.013
mode = tt.cast(tt.round(self.mean), "int32")
diff = self.n - tt.sum(mode, axis=-1, keepdims=True)
inc_bool_arr = tt.abs_(diff) > 0
mode = tt.inc_subtensor(mode[inc_bool_arr.nonzero()], diff[inc_bool_arr.nonzero()])
self._defaultval = mode

def _random(self, n, a, size=None):
# numpy will cast dirichlet and multinomial samples to float64 by default
original_dtype = a.dtype

# Thanks to the default shape handling done in generate_values, the last
# axis of n is a dummy axis that allows it to broadcast well with `a`
n = np.broadcast_to(n, size)
a = np.broadcast_to(a, size)
n = n[..., 0]

# np.random.multinomial needs `n` to be a scalar int and `a` a
# sequence so we semi flatten them and iterate over them
n_ = n.reshape([-1])
a_ = a.reshape([-1, a.shape[-1]])
p_ = np.array([np.random.dirichlet(aa) for aa in a_])
samples = np.array([np.random.multinomial(nn, pp) for nn, pp in zip(n_, p_)])
samples = samples.reshape(a.shape)
Sayam753 marked this conversation as resolved.
Show resolved Hide resolved

# We cast back to the original dtype
return samples.astype(original_dtype)

def random(self, point=None, size=None):
"""
Draw random values from Dirichlet-Multinomial distribution.

Parameters
----------
point: dict, optional
Dict of variable values on which random values are to be
conditioned (uses default point if not specified).
AlexAndorra marked this conversation as resolved.
Show resolved Hide resolved
size: int, optional
Desired size of random sample (returns one sample if not
specified).

Returns
-------
array
"""
n, a = draw_values([self.n, self.a], point=point, size=size)
samples = generate_samples(
self._random,
n,
a,
dist_shape=self.shape,
size=size,
)

# If distribution is initialized with .dist(), valid init shape is not asserted.
# Under normal use in a model context valid init shape is asserted at start.
expected_shape = to_tuple(size) + to_tuple(self.shape)
sample_shape = tuple(samples.shape)
if sample_shape != expected_shape:
raise ShapeError(
f"Expected sample shape was {expected_shape} but got {sample_shape}. "
"This may reflect an invalid initialization shape."
)

return samples

def logp(self, value):
"""
Calculate log-probability of DirichletMultinomial distribution
at specified value.

Parameters
----------
value: integer array
Value for which log-probability is calculated.

Returns
-------
TensorVariable
"""
a = self.a
n = self.n
sum_a = a.sum(axis=-1, keepdims=True)

const = (gammaln(n + 1) + gammaln(sum_a)) - gammaln(n + sum_a)
series = gammaln(value + a) - (gammaln(value + 1) + gammaln(a))
result = const + series.sum(axis=-1, keepdims=True)
# Bounds checking to confirm parameters and data meet all constraints
# and that each observation value_i sums to n_i.
return bound(
result,
tt.all(tt.ge(value, 0)),
tt.all(tt.gt(a, 0)),
tt.all(tt.ge(n, 0)),
tt.all(tt.eq(value.sum(axis=-1, keepdims=True), n)),
broadcast_conditions=False,
)

def _distr_parameters_for_repr(self):
ricardoV94 marked this conversation as resolved.
Show resolved Hide resolved
return ["n", "a"]
ricardoV94 marked this conversation as resolved.
Show resolved Hide resolved


def posdef(AA):
try:
linalg.cholesky(AA)
Expand Down
Loading