Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Laplace approximation now uses Multivariate Normal's #506

Merged
merged 6 commits into from
Mar 5, 2017
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions docs/tex/bib.bib
Original file line number Diff line number Diff line change
Expand Up @@ -652,3 +652,12 @@ @article{marin2012approximate
number = {6},
pages = {1167--1180}
}

@article{fisher1925theory,
author = {Fisher, R A},
title = {{Theory of statistical estimation}},
journal = {Mathematical Proceedings of the Cambridge Philosophical Society},
year = {1925},
volume = {22},
number = {5}
}
12 changes: 7 additions & 5 deletions docs/tex/tutorials/map-laplace.tex
Original file line number Diff line number Diff line change
Expand Up @@ -20,13 +20,16 @@ \subsection{Laplace approximation}
&\approx
\text{Normal}(\mathbf{z}\;;\; \mathbf{z}_\text{MAP}, \Lambda^{-1}).
\end{align*}
This requires computing a precision matrix $\Lambda$. The Laplace approximation
uses the Hessian of the log joint density at the MAP estimate,
defined component-wise as
This requires computing a precision matrix $\Lambda$. Derived from a
Taylor expansion, the Laplace approximation uses the Hessian of the
negative log joint density at the MAP estimate. For flat priors
(equivalent to maximum likelihood), the precision matrix is known
as the observed Fisher information \citep{fisher1925theory}.
It is defined component-wise as
\begin{align*}
\Lambda_{ij}
&=
\frac{\partial^2 \log p(\mathbf{x}, \mathbf{z})}{\partial z_i \partial z_j}.
\frac{\partial^2}{\partial z_i \partial z_j} -\log p(\mathbf{x}, \mathbf{z}).
\end{align*}
Edward uses automatic differentiation, specifically with TensorFlow's
computational graphs, making this gradient computation both simple and
Expand All @@ -37,4 +40,3 @@ \subsection{Laplace approximation}
implementation in Edward's code base.

\subsubsection{References}\label{references}

4 changes: 2 additions & 2 deletions edward/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,6 @@
RandomVariable
from edward.util import copy, dot, get_ancestors, get_children, \
get_descendants, get_dims, get_parents, get_session, get_siblings, \
get_variables, hessian, logit, multivariate_rbf, placeholder, \
random_variables, rbf, reduce_logmeanexp, set_seed, to_simplex
get_variables, logit, multivariate_rbf, placeholder, random_variables, \
rbf, reduce_logmeanexp, set_seed, to_simplex
from edward.version import __version__
1 change: 1 addition & 0 deletions edward/inferences/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from edward.inferences.inference import *
from edward.inferences.klpq import *
from edward.inferences.klqp import *
from edward.inferences.laplace import *
from edward.inferences.map import *
from edward.inferences.metropolis_hastings import *
from edward.inferences.monte_carlo import *
Expand Down
136 changes: 136 additions & 0 deletions edward/inferences/laplace.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,136 @@
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import six
import tensorflow as tf

from edward.inferences.map import MAP
from edward.models import \
MultivariateNormalCholesky, MultivariateNormalDiag, \
MultivariateNormalFull, PointMass, RandomVariable
from edward.util import get_session, get_variables


