-
-
Notifications
You must be signed in to change notification settings - Fork 113
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Advanced usage: allow the users to deal with manually batching models #193
Conversation
This PR isn't linked to a particular issue because it's mostly about new features. But I think that it should close #91, because it adds all the flexibility of the three kinds of shape (event, batch and sample) with all the added problems |
pymc4/distributions/plate.py
Outdated
|
||
class Plate(distribution_lib.Distribution): | ||
"""Plate distribution via independent draws. | ||
This distribution is useful for stacking collections of independent, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
indent looks off
pymc4/distributions/plate.py
Outdated
@kullback_leibler.RegisterKL(Plate, Plate) | ||
def _kl_sample(a, b, name="kl_sample"): | ||
"""Batched KL divergence `KL(a || b)` for Plate distributions. | ||
We can leverage the fact that: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
same
@@ -546,7 +552,12 @@ def proceed_distribution( | |||
# might be posterior predictive or programmatically override to exchange observed variable to latent | |||
if scoped_name not in state.untransformed_values: | |||
# posterior predictive | |||
return_value = state.untransformed_values[scoped_name] = dist.sample() | |||
if dist.is_root: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
getting less and less difference from tfd coroutines
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah, this PR brings us much closer to tfd coroutines. However, all of the "root" stuff and manual vectorization is optional. We do auto batching by default and the users don't have to think so much about all of the potential shape issues. So I think that we will be providing a much better UX than tfd coroutines
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
+1, in tfd.jointCoroutine user need to identify the root, if we can do that automatically it is a win.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Oh wait, we are doing that with conditional_independent
kwarg... hmmm it is probably needed but could we find a better name for this kwarg?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think that calling them root is really vague. I think that they should be given a name that bears a meaning in probability distributions. I could only think of conditionally independent but maybe there's a more appropriate alternative.
About identifying the root automatically, we can't do that when we define the model. We at least need to do one forward pass to build something like the probabilistic graphical model that we plot in pymc3 with model_graph
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good from the first glance and I think this is a very much needed. I am going to do a more in-depth review in coming days.
@@ -546,7 +552,12 @@ def proceed_distribution( | |||
# might be posterior predictive or programmatically override to exchange observed variable to latent | |||
if scoped_name not in state.untransformed_values: | |||
# posterior predictive | |||
return_value = state.untransformed_values[scoped_name] = dist.sample() | |||
if dist.is_root: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
+1, in tfd.jointCoroutine user need to identify the root, if we can do that automatically it is a win.
pymc4/distributions/distribution.py
Outdated
*, | ||
transform=None, | ||
observed=None, | ||
plate=None, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
One thing we should definitely clarify in the doc, is how plate
/plate_events
are the high level user API, while sample_shape
and plate_shape
is the internal (and lower level TFP) API depending on the resample/reshape distribution.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should we still call it plate I wonder and not side with TFP's naming convention?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think that we can't call it either batch_shape
or event_shape
because, with either name, it would give the impression to the user that it is setting the full batch or event shape of the distribution, instead of just appending some dimensions to either of them. I think that it needs to have a different name, and plate seems nice.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We could maybe call it something like stack_batch
or stack_event
because that gives the impression that you are stacking things to either the batch or event dimensions
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
DAG jargon like plate
should be avoided, I think. It should be clear to the user what it represents. Something like dist_sample_shape
? This links it to the distribution, and describes it more unambiguously.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@lucianopaz thoughts?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sorry that I missed the notifications of the previous comments. @fonnesbeck, so we should change the names of the plate
stuff.
@twiecki, I don't really like sample_shape
because it will be confusing for the predictive sampling methods. At least I find that the tfd.Sample
naming is very confusing and I wouldn't like to copy it in pymc4.
@junpenglao, I'm at a loss for an appropriate name for these. The correct term for the operation that we are carrying out could be either stacking or repeating. With plate
and plate_events
we either stack or repeat events or batches. I would call it event_stack
and batch_stack
, and we could then change the Plate
distribution to BatchStacker
. What do you guys think?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I like event_stack
and batch_stack
and BatchStacker
. Let's go with that if no one is opposed, we will probably not find the perfect name anyway.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
+1, I would be ok with plate as well actually. As long as we document it clearly even if we make a mistake user would adapt (hopefully).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Great, I'll change the names and fix the conflicts tonight so we can merge this
I had a bit more thought about this PR: So first of all, I totally agree that currently our way of handling models like @pm.model
def model():
mu = yield pm.Normal('mu', tf.zeros(2), 1)
x = yield pm.Normal('x', mu, 1, plate=3)
y = yield pm.MvNormal('y', mu, tf.linalg.diag(tf.ones(2)), plate=3)
value, state=pm.flow.evaluate_model(model())
for name, dist in state.distributions.items():
print(name,
'batch_shape:', dist._distribution.batch_shape,
'event_shape:', dist._distribution.event_shape,
'shape of dist.sample() -->', dist.sample().shape)
# model/mu batch_shape: (2,) event_shape: () shape of dist.sample() --> (2,)
# model/x batch_shape: (2,) event_shape: (3,) shape of dist.sample() --> (2, 3)
# model/y batch_shape: () event_shape: (3, 2) shape of dist.sample() --> (3, 2) outputs confusing random variables, and pm.Plate is potentially useful for moving these kind of repeated shape to the "right" axis (right most). But I am not completely sure whether adding Plate here is right approach to solve this problem. There are a few design choices are TBD here:
@pm.model
def model_with_independent():
mu = yield pm.Normal('mu', tf.zeros(2), 1)
x = yield pm.Independent(pm.Normal('x', mu, 1), plate=3)
y = yield pm.MvNormal('y', mu, tf.linalg.diag(tf.ones(2)), plate=3) Just as currently in TFP land you would do: def model_jd():
mu = yield root(tfd.Normal(tf.zeros(2), 1))
x = yield tfd.Sample(tfd.Independent(tfd.Normal(mu, 1), 1), 3)
y = yield tfd.Sample(tfd.MultivariateNormalFullCovariance(
mu, tf.linalg.diag(tf.ones(2))), 3)
jd = tfd.JointDistributionCoroutine(model_jd)
dists, values = jd.sample_distributions()
dists
value, state=pm.flow.evaluate_model(model())
for name, dist in state.distributions.items():
print(name,
'batch_shape:', dist._distribution.batch_shape,
'event_shape:', dist._distribution.event_shape,
'shape of dist.sample() -->', dist.sample().shape)
# Junpeng's expected behavior
# model/mu batch_shape: () event_shape: (2, ) shape of dist.sample() --> (2,)
# model/x batch_shape: () event_shape: (3, 2) shape of dist.sample() --> (3, 2)
# model/y batch_shape: () event_shape: (3, 2) shape of dist.sample() --> (3, 2) On the implementation side, this would means we wrap everything with tfd.Independent by default if they are not root. The restriction is that we will need to assume a static batch shape so that we can rewrite the reinterpreted_batch_ndims at runtime (i.e., pm.flow) |
Thanks @junpenglao! I totally agree with discussing the design. I like your suggestion of using |
@junpenglao, this should be ready for review. In the next weeks I'll write a notebook about the advanced usage. Also, in another PR, I'll write a docstring for |
Sorry about the delay - I will take a look this weekend. |
@lucianopaz I dont have further comments, can you resolve the conflicts? |
@junpenglao, I'll resolve the conflicts tonight. |
Looks great, thanks @lucianopaz! I'm going to merge as I think it's in a good state, we can still follow up if there are more changes we think we need. |
This is a rather big and ugly PR that addresses the following issues:
plate
axis of aDistribution
.tfp.distributions.Distribution
subclass calledPlate
, that enables the above pointsample_prior_predictive
,sample
orsample_posterior_predictive
.1 and 2. Problems with
plate
kwargAt the moment we were using
tfd.Sample
to add plate axes to aDistribution
in the following way:This can have some confusing side effects because
Sample
adds dimensions to theevent_shape
, and leaves thebatch_shape
s as they were so:Makes samples from
x
have shape(2, 3)
instead of the expected(3, 2)
, becausemu
gets interpreted as multiple batches. However, samples fromy
do come out with shape(3, 2)
becausemu
gets interpreted as defining theMvNormal
sevent_shape
.The solution was to add a
tfd.Distribution
subclassPlate
that's mostly copied offtfd.Sample
but it adds dimensions to the batch dimensions instead of the event dimensions.I've added a flag,
plate_events
to control whether to usePlate
orSample
depending on the user's wishes.3. The
use_auto_batching
flagBy far the auto batching part of this PR was the longest and it involved some API changes.
Distribution
to be conditionally independent. Why would one need that? To be able to passpm.evaluate_model
asample_shape
that draws multiple samples in a vectorized way for conditionally independent nodes. The user will then have to be sure that it didn't skrew up something later in the model's flow, and that all the operations work in a vectorized way across batches. This will enable potentially faster forward sampling at the expense that the user will have to deal with all the shape problems.Distribution
'sbatch_shape
axes asevent_shape
s. This enables the user to have fine grained control on how the distribution'slog_prob
deals with batches, and eventually lets it not depend onvectorized_map
ininference.sampling.sample
.Things left out
I started out writing this with the intention of allowing users to vectorize draws from the prior, but regrettably that problem also affects how to do vectorized draws from the posterior or the posterior predictive, so in the end, I had to do all three tasks in a single big PR (sorry).
What I didn't do, and will do in january, is to write down a notebook guide to using these features.