Skip to content

Commit

Permalink
Fix for #3210 without computing the Bayes network (#3273)
Browse files Browse the repository at this point in the history
* Fix for #3225. Made Triangular `c` attribute be handled consistently with scipy.stats. Added test and updated example code.

* Fix for #3210 which uses a completely different approach than PR #3214. It uses a context manager inside `draw_values` that makes all the values drawn from `TensorVariables` or `MultiObservedRV`s available to nested calls of the original call to `draw_values`. It is partly inspired by how Edward2 approaches the problem of forward sampling. Ed2 tensors fix a `_values` attribute after they first call `sample` and then only return that. They can do it because of their functional scheme, where the entire graph is recreated each time the generative function is called. Our object oriented paradigm cannot set a fixed _values, it has to know it is in the context of a single `draw_values` call. That is why I opted for context managers to store the drawn values.

* Removed leftover print statement

* Added release notes and draw values context managers to mixture and multivariate distributions that make many calls to draw_values or other distributions random methods within their own random.
  • Loading branch information
lucianopaz authored and junpenglao committed Dec 3, 2018
1 parent 589aee1 commit 686b81d
Show file tree
Hide file tree
Showing 7 changed files with 518 additions and 320 deletions.
11 changes: 11 additions & 0 deletions RELEASE-NOTES.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,13 +10,24 @@
- Add log CDF functions to continuous distributions: `Beta`, `Cauchy`, `ExGaussian`, `Exponential`, `Flat`, `Gumbel`, `HalfCauchy`, `HalfFlat`, `HalfNormal`, `Laplace`, `Logistic`, `Lognormal`, `Normal`, `Pareto`, `StudentT`, `Triangular`, `Uniform`, `Wald`, `Weibull`.
- Behavior of `sample_posterior_predictive` is now to produce posterior predictive samples, in order, from all values of the `trace`. Previously, by default it would produce 1 chain worth of samples, using a random selection from the `trace` (#3212)
- Show diagnostics for initial energy errors in HMC and NUTS.
- PR #3273 has added the `distributions.distribution._DrawValuesContext` context
manager. This is used to store the values already drawn in nested `random`
and `draw_values` calls, enabling `draw_values` to draw samples from the
joint probability distribution of RVs and not the marginals. Custom
distributions that must call `draw_values` several times in their `random`
method, or that invoke many calls to other distribution's `random` methods
(e.g. mixtures) must do all of these calls under the same `_DrawValuesContext`
context manager instance. If they do not, the conditional relations between
the distribution's parameters could be broken, and `random` could return
values drawn from an incorrect distribution.

### Maintenance

- Big rewrite of documentation (#3275)
- Fixed Triangular distribution `c` attribute handling in `random` and updated sample codes for consistency (#3225)
- Refactor SMC and properly compute marginal likelihood (#3124)
- Removed use of deprecated `ymin` keyword in matplotlib's `Axes.set_ylim` (#3279)
- Fix for #3210. Now `distribution.draw_values(params)`, will draw the `params` values from their joint probability distribution and not from combinations of their marginals (Refer to PR #3273).

### Deprecations

Expand Down
307 changes: 210 additions & 97 deletions pymc3/distributions/distribution.py

Large diffs are not rendered by default.

26 changes: 15 additions & 11 deletions pymc3/distributions/mixture.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,8 @@
from pymc3.util import get_variable_name
from ..math import logsumexp
from .dist_math import bound, random_choice
from .distribution import Discrete, Distribution, draw_values, generate_samples
from .distribution import (Discrete, Distribution, draw_values,
generate_samples, _DrawValuesContext)
from .continuous import get_tau_sd, Normal


Expand Down Expand Up @@ -147,8 +148,9 @@ def logp(self, value):
broadcast_conditions=False)

def random(self, point=None, size=None):
w = draw_values([self.w], point=point)[0]
comp_tmp = self._comp_samples(point=point, size=None)
with _DrawValuesContext() as draw_context:
w = draw_values([self.w], point=point)[0]
comp_tmp = self._comp_samples(point=point, size=None)
if np.asarray(self.shape).size == 0:
distshape = np.asarray(np.broadcast(w, comp_tmp).shape)[..., :-1]
else:
Expand All @@ -163,7 +165,8 @@ def random(self, point=None, size=None):
dist_shape=distshape,
size=size).squeeze()
if (size is None) or (distshape.size == 0):
comp_samples = self._comp_samples(point=point, size=size)
with draw_context:
comp_samples = self._comp_samples(point=point, size=size)
if comp_samples.ndim > 1:
samples = np.squeeze(comp_samples[np.arange(w_samples.size), ..., w_samples])
else:
Expand All @@ -172,13 +175,14 @@ def random(self, point=None, size=None):
if w_samples.ndim == 1:
w_samples = np.reshape(np.tile(w_samples, size), (size,) + w_samples.shape)
samples = np.zeros((size,)+tuple(distshape))
for i in range(size):
w_tmp = w_samples[i, :]
comp_tmp = self._comp_samples(point=point, size=None)
if comp_tmp.ndim > 1:
samples[i, :] = np.squeeze(comp_tmp[np.arange(w_tmp.size), ..., w_tmp])
else:
samples[i, :] = np.squeeze(comp_tmp[w_tmp])
with draw_context:
for i in range(size):
w_tmp = w_samples[i, :]
comp_tmp = self._comp_samples(point=point, size=None)
if comp_tmp.ndim > 1:
samples[i, :] = np.squeeze(comp_tmp[np.arange(w_tmp.size), ..., w_tmp])
else:
samples[i, :] = np.squeeze(comp_tmp[w_tmp])

return samples

Expand Down
26 changes: 14 additions & 12 deletions pymc3/distributions/multivariate.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,8 @@
from pymc3.theanof import floatX
from . import transforms
from pymc3.util import get_variable_name
from .distribution import Continuous, Discrete, draw_values, generate_samples
from .distribution import (Continuous, Discrete, draw_values, generate_samples,
_DrawValuesContext)
from ..model import Deterministic
from .continuous import ChiSquared, Normal
from .special import gammaln, multigammaln
Expand Down Expand Up @@ -338,18 +339,19 @@ def __init__(self, nu, Sigma=None, mu=None, cov=None, tau=None, chol=None,
self.mean = self.median = self.mode = self.mu = self.mu

def random(self, point=None, size=None):
nu, mu = draw_values([self.nu, self.mu], point=point, size=size)
if self._cov_type == 'cov':
cov, = draw_values([self.cov], point=point, size=size)
dist = MvNormal.dist(mu=np.zeros_like(mu), cov=cov)
elif self._cov_type == 'tau':
tau, = draw_values([self.tau], point=point, size=size)
dist = MvNormal.dist(mu=np.zeros_like(mu), tau=tau)
else:
chol, = draw_values([self.chol_cov], point=point, size=size)
dist = MvNormal.dist(mu=np.zeros_like(mu), chol=chol)
with _DrawValuesContext():
nu, mu = draw_values([self.nu, self.mu], point=point, size=size)
if self._cov_type == 'cov':
cov, = draw_values([self.cov], point=point, size=size)
dist = MvNormal.dist(mu=np.zeros_like(mu), cov=cov)
elif self._cov_type == 'tau':
tau, = draw_values([self.tau], point=point, size=size)
dist = MvNormal.dist(mu=np.zeros_like(mu), tau=tau)
else:
chol, = draw_values([self.chol_cov], point=point, size=size)
dist = MvNormal.dist(mu=np.zeros_like(mu), chol=chol)

samples = dist.random(point, size)
samples = dist.random(point, size)

chi2 = np.random.chisquare
return (np.sqrt(nu) * samples.T / chi2(nu, size)).T + mu
Expand Down
10 changes: 10 additions & 0 deletions pymc3/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -1386,6 +1386,16 @@ def __init__(self, name, data, distribution, total_size=None, model=None):
self.distribution = distribution
self.scaling = _get_scaling(total_size, self.logp_elemwiset.shape, self.logp_elemwiset.ndim)

# Make hashable by id for draw_values
def __hash__(self):
return id(self)

def __eq__(self, other):
return self.id == other.id

def __ne__(self, other):
return not self == other


def _walk_up_rv(rv):
"""Walk up theano graph to get inputs for deterministic RV."""
Expand Down
58 changes: 58 additions & 0 deletions pymc3/tests/test_random.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
import pymc3 as pm
import numpy as np
from numpy import random as nr
import numpy.testing as npt
import pytest
import theano.tensor as tt
import theano

from pymc3.distributions.distribution import _draw_value, draw_values
from .helpers import SeededTest


def test_draw_value():
Expand Down Expand Up @@ -88,3 +90,59 @@ def test_dep_vars(self):
assert all([np.all(val1 != val2), np.all(val1 != val3),
np.all(val1 != val4), np.all(val2 != val3),
np.all(val2 != val4), np.all(val3 != val4)])


class TestJointDistributionDrawValues(SeededTest):
def test_joint_distribution(self):
with pm.Model() as model:
a = pm.Normal('a', mu=0, sd=100)
b = pm.Normal('b', mu=a, sd=1e-8)
c = pm.Normal('c', mu=a, sd=1e-8)
d = pm.Deterministic('d', b + c)

# Expected RVs
N = 1000
norm = np.random.randn(3, N)
eA = norm[0] * 100
eB = eA + norm[1] * 1e-8
eC = eA + norm[2] * 1e-8
eD = eB + eC

# Drawn RVs
nr.seed(self.random_seed)
# A, B, C, D = list(zip(*[draw_values([a, b, c, d]) for i in range(N)]))
A, B, C, D = draw_values([a, b, c, d], size=N)
A = np.array(A).flatten()
B = np.array(B).flatten()
C = np.array(C).flatten()
D = np.array(D).flatten()

# Assert that the drawn samples match the expected values
assert np.allclose(eA, A)
assert np.allclose(eB, B)
assert np.allclose(eC, C)
assert np.allclose(eD, D)

# Assert that A, B and C have the expected difference
assert np.all(np.abs(A - B) < 1e-6)
assert np.all(np.abs(A - C) < 1e-6)
assert np.all(np.abs(B - C) < 1e-6)

# Marginal draws
mA = np.array([draw_values([a]) for i in range(N)]).flatten()
mB = np.array([draw_values([b]) for i in range(N)]).flatten()
mC = np.array([draw_values([c]) for i in range(N)]).flatten()
# Also test the with model context of draw_values
with model:
mD = np.array([draw_values([d]) for i in range(N)]).flatten()

# Assert that the marginal distributions have different sample values
assert not np.all(np.abs(B - mB) < 1e-2)
assert not np.all(np.abs(C - mC) < 1e-2)
assert not np.all(np.abs(D - mD) < 1e-2)

# Assert that the marginal distributions do not have high cross
# correlation
assert np.abs(np.corrcoef(mA, mB)[0, 1]) < 0.1
assert np.abs(np.corrcoef(mA, mC)[0, 1]) < 0.1
assert np.abs(np.corrcoef(mB, mC)[0, 1]) < 0.1
Loading

0 comments on commit 686b81d

Please sign in to comment.