class Laplace(MAP):
"""Laplace approximation (Laplace, 1774).

It approximates the posterior distribution using a multivariate
normal distribution centered at the mode of the posterior.

We implement this by running ``MAP`` to find the posterior mode.
This forms the mean of the normal approximation. We then compute the
inverse Hessian at the mode of the posterior. This forms the
covariance of the normal approximation.
"""
def __init__(self, latent_vars, data=None, model_wrapper=None):
"""
Parameters
----------
latent_vars : list of RandomVariable or
dict of RandomVariable to RandomVariable
Collection of random variables to perform inference on. If list,
each random variable will be implictly optimized using a
``MultivariateNormalCholesky`` random variable that is defined
internally (with unconstrained support). If dictionary, each
random variable must be a ``MultivariateNormalCholesky``,
``MultivariateNormalFull``, or ``MultivariateNormalDiag`` random
variable.

Notes
-----
If ``MultivariateNormalDiag`` random variables are specified as
approximations, then the Laplace approximation will only produce
the diagonal. This does not capture correlation among the
variables but it does not require a potentially expensive matrix
inversion.

Examples
--------
>>> X = tf.placeholder(tf.float32, [N, D])
>>> w = Normal(mu=tf.zeros(D), sigma=tf.ones(D))
>>> y = Normal(mu=ed.dot(X, w), sigma=tf.ones(N))
>>>
>>> qw = MultivariateNormalFull(mu=tf.Variable(tf.random_normal([D])),
>>> sigma=tf.Variable(tf.random_normal([D, D])))
>>>
>>> inference = ed.Laplace({w: qw}, data={X: X_train, y: y_train})
"""
if isinstance(latent_vars, list):
with tf.variable_scope("posterior"):
if model_wrapper is None:
latent_vars = {rv: MultivariateNormalCholesky(
mu=tf.Variable(tf.random_normal(rv.batch_shape())),
chol=tf.Variable(tf.random_normal(
rv.get_batch_shape().concatenate(rv.get_batch_shape()[-1]))))
for rv in latent_vars}
elif len(latent_vars) == 1:
latent_vars = {latent_vars[0]: MultivariateNormalCholesky(
mu=tf.Variable(tf.random_normal([model_wrapper.n_vars])),
chol=tf.Variable(tf.random_normal([model_wrapper.n_vars] * 2)))}
elif len(latent_vars) == 0:
latent_vars = {}
else:
raise NotImplementedError("A list of more than one element is "
"not supported. See documentation.")
elif isinstance(latent_vars, dict):
for qz in six.itervalues(latent_vars):
if not isinstance(
qz, (MultivariateNormalCholesky, MultivariateNormalDiag,
MultivariateNormalFull)):
raise TypeError("Posterior approximation must consist of only "
"MultivariateCholesky, MultivariateNormalDiag, "
"or MultivariateNormalFull random variables.")

# call grandparent's method; avoid parent (MAP)
super(MAP, self).__init__(latent_vars, data, model_wrapper)

def initialize(self, var_list=None, *args, **kwargs):
# Store latent variables in a temporary attribute; MAP will
# optimize ``PointMass`` random variables, which subsequently
# optimizes mean parameters of the normal approximations.
self.latent_vars_normal = self.latent_vars.copy()
self.latent_vars = {z: PointMass(params=qz.mu)
for z, qz in six.iteritems(self.latent_vars_normal)}
super(Laplace, self).initialize(var_list, *args, **kwargs)

def finalize(self, feed_dict=None):
"""Function to call after convergence.

Computes the Hessian at the mode.

Parameters
----------
feed_dict : dict, optional
Feed dictionary for a TensorFlow session run during evaluation
of Hessian. It is used to feed placeholders that are not fed
during initialization.
"""
if feed_dict is None:
feed_dict = {}

for key, value in six.iteritems(self.data):
if isinstance(key, tf.Tensor) and "Placeholder" in key.op.type:
feed_dict[key] = value

var_list = list(six.itervalues(self.latent_vars))
hessians = tf.hessians(self.loss, var_list)

assign_ops = []
for z, hessian in zip(six.iterkeys(self.latent_vars), hessians):
qz = self.latent_vars_normal[z]
sigma_var = get_variables(qz.sigma)[0]
if isinstance(qz, MultivariateNormalCholesky):
sigma = tf.matrix_inverse(tf.cholesky(hessian))
elif isinstance(qz, MultivariateNormalDiag):
sigma = 1.0 / tf.diag_part(hessian)
else: # qz is MultivariateNormalFull
sigma = tf.matrix_inverse(hessian)

assign_ops.append(sigma_var.assign(sigma))

sess = get_session()
sess.run(assign_ops, feed_dict)
self.latent_vars = self.latent_vars_normal.copy()
del self.latent_vars_normal
super(Laplace, self).finalize()
27 changes: 1 addition & 26 deletions edward/inferences/map.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

from edward.inferences.variational_inference import VariationalInference
from edward.models import RandomVariable, PointMass
from edward.util import copy, hessian
from edward.util import copy


class MAP(VariationalInference):
Expand Down Expand Up @@ -143,28 +143,3 @@ def build_loss_and_gradients(self, var_list):
grads = tf.gradients(loss, [v._ref() for v in var_list])
grads_and_vars = list(zip(grads, var_list))
return loss, grads_and_vars


class Laplace(MAP):
"""Laplace approximation.

It approximates the posterior distribution using a normal
distribution centered at the mode of the posterior.
"""
def __init__(self, *args, **kwargs):
super(Laplace, self).__init__(*args, **kwargs)

