diff --git a/edward/__init__.py b/edward/__init__.py index d71eab3ee..c28673559 100644 --- a/edward/__init__.py +++ b/edward/__init__.py @@ -21,7 +21,6 @@ RandomVariable from edward.util import copy, dot, get_ancestors, get_children, \ get_descendants, get_dims, get_parents, get_session, get_siblings, \ - get_variables, hessian, kl_multivariate_normal, log_sum_exp, logit, \ - multivariate_rbf, placeholder, random_variables, rbf, set_seed, \ - to_simplex + get_variables, hessian, log_sum_exp, logit, multivariate_rbf, \ + placeholder, random_variables, rbf, set_seed, to_simplex from edward.version import __version__ diff --git a/edward/inferences/klqp.py b/edward/inferences/klqp.py index 9bc5bdd0f..a71f25b2a 100644 --- a/edward/inferences/klqp.py +++ b/edward/inferences/klqp.py @@ -8,7 +8,8 @@ from edward.inferences.variational_inference import VariationalInference from edward.models import RandomVariable, Normal -from edward.util import copy, kl_multivariate_normal +from edward.util import copy +from tensorflow.contrib import distributions as ds class KLqp(VariationalInference): @@ -92,8 +93,8 @@ def build_loss_and_gradients(self, var_list): of the loss function. - If the variational model is a normal distribution and the prior is - standard normal, then loss function can be written as + If the KL divergence between the variational model and the prior + is tractable, then the loss function can be written as .. math:: @@ -101,16 +102,14 @@ def build_loss_and_gradients(self, var_list): \\text{KL}( q(z; \lambda) \| p(z) ), where the KL term is computed analytically (Kingma and Welling, - 2014). + 2014). We compute this automatically when :math:`p(z)` and + :math:`q(z; \lambda)` are Normal. """ is_reparameterizable = all([rv.is_reparameterized and rv.is_continuous for rv in six.itervalues(self.latent_vars)]) - qz_is_normal = all([isinstance(rv, Normal) for - rv in six.itervalues(self.latent_vars)]) - z_is_normal = all([isinstance(rv, Normal) for - rv in six.iterkeys(self.latent_vars)]) - is_analytic_kl = qz_is_normal and \ - (z_is_normal or hasattr(self.model_wrapper, 'log_lik')) + is_analytic_kl = all([isinstance(z, Normal) and isinstance(qz, Normal) + for z, qz in six.iteritems(self.latent_vars)]) or \ + hasattr(self.model_wrapper, 'log_lik') if is_reparameterizable: if is_analytic_kl: return build_reparam_kl_loss_and_gradients(self, var_list) @@ -451,12 +450,12 @@ def build_reparam_kl_loss_and_gradients(inference, var_list): if inference.model_wrapper is None: kl = tf.reduce_sum([ - inference.kl_scaling.get(z, 1.0) * tf.reduce_sum( - kl_multivariate_normal(qz.mu, qz.sigma, z.mu, z.sigma)) + inference.kl_scaling.get(z, 1.0) * tf.reduce_sum(ds.kl(qz, z)) for z, qz in six.iteritems(inference.latent_vars)]) else: - kl = tf.reduce_sum([tf.reduce_sum(kl_multivariate_normal(qz.mu, qz.sigma)) - for qz in six.itervalues(inference.latent_vars)]) + kl = tf.reduce_sum([tf.reduce_sum( + ds.kl(qz, Normal(mu=tf.zeros_like(qz), sigma=tf.ones_like(qz)))) + for qz in six.itervalues(inference.latent_vars)]) loss = -(tf.reduce_mean(p_log_lik) - kl) @@ -521,8 +520,8 @@ def build_reparam_entropy_loss_and_gradients(inference, var_list): p_log_prob = tf.stack(p_log_prob) - q_entropy = tf.reduce_sum([inference.data.get(z, 1.0) * qz.entropy() - for z, qz in six.iteritems(inference.latent_vars)]) + q_entropy = tf.reduce_sum([ + qz.entropy() for z, qz in six.iteritems(inference.latent_vars)]) loss = -(tf.reduce_mean(p_log_prob) + q_entropy) @@ -647,12 +646,12 @@ def build_score_kl_loss_and_gradients(inference, var_list): if inference.model_wrapper is None: kl = tf.reduce_sum([ - inference.kl_scaling.get(z, 1.0) * tf.reduce_sum( - kl_multivariate_normal(qz.mu, qz.sigma, z.mu, z.sigma)) + inference.kl_scaling.get(z, 1.0) * tf.reduce_sum(ds.kl(qz, z)) for z, qz in six.iteritems(inference.latent_vars)]) else: - kl = tf.reduce_sum([tf.reduce_sum(kl_multivariate_normal(qz.mu, qz.sigma)) - for qz in six.itervalues(inference.latent_vars)]) + kl = tf.reduce_sum([tf.reduce_sum( + ds.kl(qz, Normal(mu=tf.zeros_like(qz), sigma=tf.ones_like(qz)))) + for qz in six.itervalues(inference.latent_vars)]) if var_list is None: var_list = tf.trainable_variables() @@ -716,8 +715,8 @@ def build_score_entropy_loss_and_gradients(inference, var_list): p_log_prob = tf.stack(p_log_prob) q_log_prob = tf.stack(q_log_prob) - q_entropy = tf.reduce_sum([inference.data.get(z, 1.0) * qz.entropy() - for z, qz in six.iteritems(inference.latent_vars)]) + q_entropy = tf.reduce_sum([ + qz.entropy() for z, qz in six.iteritems(inference.latent_vars)]) if var_list is None: var_list = tf.trainable_variables() diff --git a/edward/util/tensorflow.py b/edward/util/tensorflow.py index 0e4ae6dee..dfea20cb4 100644 --- a/edward/util/tensorflow.py +++ b/edward/util/tensorflow.py @@ -110,69 +110,6 @@ def hessian(y, xs): return tf.stack(mat) -def kl_multivariate_normal(loc_one, scale_one, loc_two=0.0, scale_two=1.0): - """Calculate the KL of multivariate normal distributions with - diagonal covariances. - - Parameters - ---------- - loc_one : tf.Tensor - A 0-D tensor, 1-D tensor of length n, or 2-D tensor of shape M - x n where each row represents the mean of a n-dimensional - Gaussian. - scale_one : tf.Tensor - A tensor of same shape as ``loc_one``, representing the - standard deviation. - loc_two : tf.Tensor, optional - A tensor of same shape as ``loc_one``, representing the - mean of another Gaussian. - scale_two : tf.Tensor, optional - A tensor of same shape as ``loc_one``, representing the - standard deviation of another Gaussian. - - Returns - ------- - tf.Tensor - For 0-D or 1-D tensor inputs, outputs the 0-D tensor - ``KL( N(z; loc_one, scale_one) || N(z; loc_two, scale_two) )`` - For 2-D tensor inputs, outputs the 1-D tensor - ``[KL( N(z; loc_one[m,:], scale_one[m,:]) || `` - ``N(z; loc_two[m,:], scale_two[m,:]) )]_{m=1}^M`` - - Raises - ------ - InvalidArgumentError - If the location variables have Inf or NaN values, or if the scale - variables are not positive. - """ - loc_one = tf.convert_to_tensor(loc_one) - scale_one = tf.convert_to_tensor(scale_one) - loc_two = tf.convert_to_tensor(loc_two) - scale_two = tf.convert_to_tensor(scale_two) - dependencies = [tf.verify_tensor_all_finite(loc_one, msg=''), - tf.verify_tensor_all_finite(loc_two, msg=''), - tf.assert_positive(scale_one), - tf.assert_positive(scale_two)] - loc_one = control_flow_ops.with_dependencies(dependencies, loc_one) - scale_one = control_flow_ops.with_dependencies(dependencies, scale_one) - - if loc_two == 0.0 and scale_two == 1.0: - # With default arguments, we can avoid some intermediate computation. - out = tf.square(scale_one) + tf.square(loc_one) - \ - 1.0 - 2.0 * tf.log(scale_one) - else: - loc_two = control_flow_ops.with_dependencies(dependencies, loc_two) - scale_two = control_flow_ops.with_dependencies(dependencies, scale_two) - out = tf.square(scale_one / scale_two) + \ - tf.square((loc_two - loc_one) / scale_two) - \ - 1.0 + 2.0 * tf.log(scale_two) - 2.0 * tf.log(scale_one) - - if len(out.get_shape()) <= 1: # scalar or vector - return 0.5 * tf.reduce_sum(out) - else: # matrix - return 0.5 * tf.reduce_sum(out, 1) - - def log_mean_exp(input_tensor, axis=None, keep_dims=False): """Compute the ``log_mean_exp`` of elements in a tensor, taking the mean across axes given by ``axis``. diff --git a/tests/test-util/test_kl_multivariate_normal.py b/tests/test-util/test_kl_multivariate_normal.py deleted file mode 100644 index e9f9df986..000000000 --- a/tests/test-util/test_kl_multivariate_normal.py +++ /dev/null @@ -1,154 +0,0 @@ -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -import numpy as np -import tensorflow as tf - -from edward.util import kl_multivariate_normal - - -class test_kl_multivariate_normal_class(tf.test.TestCase): - - def test_kl_multivariate_normal_0d(self): - with self.test_session(): - loc_one = tf.constant(0.0) - scale_one = tf.constant(1.0) - self.assertAllClose(kl_multivariate_normal(loc_one, - scale_one).eval(), - 0.0) - loc_one = tf.constant(10.0) - scale_one = tf.constant(2.0) - self.assertAllClose(kl_multivariate_normal(loc_one, - scale_one).eval(), - 50.806854) - loc_one = tf.constant(0.0) - scale_one = tf.constant(1.0) - loc_two = tf.constant(0.0) - scale_two = tf.constant(1.0) - self.assertAllClose(kl_multivariate_normal(loc_one, - scale_one, - loc_two=loc_two, - scale_two=scale_two).eval(), - 0.0) - loc_one = tf.constant(10.0) - scale_one = tf.constant(2.0) - loc_two = tf.constant(10.0) - scale_two = tf.constant(5.0) - self.assertAllClose(kl_multivariate_normal(loc_one, - scale_one, - loc_two=loc_two, - scale_two=scale_two).eval(), - 0.496290802) - - def test_kl_multivariate_normal_1d(self): - with self.test_session(): - loc_one = tf.constant([0.0]) - scale_one = tf.constant([1.0]) - self.assertAllClose(kl_multivariate_normal(loc_one, - scale_one).eval(), - 0.0) - loc_one = tf.constant([10.0]) - scale_one = tf.constant([2.0]) - self.assertAllClose(kl_multivariate_normal(loc_one, - scale_one).eval(), - 50.806854) - loc_one = tf.constant([10.0]) - scale_one = tf.constant([2.0]) - loc_two = tf.constant([10.0]) - scale_two = tf.constant([2.0]) - self.assertAllClose(kl_multivariate_normal(loc_one, - scale_one, - loc_two=loc_two, - scale_two=scale_two).eval(), - 0.0) - loc_one = tf.constant([0.0, 0.0]) - scale_one = tf.constant([1.0, 1.0]) - self.assertAllClose(kl_multivariate_normal(loc_one, - scale_one).eval(), - 0.0) - loc_one = tf.constant([10.0, 10.0]) - scale_one = tf.constant([2.0, 2.0]) - self.assertAllClose(kl_multivariate_normal(loc_one, - scale_one).eval(), - 101.61370849) - loc_one = tf.constant([10.0, 10.0]) - scale_one = tf.constant([2.0, 2.0]) - loc_two = tf.constant([9.0, 9.0]) - scale_two = tf.constant([1.0, 1.0]) - self.assertAllClose(kl_multivariate_normal(loc_one, - scale_one, - loc_two=loc_two, - scale_two=scale_two).eval(), - 2.6137056350) - - def test_kl_multivariate_normal_2d(self): - with self.test_session(): - loc_one = tf.constant([[0.0, 0.0], [0.0, 0.0]]) - scale_one = tf.constant([[1.0, 1.0], [1.0, 1.0]]) - self.assertAllClose(kl_multivariate_normal(loc_one, - scale_one).eval(), - np.array([0.0, 0.0])) - loc_one = tf.constant([[10.0, 10.0], [10.0, 10.0]]) - scale_one = tf.constant([[2.0, 2.0], [2.0, 2.0]]) - self.assertAllClose(kl_multivariate_normal(loc_one, - scale_one).eval(), - np.array([101.61370849, 101.61370849])) - loc_one = tf.constant([[10.0, 10.0], [10.0, 10.0]]) - scale_one = tf.constant([[2.0, 2.0], [2.0, 2.0]]) - loc_two = tf.constant([[10.0, 10.0], [10.0, 10.0]]) - scale_two = tf.constant([[2.0, 2.0], [2.0, 2.0]]) - self.assertAllClose(kl_multivariate_normal(loc_one, - scale_one, - loc_two=loc_two, - scale_two=scale_two).eval(), - np.array([0.0, 0.0])) - loc_one = tf.constant([[10.0, 10.0], [0.0, 0.0]]) - scale_one = tf.constant([[2.0, 2.0], [1.0, 1.0]]) - loc_two = tf.constant([[9.0, 9.0], [0.0, 0.0]]) - scale_two = tf.constant([[1.0, 1.0], [1.0, 1.0]]) - self.assertAllClose(kl_multivariate_normal(loc_one, - scale_one, - loc_two=loc_two, - scale_two=scale_two).eval(), - np.array([2.6137056350, 0.0])) - - def test_contraint_raises(self): - with self.test_session(): - loc_one = tf.constant(10.0) - scale_one = tf.constant(-1.0) - loc_two = tf.constant(10.0) - scale_two = tf.constant(-1.0) - with self.assertRaisesOpError('Condition'): - kl_multivariate_normal(loc_one, - scale_one).eval() - kl_multivariate_normal(loc_one, - scale_one, - loc_two=loc_two, - scale_two=scale_two).eval() - loc_one = np.inf * tf.constant(10.0) - scale_one = tf.constant(1.0) - loc_two = tf.constant(10.0) - scale_two = tf.constant(1.0) - with self.assertRaisesOpError('Inf'): - kl_multivariate_normal(loc_one, - scale_one).eval() - kl_multivariate_normal(loc_one, - scale_one, - loc_two=loc_two, - scale_two=scale_two).eval() - loc_one = tf.constant(10.0) - scale_one = tf.constant(1.0) - loc_two = np.nan * tf.constant(10.0) - scale_two = tf.constant(1.0) - with self.assertRaisesOpError('NaN'): - kl_multivariate_normal(loc_one, - scale_one).eval() - kl_multivariate_normal(loc_one, - scale_one, - loc_two=loc_two, - scale_two=scale_two).eval() - - -if __name__ == '__main__': - tf.test.main()