From 78f1786d30a4784d0f541d19bc60acfafe27090b Mon Sep 17 00:00:00 2001 From: Ricardo Date: Fri, 11 Mar 2022 17:52:50 +0100 Subject: [PATCH] Allow broadcasting of mixture components --- pymc/distributions/mixture.py | 20 +++++++++----------- pymc/tests/test_mixture.py | 33 +++++++++++++++++++++++++++++++++ 2 files changed, 42 insertions(+), 11 deletions(-) diff --git a/pymc/distributions/mixture.py b/pymc/distributions/mixture.py index 76f7ce1950b..8e71afa4e04 100644 --- a/pymc/distributions/mixture.py +++ b/pymc/distributions/mixture.py @@ -176,7 +176,6 @@ def dist(cls, w, comp_dists, **kwargs): ) # Check that components are not associated with a registered variable in the model - components_ndim = set() components_ndim_supp = set() for dist in comp_dists: # TODO: Allow these to not be a RandomVariable as long as we can call `ndim_supp` on them @@ -188,14 +187,8 @@ def dist(cls, w, comp_dists, **kwargs): f"Component dist must be a distribution created via the `.dist()` API, got {type(dist)}" ) check_dist_not_registered(dist) - components_ndim.add(dist.ndim) components_ndim_supp.add(dist.owner.op.ndim_supp) - if len(components_ndim) > 1: - raise ValueError( - f"Mixture components must all have the same dimensionality, got {components_ndim}" - ) - if len(components_ndim_supp) > 1: raise ValueError( f"Mixture components must all have the same support dimensionality, got {components_ndim_supp}" @@ -214,13 +207,18 @@ def rv_op(cls, weights, *components, size=None, rngs=None): # Create new rng for the mix_indexes internal RV mix_indexes_rng = aesara.shared(np.random.default_rng()) + single_component = len(components) == 1 + ndim_supp = components[0].owner.op.ndim_supp + if size is not None: components = cls._resize_components(size, *components) + elif not single_component: + # We might need to broadcast components when size is not specified + shape = tuple(at.broadcast_shape(*components)) + size = shape[: len(shape) - ndim_supp] + components = cls._resize_components(size, *components) - single_component = len(components) == 1 - - # Extract support and replication ndims from components and weights - ndim_supp = components[0].owner.op.ndim_supp + # Extract replication ndims from components and weights ndim_batch = components[0].ndim - ndim_supp if single_component: # One dimension is taken by the mixture axis in the single component case diff --git a/pymc/tests/test_mixture.py b/pymc/tests/test_mixture.py index 5d5c77ddca1..3c061d5ef43 100644 --- a/pymc/tests/test_mixture.py +++ b/pymc/tests/test_mixture.py @@ -20,6 +20,7 @@ import scipy.stats as st from aesara import tensor as at +from aesara.tensor.random.op import RandomVariable from numpy.testing import assert_allclose from scipy.special import logsumexp @@ -677,6 +678,38 @@ def test_mixture_dtype(self): ).dtype assert mix_dtype == aesara.config.floatX + @pytest.mark.parametrize( + "comp_dists, expected_shape", + [ + ( + [ + Normal.dist([[0, 0, 0], [0, 0, 0]]), + Normal.dist([0, 0, 0]), + Normal.dist([0]), + ], + (2, 3), + ), + ( + [ + Dirichlet.dist([[1, 1, 1], [1, 1, 1]]), + Dirichlet.dist([1, 1, 1]), + ], + (2, 3), + ), + ], + ) + def test_broadcast_components(self, comp_dists, expected_shape): + n_dists = len(comp_dists) + mix = Mixture.dist(w=np.ones(n_dists) / n_dists, comp_dists=comp_dists) + mix_eval = mix.eval() + assert tuple(mix_eval.shape) == expected_shape + assert np.unique(mix_eval).size == mix.eval().size + for comp_dist in mix.owner.inputs[2:]: + # We check that the input is a "pure" RandomVariable and not a broadcast + # operation. This confirms that all draws will be unique + assert isinstance(comp_dist.owner.op, RandomVariable) + assert tuple(comp_dist.shape.eval()) == expected_shape + class TestNormalMixture(SeededTest): def test_normal_mixture_sampling(self):