Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add regularization terms in klqp.py #813

Merged
merged 3 commits into from
Jan 5, 2018
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
71 changes: 60 additions & 11 deletions edward/inferences/klqp.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,9 @@ class KLqp(VariationalInference):

where $z^{(s)} \sim q(z; \lambda)$ and $\\beta^{(s)}
\sim q(\\beta)$.

The objective function also adds to itself a summation over all
tensors in the `REGULARIZATION_LOSSES` collection.
"""
def __init__(self, latent_vars=None, data=None):
"""Create an inference algorithm.
Expand Down Expand Up @@ -164,6 +167,9 @@ class ReparameterizationKLqp(VariationalInference):

This class minimizes the objective using the reparameterization
gradient.

The objective function also adds to itself a summation over all
tensors in the `REGULARIZATION_LOSSES` collection.
"""
def __init__(self, latent_vars=None, data=None):
"""Create an inference algorithm.
Expand Down Expand Up @@ -221,6 +227,9 @@ class ReparameterizationKLKLqp(VariationalInference):

This class minimizes the objective using the reparameterization
gradient and an analytic KL term.

The objective function also adds to itself a summation over all
tensors in the `REGULARIZATION_LOSSES` collection.
"""
def __init__(self, latent_vars=None, data=None):
"""Create an inference algorithm.
Expand Down Expand Up @@ -292,6 +301,9 @@ class ReparameterizationEntropyKLqp(VariationalInference):

This class minimizes the objective using the reparameterization
gradient and an analytic entropy term.

The objective function also adds to itself a summation over all
tensors in the `REGULARIZATION_LOSSES` collection.
"""
def __init__(self, latent_vars=None, data=None):
"""Create an inference algorithm.
Expand Down Expand Up @@ -350,6 +362,9 @@ class ScoreKLqp(VariationalInference):

This class minimizes the objective using the score function
gradient.

The objective function also adds to itself a summation over all
tensors in the `REGULARIZATION_LOSSES` collection.
"""
def __init__(self, latent_vars=None, data=None):
"""Create an inference algorithm.
Expand Down Expand Up @@ -407,6 +422,9 @@ class ScoreKLKLqp(VariationalInference):

This class minimizes the objective using the score function gradient
and an analytic KL term.

The objective function also adds to itself a summation over all
tensors in the `REGULARIZATION_LOSSES` collection.
"""
def __init__(self, latent_vars=None, data=None):
"""Create an inference algorithm.
Expand Down Expand Up @@ -478,6 +496,9 @@ class ScoreEntropyKLqp(VariationalInference):

This class minimizes the objective using the score function gradient
and an analytic entropy term.

The objective function also adds to itself a summation over all
tensors in the `REGULARIZATION_LOSSES` collection.
"""
def __init__(self, latent_vars=None, data=None):
"""Create an inference algorithm.
Expand Down Expand Up @@ -542,6 +563,9 @@ class ScoreRBKLqp(VariationalInference):
stochastic nodes in the computation graph. It does not
Rao-Blackwellize within a node such as when a node represents
multiple random variables via non-scalar batch shape.

The objective function also adds to itself a summation over all
tensors in the `REGULARIZATION_LOSSES` collection.
"""
def __init__(self, latent_vars=None, data=None):
"""Create an inference algorithm.
Expand Down Expand Up @@ -640,14 +664,17 @@ def build_reparam_loss_and_gradients(inference, var_list):

p_log_prob = tf.reduce_mean(p_log_prob)
q_log_prob = tf.reduce_mean(q_log_prob)
reg_penalty = tf.reduce_sum(tf.losses.get_regularization_losses())

if inference.logging:
tf.summary.scalar("loss/p_log_prob", p_log_prob,
collections=[inference._summary_key])
tf.summary.scalar("loss/q_log_prob", q_log_prob,
collections=[inference._summary_key])
tf.summary.scalar("loss/reg_penalty", reg_penalty,
collections=[inference._summary_key])

loss = -(p_log_prob - q_log_prob)
loss = -(p_log_prob - q_log_prob - reg_penalty)

grads = tf.gradients(loss, var_list)
grads_and_vars = list(zip(grads, var_list))
Expand Down Expand Up @@ -702,13 +729,17 @@ def build_reparam_kl_loss_and_gradients(inference, var_list):
tf.reduce_sum(inference.kl_scaling.get(z, 1.0) * kl_divergence(qz, z))
for z, qz in six.iteritems(inference.latent_vars)])

reg_penalty = tf.reduce_sum(tf.losses.get_regularization_losses())

if inference.logging:
tf.summary.scalar("loss/p_log_lik", p_log_lik,
collections=[inference._summary_key])
tf.summary.scalar("loss/kl_penalty", kl_penalty,
collections=[inference._summary_key])
tf.summary.scalar("loss/reg_penalty", reg_penalty,
collections=[inference._summary_key])

loss = -(p_log_lik - kl_penalty)
loss = -(p_log_lik - kl_penalty - reg_penalty)

grads = tf.gradients(loss, var_list)
grads_and_vars = list(zip(grads, var_list))
Expand Down Expand Up @@ -766,13 +797,17 @@ def build_reparam_entropy_loss_and_gradients(inference, var_list):
tf.reduce_sum(qz.entropy())
for z, qz in six.iteritems(inference.latent_vars)])

reg_penalty = tf.reduce_sum(tf.losses.get_regularization_losses())

if inference.logging:
tf.summary.scalar("loss/p_log_prob", p_log_prob,
collections=[inference._summary_key])
tf.summary.scalar("loss/q_entropy", q_entropy,
collections=[inference._summary_key])
tf.summary.scalar("loss/reg_penalty", reg_penalty,
collections=[inference._summary_key])

