From e659133539cd4ed6ef0f57d09414a6397568b47f Mon Sep 17 00:00:00 2001 From: Dave Moore Date: Fri, 15 Dec 2017 16:53:11 -0800 Subject: [PATCH] Explicitly track latent and posterior variables in both constrained and unconstrained space --- edward/inferences/hmc.py | 66 ++++++++++++++++--- edward/inferences/inference.py | 44 +++++++++---- .../test_inference_auto_transform.py | 7 +- 3 files changed, 92 insertions(+), 25 deletions(-) diff --git a/edward/inferences/hmc.py b/edward/inferences/hmc.py index 45ed38363..710dae95f 100644 --- a/edward/inferences/hmc.py +++ b/edward/inferences/hmc.py @@ -73,13 +73,25 @@ def build_update(self): The updates assume each Empirical random variable is directly parameterized by `tf.Variable`s. """ - old_sample = {z: tf.gather(qz.params, tf.maximum(self.t - 1, 0)) - for z, qz in six.iteritems(self.latent_vars)} + + # Gather the initial state, transformed to unconstrained space. + try: + self.latent_vars_unconstrained + except: + raise ValueError("This implementation of HMC requires that all " + "variables have unconstrained support. Please " + "initialize with auto_transform=True to ensure " + "this. (if your variables already have unconstrained " + "support then doing this is a no-op).") + old_sample = {z_unconstrained: + tf.gather(qz_unconstrained.params, tf.maximum(self.t - 1, 0)) + for z_unconstrained, qz_unconstrained in + six.iteritems(self.latent_vars_unconstrained)} old_sample = OrderedDict(old_sample) # Sample momentum. old_r_sample = OrderedDict() - for z, qz in six.iteritems(self.latent_vars): + for z, qz in six.iteritems(self.latent_vars_unconstrained): event_shape = qz.event_shape normal = Normal(loc=tf.zeros(event_shape, dtype=qz.dtype), scale=tf.ones(event_shape, dtype=qz.dtype)) @@ -87,7 +99,8 @@ def build_update(self): # Simulate Hamiltonian dynamics. new_sample, new_r_sample = leapfrog(old_sample, old_r_sample, - self.step_size, self._log_joint, + self.step_size, + self._log_joint_unconstrained, self.n_steps) # Calculate acceptance ratio. @@ -95,8 +108,8 @@ def build_update(self): for r in six.itervalues(old_r_sample)]) ratio -= tf.reduce_sum([0.5 * tf.reduce_sum(tf.square(r)) for r in six.itervalues(new_r_sample)]) - ratio += self._log_joint(new_sample) - ratio -= self._log_joint(old_sample) + ratio += self._log_joint_unconstrained(new_sample) + ratio -= self._log_joint_unconstrained(old_sample) # Accept or reject sample. u = Uniform(low=tf.constant(0.0, dtype=ratio.dtype), @@ -108,19 +121,51 @@ def build_update(self): # `tf.cond` returns tf.Tensor if output is a list of size 1. sample_values = [sample_values] - sample = {z: sample_value for z, sample_value in + sample = {z_unconstrained: sample_value for + z_unconstrained, sample_value in zip(six.iterkeys(new_sample), sample_values)} # Update Empirical random variables. assign_ops = [] - for z, qz in six.iteritems(self.latent_vars): - variable = qz.get_variables()[0] - assign_ops.append(tf.scatter_update(variable, self.t, sample[z])) + for z_unconstrained, qz_unconstrained in six.iteritems( + self.latent_vars_unconstrained): + variable = qz_unconstrained.get_variables()[0] + assign_ops.append(tf.scatter_update( + variable, self.t, sample[z_unconstrained])) # Increment n_accept (if accepted). assign_ops.append(self.n_accept.assign_add(tf.where(accept, 1, 0))) return tf.group(*assign_ops) + def _log_joint_unconstrained(self, z_sample): + """ + Given a sample in unconstrained latent space, transform it back into + the original space, and compute the log joint density with appropriate + Jacobian correction. + """ + + unconstrained_to_z = {v: k for (k, v) in self.transformations.items()} + + # transform all samples back into the original (potentially + # constrained) space. + z_sample_transformed = {} + log_det_jacobian = 0.0 + for z_unconstrained, qz_unconstrained in z_sample.items(): + z = (unconstrained_to_z[z_unconstrained] + if z_unconstrained in unconstrained_to_z + else z_unconstrained) + + try: + bij = self.transformations[z].bijector + z_sample_transformed[z] = bij.inverse(qz_unconstrained) + log_det_jacobian += tf.reduce_sum( + bij.inverse_log_det_jacobian(qz_unconstrained)) + except: # if z not in self.transformations, + # or is not a TransformedDist w/ bijector + z_sample_transformed[z] = qz_unconstrained + + return self._log_joint(z_sample_transformed) + log_det_jacobian + def _log_joint(self, z_sample): """Utility function to calculate model's log joint density, log p(x, z), for inputs z (and fixed data x). @@ -133,6 +178,7 @@ def _log_joint(self, z_sample): # Form dictionary in order to replace conditioning on prior or # observed variable with conditioning on a specific value. dict_swap = z_sample.copy() + for x, qx in six.iteritems(self.data): if isinstance(x, RandomVariable): if isinstance(qx, RandomVariable): diff --git a/edward/inferences/inference.py b/edward/inferences/inference.py index 1d3cdc3b2..85dc7f116 100644 --- a/edward/inferences/inference.py +++ b/edward/inferences/inference.py @@ -13,6 +13,7 @@ from edward.util import check_data, check_latent_vars, get_session, \ get_variables, Progbar, transform +from tensorflow.contrib.distributions import bijectors @six.add_metaclass(abc.ABCMeta) class Inference(object): @@ -217,25 +218,46 @@ def initialize(self, n_iter=1000, n_print=None, scale=None, self.scale = scale - # Set of all latent variables binded to their transformation on - # the unconstrained space (if any). + # map from original latent vars to unconstrained versions self.transformations = {} if auto_transform: latent_vars = self.latent_vars.copy() - self.latent_vars = {} + self.latent_vars = {} # maps original latent vars to constrained Q's + self.latent_vars_unconstrained = {} # maps unconstrained vars to unconstrained Q's for z, qz in six.iteritems(latent_vars): if hasattr(z, 'support') and hasattr(qz, 'support') and \ - z.support != qz.support and qz.support != 'point': - z_transform = transform(z) - self.transformations[z] = z_transform - if qz.support == 'points': # don't transform empirical approx's - self.latent_vars[z_transform] = qz + z.support != qz.support and qz.support != 'point': + + # transform z to an unconstrained space + z_unconstrained = transform(z) + self.transformations[z] = z_unconstrained + + # make sure we also have a qz that covers the unconstrained space + if qz.support == "points": + qz_unconstrained = qz else: - qz_transform = transform(qz) - self.latent_vars[z_transform] = qz_transform - self.transformations[qz] = qz_transform + qz_unconstrained = transform(qz) + self.latent_vars_unconstrained[z_unconstrained] = qz_unconstrained + + # additionally construct the transformation of qz + # back into the original constrained space + if z_unconstrained != z: + qz_constrained = transform( + qz_unconstrained, bijectors.Invert(z_unconstrained.bijector)) + + try: # attempt to pushforward the params of Empirical distributions + qz_constrained.params = z_unconstrained.bijector.inverse( + qz_unconstrained.params) + except: # qz_unconstrained is not an Empirical distribution + pass + + else: + qz_constrained = qz_unconstrained + + self.latent_vars[z] = qz_constrained else: self.latent_vars[z] = qz + self.latent_vars_unconstrained[z] = qz del latent_vars if logdir is not None: diff --git a/tests/inferences/test_inference_auto_transform.py b/tests/inferences/test_inference_auto_transform.py index 1e6d31c63..36ab88649 100644 --- a/tests/inferences/test_inference_auto_transform.py +++ b/tests/inferences/test_inference_auto_transform.py @@ -107,7 +107,7 @@ def test_laplace_default(self): # Check approximation on constrained space has same mode as # target distribution. - qx = inference.latent_vars[x] + qx = inference.latent_vars[x] stats = sess.run([x.mode(), qx.mean()]) self.assertAllClose(stats[0], stats[1], rtol=1e-5, atol=1e-5) @@ -127,7 +127,7 @@ def test_hmc_custom(self): # Check approximation on constrained space has same moments as # target distribution. - n_samples = 10000 + n_samples = 10000 x_unconstrained = inference.transformations[x] qx_constrained_params = x_unconstrained.bijector.inverse(qx.params) x_mean, x_var = tf.nn.moments(x.sample(n_samples), 0) @@ -144,8 +144,7 @@ def test_hmc_default(self): x.support = 'nonnegative' inference = ed.HMC([x]) - inference.initialize(auto_transform=True, - step_size=0.8) + inference.initialize(auto_transform=True, step_size=0.8) tf.global_variables_initializer().run() for _ in range(inference.n_iter): info_dict = inference.update()