Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Compound/Gibbs step and discrete suppot #306

Merged
merged 139 commits into from
Aug 27, 2020
Merged
Show file tree
Hide file tree
Changes from 63 commits
Commits
Show all changes
139 commits
Select commit Hold shift + click to select a range
1a00a83
fix tests
Jul 1, 2020
b434843
add init for compound with fixed experiments
Jul 28, 2020
4dc6a22
add init for compound with fixed experiments
Jul 28, 2020
c113492
fix issues with import
Jul 28, 2020
bd35890
remove unnec. files
Jul 28, 2020
c8675c5
fix issues with variable grouping in compound
Jul 30, 2020
7163c3d
complete variable merging based on sampler
Jul 31, 2020
8c0b162
fix issues with performance
Jul 31, 2020
4204d99
add seed
Jul 31, 2020
5033e74
add notebook with discrete examples/fix issues with state for categor…
Jul 31, 2020
1980bda
update notebook
Jul 31, 2020
b5e5428
fix class check
Jul 31, 2020
c7f8ed5
fix issues with class check in sampler
Jul 31, 2020
2350938
add version
Jul 31, 2020
2ebb1db
fix scoped name
Jul 31, 2020
b47cc79
add support for discrete sampling for poisson with normal/round proposal
Jul 31, 2020
fe826dc
add casting argument
Aug 1, 2020
61622f4
add basic tests for discrete sampling
Aug 1, 2020
c724d92
add support for gibbs with no tests
Aug 1, 2020
0382a5a
remove unnec files
Aug 1, 2020
bff9f28
rm unn files
Aug 1, 2020
399d087
remove . notebooks
Aug 1, 2020
85e8843
fix some issues
Aug 8, 2020
51d92f0
add support for proposal generation in rwm/add tests for compound
Aug 9, 2020
41daff4
fix tests
Aug 9, 2020
2b5427f
fix tests
Aug 9, 2020
58e58da
remove notebooks
Aug 9, 2020
f847153
remove some files
Aug 9, 2020
6e2671a
add basic logging
Aug 9, 2020
48220aa
fix logging
Aug 9, 2020
b5d513a
fix state functions
Aug 12, 2020
e4f274c
fix xla
Aug 12, 2020
7c4e416
fix logging
Aug 12, 2020
49d01fb
expose samplers
Aug 12, 2020
62eec96
fix issues with proposal functions/overload abstract operators for co…
Aug 13, 2020
b13cb09
fix merge/add speed notebook with different number of samplers+xla
Aug 13, 2020
4776251
add test
Aug 13, 2020
6bafc6d
fix seed
Aug 13, 2020
d3472e2
add more tests on compound
Aug 13, 2020
5d6368d
fix docs in state functions
Aug 13, 2020
26f2aa8
restore compound tests
Aug 13, 2020
4413f83
fix docs
Aug 13, 2020
6694039
fix docs
Aug 13, 2020
ff551bf
remove rwm_da
Aug 13, 2020
7a750de
fix naming
Aug 13, 2020
360fb2b
fix issues with pylint/mypy
Aug 14, 2020
d4c4b1f
resolve conflicts
Aug 14, 2020
76db69a
black/remove smc tests
Aug 15, 2020
fcbb5e1
fix tests
Aug 15, 2020
c0e9a34
fix init
Aug 15, 2020
e8c59f0
reduce tests
Aug 15, 2020
62f6645
fix executor tests
Aug 15, 2020
8632139
black
Aug 15, 2020
f846a39
fix test
Aug 15, 2020
245a268
fix final test
Aug 15, 2020
bd8e0c9
fix rev issues
Aug 16, 2020
904f664
fix prefer_static
Aug 20, 2020
c3fe21a
pylint fix
Aug 21, 2020
4c8c383
fix symmetry issues
Aug 22, 2020
2f57139
fix bug
Aug 22, 2020
180df6c
part seed add
Aug 22, 2020
40641b9
black
Aug 22, 2020
173943a
remove debug
Aug 23, 2020
e0c22c4
fix distributions transformed
Aug 25, 2020
575f202
refactor sampling
Aug 25, 2020
0756894
refactor sampler
Aug 25, 2020
6fc57f0
print->log
Aug 27, 2020
c347b73
fix tests
Jul 1, 2020
779f3a4
add init for compound with fixed experiments
Jul 28, 2020
9a10e42
add init for compound with fixed experiments
Jul 28, 2020
876ac2c
fix issues with import
Jul 28, 2020
2441cb7
remove unnec. files
Jul 28, 2020
1317292
fix issues with variable grouping in compound
Jul 30, 2020
ee1f81b
complete variable merging based on sampler
Jul 31, 2020
42ef4f1
fix issues with performance
Jul 31, 2020
a865448
add seed
Jul 31, 2020
8453219
add notebook with discrete examples/fix issues with state for categor…
Jul 31, 2020
7d6d565
update notebook
Jul 31, 2020
bd48468
fix class check
Jul 31, 2020
75103fc
fix issues with class check in sampler
Jul 31, 2020
851fd5d
add version
Jul 31, 2020
1701c4e
fix scoped name
Jul 31, 2020
3d57ad6
add support for discrete sampling for poisson with normal/round proposal
Jul 31, 2020
4e67a2c
add casting argument
Aug 1, 2020
566b062
add basic tests for discrete sampling
Aug 1, 2020
ffd59db
add support for gibbs with no tests
Aug 1, 2020
03c86aa
remove unnec files
Aug 1, 2020
6ed0a71
rm unn files
Aug 1, 2020
3a899e8
remove . notebooks
Aug 1, 2020
234d9c8
fix some issues
Aug 8, 2020
c2a5a0d
add support for proposal generation in rwm/add tests for compound
Aug 9, 2020
903086e
fix tests
Aug 9, 2020
17165ce
fix tests
Aug 9, 2020
7a2607b
remove notebooks
Aug 9, 2020
e48f4c4
remove some files
Aug 9, 2020
94d6882
add basic logging
Aug 9, 2020
0e5c12a
fix logging
Aug 9, 2020
847b7fd
fix state functions
Aug 12, 2020
a536125
fix xla
Aug 12, 2020
837215b
fix logging
Aug 12, 2020
b1664b8
expose samplers
Aug 12, 2020
6da2b02
fix issues with proposal functions/overload abstract operators for co…
Aug 13, 2020
868b09f
fix merge/add speed notebook with different number of samplers+xla
Aug 13, 2020
50d23f1
add test
Aug 13, 2020
04b4b42
fix seed
Aug 13, 2020
fa68a49
add more tests on compound
Aug 13, 2020
8c01277
fix docs in state functions
Aug 13, 2020
db010a0
restore compound tests
Aug 13, 2020
02443c4
fix docs
Aug 13, 2020
d71dcf6
fix docs
Aug 13, 2020
68720fe
remove rwm_da
Aug 13, 2020
6538efe
fix naming
Aug 13, 2020
28a66c9
fix issues with pylint/mypy
Aug 14, 2020
810de35
resolve conflicts
Aug 14, 2020
de657bf
black/remove smc tests
Aug 15, 2020
a9cbafa
fix tests
Aug 15, 2020
9836054
fix init
Aug 15, 2020
e85aa02
reduce tests
Aug 15, 2020
4e1f0dd
fix executor tests
Aug 15, 2020
522faac
black
Aug 15, 2020
f023e8c
fix test
Aug 15, 2020
3e82bad
fix final test
Aug 15, 2020
231f2d8
fix rev issues
Aug 16, 2020
77cf17b
fix prefer_static
Aug 20, 2020
be7ff36
pylint fix
Aug 21, 2020
321a23d
fix symmetry issues
Aug 22, 2020
119121b
fix bug
Aug 22, 2020
deba707
part seed add
Aug 22, 2020
dfb7d78
black
Aug 22, 2020
18ec85a
remove debug
Aug 23, 2020
0e894e4
fix distributions transformed
Aug 25, 2020
1116454
refactor sampling
Aug 25, 2020
0fa7d3e
refactor sampler
Aug 25, 2020
94ede84
print->log
Aug 27, 2020
ddeee4b
Merge branch 'compound' of https://github.com/rrkarim/pymc4 into comp…
Aug 27, 2020
0de3d7d
upgrade black
Aug 27, 2020
9b4fe4a
Merge remote-tracking branch 'upstream/master' into compound
Aug 27, 2020
7462e48
fix logp for determinitiscs_values
Aug 27, 2020
d906238
black
Aug 27, 2020
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 2 additions & 3 deletions nbconfig.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
"notebooks/baseball.ipynb",
"notebooks/basic-usage.ipynb",
"notebooks/rugby_analytics.ipynb",
# will reinstate in a later PR
# "notebooks/radon_hierarchical.ipynb",
# will reinstate in a later PR
# "notebooks/radon_hierarchical.ipynb",
]

