-
Notifications
You must be signed in to change notification settings - Fork 759
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
tensorboard for variational inference #598
Changes from 8 commits
43614a0
4e21dde
aae5407
5992c88
e746fc3
e08f1af
b4fa6c5
bb46088
92756ea
c047be8
7b2642b
d521bcf
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -53,25 +53,70 @@ def initialize(self, optimizer=None, var_list=None, use_prettytensor=False, | |
""" | ||
super(VariationalInference, self).initialize(*args, **kwargs) | ||
|
||
latent_var_list = set() | ||
data_var_list = set() | ||
if var_list is None: | ||
# Traverse random variable graphs to get default list of variables. | ||
var_list = set() | ||
trainables = tf.trainable_variables() | ||
for z, qz in six.iteritems(self.latent_vars): | ||
if isinstance(z, RandomVariable): | ||
var_list.update(get_variables(z, collection=trainables)) | ||
latent_var_list.update(get_variables(z, collection=trainables)) | ||
|
||
var_list.update(get_variables(qz, collection=trainables)) | ||
latent_var_list.update(get_variables(qz, collection=trainables)) | ||
|
||
for x, qx in six.iteritems(self.data): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. a good practice is to instantiate variables closest to where they're used. e.g., instantiate |
||
if isinstance(x, RandomVariable) and \ | ||
not isinstance(qx, RandomVariable): | ||
var_list.update(get_variables(x, collection=trainables)) | ||
data_var_list.update(get_variables(x, collection=trainables)) | ||
|
||
var_list = list(var_list) | ||
var_list = list(data_var_list | latent_var_list) | ||
|
||
self.loss, grads_and_vars = self.build_loss_and_gradients(var_list) | ||
|
||
if self.logging: | ||
summary_key = 'summaries_' + str(id(self)) | ||
tf.summary.scalar("loss", self.loss, collections=[summary_key]) | ||
with tf.name_scope('variational'): | ||
for grad, var in grads_and_vars: | ||
if var in latent_var_list: | ||
tf.summary.histogram("parameter_" + | ||
var.name.replace(':', '_'), | ||
var, collections=[summary_key]) | ||
tf.summary.histogram("gradient_" + | ||
grad.name.replace(':', '_'), | ||
grad, collections=[summary_key]) | ||
tf.summary.scalar("gradient_norm_" + | ||
grad.name.replace(':', '_'), | ||
tf.norm(grad), collections=[summary_key]) | ||
# replace : with _ because tf does not allow : in var names in summaries | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @akshaykhatri639 Really looking forward to this feature! I would replace I have been using Maybe rename to There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Maybe there's a nicer way to catch this? I guess that's more an idiosyncrasy of tensorflow than of edward. But still, catching it in a more readable way would be nice There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
|
||
|
||
with tf.name_scope('model'): | ||
for grad, var in grads_and_vars: | ||
if var in data_var_list: | ||
tf.summary.histogram("parameter_" + var.name.replace(':', '_'), | ||
var, collections=[summary_key]) | ||
tf.summary.histogram("gradient_" + | ||
grad.name.replace(':', '_'), | ||
grad, collections=[summary_key]) | ||
tf.summary.scalar("gradient_norm_" + | ||
grad.name.replace(':', '_'), | ||
tf.norm(grad), collections=[summary_key]) | ||
|
||
# when var_list is not initialized with None | ||
with tf.name_scope(''): | ||
for grad, var in grads_and_vars: | ||
if var not in latent_var_list and var not in data_var_list: | ||
tf.summary.histogram("parameter_" + var.name.replace(':', '_'), | ||
var, collections=[summary_key]) | ||
tf.summary.histogram("gradient_" + | ||
grad.name.replace(':', '_'), | ||
grad, collections=[summary_key]) | ||
tf.summary.scalar("gradient_norm_" + | ||
grad.name.replace(':', '_'), | ||
tf.norm(grad), collections=[summary_key]) | ||
|
||
self.summarize = tf.summary.merge_all(key=summary_key) | ||
|
||
if optimizer is None: | ||
# Use ADAM with a decaying scale factor. | ||
global_step = tf.Variable(0, trainable=False, name="global_step") | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you rename these to 'loss_discriminative' and 'loss_generative' respectively?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.