Skip to content

Commit

Permalink
Explicitly track latent and posterior variables in both constrained a…
Browse files Browse the repository at this point in the history
…nd unconstrained space
  • Loading branch information
davmre committed Dec 20, 2017
1 parent 95a6ba0 commit e659133
Show file tree
Hide file tree
Showing 3 changed files with 92 additions and 25 deletions.
66 changes: 56 additions & 10 deletions edward/inferences/hmc.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,30 +73,43 @@ def build_update(self):
The updates assume each Empirical random variable is directly
parameterized by `tf.Variable`s.
"""
old_sample = {z: tf.gather(qz.params, tf.maximum(self.t - 1, 0))
for z, qz in six.iteritems(self.latent_vars)}

# Gather the initial state, transformed to unconstrained space.
try:
self.latent_vars_unconstrained
except:
raise ValueError("This implementation of HMC requires that all "
"variables have unconstrained support. Please "
"initialize with auto_transform=True to ensure "
"this. (if your variables already have unconstrained "
"support then doing this is a no-op).")
old_sample = {z_unconstrained:
tf.gather(qz_unconstrained.params, tf.maximum(self.t - 1, 0))
for z_unconstrained, qz_unconstrained in
six.iteritems(self.latent_vars_unconstrained)}
old_sample = OrderedDict(old_sample)

# Sample momentum.
old_r_sample = OrderedDict()
for z, qz in six.iteritems(self.latent_vars):
for z, qz in six.iteritems(self.latent_vars_unconstrained):
event_shape = qz.event_shape
normal = Normal(loc=tf.zeros(event_shape, dtype=qz.dtype),
scale=tf.ones(event_shape, dtype=qz.dtype))
old_r_sample[z] = normal.sample()

# Simulate Hamiltonian dynamics.
new_sample, new_r_sample = leapfrog(old_sample, old_r_sample,
self.step_size, self._log_joint,
self.step_size,
self._log_joint_unconstrained,
self.n_steps)

# Calculate acceptance ratio.
ratio = tf.reduce_sum([0.5 * tf.reduce_sum(tf.square(r))
for r in six.itervalues(old_r_sample)])
ratio -= tf.reduce_sum([0.5 * tf.reduce_sum(tf.square(r))
for r in six.itervalues(new_r_sample)])
ratio += self._log_joint(new_sample)
ratio -= self._log_joint(old_sample)
ratio += self._log_joint_unconstrained(new_sample)
ratio -= self._log_joint_unconstrained(old_sample)

# Accept or reject sample.
u = Uniform(low=tf.constant(0.0, dtype=ratio.dtype),
Expand All @@ -108,19 +121,51 @@ def build_update(self):
# `tf.cond` returns tf.Tensor if output is a list of size 1.
sample_values = [sample_values]

sample = {z: sample_value for z, sample_value in
sample = {z_unconstrained: sample_value for
z_unconstrained, sample_value in
zip(six.iterkeys(new_sample), sample_values)}

# Update Empirical random variables.
assign_ops = []
for z, qz in six.iteritems(self.latent_vars):
variable = qz.get_variables()[0]
assign_ops.append(tf.scatter_update(variable, self.t, sample[z]))
for z_unconstrained, qz_unconstrained in six.iteritems(
self.latent_vars_unconstrained):
variable = qz_unconstrained.get_variables()[0]
assign_ops.append(tf.scatter_update(
variable, self.t, sample[z_unconstrained]))

# Increment n_accept (if accepted).
assign_ops.append(self.n_accept.assign_add(tf.where(accept, 1, 0)))
return tf.group(*assign_ops)

def _log_joint_unconstrained(self, z_sample):
"""
Given a sample in unconstrained latent space, transform it back into
the original space, and compute the log joint density with appropriate
Jacobian correction.
"""

unconstrained_to_z = {v: k for (k, v) in self.transformations.items()}

# transform all samples back into the original (potentially
# constrained) space.
z_sample_transformed = {}
log_det_jacobian = 0.0
for z_unconstrained, qz_unconstrained in z_sample.items():
z = (unconstrained_to_z[z_unconstrained]
if z_unconstrained in unconstrained_to_z
else z_unconstrained)

try:
bij = self.transformations[z].bijector
z_sample_transformed[z] = bij.inverse(qz_unconstrained)
log_det_jacobian += tf.reduce_sum(
bij.inverse_log_det_jacobian(qz_unconstrained))
except: # if z not in self.transformations,
# or is not a TransformedDist w/ bijector
z_sample_transformed[z] = qz_unconstrained

return self._log_joint(z_sample_transformed) + log_det_jacobian

def _log_joint(self, z_sample):
"""Utility function to calculate model's log joint density,
log p(x, z), for inputs z (and fixed data x).
Expand All @@ -133,6 +178,7 @@ def _log_joint(self, z_sample):
# Form dictionary in order to replace conditioning on prior or
# observed variable with conditioning on a specific value.
dict_swap = z_sample.copy()