def finalize(self):
"""Function to call after convergence.

Computes the Hessian at the mode.
"""
# use only a batch of data to estimate hessian
x = self.data
z = {z: qz.value() for z, qz in six.iteritems(self.latent_vars)}
var_list = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES,
scope='posterior')
inv_cov = hessian(self.model_wrapper.log_prob(x, z), var_list)
print("Precision matrix:")
print(inv_cov.eval())
super(Laplace, self).finalize()
41 changes: 19 additions & 22 deletions edward/inferences/variational_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,28 +44,25 @@ def initialize(self, optimizer=None, var_list=None, use_prettytensor=False,
"""
super(VariationalInference, self).initialize(*args, **kwargs)

if var_list is None:
if self.model_wrapper is None:
# Traverse random variable graphs to get default list of variables.
var_list = set([])
trainables = tf.trainable_variables()
for z, qz in six.iteritems(self.latent_vars):
if isinstance(z, RandomVariable):
var_list.update(get_variables(z, collection=trainables))

var_list.update(get_variables(qz, collection=trainables))

for x, qx in six.iteritems(self.data):
if isinstance(x, RandomVariable) and \
not isinstance(qx, RandomVariable):
var_list.update(get_variables(x, collection=trainables))

var_list = list(var_list)
else:
# Variables may not be instantiated for model wrappers until
# their methods are first called. For now, hard-code
# ``var_list`` inside build_losses.
var_list = None
# Variables may not be instantiated for model wrappers until
# their methods are first called. For now, hard-code
# ``var_list`` inside ``build_loss_and_gradients``.
if var_list is None and self.model_wrapper is None:
# Traverse random variable graphs to get default list of variables.
var_list = set()
trainables = tf.trainable_variables()
for z, qz in six.iteritems(self.latent_vars):
if isinstance(z, RandomVariable):
var_list.update(get_variables(z, collection=trainables))

var_list.update(get_variables(qz, collection=trainables))

for x, qx in six.iteritems(self.data):
if isinstance(x, RandomVariable) and \
not isinstance(qx, RandomVariable):
var_list.update(get_variables(x, collection=trainables))

var_list = list(var_list)

self.loss, grads_and_vars = self.build_loss_and_gradients(var_list)

Expand Down
58 changes: 0 additions & 58 deletions edward/util/tensorflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,64 +52,6 @@ def dot(x, y):
return tf.reshape(tf.matmul(mat, tf.expand_dims(vec, 1)), [-1])


def hessian(y, xs):
"""Calculate Hessian of y with respect to each x in xs.

Parameters
----------
y : tf.Tensor
Tensor to calculate Hessian of.
xs : list of tf.Variable
List of TensorFlow variables to calculate with respect to.
The variables can have different shapes.

Returns
-------
tf.Tensor
A 2-D tensor where each row is
.. math:: \partial_{xs} ( [ \partial_{xs} y ]_j ).

Raises
------
InvalidArgumentError
If the inputs have Inf or NaN values.
"""
y = tf.convert_to_tensor(y)
dependencies = [tf.verify_tensor_all_finite(y, msg='')]
dependencies.extend([tf.verify_tensor_all_finite(x, msg='') for x in xs])

with tf.control_dependencies(dependencies):
# Calculate flattened vector grad_{xs} y.
grads = tf.gradients(y, xs)
grads = [tf.reshape(grad, [-1]) for grad in grads]
grads = tf.concat(grads, 0)
# Loop over each element in the vector.
mat = []
d = grads.get_shape()[0]
if not isinstance(d, int):
d = grads.eval().shape[0]

for j in range(d):
# Calculate grad_{xs} ( [ grad_{xs} y ]_j ).
gradjgrads = tf.gradients(grads[j], xs)
# Flatten into vector.
hi = []
for l in range(len(xs)):
hij = gradjgrads[l]
# return 0 if gradient doesn't exist; TensorFlow returns None
if hij is None:
hij = tf.zeros(xs[l].get_shape(), dtype=tf.float32)

hij = tf.reshape(hij, [-1])
hi.append(hij)

hi = tf.concat(hi, 0)
mat.append(hi)

# Form matrix where each row is grad_{xs} ( [ grad_{xs} y ]_j ).
return tf.stack(mat)


def logit(x):
"""Evaluate :math:`\log(x / (1 - x))` elementwise.

Expand Down
Loading