diff --git a/edward/inferences/hmc.py b/edward/inferences/hmc.py index 45ed38363..710dae95f 100644 --- a/edward/inferences/hmc.py +++ b/edward/inferences/hmc.py @@ -73,13 +73,25 @@ def build_update(self): The updates assume each Empirical random variable is directly parameterized by `tf.Variable`s. """ - old_sample = {z: tf.gather(qz.params, tf.maximum(self.t - 1, 0)) - for z, qz in six.iteritems(self.latent_vars)} + + # Gather the initial state, transformed to unconstrained space. + try: + self.latent_vars_unconstrained + except: + raise ValueError("This implementation of HMC requires that all " + "variables have unconstrained support. Please " + "initialize with auto_transform=True to ensure " + "this. (if your variables already have unconstrained " + "support then doing this is a no-op).") + old_sample = {z_unconstrained: + tf.gather(qz_unconstrained.params, tf.maximum(self.t - 1, 0)) + for z_unconstrained, qz_unconstrained in + six.iteritems(self.latent_vars_unconstrained)} old_sample = OrderedDict(old_sample) # Sample momentum. old_r_sample = OrderedDict() - for z, qz in six.iteritems(self.latent_vars): + for z, qz in six.iteritems(self.latent_vars_unconstrained): event_shape = qz.event_shape normal = Normal(loc=tf.zeros(event_shape, dtype=qz.dtype), scale=tf.ones(event_shape, dtype=qz.dtype)) @@ -87,7 +99,8 @@ def build_update(self): # Simulate Hamiltonian dynamics. new_sample, new_r_sample = leapfrog(old_sample, old_r_sample, - self.step_size, self._log_joint, + self.step_size, + self._log_joint_unconstrained, self.n_steps) # Calculate acceptance ratio. @@ -95,8 +108,8 @@ def build_update(self): for r in six.itervalues(old_r_sample)]) ratio -= tf.reduce_sum([0.5 * tf.reduce_sum(tf.square(r)) for r in six.itervalues(new_r_sample)]) - ratio += self._log_joint(new_sample) - ratio -= self._log_joint(old_sample) + ratio += self._log_joint_unconstrained(new_sample) + ratio -= self._log_joint_unconstrained(old_sample) # Accept or reject sample. u = Uniform(low=tf.constant(0.0, dtype=ratio.dtype), @@ -108,19 +121,51 @@ def build_update(self): # `tf.cond` returns tf.Tensor if output is a list of size 1. sample_values = [sample_values] - sample = {z: sample_value for z, sample_value in + sample = {z_unconstrained: sample_value for + z_unconstrained, sample_value in zip(six.iterkeys(new_sample), sample_values)} # Update Empirical random variables. assign_ops = [] - for z, qz in six.iteritems(self.latent_vars): - variable = qz.get_variables()[0] - assign_ops.append(tf.scatter_update(variable, self.t, sample[z])) + for z_unconstrained, qz_unconstrained in six.iteritems( + self.latent_vars_unconstrained): + variable = qz_unconstrained.get_variables()[0] + assign_ops.append(tf.scatter_update( + variable, self.t, sample[z_unconstrained])) # Increment n_accept (if accepted). assign_ops.append(self.n_accept.assign_add(tf.where(accept, 1, 0))) return tf.group(*assign_ops) + def _log_joint_unconstrained(self, z_sample): + """ + Given a sample in unconstrained latent space, transform it back into + the original space, and compute the log joint density with appropriate + Jacobian correction. + """ + + unconstrained_to_z = {v: k for (k, v) in self.transformations.items()} + + # transform all samples back into the original (potentially + # constrained) space. + z_sample_transformed = {} + log_det_jacobian = 0.0 + for z_unconstrained, qz_unconstrained in z_sample.items(): + z = (unconstrained_to_z[z_unconstrained] + if z_unconstrained in unconstrained_to_z + else z_unconstrained) + + try: + bij = self.transformations[z].bijector + z_sample_transformed[z] = bij.inverse(qz_unconstrained) + log_det_jacobian += tf.reduce_sum( + bij.inverse_log_det_jacobian(qz_unconstrained)) + except: # if z not in self.transformations, + # or is not a TransformedDist w/ bijector + z_sample_transformed[z] = qz_unconstrained + + return self._log_joint(z_sample_transformed) + log_det_jacobian + def _log_joint(self, z_sample): """Utility function to calculate model's log joint density, log p(x, z), for inputs z (and fixed data x). @@ -133,6 +178,7 @@ def _log_joint(self, z_sample): # Form dictionary in order to replace conditioning on prior or # observed variable with conditioning on a specific value. dict_swap = z_sample.copy() + for x, qx in six.iteritems(self.data): if isinstance(x, RandomVariable): if isinstance(qx, RandomVariable): diff --git a/edward/inferences/inference.py b/edward/inferences/inference.py index 1d3cdc3b2..85dc7f116 100644 --- a/edward/inferences/inference.py +++ b/edward/inferences/inference.py @@ -13,6 +13,7 @@ from edward.util import check_data, check_latent_vars, get_session, \ get_variables, Progbar, transform +from tensorflow.contrib.distributions import bijectors @six.add_metaclass(abc.ABCMeta) class Inference(object): @@ -217,25 +218,46 @@ def initialize(self, n_iter=1000, n_print=None, scale=None, self.scale = scale - # Set of all latent variables binded to their transformation on - # the unconstrained space (if any). + # map from original latent vars to unconstrained versions self.transformations = {} if auto_transform: latent_vars = self.latent_vars.copy() - self.latent_vars = {} + self.latent_vars = {} # maps original latent vars to constrained Q's + self.latent_vars_unconstrained = {} # maps unconstrained vars to unconstrained Q's for z, qz in six.iteritems(latent_vars): if hasattr(z, 'support') and hasattr(qz, 'support') and \ - z.support != qz.support and qz.support != 'point': - z_transform = transform(z) - self.transformations[z] = z_transform - if qz.support == 'points': # don't transform empirical approx's - self.latent_vars[z_transform] = qz + z.support != qz.support and qz.support != 'point': + + # transform z to an unconstrained space + z_unconstrained = transform(z) + self.transformations[z] = z_unconstrained + + # make sure we also have a qz that covers the unconstrained space + if qz.support == "points": + qz_unconstrained = qz else: - qz_transform = transform(qz) - self.latent_vars[z_transform] = qz_transform - self.transformations[qz] = qz_transform + qz_unconstrained = transform(qz) + self.latent_vars_unconstrained[z_unconstrained] = qz_unconstrained + + # additionally construct the transformation of qz + # back into the original constrained space + if z_unconstrained != z: + qz_constrained = transform( + qz_unconstrained, bijectors.Invert(z_unconstrained.bijector)) + + try: # attempt to pushforward the params of Empirical distributions + qz_constrained.params = z_unconstrained.bijector.inverse( + qz_unconstrained.params) + except: # qz_unconstrained is not an Empirical distribution + pass + + else: + qz_constrained = qz_unconstrained + + self.latent_vars[z] = qz_constrained else: self.latent_vars[z] = qz + self.latent_vars_unconstrained[z] = qz del latent_vars if logdir is not None: diff --git a/tests/inferences/test_inference_auto_transform.py b/tests/inferences/test_inference_auto_transform.py index ced8e48e6..36ab88649 100644 --- a/tests/inferences/test_inference_auto_transform.py +++ b/tests/inferences/test_inference_auto_transform.py @@ -7,7 +7,7 @@ import tensorflow as tf from edward.models import (Empirical, Gamma, Normal, PointMass, - TransformedDistribution) + TransformedDistribution, Beta, Bernoulli) from edward.util import transform from tensorflow.contrib.distributions import bijectors @@ -129,9 +129,9 @@ def test_hmc_custom(self): # target distribution. n_samples = 10000 x_unconstrained = inference.transformations[x] - qx_constrained = Empirical(x_unconstrained.bijector.inverse(qx.params)) + qx_constrained_params = x_unconstrained.bijector.inverse(qx.params) x_mean, x_var = tf.nn.moments(x.sample(n_samples), 0) - qx_mean, qx_var = tf.nn.moments(qx_constrained.params[500:], 0) + qx_mean, qx_var = tf.nn.moments(qx_constrained_params[500:], 0) stats = sess.run([x_mean, qx_mean, x_var, qx_var]) self.assertAllClose(stats[0], stats[1], rtol=1e-1, atol=1e-1) self.assertAllClose(stats[2], stats[3], rtol=1e-1, atol=1e-1) @@ -152,16 +152,68 @@ def test_hmc_default(self): # Check approximation on constrained space has same moments as # target distribution. - n_samples = 10000 - x_unconstrained = inference.transformations[x] - qx = inference.latent_vars[x_unconstrained] - qx_constrained = Empirical(x_unconstrained.bijector.inverse(qx.params)) + n_samples = 1000 + qx_constrained = inference.latent_vars[x] x_mean, x_var = tf.nn.moments(x.sample(n_samples), 0) qx_mean, qx_var = tf.nn.moments(qx_constrained.params[500:], 0) stats = sess.run([x_mean, qx_mean, x_var, qx_var]) self.assertAllClose(stats[0], stats[1], rtol=1e-1, atol=1e-1) self.assertAllClose(stats[2], stats[3], rtol=1e-1, atol=1e-1) + def test_hmc_betabernoulli(self): + """Do we correctly handle dependencies of transformed variables?""" + + with self.test_session() as sess: + # model + z = Beta(1., 1., name="z") + xs = Bernoulli(probs=z, sample_shape=10) + x_obs = np.asarray([0, 0, 1, 1, 0, 0, 0, 0, 0, 1], dtype=np.int32) + + # inference + qz_samples = tf.Variable(tf.random_uniform(shape=(1000,))) + qz = ed.models.Empirical(params=qz_samples, name="z_posterior") + inference_hmc = ed.inferences.HMC({z: qz}, data={xs: x_obs}) + inference_hmc.run(step_size=1.0, n_steps=5, auto_transform=True) + + # check that inferred posterior mean/variance is close to + # that of the exact Beta posterior + z_unconstrained = inference_hmc.transformations[z] + qz_constrained = z_unconstrained.bijector.inverse(qz_samples) + qz_mean, qz_var = sess.run(tf.nn.moments(qz_constrained, 0)) + + true_posterior = Beta(1. + np.sum(x_obs), 1. + np.sum(1-x_obs)) + pz_mean, pz_var = sess.run((true_posterior.mean(), + true_posterior.variance())) + self.assertAllClose(qz_mean, pz_mean, rtol=5e-2, atol=5e-2) + self.assertAllClose(qz_var, pz_var, rtol=1e-2, atol=1e-2) + + def test_klqp_betabernoulli(self): + with self.test_session() as sess: + # model + z = Beta(1., 1., name="z") + xs = Bernoulli(probs=z, sample_shape=10) + x_obs = np.asarray([0, 0, 1, 1, 0, 0, 0, 0, 0, 1], dtype=np.int32) + + # inference + qz_mean = tf.get_variable("qz_mean", + initializer=tf.random_normal(())) + qz_std = tf.nn.softplus(tf.get_variable(name="qz_prestd", + initializer=tf.random_normal(()))) + qz_unconstrained = ed.models.Normal(loc=qz_mean, scale=qz_std, name="z_posterior") + + inference_klqp = ed.inferences.KLqp({z: qz_unconstrained}, data={xs: x_obs}) + inference_klqp.run(n_iter=500, auto_transform=True) + + z_unconstrained = inference_klqp.transformations[z] + qz_constrained = z_unconstrained.bijector.inverse(qz_unconstrained.sample(1000)) + qz_mean, qz_var = sess.run(tf.nn.moments(qz_constrained, 0)) + + true_posterior = Beta(np.sum(x_obs) + 1., np.sum(1-x_obs) + 1.) + pz_mean, pz_var = sess.run((true_posterior.mean(), + true_posterior.variance())) + self.assertAllClose(qz_mean, pz_mean, rtol=5e-2, atol=5e-2) + self.assertAllClose(qz_var, pz_var, rtol=1e-2, atol=1e-2) + if __name__ == '__main__': ed.set_seed(124125) tf.test.main()