Skip to content

Commit

Permalink
clean up code on pointmass vs normal
Browse files Browse the repository at this point in the history
  • Loading branch information
dustinvtran committed Mar 5, 2017
1 parent 0105b82 commit b5851b3
Showing 1 changed file with 8 additions and 32 deletions.
40 changes: 8 additions & 32 deletions edward/inferences/laplace.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,36 +75,12 @@ def __init__(self, latent_vars, data=None, model_wrapper=None):
super(MAP, self).__init__(latent_vars, data, model_wrapper)

def initialize(self, var_list=None, *args, **kwargs):
# TODO the algorithm relies on sampling from a pointmass
# + maybe better to use pointmasses internally, where user passes
# in normal distributions, we get its mean tf variables, and define
# a pointmass with them.
# + how to store it? map.py uses self.latent vars as the pointmass
# + for now, we can just hack it
self.latent_vars_temp = self.latent_vars.copy()
# Store latent variables in a temporary attribute; MAP will
# optimize ``PointMass`` random variables, which subsequently
# optimizes mean parameters of the normal approximations.
self.latent_vars_normal = self.latent_vars.copy()
self.latent_vars = {z: PointMass(params=qz.mu)
for z, qz in six.iteritems(self.latent_vars_temp)}
# # Variables may not be instantiated for model wrappers until
# # their methods are first called. For now, hard-code
# # ``var_list`` inside ``build_loss_and_gradients``.
# if var_list is None and self.model_wrapper is None:
# # Traverse random variable graphs to get default list of variables.
# # For Laplace, the default is the mean parameters of the
# # normal approximation and any model parameters.
# var_list = set()
# trainables = tf.trainable_variables()
# for z, qz in six.iteritems(self.latent_vars):
# if isinstance(z, RandomVariable):
# var_list.update(get_variables(z, collection=trainables))

# var_list.update(get_variables(qz.mu, collection=trainables))

# for x, qx in six.iteritems(self.data):
# if isinstance(x, RandomVariable) and \
# not isinstance(qx, RandomVariable):
# var_list.update(get_variables(x, collection=trainables))

# var_list = list(var_list)
for z, qz in six.iteritems(self.latent_vars_normal)}
super(Laplace, self).initialize(var_list, *args, **kwargs)

def finalize(self, feed_dict=None):
Expand All @@ -131,7 +107,7 @@ def finalize(self, feed_dict=None):

assign_ops = []
for z, hessian in zip(six.iterkeys(self.latent_vars), hessians):
qz = self.latent_vars_temp[z]
qz = self.latent_vars_normal[z]
sigma_var = get_variables(qz.sigma)[0]
if isinstance(qz, MultivariateNormalCholesky):
sigma = tf.matrix_inverse(tf.cholesky(hessian))
Expand All @@ -144,6 +120,6 @@ def finalize(self, feed_dict=None):

sess = get_session()
sess.run(assign_ops, feed_dict)
self.latent_vars = self.latent_vars_temp.copy()
del self.latent_vars_temp
self.latent_vars = self.latent_vars_normal.copy()
del self.latent_vars_normal
super(Laplace, self).finalize()

0 comments on commit b5851b3

Please sign in to comment.