314 changes: 314 additions & 0 deletions notebooks/discrete_distributions_sampling.ipynb

Large diffs are not rendered by default.

11 changes: 7 additions & 4 deletions notebooks/utils.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,16 @@
import matplotlib.pyplot as plt
import numpy as np


def plot_samples(x, batched_samples, labels, names, ylim=None):
if not isinstance(batched_samples, np.ndarray):
batched_samples = np.asarray(batched_samples)
n_samples = batched_samples.shape[0]
if ylim is not None:
ymin, ymax = ylim
else:
ymin, ymax = batched_samples.min()-0.2, batched_samples.max()+0.2
fig, ax = plt.subplots(n_samples, 1, figsize=(14, n_samples*3))
ymin, ymax = batched_samples.min() - 0.2, batched_samples.max() + 0.2
fig, ax = plt.subplots(n_samples, 1, figsize=(14, n_samples * 3))
if isinstance(labels, (list, tuple)):
labels = [np.asarray(label) for label in labels]
else:
Expand All @@ -29,6 +30,7 @@ def plot_samples(x, batched_samples, labels, names, ylim=None):
axi.set_title(lab)
plt.show()


def plot_cov_matrix(k, X, labels, names, vlim=None, cmap="inferno", interpolation="none"):
cov = k(X, X)
cov = np.asarray(cov)
Expand All @@ -42,8 +44,9 @@ def plot_cov_matrix(k, X, labels, names, vlim=None, cmap="inferno", interpolatio
else:
labels = np.asarray(labels)
n_samples = 1
fig, ax = plt.subplots(1, n_samples, figsize=(5*n_samples, 4))
if not isinstance(ax, np.ndarray): ax = np.asarray([ax])
fig, ax = plt.subplots(1, n_samples, figsize=(5 * n_samples, 4))
if not isinstance(ax, np.ndarray):
ax = np.asarray([ax])
for i in range(ax.shape[0]):
axi = ax[i]
if isinstance(labels, (list, tuple)):
Expand Down
6 changes: 4 additions & 2 deletions pymc4/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
"""PyMC4."""

from . import utils
from .coroutine_model import Model, model
from .scopes import name_scope, variable_name
from . import coroutine_model
from . import distributions
Expand All @@ -12,12 +12,14 @@
evaluate_meta_model,
evaluate_meta_posterior_predictive_model,
)
from .coroutine_model import Model, model
from . import inference
from .distributions import *
from .forward_sampling import sample_prior_predictive, sample_posterior_predictive
from .inference.sampling import sample
from .mcmc.samplers import *
from . import gp
from . import mcmc
from .variational import *


