Skip to content

Commit

Permalink
Fix random_choice to handle multidim p and sizes that are not None (#…
Browse files Browse the repository at this point in the history
…3380)

* Fixed dist_math.random_choice to handle multidimensional p and also non None sizes correctly.

* Fixed mixture distribution conflict.

* Moved to_tuple from distribution.py to dist_math.py
  • Loading branch information
lucianopaz authored and twiecki committed Feb 25, 2019
1 parent 6187eee commit 9ef2947
Show file tree
Hide file tree
Showing 6 changed files with 37 additions and 25 deletions.
1 change: 1 addition & 0 deletions RELEASE-NOTES.md
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
- Added tests for mixtures of multidimensional distributions to the test suite.
- Fixed incorrect usage of `broadcast_distribution_samples` in `DiscreteWeibull`.
- `Mixture`'s default dtype is now determined by `theano.config.floatX`.
- `dist_math.random_choice` now handles nd-arrays of category probabilities, and also handles sizes that are not `None`. Also removed unused `k` kwarg from `dist_math.random_choice`.

### Deprecations

Expand Down
35 changes: 29 additions & 6 deletions pymc3/distributions/dist_math.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,17 @@
c = - .5 * np.log(2. * np.pi)


def to_tuple(shape):
"""Convert ints, arrays, and Nones to tuples"""
if shape is None:
return tuple()
temp = np.atleast_1d(shape)
if temp.size == 0:
return tuple()
else:
return tuple(temp)


def bound(logp, *conditions, **kwargs):
"""
Bounds a log probability density with several conditions.
Expand Down Expand Up @@ -308,11 +319,12 @@ def random_choice(*args, **kwargs):
Args:
p: array
Probability of each class
size: int
Number of draws to return
k: int
Number of bins
Probability of each class. If p.ndim > 1, the last axis is
interpreted as the probability of each class, and numpy.random.choice
is iterated for every other axis element.
size: int or tuple
Shape of the desired output array. If p is multidimensional, size
should broadcast with p.shape[:-1].
Returns:
random sample: array
Expand All @@ -323,8 +335,19 @@ def random_choice(*args, **kwargs):
k = p.shape[-1]

if p.ndim > 1:
# If a 2d vector of probabilities is passed return a sample for each row of categorical probability
# If p is an nd-array, the last axis is interpreted as the class
# probability. We must iterate over the elements of all the other
# dimensions.
# We first ensure that p is broadcasted to the output's shape
size = to_tuple(size) + (1,)
p = np.broadcast_arrays(p, np.empty(size))[0]
out_shape = p.shape[:-1]
# np.random.choice accepts 1D p arrays, so we semiflatten p to
# iterate calls using the last axis as the category probabilities
p = np.reshape(p, (-1, p.shape[-1]))
samples = np.array([np.random.choice(k, p=p_) for p_ in p])
# We reshape to the desired output shape
samples = np.reshape(samples, out_shape)
else:
samples = np.random.choice(k, p=p, size=size)
return samples
Expand Down
12 changes: 1 addition & 11 deletions pymc3/distributions/distribution.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
ObservedRV, MultiObservedRV, Context, InitContextMeta
)
from ..vartypes import string_types
from .dist_math import to_tuple

__all__ = ['DensityDist', 'Distribution', 'Continuous', 'Discrete',
'NoDistribution', 'TensorType', 'draw_values', 'generate_samples']
Expand Down Expand Up @@ -553,17 +554,6 @@ def _draw_value(param, point=None, givens=None, size=None):
return output
raise ValueError('Unexpected type in draw_value: %s' % type(param))


def to_tuple(shape):
"""Convert ints, arrays, and Nones to tuples"""
if shape is None:
return tuple()
temp = np.atleast_1d(shape)
if temp.size == 0:
return tuple()
else:
return tuple(temp)

def _is_one_d(dist_shape):
if hasattr(dist_shape, 'dshape') and dist_shape.dshape in ((), (0,), (1,)):
return True
Expand Down
11 changes: 4 additions & 7 deletions pymc3/distributions/mixture.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,10 @@

from pymc3.util import get_variable_name
from ..math import logsumexp
from .dist_math import bound, random_choice
from .dist_math import bound, random_choice, to_tuple
from .distribution import (Discrete, Distribution, draw_values,
generate_samples, _DrawValuesContext,
_DrawValuesContextBlocker, to_tuple,
_DrawValuesContextBlocker,
broadcast_distribution_samples)
from .continuous import get_tau_sigma, Normal
from ..theanof import _conversion_map
Expand Down Expand Up @@ -464,11 +464,8 @@ def random(self, point=None, size=None):
# mixture mixture components, and the rest is all about size,
# dist_shape and broadcasting
w_ = np.reshape(w, (-1, w.shape[-1]))
w_samples = generate_samples(random_choice,
p=w_,
broadcast_shape=w.shape[:-1] or (1,),
dist_shape=w.shape[:-1] or (1,),
size=None) # w's shape already includes size
w_samples = random_choice(p=w_,
size=None) # w's shape already includes size
# Now we broadcast the chosen components to the dist_shape
w_samples = np.reshape(w_samples, w.shape[:-1])
if size is not None and dist_shape[:len(size)] != size:
Expand Down
1 change: 1 addition & 0 deletions pymc3/tests/test_distributions_random.py
Original file line number Diff line number Diff line change
Expand Up @@ -443,6 +443,7 @@ def test_probability_vector_shape(self):
"""Check that if a 2d array of probabilities are passed to categorical correct shape is returned"""
p = np.ones((10, 5))
assert pm.Categorical.dist(p=p).random().shape == (10,)
assert pm.Categorical.dist(p=p).random(size=4).shape == (4, 10)


class TestScalarParameterSamples(SeededTest):
Expand Down
2 changes: 1 addition & 1 deletion pymc3/tests/test_mixture.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from pymc3.theanof import floatX
import theano
from theano import tensor as tt
from pymc3.distributions.distribution import to_tuple
from pymc3.distributions.dist_math import to_tuple

# Generate data
def generate_normal_mixture_data(w, mu, sd, size=1000):
Expand Down

0 comments on commit 9ef2947

Please sign in to comment.