Skip to content

Commit

Permalink
More WIP
Browse files Browse the repository at this point in the history
  • Loading branch information
ColCarroll committed Jun 15, 2018
1 parent 1998ed9 commit 22147c4
Show file tree
Hide file tree
Showing 6 changed files with 55 additions and 19 deletions.
3 changes: 2 additions & 1 deletion pymc3/distributions/bound.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,8 @@ def _random(self, lower, upper, point=None, size=None):
samples = np.zeros(size, dtype=self.dtype).flatten()
i, n = 0, len(samples)
while i < len(samples):
sample = self._wrapped.random(point=point, size=n)
sample = np.atleast_1d(self._wrapped.random(point=point, size=n))

select = sample[np.logical_and(sample >= lower, sample <= upper)]
samples[i:(i + len(select))] = select[:]
i += len(select)
Expand Down
36 changes: 31 additions & 5 deletions pymc3/distributions/distribution.py
Original file line number Diff line number Diff line change
Expand Up @@ -408,6 +408,14 @@ def to_tuple(shape):
shape = tuple(shape)
return shape

def _is_one_d(dist_shape):
if hasattr(dist_shape, 'dshape') and dist_shape.dshape in ((), (0,), (1,)):
return True
elif hasattr(dist_shape, 'shape') and dist_shape.shape in ((), (0,), (1,)):
return True
elif dist_shape == ():
return True
return False

def generate_samples(generator, *args, **kwargs):
"""Generate samples from the distribution of a random variable.
Expand Down Expand Up @@ -439,6 +447,7 @@ def generate_samples(generator, *args, **kwargs):
Any remaining *args and **kwargs are passed on to the generator function.
"""
dist_shape = kwargs.pop('dist_shape', ())
one_d = _is_one_d(dist_shape)
size = kwargs.pop('size', None)
broadcast_shape = kwargs.pop('broadcast_shape', None)
if size is None:
Expand All @@ -457,24 +466,41 @@ def generate_samples(generator, *args, **kwargs):
dist_shape = to_tuple(dist_shape)
broadcast_shape = to_tuple(broadcast_shape)
size_tup = to_tuple(size)

# All inputs are scalars, end up size (size_tup, dist_shape)
if broadcast_shape == () or broadcast_shape == (0,):
samples = generator(size=size_tup + dist_shape, *args, **kwargs)
# Inputs already have the right shape. Just get the right size.
elif broadcast_shape[-len(dist_shape):] == dist_shape:
if size == 1 or (broadcast_shape == size_tup + dist_shape):
samples = generator(size=broadcast_shape, *args, **kwargs)
elif dist_shape == broadcast_shape:
samples = generator(size=size_tup + dist_shape, *args, **kwargs)
else:
samples = generator(size=size, *args, **kwargs)
elif dist_shape == broadcast_shape:
samples = generator(size=size, *args, **kwargs)
if size_tup[-len(broadcast_shape):] != broadcast_shape:
samples = generator(size=size_tup + broadcast_shape, *args, **kwargs)
else:
samples = generator(size=size_tup + dist_shape, *args, **kwargs)
# Inputs have the right size, have to manually broadcast to the right dist_shape
elif broadcast_shape[:len(size_tup)] == size_tup:
suffix = broadcast_shape[len(size_tup):] + dist_shape
samples = [generator(*args, **kwargs).reshape(size_tup + (1,)) for _ in range(np.prod(suffix, dtype=int))]
samples = np.hstack(samples).reshape(size_tup + suffix)
# Args have been broadcast correctly, can just ask for the right shape out
elif dist_shape[-len(broadcast_shape):] == broadcast_shape:
samples = generator(size=size_tup + dist_shape, *args, **kwargs)
else:
raise TypeError(f'''Attempted to generate values with incompatible shapes:
size: {size}
dist_shape: {dist_shape}
broadcast_shape: {broadcast_shape}
''')
samples = samples.squeeze()
return samples

# reshape samples here
if samples.shape[0] == 1 and size == 1:
if len(samples.shape) > len(dist_shape) and samples.shape[-len(dist_shape):] == dist_shape:
samples = samples.reshape(samples.shape[1:])

