Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Advanced usage: allow the users to deal with manually batching models #193

Merged
merged 17 commits into from
Feb 14, 2020

Conversation

lucianopaz
Copy link
Contributor

This is a rather big and ugly PR that addresses the following issues:

  1. Make a more adequate choice on where to add the plate axis of a Distribution.
  2. Adds a custom tfp.distributions.Distribution subclass called Plate, that enables the above point
  3. Enable users to specify that they wrote their model's in a vectorized manner, so that there is no need to perform automatic batching in either sample_prior_predictive, sample or sample_posterior_predictive.

1 and 2. Problems with plate kwarg

At the moment we were using tfd.Sample to add plate axes to a Distribution in the following way:

def __init__(self, ...):
    ...
    self._distribution = tfd.Sample(self._distribution, sample_shape=plate)
    ...

This can have some confusing side effects because Sample adds dimensions to the event_shape, and leaves the batch_shapes as they were so:

@pm.model
def model():
    mu = pm.Normal("mu", tf.zeros(2), 1)
    x = pm.Normal("x", mu, 1, plate=3)
    y = pm.MvNormal("y", mu, tf.linalg.diag(tf.ones(2)), plate=3) 

Makes samples from x have shape (2, 3) instead of the expected (3, 2), because mu gets interpreted as multiple batches. However, samples from y do come out with shape (3, 2) because mu gets interpreted as defining the MvNormals event_shape.

The solution was to add a tfd.Distribution subclass Plate that's mostly copied off tfd.Sample but it adds dimensions to the batch dimensions instead of the event dimensions.

I've added a flag, plate_events to control whether to use Plate or Sample depending on the user's wishes.

3. The use_auto_batching flag

By far the auto batching part of this PR was the longest and it involved some API changes.

  1. One can define a Distribution to be conditionally independent. Why would one need that? To be able to pass pm.evaluate_model a sample_shape that draws multiple samples in a vectorized way for conditionally independent nodes. The user will then have to be sure that it didn't skrew up something later in the model's flow, and that all the operations work in a vectorized way across batches. This will enable potentially faster forward sampling at the expense that the user will have to deal with all the shape problems.
  2. One can re-interpret a Distribution's batch_shape axes as event_shapes. This enables the user to have fine grained control on how the distribution's log_prob deals with batches, and eventually lets it not depend on vectorized_map in inference.sampling.sample.

Things left out

I started out writing this with the intention of allowing users to vectorize draws from the prior, but regrettably that problem also affects how to do vectorized draws from the posterior or the posterior predictive, so in the end, I had to do all three tasks in a single big PR (sorry).
What I didn't do, and will do in january, is to write down a notebook guide to using these features.

@lucianopaz lucianopaz added the enhancement New feature or request label Dec 19, 2019
@lucianopaz
Copy link
Contributor Author

This PR isn't linked to a particular issue because it's mostly about new features. But I think that it should close #91, because it adds all the flexibility of the three kinds of shape (event, batch and sample) with all the added problems


class Plate(distribution_lib.Distribution):
"""Plate distribution via independent draws.
This distribution is useful for stacking collections of independent,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

indent looks off

@kullback_leibler.RegisterKL(Plate, Plate)
def _kl_sample(a, b, name="kl_sample"):
"""Batched KL divergence `KL(a || b)` for Plate distributions.
We can leverage the fact that:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same