loss = -(p_log_prob + q_entropy)
loss = -(p_log_prob + q_entropy - reg_penalty)

grads = tf.gradients(loss, var_list)
grads_and_vars = list(zip(grads, var_list))
Expand Down Expand Up @@ -823,21 +858,24 @@ def build_score_loss_and_gradients(inference, var_list):

p_log_prob = tf.stack(p_log_prob)
q_log_prob = tf.stack(q_log_prob)
reg_penalty = tf.reduce_sum(tf.losses.get_regularization_losses())

if inference.logging:
tf.summary.scalar("loss/p_log_prob", tf.reduce_mean(p_log_prob),
collections=[inference._summary_key])
tf.summary.scalar("loss/q_log_prob", tf.reduce_mean(q_log_prob),
collections=[inference._summary_key])
tf.summary.scalar("loss/reg_penalty", reg_penalty,
collections=[inference._summary_key])

losses = p_log_prob - q_log_prob
loss = -tf.reduce_mean(losses)
loss = -(tf.reduce_mean(losses) - reg_penalty)

q_rvs = list(six.itervalues(inference.latent_vars))
q_vars = [v for v in var_list
if len(get_descendants(tf.convert_to_tensor(v), q_rvs)) != 0]
q_grads = tf.gradients(
-tf.reduce_mean(q_log_prob * tf.stop_gradient(losses)),
-(tf.reduce_mean(q_log_prob * tf.stop_gradient(losses)) - reg_penalty),
q_vars)
p_vars = [v for v in var_list if v not in q_vars]
p_grads = tf.gradients(loss, p_vars)
Expand Down Expand Up @@ -891,19 +929,24 @@ def build_score_kl_loss_and_gradients(inference, var_list):
tf.reduce_sum(inference.kl_scaling.get(z, 1.0) * kl_divergence(qz, z))
for z, qz in six.iteritems(inference.latent_vars)])

reg_penalty = tf.reduce_sum(tf.losses.get_regularization_losses())

if inference.logging:
tf.summary.scalar("loss/p_log_lik", tf.reduce_mean(p_log_lik),
collections=[inference._summary_key])
tf.summary.scalar("loss/kl_penalty", kl_penalty,
collections=[inference._summary_key])
tf.summary.scalar("loss/reg_penalty", reg_penalty,
collections=[inference._summary_key])

loss = -(tf.reduce_mean(p_log_lik) - kl_penalty)
loss = -(tf.reduce_mean(p_log_lik) - kl_penalty - reg_penalty)

q_rvs = list(six.itervalues(inference.latent_vars))
q_vars = [v for v in var_list
if len(get_descendants(tf.convert_to_tensor(v), q_rvs)) != 0]
q_grads = tf.gradients(
-(tf.reduce_mean(q_log_prob * tf.stop_gradient(p_log_lik)) - kl_penalty),
-(tf.reduce_mean(q_log_prob * tf.stop_gradient(p_log_lik)) - kl_penalty -
reg_penalty),
q_vars)
p_vars = [v for v in var_list if v not in q_vars]
p_grads = tf.gradients(loss, p_vars)
Expand Down Expand Up @@ -962,22 +1005,26 @@ def build_score_entropy_loss_and_gradients(inference, var_list):
tf.reduce_sum(qz.entropy())
for z, qz in six.iteritems(inference.latent_vars)])

reg_penalty = tf.reduce_sum(tf.losses.get_regularization_losses())

if inference.logging:
tf.summary.scalar("loss/p_log_prob", tf.reduce_mean(p_log_prob),
collections=[inference._summary_key])
tf.summary.scalar("loss/q_log_prob", tf.reduce_mean(q_log_prob),
collections=[inference._summary_key])
tf.summary.scalar("loss/q_entropy", q_entropy,
collections=[inference._summary_key])
tf.summary.scalar("loss/reg_penalty", reg_penalty,
collections=[inference._summary_key])

loss = -(tf.reduce_mean(p_log_prob) + q_entropy)
loss = -(tf.reduce_mean(p_log_prob) + q_entropy - reg_penalty)

q_rvs = list(six.itervalues(inference.latent_vars))
q_vars = [v for v in var_list
if len(get_descendants(tf.convert_to_tensor(v), q_rvs)) != 0]
q_grads = tf.gradients(
-(tf.reduce_mean(q_log_prob * tf.stop_gradient(p_log_prob)) +
q_entropy),
q_entropy - reg_penalty),
q_vars)
p_vars = [v for v in var_list if v not in q_vars]
p_grads = tf.gradients(loss, p_vars)
Expand Down Expand Up @@ -1062,7 +1109,8 @@ def build_score_rb_loss_and_gradients(inference, var_list):
qi_log_prob = tf.stack(qi_log_prob)
grad = tf.gradients(
-tf.reduce_mean(qi_log_prob *
tf.stop_gradient(pi_log_prob - qi_log_prob)),
tf.stop_gradient(pi_log_prob - qi_log_prob)) +
tf.reduce_sum(tf.losses.get_regularization_losses()),
var)
grads.extend(grad)
grads_vars.append(var)
Expand All @@ -1071,7 +1119,8 @@ def build_score_rb_loss_and_gradients(inference, var_list):
loss = -(tf.reduce_mean([tf.reduce_sum(list(six.itervalues(p_log_prob)))
for p_log_prob in p_log_probs]) -
tf.reduce_mean([tf.reduce_sum(list(six.itervalues(q_log_prob)))
for q_log_prob in q_log_probs]))
for q_log_prob in q_log_probs]) -
tf.reduce_sum(tf.losses.get_regularization_losses()))
model_vars = [v for v in var_list if v not in grads_vars]
model_grads = tf.gradients(loss, model_vars)
grads.extend(model_grads)
Expand Down