__version__ = "4.0a2"
2 changes: 2 additions & 0 deletions pymc4/distributions/__init__.py
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -5,3 +5,5 @@
from .distribution import Potential, Deterministic
from .mixture import Mixture
from . import transforms
from .mixture import *
from .state_functions import *
8 changes: 4 additions & 4 deletions pymc4/distributions/batchstack.py
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ def _fn(self, **kwargs):
shape = prefer_static.concat(
[
prefer_static.ones(
prefer_static.rank_from_shape(self.batch_stack), dtype=self.batch_stack.dtype
prefer_static.rank_from_shape(self.batch_stack), dtype=self.batch_stack.dtype,
),
self.distribution.batch_shape_tensor(),
self.distribution.event_shape_tensor(),
Expand Down Expand Up @@ -60,7 +60,7 @@ class BatchStacker(distribution_lib.Distribution):

The probability function is,

.. math::
.. math::
p(x) = prod{ p(x[i]) : i = 0, ..., (n - 1) }

Examples
Expand All @@ -77,7 +77,7 @@ class BatchStacker(distribution_lib.Distribution):
>>> lp = s.log_prob(x)
>>> lp.shape.as_list()
[5]

Example 2: `[5, 4]`-draws of a bivariate Normal.

>>> s = BatchStacker(
Expand Down Expand Up @@ -186,7 +186,7 @@ def _log_prob(self, x, **kwargs):
x = tf.reshape(
x,
shape=tf.pad(
tf.shape(x), paddings=[[prefer_static.maximum(0, -d), 0]], constant_values=1
tf.shape(x), paddings=[[prefer_static.maximum(0, -d), 0]], constant_values=1,
),
)
# (2) Compute x's log_prob.
Expand Down
21 changes: 17 additions & 4 deletions pymc4/distributions/continuous.py
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,7 @@ class GeneralizedNormal(ContinuousDistribution):

.. math::

f(x \mid \mu, \alpha, \beta) =
f(x \mid \mu, \alpha, \beta) =
\frac{\beta}{2 \Gamma(1/\beta)}
\exp(-(|x - \mu| /\alpha)^\beta)

Expand Down Expand Up @@ -354,7 +354,10 @@ def __init__(self, name, concentration0, concentration1, **kwargs):

@staticmethod
def _init_distribution(conditions, **kwargs):
concentration0, concentration1 = conditions["concentration0"], conditions["concentration1"]
concentration0, concentration1 = (
conditions["concentration0"],
conditions["concentration1"],
)
return tfd.Beta(concentration0=concentration0, concentration1=concentration1, **kwargs)


Expand Down Expand Up @@ -779,7 +782,10 @@ def __init__(self, name, concentration0, concentration1, **kwargs):

@staticmethod
def _init_distribution(conditions, **kwargs):
concentration0, concentration1 = conditions["concentration0"], conditions["concentration1"]
concentration0, concentration1 = (
conditions["concentration0"],
conditions["concentration1"],
)
return tfd.Kumaraswamy(
concentration0=concentration0, concentration1=concentration1, **kwargs
)
Expand Down Expand Up @@ -987,7 +993,7 @@ class Moyal(ContinuousDistribution):

.. math::

f(x \mid \mu, \sigma) =
f(x \mid \mu, \sigma) =
\frac{1}{\sqrt{2\pi}\sigma}
\exp\left(-\frac{1}{2}\left[\frac{x-\mu}{\sigma}+\exp\left(-\frac{x-\mu}{\sigma}\right)\right]\right)

Expand Down Expand Up @@ -1493,6 +1499,13 @@ class Weibull(PositiveContinuousDistribution):
Shape parameter (concentration > 0).
scale : float|tensor
Scale parameter (scale > 0).

Developer Notes
---------------
The Weibull distribution is implemented as a standard uniform distribution transformed by the
Inverse of the WeibullCDF bijector. The shape to broadcast the low and high parameters for the
Uniform distribution are obtained using
tensorflow_probability.python.internal.distribution_util.prefer_static_broadcast_shape()
"""

def __init__(self, name, concentration, scale, **kwargs):
Expand Down
18 changes: 14 additions & 4 deletions pymc4/distributions/discrete.py
100644 → 100755
Original file line number Diff line number Diff line change
@@ -1,12 +1,18 @@
"""PyMC4 discrete random variables."""
import tensorflow as tf
from tensorflow_probability import distributions as tfd
from tensorflow_probability.python.internal import prefer_static
from pymc4.distributions.distribution import (
PositiveDiscreteDistribution,
BoundedDiscreteDistribution,
)
from pymc4.distributions import transforms

from pymc4.distributions.state_functions import (
categorical_uniform_fn,
bernoulli_fn,
)

__all__ = [
"Bernoulli",
"Binomial",
Expand Down Expand Up @@ -58,9 +64,11 @@ class Bernoulli(BoundedDiscreteDistribution):
probs : float
Probability of success (0 < probs < 1).
"""
_grad_support = False

def __init__(self, name, probs, **kwargs):
super().__init__(name, probs=probs, **kwargs)
self._default_new_state_part = bernoulli_fn()

@staticmethod
def _init_distribution(conditions, **kwargs):
Expand Down Expand Up @@ -245,6 +253,7 @@ class DiscreteUniform(BoundedDiscreteDistribution):
high : int
Upper limit (high > low).
"""
_grad_support = False

def __init__(self, name, low, high, **kwargs):
super().__init__(name, low=low, high=high, **kwargs)
Expand Down Expand Up @@ -296,9 +305,12 @@ class Categorical(BoundedDiscreteDistribution):
probs : array of floats
probs > 0 and the elements of probs must sum to 1.
"""
_grad_support = False

def __init__(self, name, probs, **kwargs):
super().__init__(name, probs=probs, **kwargs)
classes = prefer_static.shape(probs)[-1]
self._default_new_state_part = categorical_uniform_fn(classes=classes)

@staticmethod
def _init_distribution(conditions, **kwargs):
Expand All @@ -310,7 +322,7 @@ def lower_limit(self):
return 0.0

def upper_limit(self):
return self.conditions["probs"].shape[-1]
return float(tf.shape(self.conditions["probs"])[-1])
rrkarim marked this conversation as resolved.
Show resolved Hide resolved


class Geometric(BoundedDiscreteDistribution):
Expand Down Expand Up @@ -486,8 +498,6 @@ class Poisson(PositiveDiscreteDistribution):
"""

# For some ridiculous reason, tfp needs poisson values to be floats...
_test_value = 0.0 # type: ignore

def __init__(self, name, rate, **kwargs):
super().__init__(name, rate=rate, **kwargs)

Expand Down Expand Up @@ -770,4 +780,4 @@ def lower_limit(self):
return 0.0

def upper_limit(self):
return self.conditions["cutpoints"].shape[-1]
return prefer_static.shape(self.conditions["cutpoints"])[-1]
21 changes: 10 additions & 11 deletions pymc4/distributions/distribution.py
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from tensorflow_probability import distributions as tfd
from pymc4.coroutine_model import Model, unpack
from pymc4.distributions.batchstack import BatchStacker
from . import transforms
from pymc4.distributions import transforms

NameType = Union[str, int]

Expand All @@ -29,6 +29,7 @@
class Distribution(Model):
"""Statistical distribution."""

_grad_support: bool = True
_test_value = 0.0
_base_parameters = ["dtype", "validate_args", "allow_nan_stats"]

Expand All @@ -48,9 +49,10 @@ def __init__(
**kwargs,
):
self.conditions, self.base_parameters = self.unpack_conditions(
dtype=dtype, validate_args=validate_args, allow_nan_stats=allow_nan_stats, **kwargs
dtype=dtype, validate_args=validate_args, allow_nan_stats=allow_nan_stats, **kwargs,
)
self._distribution = self._init_distribution(self.conditions, **self.base_parameters)
self._default_new_state_part = None
super().__init__(
self.unpack_distribution, name=name, keep_return=True, keep_auxiliary=False
)
Expand Down Expand Up @@ -104,7 +106,7 @@ def unpack_conditions(cls, **kwargs) -> Tuple[dict, dict]:
@property
def test_value(self):
return tf.cast(
tf.broadcast_to(self._test_value, self.batch_shape + self.event_shape), self.dtype
tf.broadcast_to(self._test_value, self.batch_shape + self.event_shape), self.dtype,
)

def sample(self, sample_shape=(), seed=None):
Expand Down Expand Up @@ -139,14 +141,14 @@ def sample_numpy(self, sample_shape=(), seed=None):
def get_test_sample(self, sample_shape=(), seed=None):
"""
Get the test value using a function signature similar to meth:`~.sample`.

Parameters
----------
sample_shape : tuple
sample shape
seed : int | None
ignored. Is only present to match the signature of meth:`~.sample`

Returns
-------
The distribution's ``test_value`` broadcasted to
Expand Down Expand Up @@ -295,6 +297,9 @@ def upper_limit(self):


class BoundedDiscreteDistribution(DiscreteDistribution, BoundedDistribution):
def _init_transform(self, transform):
return transform

@property
def _test_value(self):
return tf.cast(tf.round(0.5 * (self.upper_limit() + self.lower_limit())), self.dtype)
Expand Down Expand Up @@ -339,12 +344,6 @@ def upper_limit(self):
class PositiveDiscreteDistribution(BoundedDiscreteDistribution):
_test_value = 1

def _init_transform(self, transform):
rrkarim marked this conversation as resolved.
Show resolved Hide resolved
if transform is None:
return transforms.Log()
else:
return transform

def lower_limit(self):
return 0

Expand Down
6 changes: 3 additions & 3 deletions pymc4/distributions/half_student_t.py
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ class HalfStudentT(distribution.Distribution):
"""

def __init__(
self, df, loc, scale, validate_args=False, allow_nan_stats=True, name="HalfStudentT"
self, df, loc, scale, validate_args=False, allow_nan_stats=True, name="HalfStudentT",
):
r"""
Construct a half-Student's t distribution with ``df``, ``loc`` and ``scale``.
Expand Down Expand Up @@ -93,7 +93,7 @@ def __init__(
@staticmethod
def _param_shapes(sample_shape):
return dict(
zip(("df", "loc", "scale"), ([tf.convert_to_tensor(sample_shape, dtype=tf.int32)] * 3))
zip(("df", "loc", "scale"), ([tf.convert_to_tensor(sample_shape, dtype=tf.int32)] * 3),)
)

@classmethod
Expand Down Expand Up @@ -252,7 +252,7 @@ def _variance(self):
)
if self.allow_nan_stats:
return tf.where(
df > 1.0, result_where_defined, dtype_util.as_numpy_dtype(self.dtype)(np.nan)
df > 1.0, result_where_defined, dtype_util.as_numpy_dtype(self.dtype)(np.nan),
)
else:
return distribution_util.with_dependencies(
Expand Down
6 changes: 1 addition & 5 deletions pymc4/distributions/mixture.py
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -15,16 +15,12 @@
class Mixture(Distribution):
r"""
Mixture random variable.

Often used to model subpopulation heterogeneity

.. math:: f(x \mid w, \theta) = \sum_{i = 1}^n w_i f_i(x \mid \theta_i)

======== ============================================
Support :math:`\cap_{i = 1}^n \textrm{support}(f_i)`
Mean :math:`\sum_{i = 1}^n w_i \mu_i`
======== ============================================

Parameters
----------
p : tf.Tensor
Expand Down Expand Up @@ -98,7 +94,7 @@ def _init_distribution(conditions, **kwargs):
)
distr = [el._distribution for el in d]
return tfd.Mixture(
tfd.Categorical(probs=p, **kwargs), distr, **kwargs, use_static_graph=True
tfd.Categorical(probs=p, **kwargs), distr, **kwargs, use_static_graph=True,
)
# else if 'd' is a pymc distribution with batch_size > 1
elif isinstance(d, Distribution):
Expand Down
5 changes: 4 additions & 1 deletion pymc4/distributions/multivariate.py
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -233,7 +233,10 @@ def __init__(self, name, mean_direction, concentration, **kwargs):

@staticmethod
def _init_distribution(conditions, **kwargs):
mean_direction, concentration = conditions["mean_direction"], conditions["concentration"]
mean_direction, concentration = (
conditions["mean_direction"],
conditions["concentration"],
)
return tfd.VonMisesFisher(
mean_direction=mean_direction, concentration=concentration, **kwargs
)
Expand Down
Loading