if one_d and samples.shape[-1] == 1:
samples = samples.reshape(samples.shape[:-1])
return np.asarray(samples)
22 changes: 15 additions & 7 deletions pymc3/distributions/multivariate.py
Original file line number Diff line number Diff line change
Expand Up @@ -492,10 +492,10 @@ def __init__(self, n, p, *args, **kwargs):

if len(self.shape) > 1:
m = self.shape[-2]
try:
assert n.shape == (m,)
except (AttributeError, AssertionError):
n = n * tt.ones(m)
# try:
# assert n.shape == (m,)
# except (AttributeError, AssertionError):
# n = n * tt.ones(m)
self.n = tt.shape_padright(n)
self.p = p if p.ndim > 1 else tt.shape_padleft(p)
elif n.ndim == 1:
Expand All @@ -521,27 +521,35 @@ def _random(self, n, p, size=None):
# Now, re-normalize all of the values in float64 precision. This is done inside the conditionals
if size == p.shape:
size = None
if (n.ndim == 0) and (p.ndim == 1):
elif size[-len(p.shape):] == p.shape:
size = size[:len(size) - len(p.shape)]

n_dim = n.squeeze().ndim

if (n_dim == 0) and (p.ndim == 1):
p = p / p.sum()
randnum = np.random.multinomial(n, p.squeeze(), size=size)
elif (n.ndim == 0) and (p.ndim > 1):
elif (n_dim == 0) and (p.ndim > 1):
p = p / p.sum(axis=1, keepdims=True)
randnum = np.asarray([
np.random.multinomial(n.squeeze(), pp, size=size)
for pp in p
])
elif (n.ndim > 0) and (p.ndim == 1):
randnum = np.moveaxis(randnum, 1, 0)
elif (n_dim > 0) and (p.ndim == 1):
p = p / p.sum()
randnum = np.asarray([
np.random.multinomial(nn, p.squeeze(), size=size)
for nn in n
])
randnum = np.moveaxis(randnum, 1, 0)
else:
p = p / p.sum(axis=1, keepdims=True)
randnum = np.asarray([
np.random.multinomial(nn, pp, size=size)
for (nn, pp) in zip(n, p)
])
randnum = np.moveaxis(randnum, 1, 0)
return randnum.astype(original_dtype)

def random(self, point=None, size=None):
Expand Down
1 change: 0 additions & 1 deletion pymc3/tests/test_distributions.py
Original file line number Diff line number Diff line change
Expand Up @@ -909,7 +909,6 @@ def test_multinomial_mode(self, p, n):
[[.25, .25, .25, .25], (1, 4), 3],
# 3: expect to fail
# [[.25, .25, .25, .25], (10, 4)],
[[.25, .25, .25, .25], (10, 1, 4), 5],
# 5: expect to fail
# [[[.25, .25, .25, .25]], (2, 4), [7, 11]],
[[[.25, .25, .25, .25],
Expand Down
9 changes: 5 additions & 4 deletions pymc3/tests/test_distributions_random.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,12 +204,13 @@ def test_different_shapes_and_sample_sizes(self):
s = list(size)
except TypeError:
s = [size]
s.extend(shape)
if s == [1]:
s = []
if shape not in ((), (1,)):
s.extend(shape)
e = tuple(s)
a = self.sample_random_variable(rv, size).shape
expected.append(e)
actual.append(a)
assert expected == actual
assert e == a


class TestNormal(BaseTestCases.BaseTestCase):
Expand Down
3 changes: 2 additions & 1 deletion pymc3/tests/test_sampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -367,8 +367,9 @@ def test_respects_shape(self):

def test_multivariate(self):
with pm.Model():
m = pm.Multinomial('m', n=5, p=np.array([[0.25, 0.25, 0.25, 0.25]]), shape=(1, 4))
m = pm.Multinomial('m', n=5, p=np.array([0.25, 0.25, 0.25, 0.25]), shape=4)
trace = pm.sample_generative(10)

assert m.random(size=10).shape == (10, 4)
assert trace['m'].shape == (10, 4)

Expand Down

0 comments on commit 22147c4

Please sign in to comment.