for x, qx in six.iteritems(self.data):
if isinstance(x, RandomVariable):
if isinstance(qx, RandomVariable):
Expand Down
44 changes: 33 additions & 11 deletions edward/inferences/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from edward.util import check_data, check_latent_vars, get_session, \
get_variables, Progbar, transform

from tensorflow.contrib.distributions import bijectors

@six.add_metaclass(abc.ABCMeta)
class Inference(object):
Expand Down Expand Up @@ -217,25 +218,46 @@ def initialize(self, n_iter=1000, n_print=None, scale=None,

self.scale = scale

# Set of all latent variables binded to their transformation on
# the unconstrained space (if any).
# map from original latent vars to unconstrained versions
self.transformations = {}
if auto_transform:
latent_vars = self.latent_vars.copy()
self.latent_vars = {}
self.latent_vars = {} # maps original latent vars to constrained Q's
self.latent_vars_unconstrained = {} # maps unconstrained vars to unconstrained Q's
for z, qz in six.iteritems(latent_vars):
if hasattr(z, 'support') and hasattr(qz, 'support') and \
z.support != qz.support and qz.support != 'point':
z_transform = transform(z)
self.transformations[z] = z_transform
if qz.support == 'points': # don't transform empirical approx's
self.latent_vars[z_transform] = qz
z.support != qz.support and qz.support != 'point':

# transform z to an unconstrained space
z_unconstrained = transform(z)
self.transformations[z] = z_unconstrained

# make sure we also have a qz that covers the unconstrained space
if qz.support == "points":
qz_unconstrained = qz
else:
qz_transform = transform(qz)
self.latent_vars[z_transform] = qz_transform
self.transformations[qz] = qz_transform
qz_unconstrained = transform(qz)
self.latent_vars_unconstrained[z_unconstrained] = qz_unconstrained

# additionally construct the transformation of qz
# back into the original constrained space
if z_unconstrained != z:
qz_constrained = transform(
qz_unconstrained, bijectors.Invert(z_unconstrained.bijector))

try: # attempt to pushforward the params of Empirical distributions
qz_constrained.params = z_unconstrained.bijector.inverse(
qz_unconstrained.params)
except: # qz_unconstrained is not an Empirical distribution
pass

else:
qz_constrained = qz_unconstrained

self.latent_vars[z] = qz_constrained
else:
self.latent_vars[z] = qz
self.latent_vars_unconstrained[z] = qz
del latent_vars

if logdir is not None:
Expand Down
7 changes: 3 additions & 4 deletions tests/inferences/test_inference_auto_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ def test_laplace_default(self):

# Check approximation on constrained space has same mode as
# target distribution.
qx = inference.latent_vars[x]
qx = inference.latent_vars[x]
stats = sess.run([x.mode(), qx.mean()])
self.assertAllClose(stats[0], stats[1], rtol=1e-5, atol=1e-5)

Expand All @@ -127,7 +127,7 @@ def test_hmc_custom(self):

# Check approximation on constrained space has same moments as
# target distribution.
n_samples = 10000
n_samples = 10000
x_unconstrained = inference.transformations[x]
qx_constrained_params = x_unconstrained.bijector.inverse(qx.params)
x_mean, x_var = tf.nn.moments(x.sample(n_samples), 0)
Expand All @@ -144,8 +144,7 @@ def test_hmc_default(self):
x.support = 'nonnegative'

inference = ed.HMC([x])
inference.initialize(auto_transform=True,
step_size=0.8)
inference.initialize(auto_transform=True, step_size=0.8)
tf.global_variables_initializer().run()
for _ in range(inference.n_iter):
info_dict = inference.update()
Expand Down

0 comments on commit e659133

Please sign in to comment.