-
Notifications
You must be signed in to change notification settings - Fork 759
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Feature/simpler conjugacy #588
Conversation
Thanks for the comments. I'm also excited to revisit the empirical question of what (if anything) conjugacy buys us over more black-box methods. I'll pull out
mu = Normal(0., 1.)
x = Normal(3 * mu, 1.) but if It might be possible to revisit this at some point—for example, we could use some of the logic in simplify.py to give the
|
That makes sense. Too bad Once you tease out |
7f85baf
to
b5acb1f
Compare
b5acb1f
to
b1e94d7
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I merged to update the branch to master (e.g., following ParamMixture
).
Following the previous comments, can you also document the utility functions in simplify.py
and conjugacy.py
? I had trouble following the internal details of the functions when trying to understand complete_conditional
.
Other comments below.
def complete_conditional(rv, blanket, log_joint=None): | ||
with tf.name_scope('complete_conditional_%s' % rv.name) as scope: | ||
# log_joint holds all the information we need to get a conditional. | ||
extended_blanket = copy(blanket) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this shallow copy of the list needed?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good catch; there's a bug here. I meant to add rv
to blanket
here. It's fixed now.
return '_log_joint_of_' + ('&'.join([i.name[:-1] for i in blanket])) + '_' | ||
|
||
|
||
def get_log_joint(blanket): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
My understanding is that get_log_joint
will get the log joint tensor instead of forming a new one if it already exists. If we call complete_conditional
twice, on two nodes which share variables in their blanket, it would redo the full joint. Is that correct?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That's right. I introduced this caching for other reasons (which aren't all that relevant now), but it still speeds up graph construction.
But maybe it makes sense to move that caching a level up, so that each RandomVariable
caches its conjugate_log_prob()
?
(The term "cache" is a little funny here, since the graph remembers everything regardless. We're just caching a pointer.)
swap_back = {} | ||
for s_stat in s_stat_exprs.keys(): | ||
s_stat_shape = s_stat_exprs[s_stat][0][0].get_shape() | ||
s_stat_placeholder = tf.placeholder(np.float32, s_stat_shape) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is the float32 for placeholders because the log joint is float32?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's because we want to take the gradient w.r.t. the placeholders, and TF doesn't like taking gradients w.r.t. non-floats.
@matthewdhoffman very much looking forward to this feature! Checked out the branch to play around with the examples. Direct import wasn't working so I added a line to the A minor comment. It could be cool if in the
with
or do the equivalent replacement in the return statement of the |
Is there anything this PR is still waiting on? Happy to merge now. I have some minor suggestions, but I'll submit another PR after this one. |
Nope, I think it's at a good checkpoint—let's merge. I haven't done the I made Also, it turns out that |
@matthewdhoffman you will gradually make all of our jobs meaningless with this PR :) |
This is a new, improved, rewritten version of the feature/conjugacy module I started months ago.
It mostly boils down to supporting one function:
ed.complete_conditional()
. This function takes as inputrv
(aRandomVariable
) andblanket
(a set of otherRandomVariable
s), and tries to do some exponential-family algebra to figure out the conditional distribution ofrv
givenblanket
. (Behavior is unchanged ifrv in blanket == True
.)The other user-facing change is adding a
sample_shape
kwarg toRandomVariable
, which brings Edward syntax a bit more up to date with thetf.contrib.distribution
syntax and reduces the amount of algebra the system has to do (compared with syntax likeNormal(mu + tf.zeros([5]))
). This supersedes PR #323.The high-level algorithm is:
rv
and all of the (other)RandomVariable
s inblanket
.a. A member of
blanket
(which truncates the search) orb.
rv
or a nonlinear function ofrv
(which truncates the search and adds the node that stopped the search to a list of sufficient-statistic nodes).log(mul(a, b))
becomesadd(log(a), log(b))
).rv
) to see if there's a match in our table of exponential-family distributions—if not, we're done.tf.gradients()
with respect to the sufficient-statistic placeholders on the copied log-joint. This trick, which exploits the trivial identity ∂/∂t (η•t) = η, is there to save us from worrying about shapes.RandomVariable
with those natural parameters.