Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Avoid unclear TypeError when using theano.shared variables as input to distribution parameters #4445

Merged
merged 2 commits into from
Jan 31, 2021

Conversation

kc611
Copy link
Contributor

@kc611 kc611 commented Jan 28, 2021

Fixes #3139

As suggested by @rpgoldman over here #3139 (comment). This PR just makes the error message a bit clearer (I just re-arranged it a bit, the error message suggesting to add a test_val argument was already present).

@kc611
Copy link
Contributor Author

kc611 commented Jan 28, 2021

We could instead also do a tt.isinf check for such cases (Hence adding support for tt.shared variables) but I guess the new RandomVariable OPs will make these changes obsolete anyway.

@ricardoV94
Copy link
Member

ricardoV94 commented Jan 28, 2021

Interesting. Do we want to nudge people towards pm.Data (which can be safely used for distribution parameters since #3925).

If we want to nudge I would mention it in the AttributeError (splitting the infinite and the TypeError messages).

If we want to support theano.shared I wouldn't close the linked issue.

Edit: Just for completeness, since #3925 shared tensors will work properly (get_value() is called by getattr_value), but shared scalars or shared tensors created from non-numpy arrays (which are of type theano.tensor.sharedvar.ScalarSharedVariable and theano.compile.sharedvalue.SharedVariable, respectively) still fail because they are not asked to return get_value(): https://github.com/pymc-devs/pymc3/blob/03d7af5b6dd5ad99ab2f3bd8ca7987a744dbef46/pymc3/distributions/distribution.py#L170-L171

@ricardoV94
Copy link
Member

But I do agree with @kc611 that this may all be made redundant with the RandomVariable Op

@ricardoV94
Copy link
Member

ricardoV94 commented Jan 29, 2021

I think we can avoid the TypeError altogether if we change: https://github.com/pymc-devs/pymc3/blob/03d7af5b6dd5ad99ab2f3bd8ca7987a744dbef46/pymc3/distributions/distribution.py#L170-L171

To:

if isinstance(val, (
    tt.sharedvar.TensorSharedVariable,              # pm.Data or theano.shared tensor
    theano.tensor.sharedvar.ScalarSharedVariable,   # theano.shared scalar
    theano.compile.sharedvalue.SharedVariable,      # theano.shared tensor from non-numpy array such as list

):
    return val.get_value()

I don't see any obvious drawbacks and it makes the API more forgiving. However, we should then add some unittest along these lines, to make sure it is working as intended: https://github.com/pymc-devs/pymc3/blob/03d7af5b6dd5ad99ab2f3bd8ca7987a744dbef46/pymc3/tests/test_data_container.py#L139-L157

@AlexAndorra since you worked on enabling the pm.Data as input to other rvs, do you think there is any reason to not accomodate these other shared types?

@AlexAndorra
Copy link
Contributor

Thanks for the deep dive @ricardoV94 !

@AlexAndorra since you worked on enabling the pm.Data as input to other rvs, do you think there is any reason to not accomodate these other shared types?

No, adding these looks really fine to me 👌

@ricardoV94
Copy link
Member

ricardoV94 commented Jan 29, 2021

@kc611 is this something that you want to do in this PR? (it's totally fine if you are not interested)

@kc611
Copy link
Contributor Author

kc611 commented Jan 30, 2021

No issues on my side. I'll make the changes shortly.

pymc3/tests/test_data_container.py Outdated Show resolved Hide resolved
pymc3/distributions/distribution.py Outdated Show resolved Hide resolved
@ricardoV94
Copy link
Member

ricardoV94 commented Jan 30, 2021

I just realized that the changes checking for theano.compile.sharedvalue.SharedVariable (e.g., theano.shared([5.0, 5.0]) will not work for most distributions, because parameters are usually coerced at the beginning via theanof::floatX or theanof::intX, which fails with lists. Here is an illustration:

with pm.Model() as m:
    shared_var = theano.shared([5.0, 5.0])  # Fails: cannot be safely coerced into theano.config.floatX
    v = pm.Normal("v", mu=shared_var, shape=2)

Raises ValueError: setting an array element with a sequence in theanof::floatX.

So I think we can just drop that condition / check. The best would be to have a way to nudge users into using np.array, but I don't see an easy way to achieve that. Note that the original goal of this PR would not address this issue either.

The current PR still solves the case where theano.shared is a scalar (which was the failing example that motivated this PR), which is already more forgiving compared to what we had before:

with pm.Model() as m:
    shared_var = theano.shared(5.0)  # Failed before
    v = pm.Normal("v", mu=shared_var, shape=2)

And just for completeness, it is still fine to use a theano.shared with a numpy array (or equivalent, pm.Data):

with pm.Model() as m:
    shared_var = theano.shared(np.array([5.0, 5.0]))  # Still fine
    v = pm.Normal("v", mu=shared_var, shape=2)

Copy link
Member

@ricardoV94 ricardoV94 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Last suggestion I promise!

In the meanwhile, you can go ahead and add a maintenance Release Note mentioning that theano ScalarSharedVariable can now also be used as input to other RVs.

pymc3/distributions/distribution.py Outdated Show resolved Hide resolved
@ricardoV94 ricardoV94 changed the title The error catching in get_test_val() now gives a more clear message Avoid unclear TypeError when using theano.shared variables as input to distributions parameters Jan 30, 2021
@ricardoV94 ricardoV94 changed the title Avoid unclear TypeError when using theano.shared variables as input to distributions parameters Avoid unclear TypeError when using theano.shared variables as input to distribution parameters Jan 30, 2021
@kc611
Copy link
Contributor Author

kc611 commented Jan 30, 2021

Last suggestion I promise!

Here I'm feeling stupid for not noticing such obvious stuff. :-P . Anyway thanks for your huge support in this PR. (I wasn't exactly familiar with Distribution class so ended up just following whatever you said.)

@ricardoV94
Copy link
Member

ricardoV94 commented Jan 30, 2021

Here I'm feeling stupid for not noticing such obvious stuff. :-P . Anyway thanks for your huge support in this PR. (I wasn't exactly familiar with Distribution class so ended up just following whatever you said.)

I also didn't notice the "obvious" stuff before!

I am sorry if I came across a bit heavy-handed, I was unsure of how willing / comfortable you were with the suggested changes, since the original PR / issue had a very different angle to it. I really appreciate your patience and effort!

@ricardoV94
Copy link
Member

Leaving it open just in case anyone finds issues with it. Will merge in a day or so otherwise.

@michaelosthege michaelosthege merged commit 07679ec into pymc-devs:master Jan 31, 2021
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Bug: Normal distribution throws errors when using shared variables
4 participants