-
Notifications
You must be signed in to change notification settings - Fork 759
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Eliminated redundant computation in repeated calls to leapfrog() function #414
Conversation
for _ in range(self.n_steps): | ||
new_sample, new_r_sample = leapfrog(new_sample, new_r_sample, | ||
self.step_size, self._log_joint) | ||
new_sample, new_r_sample = leapfrog(new_sample, new_r_sample, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
cool! i think we can also remove the two lines above this now as well.
grad_log_joint = tf.gradients(log_joint(z_new), list(six.itervalues(z_new))) | ||
for i, key in enumerate(six.iterkeys(z_old)): | ||
r_new[key] += 0.5 * step_size * grad_log_joint[i] | ||
for n in range(n_steps): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
since we're not using the looping variable, we should do for _ in ...
for readability
z_new = {} | ||
r_new = {} | ||
|
||
for key in z_old: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
i think these lines can be condensed with the the copy method in dict's, e.g., z_new = z_old.copy()
and r_new = r_old.copy()
.
thanks for the catch; this looks great. only minor comments above |
aaee1d2
to
1514f33
Compare
hi @matthewdhoffman: i made the minor changes from my comments above. let me know if you approve/have comments, and i'll merge. |
LGTM, thanks for cleaning it up. (And for showing me a couple of better idioms!) |
In the old code, each call to
leapfrog()
computed the gradient twice with respect to latent variables that hadn't changed. This patch eliminates that duplicated computation, resulting in a near 2x speedup.It works by adding an
n_steps
parameter toleapfrog()
and moving the loop into theleapfrog()
function.