Skip to content

Commit

Permalink
Allow broadcasting of mixture components
Browse files Browse the repository at this point in the history
  • Loading branch information
ricardoV94 committed Mar 11, 2022
1 parent 7a80d49 commit 27fd32f
Show file tree
Hide file tree
Showing 2 changed files with 42 additions and 11 deletions.
20 changes: 9 additions & 11 deletions pymc/distributions/mixture.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,7 +176,6 @@ def dist(cls, w, comp_dists, **kwargs):
)

# Check that components are not associated with a registered variable in the model
components_ndim = set()
components_ndim_supp = set()
for dist in comp_dists:
# TODO: Allow these to not be a RandomVariable as long as we can call `ndim_supp` on them
Expand All @@ -188,14 +187,8 @@ def dist(cls, w, comp_dists, **kwargs):
f"Component dist must be a distribution created via the `.dist()` API, got {type(dist)}"
)
check_dist_not_registered(dist)
components_ndim.add(dist.ndim)
components_ndim_supp.add(dist.owner.op.ndim_supp)

if len(components_ndim) > 1:
raise ValueError(
f"Mixture components must all have the same dimensionality, got {components_ndim}"
)

if len(components_ndim_supp) > 1:
raise ValueError(
f"Mixture components must all have the same support dimensionality, got {components_ndim_supp}"
Expand All @@ -214,13 +207,18 @@ def rv_op(cls, weights, *components, size=None, rngs=None):
# Create new rng for the mix_indexes internal RV
mix_indexes_rng = aesara.shared(np.random.default_rng())

single_component = len(components) == 1
ndim_supp = components[0].owner.op.ndim_supp

if size is not None:
components = cls._resize_components(size, *components)
elif not single_component:
# We might need to broadcast components when size is not specified
shape = tuple(at.broadcast_shape(*components))
size = shape[: len(shape) - ndim_supp]
components = cls._resize_components(size, *components)

single_component = len(components) == 1

# Extract support and replication ndims from components and weights
ndim_supp = components[0].owner.op.ndim_supp
# Extract replication ndims from components and weights
ndim_batch = components[0].ndim - ndim_supp
if single_component:
# One dimension is taken by the mixture axis in the single component case
Expand Down
33 changes: 33 additions & 0 deletions pymc/tests/test_mixture.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import scipy.stats as st

from aesara import tensor as at
from aesara.tensor.random.op import RandomVariable
from numpy.testing import assert_allclose
from scipy.special import logsumexp

Expand Down Expand Up @@ -677,6 +678,38 @@ def test_mixture_dtype(self):
).dtype
assert mix_dtype == aesara.config.floatX

@pytest.mark.parametrize(
"comp_dists, expected_shape",
[
(
[
Normal.dist([[0, 0, 0], [0, 0, 0]]),
Normal.dist([0, 0, 0]),
Normal.dist([0]),
],
(2, 3),
),
(
[
Dirichlet.dist([[1, 1, 1], [1, 1, 1]]),
Dirichlet.dist([1, 1, 1]),
],
(2, 3),
),
],
)
def test_broadcast_components(self, comp_dists, expected_shape):
n_dists = len(comp_dists)
mix = Mixture.dist(w=np.ones(n_dists) / n_dists, comp_dists=comp_dists)
mix_eval = mix.eval()
assert tuple(mix_eval.shape) == expected_shape
assert np.unique(mix_eval).size == mix.eval().size
for comp_dist in mix.owner.inputs[2:]:
# We check that the input is a "pure" RandomVariable and not a broadcast
# operation. This confirms that all draws will be unique
assert isinstance(comp_dist.owner.op, RandomVariable)
assert tuple(comp_dist.shape.eval()) == expected_shape


class TestNormalMixture(SeededTest):
def test_normal_mixture_sampling(self):
Expand Down

0 comments on commit 27fd32f

Please sign in to comment.