diff --git a/edward/inferences/gan_inference.py b/edward/inferences/gan_inference.py index 79a42e6a3..87876a5b4 100644 --- a/edward/inferences/gan_inference.py +++ b/edward/inferences/gan_inference.py @@ -101,6 +101,14 @@ def initialize(self, optimizer=None, optimizer_d=None, self.train_d = optimizer_d.apply_gradients(grads_and_vars_d, global_step=global_step_d) + if self.logging: + summary_key = 'summaries_' + str(id(self)) + tf.summary.scalar('loss_discriminative', self.loss_d, + collections=[summary_key]) + tf.summary.scalar('loss_generative', self.loss, + collections=[summary_key]) + self.summarize = tf.summary.merge_all(key=summary_key) + def build_loss_and_gradients(self, var_list): x_true = list(six.itervalues(self.data))[0] x_fake = list(six.iterkeys(self.data))[0] @@ -110,6 +118,11 @@ def build_loss_and_gradients(self, var_list): with tf.variable_scope("Disc", reuse=True): d_fake = self.discriminator(x_fake) + if self.logging: + summary_key = 'summaries_' + str(id(self)) + tf.summary.histogram('disc_outputs', tf.concat(d_true, d_fake, axis=0), + collections=[summary_key]) + loss_d = tf.nn.sigmoid_cross_entropy_with_logits( labels=tf.ones_like(d_true), logits=d_true) + \ tf.nn.sigmoid_cross_entropy_with_logits( diff --git a/edward/inferences/klpq.py b/edward/inferences/klpq.py index 216116a27..446c35373 100644 --- a/edward/inferences/klpq.py +++ b/edward/inferences/klpq.py @@ -129,6 +129,11 @@ def build_loss_and_gradients(self, var_list): p_log_prob = tf.stack(p_log_prob) q_log_prob = tf.stack(q_log_prob) + if self.logging: + summary_key = 'summaries_' + str(id(self)) + tf.summary.histogram("p_log_prob", p_log_prob, collections=[summary_key]) + tf.summary.histogram("q_log_prob", q_log_prob, collections=[summary_key]) + log_w = p_log_prob - q_log_prob log_w_norm = log_w - tf.reduce_logsumexp(log_w) w_norm = tf.exp(log_w_norm) diff --git a/edward/inferences/klqp.py b/edward/inferences/klqp.py index eb150bb5a..04ce955da 100644 --- a/edward/inferences/klqp.py +++ b/edward/inferences/klqp.py @@ -390,6 +390,12 @@ def build_reparam_loss_and_gradients(inference, var_list): p_log_prob = tf.stack(p_log_prob) q_log_prob = tf.stack(q_log_prob) + + if inference.logging: + summary_key = 'summaries_' + str(id(inference)) + tf.summary.histogram("p_log_prob", p_log_prob, collections=[summary_key]) + tf.summary.histogram("q_log_prob", q_log_prob, collections=[summary_key]) + loss = -tf.reduce_mean(p_log_prob - q_log_prob) grads = tf.gradients(loss, var_list) @@ -444,6 +450,11 @@ def build_reparam_kl_loss_and_gradients(inference, var_list): inference.kl_scaling.get(z, 1.0) * tf.reduce_sum(ds.kl(qz, z)) for z, qz in six.iteritems(inference.latent_vars)]) + if inference.logging: + summary_key = 'summaries_' + str(id(inference)) + tf.summary.histogram('p_log_lik', p_log_lik, collections=[summary_key]) + tf.summary.scalar('kl', kl, collections=[summary_key]) + loss = -(tf.reduce_mean(p_log_lik) - kl) grads = tf.gradients(loss, var_list) @@ -502,6 +513,11 @@ def build_reparam_entropy_loss_and_gradients(inference, var_list): q_entropy = tf.reduce_sum([ qz.entropy() for z, qz in six.iteritems(inference.latent_vars)]) + if inference.logging: + summary_key = 'summaries_' + str(id(inference)) + tf.summary.histogram('p_log_prob', p_log_prob, collections=[summary_key]) + tf.summary.scalar('q_entropy', q_entropy, collections=[summary_key]) + loss = -(tf.reduce_mean(p_log_prob) + q_entropy) grads = tf.gradients(loss, var_list) @@ -553,6 +569,11 @@ def build_score_loss_and_gradients(inference, var_list): p_log_prob = tf.stack(p_log_prob) q_log_prob = tf.stack(q_log_prob) + if inference.logging: + summary_key = 'summaries_' + str(id(inference)) + tf.summary.histogram('p_log_prob', p_log_prob, collections=[summary_key]) + tf.summary.scalar('q_log_prob', q_log_prob, collections=[summary_key]) + losses = p_log_prob - q_log_prob loss = -tf.reduce_mean(losses) @@ -608,6 +629,12 @@ def build_score_kl_loss_and_gradients(inference, var_list): inference.kl_scaling.get(z, 1.0) * tf.reduce_sum(ds.kl(qz, z)) for z, qz in six.iteritems(inference.latent_vars)]) + if inference.logging: + summary_key = 'summaries_' + str(id(inference)) + tf.summary.histogram('p_log_lik', p_log_lik, collections=[summary_key]) + tf.summary.histogram('q_log_prob', q_log_prob, collections=[summary_key]) + tf.summary.scalar('kl', kl, collections=[summary_key]) + loss = -(tf.reduce_mean(p_log_lik) - kl) grads = tf.gradients( -(tf.reduce_mean(q_log_prob * tf.stop_gradient(p_log_lik)) - kl), @@ -665,6 +692,12 @@ def build_score_entropy_loss_and_gradients(inference, var_list): q_entropy = tf.reduce_sum([ qz.entropy() for z, qz in six.iteritems(inference.latent_vars)]) + if inference.logging: + summary_key = 'summaries_' + str(id(inference)) + tf.summary.histogram('p_log_prob', p_log_prob, collections=[summary_key]) + tf.summary.histogram('q_log_prob', q_log_prob, collections=[summary_key]) + tf.summary.scalar('q_entropy', q_entropy, collections=[summary_key]) + loss = -(tf.reduce_mean(p_log_prob) + q_entropy) grads = tf.gradients( -(tf.reduce_mean(q_log_prob * tf.stop_gradient(p_log_prob)) + diff --git a/edward/inferences/monte_carlo.py b/edward/inferences/monte_carlo.py index 4455d3cb5..1401c3939 100644 --- a/edward/inferences/monte_carlo.py +++ b/edward/inferences/monte_carlo.py @@ -97,6 +97,11 @@ def initialize(self, *args, **kwargs): self.n_accept_over_t = self.n_accept / self.t self.train = self.build_update() + if self.logging: + summary_key = 'summaries_' + str(id(self)) + tf.summary.scalar('n_accept', self.n_accept, collections=[summary_key]) + self.summarize = tf.summary.merge_all(key=summary_key) + def update(self, feed_dict=None): """Run one iteration of sampling for Monte Carlo. diff --git a/edward/inferences/variational_inference.py b/edward/inferences/variational_inference.py index 027f3c5f4..9a18737f9 100644 --- a/edward/inferences/variational_inference.py +++ b/edward/inferences/variational_inference.py @@ -53,25 +53,70 @@ def initialize(self, optimizer=None, var_list=None, use_prettytensor=False, """ super(VariationalInference, self).initialize(*args, **kwargs) + latent_var_list = set() + data_var_list = set() if var_list is None: # Traverse random variable graphs to get default list of variables. - var_list = set() trainables = tf.trainable_variables() for z, qz in six.iteritems(self.latent_vars): if isinstance(z, RandomVariable): - var_list.update(get_variables(z, collection=trainables)) + latent_var_list.update(get_variables(z, collection=trainables)) - var_list.update(get_variables(qz, collection=trainables)) + latent_var_list.update(get_variables(qz, collection=trainables)) for x, qx in six.iteritems(self.data): if isinstance(x, RandomVariable) and \ not isinstance(qx, RandomVariable): - var_list.update(get_variables(x, collection=trainables)) + data_var_list.update(get_variables(x, collection=trainables)) - var_list = list(var_list) + var_list = list(data_var_list | latent_var_list) self.loss, grads_and_vars = self.build_loss_and_gradients(var_list) + if self.logging: + summary_key = 'summaries_' + str(id(self)) + tf.summary.scalar("loss", self.loss, collections=[summary_key]) + with tf.name_scope('variational'): + for grad, var in grads_and_vars: + if var in latent_var_list: + tf.summary.histogram("parameter_" + + var.name.replace(':', '_'), + var, collections=[summary_key]) + tf.summary.histogram("gradient_" + + var.name.replace(':', '_'), + grad, collections=[summary_key]) + tf.summary.scalar("gradient_norm_" + + var.name.replace(':', '_'), + tf.norm(grad), collections=[summary_key]) + # replace : with _ because tf does not allow : in var names in summaries + + with tf.name_scope('model'): + for grad, var in grads_and_vars: + if var in data_var_list: + tf.summary.histogram("parameter_" + var.name.replace(':', '_'), + var, collections=[summary_key]) + tf.summary.histogram("gradient_" + + var.name.replace(':', '_'), + grad, collections=[summary_key]) + tf.summary.scalar("gradient_norm_" + + var.name.replace(':', '_'), + tf.norm(grad), collections=[summary_key]) + + # when var_list is not initialized with None + with tf.name_scope(''): + for grad, var in grads_and_vars: + if var not in latent_var_list and var not in data_var_list: + tf.summary.histogram("parameter_" + var.name.replace(':', '_'), + var, collections=[summary_key]) + tf.summary.histogram("gradient_" + + var.name.replace(':', '_'), + grad, collections=[summary_key]) + tf.summary.scalar("gradient_norm_" + + var.name.replace(':', '_'), + tf.norm(grad), collections=[summary_key]) + + self.summarize = tf.summary.merge_all(key=summary_key) + if optimizer is None: # Use ADAM with a decaying scale factor. global_step = tf.Variable(0, trainable=False, name="global_step")