Skip to content

Commit

Permalink
Refactor Mixture distribution for V4 (#5438)
Browse files Browse the repository at this point in the history
* Set `expand=True` when calling `change_size` in `SymbolicDistribution`

* Move NormalMixture tests to their own class

* Move mixture random tests from test_distributions_random to test_mixture

* Use specific imports in test_mixture

* Reenable Mixture tests in pytest workflow

* Refactor Mixture distribution

Mixtures now use an `OpFromGraph` that encapsulates the Aesara random method. This is used so that logp can be easily dispatched to the distribution without requiring involved pattern matching. The Mixture random and logp methods now fully respect the support dimensionality of its components, whereas previously only the logp method did, leading to inconsistencies between the two methods.

In the case where the weights (or size) indicate the need for more draws than what is given by the component distributions, the latter are resized to ensure there are no repeated draws.

This refactoring forces Mixture components to be basic RandomVariables, meaning that nested Mixtures or Mixtures of Symbolic distributions (like Censored) are not currently possible.

Co-authored-by: Larry Dong <[email protected]>

* Add warning when using iterable with single Mixture component

* Update Mixture docstrings

* Emphasize equivalency between iterable of components and single batched component
* Add example with mixture of two distinct distributions
* Add example with multivariate components

* Refactor NormalMixture

* Refactor TestMixtureVsLatent

The two tests relied on implicit behavior of V3, where the dimensionality of the weights implied the support dimension of mixture distribution. This, however, led to inconsistent behavior between the random method and the logp, as the latter did not enforce this assumption, and did not distinguish if values were mixed across the implied support dimension.

In this refactoring, the support dimensionality of the component variables determines the dimensionality of the mixture distribution, regardless of the weights. This leads to consistent behavior between the random and logp methods as asserted by the new checks.

Future work will explore allowing the user to specify an artificial support dimensionality that is higher than the one implied by the component distributions, but this is for now not possible.

* Remove MixtureSameFamily

Behavior is now implemented in Mixture

* Add Mixture moments

* Update release notes

Co-authored-by: Larry Dong <[email protected]>
  • Loading branch information
ricardoV94 and larryshamalama authored Mar 9, 2022
1 parent 7d4162c commit 620b11d
Show file tree
Hide file tree
Showing 9 changed files with 1,241 additions and 1,138 deletions.
1 change: 1 addition & 0 deletions .github/workflows/pytest.yml
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,7 @@ jobs:
pymc/tests/test_transforms.py
pymc/tests/test_smc.py
pymc/tests/test_bart.py
pymc/tests/test_mixture.py
- |
pymc/tests/test_parallel_sampling.py
Expand Down
3 changes: 2 additions & 1 deletion RELEASE-NOTES.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ Instead update the vNext section until 4.0.0 is out.
### Not-yet working features
We plan to get these working again, but at this point their inner workings have not been refactored.
- Timeseries distributions (see [#4642](https://github.com/pymc-devs/pymc/issues/4642))
- Mixture distributions (see [#4781](https://github.com/pymc-devs/pymc/issues/4781))
- Nested Mixture distributions (see [#5533](https://github.com/pymc-devs/pymc/issues/5533))
- Elliptical slice sampling (see [#5137](https://github.com/pymc-devs/pymc/issues/5137))
- `BaseStochasticGradient` (see [#5138](https://github.com/pymc-devs/pymc/issues/5138))
- `pm.sample_posterior_predictive_w` (see [#4807](https://github.com/pymc-devs/pymc/issues/4807))
Expand Down Expand Up @@ -72,6 +72,7 @@ All of the above apply to:
- In the gp.utils file, the `kmeans_inducing_points` function now passes through `kmeans_kwargs` to scipy's k-means function.
- The function `replace_with_values` function has been added to `gp.utils`.
- `MarginalSparse` has been renamed `MarginalApprox`.
- Removed `MixtureSameFamily`. `Mixture` is now capable of handling batched multivariate components (see [#5438](https://github.com/pymc-devs/pymc/pull/5438)).
- ...

### Expected breaks
Expand Down
1 change: 0 additions & 1 deletion docs/source/api/distributions/mixture.rst
Original file line number Diff line number Diff line change
Expand Up @@ -8,4 +8,3 @@ Mixture

Mixture
NormalMixture
MixtureSameFamily
3 changes: 1 addition & 2 deletions pymc/distributions/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@
NoDistribution,
SymbolicDistribution,
)
from pymc.distributions.mixture import Mixture, MixtureSameFamily, NormalMixture
from pymc.distributions.mixture import Mixture, NormalMixture
from pymc.distributions.multivariate import (
CAR,
Dirichlet,
Expand Down Expand Up @@ -180,7 +180,6 @@
"SkewNormal",
"Mixture",
"NormalMixture",
"MixtureSameFamily",
"Triangular",
"DiscreteWeibull",
"Gumbel",
Expand Down
1 change: 1 addition & 0 deletions pymc/distributions/distribution.py
Original file line number Diff line number Diff line change
Expand Up @@ -500,6 +500,7 @@ def __new__(
rv_out = cls.change_size(
rv=rv_out,
new_size=resize_shape,
expand=True,
)

rv_out = model.register_rv(
Expand Down
Loading

0 comments on commit 620b11d

Please sign in to comment.