-
-
Notifications
You must be signed in to change notification settings - Fork 113
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
ENH: add Flat and HalfFlat Distributions #219
Conversation
pymc4/distributions/continuous.py
Outdated
return tfd.Uniform(low=0.0, high=np.inf) | ||
|
||
def log_prob(self, value): | ||
return tf.cond( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You should use tf.where here
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done. Thanks!
@tirthasheshpatel, thanks again. Contrary to what it may look like, these two distributions will need more tests than the rest for two big reasons:
Point 1, cannot be done until we fix #167, and point 2 is a pit in shape hell, so for now I think that we should leave this PR open and when we fix up #167 we can come back to it. |
@lucianopaz Sorry for the late reply and thanks for such a detailed review! I will think more about the "shape hell" and try to raise relevant errors. Meanwhile, I will also try to catch up with #167 and try to resolve it! |
@tirthasheshpatel, for now don't worry about #167. I'm working on it in this branch, but it requires #193 to be finished first. My recomendation is to do one of the following:
|
Some distributions in #44 don't have a tfp equivalent. So, do we have to wait for it to be implemented or is there a way around? (I was thinking about implementing classes inheriting from tfp's |
pymc4/distributions/continuous.py
Outdated
# raise ValueError("Rank of input tensor less than distribution's event shape") | ||
# # if the rightmost axis of `value` doesn't match the distribution's `event_shape`, raise an error | ||
# if ( | ||
# len(self._distribution.event_shape) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
- There is no tfp equivalent, so we must test that the returned
log_prob
values are correct. That sounds simple because the log_prob should always be zero. The difficulty here is shape handling. If thevalue
passed tolog_prob
has a rank that is lower than the distribution's event shape, we should raise an error, if the value has a shape that doesn't broadcast with the distribution's batch+event shape, we should raise an error. Actually, I think it's even more stringent than that, the rightmost axis of the passed values should exactly match the distribution's event shape, not just broadcast with it, or an error should be raised. Furthermore, the log prob should sum reduce over the event shape axis, in order to only get back a batch of zeroes.
As the distribution is univariate the event_shape and batch_shape will always be ()
and so I think we don't need to worry about it, right?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The defaults will be ()
, but we can stack multiple independent variable that follow this distribution with something similar to tfd.Sample
, and that will lead to an event shape that could be a tuple of any length.
Furthermore, in #167, we are also aiming to provide a similar mechanism as tfd.Sample
but instead of stacking events, we will stack batches. So in the end, any distribution can have any number of event and batch axes. I just realized that I wrote the singular axis in my previous comment, but I meant it to be the plural, axes.
pymc4/distributions/continuous.py
Outdated
# raise ValueError( | ||
# "Batch shape of input tensor not consistent with the distributions's batch shape" | ||
# ) | ||
if tf.rank(value) > 2: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I have reduced sum when the shape of value is (n_samples, batch_shape, event_shape)
along the event_shape
axis. Shouldn't we reduce sum over n_samples
axis?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We only must reduce sum over the event shape axes. That's what truly distinguishes between the batch and event shapes. For example,
>>> import tensorflow as tf
>>> from tensorflow_probability import distributions as tfd
>>>
>>> d = tfd.Normal(loc=tf.zeros(1, 2), scale=1)
>>> d = tfd.Sample(d, sample_shape=(3, 4, 5))
>>> d.batch_shape.as_list()
[1, 2]
>>> d.event_shape.as_list()
[3, 4, 5]
>>>
>>> x = d.sample(sample_shape=(6, 7))
>>> x.numpy().shape
(6, 7, 1, 2, 3, 4, 5)
>>> v = d.log_prob(x)
>>> v.numpy().shape
(6, 7, 1, 2)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Oh! This is very clear now. Thanks!!
return tfd.Uniform(low=-np.inf, high=np.inf) | ||
|
||
def log_prob(self, value): | ||
# convert the value to tensor |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think that you're doing too many if statements. I recommend that you start out with expected = tf.zeros(self._distribution.batch_shape + self._distribution.event_shape)
. This would be the expected shape of a call to a tfd.Distribution.sample()
.
You should first check if value
has a matching event shape part to expected
, if yes then broadcast expected
and value
and create a zeros tensor with the resulting shape. Finally, you'll have to reduce sum over all event shape axes.
Pinging @tirthasheshpatel, just to let you know that I've opened #227. Once that gets merged in, you should merge it into your branch and you'll be able to run all of the tests required by the |
@lucianopaz sorry for getting back to you late. I will complete this by today. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Great work @tirthasheshpatel! However, I think that some important tests are still missing before being able to merge this.
tests/test_distributions.py
Outdated
@@ -348,6 +362,8 @@ def test_rvs_test_point_are_valid(tf_seed, distribution_conditions): | |||
dist_class = getattr(pm, distribution_name) | |||
dist = dist_class(name=distribution_name, **conditions) | |||
test_value = dist.test_value | |||
if distribution_name in ["Flat", "HalfFlat"]: | |||
pytest.skip("Flat and HalfFlat distributions don't support sampling.") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We shouldn't skip this test. In the case of the Flat
and HalfFlat
, instead of comparing the test_value.shape
against the test_sample.shape
, we should compare it against the sample_shape + batch_shape + event_shape
that we expect to get.
and value.shape[-len(self._distribution.event_shape) :] | ||
!= self._distribution.event_shape | ||
): | ||
raise ValueError("values not consistent with the event shape of distribution") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You should write a test that raises this exception to make sure it's working
pymc4/distributions/continuous.py
Outdated
try: | ||
expected = tf.broadcast_to(expected, value.shape) | ||
except tf.python.framework.errors_impl.InvalidArgumentError: | ||
raise ValueError("value can't be broadcasted to expected shape") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You should write a test that raises this exception to make sure it's working
pymc4/distributions/continuous.py
Outdated
expected = tf.zeros(self._distribution.batch_shape + self._distribution.event_shape) | ||
# check if the event shape matches | ||
if ( | ||
len(self._distribution.event_shape) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
A minor nitpick, our distributions now have the event_shape
and batch_shape
properties, so you don't have to access the _distribution
attribute to look them up. You can do self.event_shape
and self.batch_shape
instead
and value.shape[-len(self._distribution.event_shape) :] | ||
!= self._distribution.event_shape | ||
): | ||
raise ValueError("values not consistent with the event shape of distribution") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The same comment as for the Flat
distribution, we need a test that raises this exception
pymc4/distributions/continuous.py
Outdated
try: | ||
expected = tf.broadcast_to(expected, value.shape) + value | ||
except tf.python.framework.errors_impl.InvalidArgumentError: | ||
raise ValueError("value can't be broadcasted to expected shape") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The same comment as for the Flat
distribution, we need a test that raises this exception
pymc4/distributions/continuous.py
Outdated
raise ValueError("values not consistent with the event shape of distribution") | ||
# broadcast expected to shape of value | ||
try: | ||
expected = tf.broadcast_to(expected, value.shape) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Does this work if value.shape=(3,)
but the distribution has batch_shape=(4,)
and event_shape=(3,)
, so the expected shape is (4, 3)
? Wont it raise an error for that kind of situation?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes. I think this can be fixed by checking len(value.shape) < len(self.batch_shape + self.event_shape)
. If this condition evaluates to True
, we only check if the values are consistent with batch_shape
and we don't need to broadcast. While if the condition is False
, let tf.broadcast_to
handle it. What do you say? Is there a way around?
Code
# broadcast expected to shape of value
if len(value.shape) < len(self.batch_shape + self.event_shape):
expected = expected + value
if value.shape[:-len(self.event_shape)] != list(reversed(self.batch_shape))[:(len(value.shape)-len(self.event_shape))]:
raise ValueError("batch shape of values is not consistent with distribution's batch shape")
else:
try:
expected = tf.broadcast_to(expected, value.shape) + value
except tf.errors.InvalidArgumentError:
raise ValueError("value can't be broadcasted to expected shape")
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think that this is almost done. I left a few comments above
pymc4/distributions/continuous.py
Outdated
# broadcast expected to shape of value | ||
if len(value.shape) < len(self.batch_shape + self.event_shape): | ||
if ( | ||
value.shape[: -len(self.event_shape)] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This won't work if the event shape is a scalar, you'll end up with value.shape[:0]
. You will have to change this to value.shape[:len(value.shape) - len(event_shape)]
pymc4/distributions/continuous.py
Outdated
if len(value.shape) < len(self.batch_shape + self.event_shape): | ||
if ( | ||
value.shape[: -len(self.event_shape)] | ||
!= list(reversed(self.batch_shape))[: (len(value.shape) - len(self.event_shape))] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why do you need to reverse this?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If the batch_shape
is (1, 2)
, event_shape
is (3, 4)
, and we want to sample (2, 3, 4)
then the target must broadcast to (1, 2, 3, 4)
. So, we need to check the values of batch_shape from the rightmost axis.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Oh, I see the problem! I have to check the batch_shape[len(self.batch_shape) - len(value.shape):]
instead of reversing it.
tests/test_distributions.py
Outdated
|
||
|
||
@pytest.mark.parametrize("distribution_name", ["Flat", "HalfFlat"]) | ||
@pytest.mark.parametrize("sample", [tf.zeros(1), tf.zeros((1, 3, 4)), tf.zeros((1, 5, 3, 4))]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We are using fixtures instead of mark parameterize. Could you change it to that?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sure!
want to finish this up @lucianopaz? |
Sorry @tirthasheshpatel for not being able to finish this up sooner. Covid quarantine and all, I can't put almost any time to review pymc stuff. I'll merge this PR because it seems ready to go. Great job! |
Addresses #44.
Summary of my changes:
Flat
distributionHalfFlat
distributionQuestion
Do these need tests?