Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

tensorboard for variational inference #598

Merged
merged 12 commits into from
May 28, 2017
11 changes: 11 additions & 0 deletions edward/inferences/gan_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,12 @@ def initialize(self, optimizer=None, optimizer_d=None,
self.train_d = optimizer_d.apply_gradients(grads_and_vars_d,
global_step=global_step_d)

if self.logging:
summary_key = 'summaries_' + str(id(self))
tf.summary.scalar('loss_fake', self.loss_d, collections=[summary_key])
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you rename these to 'loss_discriminative' and 'loss_generative' respectively?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

tf.summary.scalar('loss_samples', self.loss, collections=[summary_key])
self.summarize = tf.summary.merge_all(key=summary_key)

def build_loss_and_gradients(self, var_list):
x_true = list(six.itervalues(self.data))[0]
x_fake = list(six.iterkeys(self.data))[0]
Expand All @@ -110,6 +116,11 @@ def build_loss_and_gradients(self, var_list):
with tf.variable_scope("Disc", reuse=True):
d_fake = self.discriminator(x_fake)

if self.logging:
summary_key = 'summaries_' + str(id(self))
tf.summary.histogram('disc_outputs', tf.concat(d_true, d_fake),
collections=[summary_key])

loss_d = tf.nn.sigmoid_cross_entropy_with_logits(
labels=tf.ones_like(d_true), logits=d_true) + \
tf.nn.sigmoid_cross_entropy_with_logits(
Expand Down
5 changes: 5 additions & 0 deletions edward/inferences/monte_carlo.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,11 @@ def initialize(self, *args, **kwargs):
self.n_accept_over_t = self.n_accept / self.t
self.train = self.build_update()

if self.logging:
summary_key = 'summaries_' + str(id(self))
tf.summary.scalar('n_accept', self.n_accept, collections=[summary_key])
self.summarize = tf.summary.merge_all(key=summary_key)

def update(self, feed_dict=None):
"""Run one iteration of sampling for Monte Carlo.

Expand Down
55 changes: 50 additions & 5 deletions edward/inferences/variational_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,25 +53,70 @@ def initialize(self, optimizer=None, var_list=None, use_prettytensor=False,
"""
super(VariationalInference, self).initialize(*args, **kwargs)

latent_var_list = set()
data_var_list = set()
if var_list is None:
# Traverse random variable graphs to get default list of variables.
var_list = set()
trainables = tf.trainable_variables()
for z, qz in six.iteritems(self.latent_vars):
if isinstance(z, RandomVariable):
var_list.update(get_variables(z, collection=trainables))
latent_var_list.update(get_variables(z, collection=trainables))

var_list.update(get_variables(qz, collection=trainables))
latent_var_list.update(get_variables(qz, collection=trainables))

for x, qx in six.iteritems(self.data):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

a good practice is to instantiate variables closest to where they're used. e.g., instantiate data_var_list above this line.

if isinstance(x, RandomVariable) and \
not isinstance(qx, RandomVariable):
var_list.update(get_variables(x, collection=trainables))
data_var_list.update(get_variables(x, collection=trainables))

var_list = list(var_list)
var_list = list(data_var_list | latent_var_list)

self.loss, grads_and_vars = self.build_loss_and_gradients(var_list)

if self.logging:
summary_key = 'summaries_' + str(id(self))
tf.summary.scalar("loss", self.loss, collections=[summary_key])
with tf.name_scope('variational'):
for grad, var in grads_and_vars:
if var in latent_var_list:
tf.summary.histogram("parameter_" +
var.name.replace(':', '_'),
var, collections=[summary_key])
tf.summary.histogram("gradient_" +
grad.name.replace(':', '_'),
grad, collections=[summary_key])
tf.summary.scalar("gradient_norm_" +
grad.name.replace(':', '_'),
tf.norm(grad), collections=[summary_key])
# replace : with _ because tf does not allow : in var names in summaries
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@akshaykhatri639 Really looking forward to this feature!
I tried it using MAP inference, but the naming turns out weird.

tensorboard

I would replace grad.name.replace(':', '_') with var.name.replace(':', '_') everywhere, because the gradients are named after the operations. But then by prefacing it with gradient_ it will still be clear from the name that its a gradient.

I have been using tf.name_scope('model') for the actual model.
tb_model

Maybe rename to tf.name_scope('training') for logging? But that's minor. I can just rename the scope for my model.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  1. naming of gradients messy: see comment above

  2. tracking gradients is super useful, but it might be even more useful to look at gradients*learning_rate or (parameters - gradients*learning_rate)/parameters.
    Especially with adaptive learning rates when the user doesn't know the current learning rate, it is impossible to know if the gradient is in the right ballpark from looking at the histograms.

  3. when training blows up and some parameters become NaN or inf (e.g. because the learning rate is too large) I get an error message that is hard to decipher.

InvalidArgumentError (see above for traceback): Nan in summary histogram for: model_1/parameter_model/embeddings/alpha_0
         [[Node: model_1/parameter_model/embeddings/alpha_0 = HistogramSummary[T=DT_FLOAT, _device="/job:localhost/replica:0/task:0/cpu:0"](model_1/parameter_model/embeddings/alpha_0/tag, model/embeddings/alpha/read)]]

Maybe there's a nicer way to catch this? I guess that's more an idiosyncrasy of tensorflow than of edward. But still, catching it in a more readable way would be nice

Copy link
Contributor Author

@akshaykhatri639 akshaykhatri639 May 7, 2017

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  1. You are right about names of the summary ops. I will make the required changes.
  2. Won't this become a bit complicated? Because each optimizer has a slightly different way of modifying weights and may be storing one or more cache/momentum matrices for each gradient. Should we implement is differently for each optimizer? Otherwise, the values in the summary still won't be close to the values that the weights were updated with.
  3. Should I check if any values become Nan before writing the summary ops?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  1. There's a self.debug member in Edward inferences for checking if certain ops blow up. I agree with Maja that it would be nice if we can somehow raise a more informative error. (but without running a check every iteration just to raise that error)


with tf.name_scope('model'):
for grad, var in grads_and_vars:
if var in data_var_list:
tf.summary.histogram("parameter_" + var.name.replace(':', '_'),
var, collections=[summary_key])
tf.summary.histogram("gradient_" +
grad.name.replace(':', '_'),
grad, collections=[summary_key])
tf.summary.scalar("gradient_norm_" +
grad.name.replace(':', '_'),
tf.norm(grad), collections=[summary_key])

# when var_list is not initialized with None
with tf.name_scope(''):
for grad, var in grads_and_vars:
if var not in latent_var_list and var not in data_var_list:
tf.summary.histogram("parameter_" + var.name.replace(':', '_'),
var, collections=[summary_key])
tf.summary.histogram("gradient_" +
grad.name.replace(':', '_'),
grad, collections=[summary_key])
tf.summary.scalar("gradient_norm_" +
grad.name.replace(':', '_'),
tf.norm(grad), collections=[summary_key])

self.summarize = tf.summary.merge_all(key=summary_key)

if optimizer is None:
# Use ADAM with a decaying scale factor.
global_step = tf.Variable(0, trainable=False, name="global_step")
Expand Down