Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature/simpler conjugacy #588

Merged
merged 57 commits into from
Apr 11, 2017
Merged

Feature/simpler conjugacy #588

merged 57 commits into from
Apr 11, 2017

Conversation

matthewdhoffman
Copy link
Collaborator

This is a new, improved, rewritten version of the feature/conjugacy module I started months ago.

It mostly boils down to supporting one function: ed.complete_conditional(). This function takes as input rv (a RandomVariable) and blanket (a set of other RandomVariables), and tries to do some exponential-family algebra to figure out the conditional distribution of rv given blanket. (Behavior is unchanged if rv in blanket == True.)

The other user-facing change is adding a sample_shape kwarg to RandomVariable, which brings Edward syntax a bit more up to date with the tf.contrib.distribution syntax and reduces the amount of algebra the system has to do (compared with syntax like Normal(mu + tf.zeros([5]))). This supersedes PR #323.

The high-level algorithm is:

  1. Build a TF node that computes the full log-joint distribution of rv and all of the (other) RandomVariables in blanket.
  2. Crawl down the TF graph depth-first from the log-joint node until we hit either
    a. A member of blanket (which truncates the search) or
    b. rv or a nonlinear function of rv (which truncates the search and adds the node that stopped the search to a list of sufficient-statistic nodes).
  3. Compute an s-expression-like intermediate representation of all sufficient-statistic nodes, and do some algebra on those nodes to simplify them (e.g., log(mul(a, b)) becomes add(log(a), log(b))).
  4. Look at the set of sufficient statistic nodes (and the support of rv) to see if there's a match in our table of exponential-family distributions—if not, we're done.
  5. If there's a match, copy the log-joint node to a scratch namespace, replacing all sufficient-statistic nodes with placeholders.
  6. Compute natural parameters in that scratch namespace by calling tf.gradients() with respect to the sufficient-statistic placeholders on the copied log-joint. This trick, which exploits the trivial identity ∂/∂t (η•t) = η, is there to save us from worrying about shapes.
  7. Copy those natural parameters back to the original namespace, getting rid of any placeholders, and construct a new exponential-family RandomVariable with those natural parameters.

@matthewdhoffman
Copy link
Collaborator Author

matthewdhoffman commented Mar 26, 2017

Thanks for the comments. I'm also excited to revisit the empirical question of what (if anything) conjugacy buys us over more black-box methods.

I'll pull out sample_shape and ParamMixture as separate PRs, make blanket optional, and move/refactor support and conjugate_log_prob. Regarding other questions:

  1. The copy/copyback approach is annoying (and slow, because graph manipulation in TF is slow), and ideally I'd like to find a way to get rid of it. The problem it solves is that sometimes sufficient-statistic nodes can depend on other sufficient-statistic nodes. E.g., log(x) depends on x, and so taking the gradient w.r.t. x gives you the wrong answers. I tried solving this problem a few different ways, and this was the only one that worked as generally as I needed it to. In particular, we want to let the user say things like
mu = Normal(0., 1.)
x = Normal(3 * mu, 1.)

but if Normal.conjugate_log_prob() doesn't know that its input is a function of a random variable, then it's not really possible to write it in a way that makes (3 mu)^2 not a function of mu.value().

It might be possible to revisit this at some point—for example, we could use some of the logic in simplify.py to give the conjugate_log_prob() functions more knowledge about their inputs, or we could write a version of tf.gradients() with a stop_nodes parameter that says where to pretend there's a stop_gradient() node. But it'd probably be a significant undertaking.

  1. Simplifying simplify.py: Good suggestions. I'll double check the implementation and decide whether as_float() and NodeWrapper are really necessary/helpful.

@dustinvtran
Copy link
Member

That makes sense. Too bad tf.gradients doesn't already let one specify the stop nodes rather than specify tf.stop_gradient when building the graph. I agree the alternatives sound like a significant undertaking.

Once you tease out sample_shape/ParamMixture I can add commits to those branches to help. I won't push any commits to this feature/simpler_conjugacy branch (yet).

@dustinvtran dustinvtran force-pushed the feature/simpler_conjugacy branch 2 times, most recently from 7f85baf to b5acb1f Compare April 5, 2017 15:20
Copy link
Member

@dustinvtran dustinvtran left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I merged to update the branch to master (e.g., following ParamMixture).

Following the previous comments, can you also document the utility functions in simplify.py and conjugacy.py? I had trouble following the internal details of the functions when trying to understand complete_conditional.

Other comments below.

def complete_conditional(rv, blanket, log_joint=None):
with tf.name_scope('complete_conditional_%s' % rv.name) as scope:
# log_joint holds all the information we need to get a conditional.
extended_blanket = copy(blanket)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this shallow copy of the list needed?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good catch; there's a bug here. I meant to add rv to blanket here. It's fixed now.

return '_log_joint_of_' + ('&'.join([i.name[:-1] for i in blanket])) + '_'


def get_log_joint(blanket):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My understanding is that get_log_joint will get the log joint tensor instead of forming a new one if it already exists. If we call complete_conditional twice, on two nodes which share variables in their blanket, it would redo the full joint. Is that correct?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's right. I introduced this caching for other reasons (which aren't all that relevant now), but it still speeds up graph construction.

But maybe it makes sense to move that caching a level up, so that each RandomVariable caches its conjugate_log_prob()?

(The term "cache" is a little funny here, since the graph remembers everything regardless. We're just caching a pointer.)

swap_back = {}
for s_stat in s_stat_exprs.keys():
s_stat_shape = s_stat_exprs[s_stat][0][0].get_shape()
s_stat_placeholder = tf.placeholder(np.float32, s_stat_shape)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is the float32 for placeholders because the log joint is float32?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's because we want to take the gradient w.r.t. the placeholders, and TF doesn't like taking gradients w.r.t. non-floats.

@mariru
Copy link
Contributor

mariru commented Apr 10, 2017

@matthewdhoffman very much looking forward to this feature! Checked out the branch to play around with the examples. Direct import wasn't working so I added a line to the __init__.py file

A minor comment. It could be cool if in the beta_bernoulli_conjugate.py example, it prints the name of the distribution. i.e., on line 38 substitute

print('p(pi | x) type:', pi_cond.parameters['name'])

with

print('p(pi | x) type:', pi_cond.__class__.__name__)

or do the equivalent replacement in the return statement of the complete_conditional() function

@dustinvtran
Copy link
Member

Is there anything this PR is still waiting on? Happy to merge now. I have some minor suggestions, but I'll submit another PR after this one.

@matthewdhoffman
Copy link
Collaborator Author

Nope, I think it's at a good checkpoint—let's merge. I haven't done the conjugate_log_prob() refactoring yet, but I probably won't have much time to look at it in the next couple of days and it can happen separately.

I made blanket optional, but added a warning about using that feature multiple times. (It'll throw a gnarly error if the user tries to do that—probably we could add a warning check.)

Also, it turns out that get_blanket() doesn't really do what we need here, since it only looks at the Markov blanket of the directed model, not the moralized graph. It might be nice at some point to add a moralization routine or Bayes ball or something to figure out what we actually need to condition on, but it shouldn't have much of a performance impact except possibly on graph creation.

@dustinvtran dustinvtran merged commit df5c913 into master Apr 11, 2017
@dustinvtran dustinvtran deleted the feature/simpler_conjugacy branch April 11, 2017 23:59
@dawenl
Copy link
Member

dawenl commented Apr 12, 2017

@matthewdhoffman you will gradually make all of our jobs meaningless with this PR :)

@dustinvtran dustinvtran mentioned this pull request Apr 14, 2017
5 tasks
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants