Skip to content

Commit

Permalink
use distributions.KL instead of utility function
Browse files Browse the repository at this point in the history
  • Loading branch information
dustinvtran committed Mar 4, 2017
1 parent 50587de commit 108cf12
Show file tree
Hide file tree
Showing 4 changed files with 23 additions and 242 deletions.
5 changes: 2 additions & 3 deletions edward/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@
RandomVariable
from edward.util import copy, dot, get_ancestors, get_children, \
get_descendants, get_dims, get_parents, get_session, get_siblings, \
get_variables, hessian, kl_multivariate_normal, log_sum_exp, logit, \
multivariate_rbf, placeholder, random_variables, rbf, set_seed, \
to_simplex
get_variables, hessian, log_sum_exp, logit, multivariate_rbf, \
placeholder, random_variables, rbf, set_seed, to_simplex
from edward.version import __version__
43 changes: 21 additions & 22 deletions edward/inferences/klqp.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,8 @@

from edward.inferences.variational_inference import VariationalInference
from edward.models import RandomVariable, Normal
from edward.util import copy, kl_multivariate_normal
from edward.util import copy
from tensorflow.contrib import distributions as ds


class KLqp(VariationalInference):
Expand Down Expand Up @@ -92,25 +93,23 @@ def build_loss_and_gradients(self, var_list):
of the loss function.
If the variational model is a normal distribution and the prior is
standard normal, then loss function can be written as
If the KL divergence between the variational model and the prior
is tractable, then the loss function can be written as
.. math::
-\mathbb{E}_{q(z; \lambda)}[\log p(x \mid z)] +
\\text{KL}( q(z; \lambda) \| p(z) ),
where the KL term is computed analytically (Kingma and Welling,
2014).
2014). We compute this automatically when :math:`p(z)` and
:math:`q(z; \lambda)` are Normal.
"""
is_reparameterizable = all([rv.is_reparameterized and rv.is_continuous
for rv in six.itervalues(self.latent_vars)])
qz_is_normal = all([isinstance(rv, Normal) for
rv in six.itervalues(self.latent_vars)])
z_is_normal = all([isinstance(rv, Normal) for
rv in six.iterkeys(self.latent_vars)])
is_analytic_kl = qz_is_normal and \
(z_is_normal or hasattr(self.model_wrapper, 'log_lik'))
is_analytic_kl = all([isinstance(z, Normal) and isinstance(qz, Normal)
for z, qz in six.iteritems(self.latent_vars)]) or \
hasattr(self.model_wrapper, 'log_lik')
if is_reparameterizable:
if is_analytic_kl:
return build_reparam_kl_loss_and_gradients(self, var_list)
Expand Down Expand Up @@ -451,12 +450,12 @@ def build_reparam_kl_loss_and_gradients(inference, var_list):

if inference.model_wrapper is None:
kl = tf.reduce_sum([
inference.kl_scaling.get(z, 1.0) * tf.reduce_sum(
kl_multivariate_normal(qz.mu, qz.sigma, z.mu, z.sigma))
inference.kl_scaling.get(z, 1.0) * tf.reduce_sum(ds.kl(qz, z))
for z, qz in six.iteritems(inference.latent_vars)])
else:
kl = tf.reduce_sum([tf.reduce_sum(kl_multivariate_normal(qz.mu, qz.sigma))
for qz in six.itervalues(inference.latent_vars)])
kl = tf.reduce_sum([tf.reduce_sum(
ds.kl(qz, Normal(mu=tf.zeros_like(qz), sigma=tf.ones_like(qz))))
for qz in six.itervalues(inference.latent_vars)])

loss = -(tf.reduce_mean(p_log_lik) - kl)

Expand Down Expand Up @@ -521,8 +520,8 @@ def build_reparam_entropy_loss_and_gradients(inference, var_list):

p_log_prob = tf.stack(p_log_prob)

q_entropy = tf.reduce_sum([inference.data.get(z, 1.0) * qz.entropy()
for z, qz in six.iteritems(inference.latent_vars)])
q_entropy = tf.reduce_sum([
qz.entropy() for z, qz in six.iteritems(inference.latent_vars)])

loss = -(tf.reduce_mean(p_log_prob) + q_entropy)

Expand Down Expand Up @@ -647,12 +646,12 @@ def build_score_kl_loss_and_gradients(inference, var_list):

if inference.model_wrapper is None:
kl = tf.reduce_sum([
inference.kl_scaling.get(z, 1.0) * tf.reduce_sum(
kl_multivariate_normal(qz.mu, qz.sigma, z.mu, z.sigma))
inference.kl_scaling.get(z, 1.0) * tf.reduce_sum(ds.kl(qz, z))
for z, qz in six.iteritems(inference.latent_vars)])
else:
kl = tf.reduce_sum([tf.reduce_sum(kl_multivariate_normal(qz.mu, qz.sigma))
for qz in six.itervalues(inference.latent_vars)])
kl = tf.reduce_sum([tf.reduce_sum(
ds.kl(qz, Normal(mu=tf.zeros_like(qz), sigma=tf.ones_like(qz))))
for qz in six.itervalues(inference.latent_vars)])

if var_list is None:
var_list = tf.trainable_variables()
Expand Down Expand Up @@ -716,8 +715,8 @@ def build_score_entropy_loss_and_gradients(inference, var_list):
p_log_prob = tf.stack(p_log_prob)
q_log_prob = tf.stack(q_log_prob)

q_entropy = tf.reduce_sum([inference.data.get(z, 1.0) * qz.entropy()
for z, qz in six.iteritems(inference.latent_vars)])
q_entropy = tf.reduce_sum([
qz.entropy() for z, qz in six.iteritems(inference.latent_vars)])

if var_list is None:
var_list = tf.trainable_variables()
Expand Down
63 changes: 0 additions & 63 deletions edward/util/tensorflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,69 +110,6 @@ def hessian(y, xs):
return tf.stack(mat)


def kl_multivariate_normal(loc_one, scale_one, loc_two=0.0, scale_two=1.0):
"""Calculate the KL of multivariate normal distributions with
diagonal covariances.
Parameters
----------
loc_one : tf.Tensor
A 0-D tensor, 1-D tensor of length n, or 2-D tensor of shape M
x n where each row represents the mean of a n-dimensional
Gaussian.
scale_one : tf.Tensor
A tensor of same shape as ``loc_one``, representing the
standard deviation.
loc_two : tf.Tensor, optional
A tensor of same shape as ``loc_one``, representing the
mean of another Gaussian.
scale_two : tf.Tensor, optional
A tensor of same shape as ``loc_one``, representing the
standard deviation of another Gaussian.
Returns
-------
tf.Tensor
For 0-D or 1-D tensor inputs, outputs the 0-D tensor
``KL( N(z; loc_one, scale_one) || N(z; loc_two, scale_two) )``
For 2-D tensor inputs, outputs the 1-D tensor
``[KL( N(z; loc_one[m,:], scale_one[m,:]) || ``
``N(z; loc_two[m,:], scale_two[m,:]) )]_{m=1}^M``
Raises
------
InvalidArgumentError
If the location variables have Inf or NaN values, or if the scale
variables are not positive.
"""
loc_one = tf.convert_to_tensor(loc_one)
scale_one = tf.convert_to_tensor(scale_one)
loc_two = tf.convert_to_tensor(loc_two)
scale_two = tf.convert_to_tensor(scale_two)
dependencies = [tf.verify_tensor_all_finite(loc_one, msg=''),
tf.verify_tensor_all_finite(loc_two, msg=''),
tf.assert_positive(scale_one),
tf.assert_positive(scale_two)]
loc_one = control_flow_ops.with_dependencies(dependencies, loc_one)
scale_one = control_flow_ops.with_dependencies(dependencies, scale_one)

if loc_two == 0.0 and scale_two == 1.0:
# With default arguments, we can avoid some intermediate computation.
out = tf.square(scale_one) + tf.square(loc_one) - \
1.0 - 2.0 * tf.log(scale_one)
else:
loc_two = control_flow_ops.with_dependencies(dependencies, loc_two)
scale_two = control_flow_ops.with_dependencies(dependencies, scale_two)
out = tf.square(scale_one / scale_two) + \
tf.square((loc_two - loc_one) / scale_two) - \
1.0 + 2.0 * tf.log(scale_two) - 2.0 * tf.log(scale_one)

if len(out.get_shape()) <= 1: # scalar or vector
return 0.5 * tf.reduce_sum(out)
else: # matrix
return 0.5 * tf.reduce_sum(out, 1)


def log_mean_exp(input_tensor, axis=None, keep_dims=False):
"""Compute the ``log_mean_exp`` of elements in a tensor, taking
the mean across axes given by ``axis``.
Expand Down
154 changes: 0 additions & 154 deletions tests/test-util/test_kl_multivariate_normal.py

This file was deleted.

0 comments on commit 108cf12

Please sign in to comment.