@@ -546,7 +552,12 @@ def proceed_distribution(
# might be posterior predictive or programmatically override to exchange observed variable to latent
if scoped_name not in state.untransformed_values:
# posterior predictive
return_value = state.untransformed_values[scoped_name] = dist.sample()
if dist.is_root:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

getting less and less difference from tfd coroutines

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, this PR brings us much closer to tfd coroutines. However, all of the "root" stuff and manual vectorization is optional. We do auto batching by default and the users don't have to think so much about all of the potential shape issues. So I think that we will be providing a much better UX than tfd coroutines

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

+1, in tfd.jointCoroutine user need to identify the root, if we can do that automatically it is a win.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh wait, we are doing that with conditional_independent kwarg... hmmm it is probably needed but could we find a better name for this kwarg?

Copy link
Contributor Author

@lucianopaz lucianopaz Dec 20, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think that calling them root is really vague. I think that they should be given a name that bears a meaning in probability distributions. I could only think of conditionally independent but maybe there's a more appropriate alternative.
About identifying the root automatically, we can't do that when we define the model. We at least need to do one forward pass to build something like the probabilistic graphical model that we plot in pymc3 with model_graph.

Copy link
Member

@junpenglao junpenglao left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good from the first glance and I think this is a very much needed. I am going to do a more in-depth review in coming days.

@@ -546,7 +552,12 @@ def proceed_distribution(
# might be posterior predictive or programmatically override to exchange observed variable to latent
if scoped_name not in state.untransformed_values:
# posterior predictive
return_value = state.untransformed_values[scoped_name] = dist.sample()
if dist.is_root:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

+1, in tfd.jointCoroutine user need to identify the root, if we can do that automatically it is a win.

pymc4/distributions/plate.py Outdated Show resolved Hide resolved
*,
transform=None,
observed=None,
plate=None,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

One thing we should definitely clarify in the doc, is how plate/plate_events are the high level user API, while sample_shape and plate_shape is the internal (and lower level TFP) API depending on the resample/reshape distribution.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we still call it plate I wonder and not side with TFP's naming convention?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think that we can't call it either batch_shape or event_shape because, with either name, it would give the impression to the user that it is setting the full batch or event shape of the distribution, instead of just appending some dimensions to either of them. I think that it needs to have a different name, and plate seems nice.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We could maybe call it something like stack_batch or stack_event because that gives the impression that you are stacking things to either the batch or event dimensions

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

DAG jargon like plate should be avoided, I think. It should be clear to the user what it represents. Something like dist_sample_shape? This links it to the distribution, and describes it more unambiguously.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@lucianopaz thoughts?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry that I missed the notifications of the previous comments. @fonnesbeck, so we should change the names of the plate stuff.
@twiecki, I don't really like sample_shape because it will be confusing for the predictive sampling methods. At least I find that the tfd.Sample naming is very confusing and I wouldn't like to copy it in pymc4.
@junpenglao, I'm at a loss for an appropriate name for these. The correct term for the operation that we are carrying out could be either stacking or repeating. With plate and plate_events we either stack or repeat events or batches. I would call it event_stack and batch_stack, and we could then change the Plate distribution to BatchStacker. What do you guys think?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I like event_stack and batch_stack and BatchStacker. Let's go with that if no one is opposed, we will probably not find the perfect name anyway.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

+1, I would be ok with plate as well actually. As long as we document it clearly even if we make a mistake user would adapt (hopefully).

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great, I'll change the names and fix the conflicts tonight so we can merge this

@junpenglao
Copy link
Member

I had a bit more thought about this PR:

So first of all, I totally agree that currently our way of handling models like

@pm.model
def model(): 
    mu = yield pm.Normal('mu', tf.zeros(2), 1) 
    x = yield pm.Normal('x', mu, 1, plate=3)
    y = yield pm.MvNormal('y', mu, tf.linalg.diag(tf.ones(2)), plate=3)

value, state=pm.flow.evaluate_model(model())
for name, dist in state.distributions.items(): 
    print(name,
          'batch_shape:', dist._distribution.batch_shape, 
          'event_shape:', dist._distribution.event_shape, 
          'shape of dist.sample() -->', dist.sample().shape)
    
# model/mu batch_shape: (2,) event_shape: () shape of dist.sample() --> (2,)
# model/x batch_shape: (2,) event_shape: (3,) shape of dist.sample() --> (2, 3)
# model/y batch_shape: () event_shape: (3, 2) shape of dist.sample() --> (3, 2)

outputs confusing random variables, and pm.Plate is potentially useful for moving these kind of repeated shape to the "right" axis (right most). But I am not completely sure whether adding Plate here is right approach to solve this problem. There are a few design choices are TBD here:

  1. are we going to introduce pm.Sample and pm.Independent which wrap tfd.Sample and tfd.Indenpendent?
    The reason being that if we introduce these "meta distribution wrapper", the above problem could be solved by user wrapping independent at the right place:
@pm.model
def model_with_independent(): 
    mu = yield pm.Normal('mu', tf.zeros(2), 1) 
    x = yield pm.Independent(pm.Normal('x', mu, 1), plate=3)
    y = yield pm.MvNormal('y', mu, tf.linalg.diag(tf.ones(2)), plate=3)

Just as currently in TFP land you would do:

def model_jd():
    mu = yield root(tfd.Normal(tf.zeros(2), 1))
    x = yield tfd.Sample(tfd.Independent(tfd.Normal(mu, 1), 1), 3)
    y = yield tfd.Sample(tfd.MultivariateNormalFullCovariance(
        mu, tf.linalg.diag(tf.ones(2))), 3)


jd = tfd.JointDistributionCoroutine(model_jd)

dists, values = jd.sample_distributions()
dists
  1. If we are not exposing pm.Independent, what are the other options?
    One option I would like to entertain is whether by default treating all RVs as batch shape = () is reasonable, assuming that we are going to batch using vectorized_map always (unless explicitly turned off). In that case, the above model would give output as:
value, state=pm.flow.evaluate_model(model())
for name, dist in state.distributions.items(): 
    print(name,
          'batch_shape:', dist._distribution.batch_shape, 
          'event_shape:', dist._distribution.event_shape, 
          'shape of dist.sample() -->', dist.sample().shape)
# Junpeng's expected behavior
# model/mu batch_shape: () event_shape: (2, ) shape of dist.sample() --> (2,)
# model/x batch_shape: () event_shape: (3, 2) shape of dist.sample() --> (3, 2)
# model/y batch_shape: () event_shape: (3, 2) shape of dist.sample() --> (3, 2)

On the implementation side, this would means we wrap everything with tfd.Independent by default if they are not root. The restriction is that we will need to assume a static batch shape so that we can rewrite the reinterpreted_batch_ndims at runtime (i.e., pm.flow)

@lucianopaz
Copy link
Contributor Author

Thanks @junpenglao! I totally agree with discussing the design. I like your suggestion of using plate and sample_shape instead of plate and the boolean flag that I'm using here, but I don't like the name sample_shape. It's confusing because it looks like a sample call.
I don't like the idea of adding the metadistributions pm.Independent and pm.Sample. I would much prefer that their functionality be controlled with parameter kwargs to the construction of regular distributions.
I'm quite time constrained, but I'll try to wrap up this PR to make it ready to merge in the next week.

@lucianopaz
Copy link
Contributor Author

@junpenglao, this should be ready for review. In the next weeks I'll write a notebook about the advanced usage. Also, in another PR, I'll write a docstring for Distribution, but this PR has been open for some time and I would like to finish it off soonish.

@pymc-devs pymc-devs deleted a comment from lucianopaz Feb 7, 2020
@junpenglao
Copy link
Member

Sorry about the delay - I will take a look this weekend.

@junpenglao
Copy link
Member

@lucianopaz I dont have further comments, can you resolve the conflicts?

@lucianopaz
Copy link
Contributor Author

@junpenglao, I'll resolve the conflicts tonight.

@twiecki
Copy link
Member

twiecki commented Feb 14, 2020

Looks great, thanks @lucianopaz! I'm going to merge as I think it's in a good state, we can still follow up if there are more changes we think we need.

@twiecki twiecki merged commit cabd67c into pymc-devs:master Feb 14, 2020
@lucianopaz lucianopaz mentioned this pull request Feb 14, 2020
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants