Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix failing default transform for LKJCorr #7065

Merged
merged 11 commits into from
Dec 13, 2023

Conversation

juanitorduz
Copy link
Contributor

@juanitorduz juanitorduz commented Dec 13, 2023

Closes #7002

Wt take a different direction from #7023


📚 Documentation preview 📚: https://pymc--7065.org.readthedocs.build/en/7065/

@juanitorduz juanitorduz marked this pull request as draft December 13, 2023 09:40
Copy link

codecov bot commented Dec 13, 2023

Codecov Report

Merging #7065 (67b80c6) into main (2e05854) will increase coverage by 0.00%.
Report is 2 commits behind head on main.
The diff coverage is 100.00%.

Additional details and impacted files

Impacted file tree graph

@@           Coverage Diff           @@
##             main    #7065   +/-   ##
=======================================
  Coverage   92.19%   92.20%           
=======================================
  Files         101      101           
  Lines       16893    16901    +8     
=======================================
+ Hits        15575    15584    +9     
+ Misses       1318     1317    -1     
Files Coverage Δ
pymc/distributions/multivariate.py 93.53% <100.00%> (+0.16%) ⬆️
pymc/logprob/transform_value.py 93.75% <100.00%> (ø)

@ricardoV94 ricardoV94 added the bug label Dec 13, 2023
@ricardoV94 ricardoV94 changed the title Change default transform LKJCOrr Fix failing default transform for LKJCorr Dec 13, 2023
@ricardoV94
Copy link
Member

Is it good on your end? Asking because it's marked as a draft

@juanitorduz
Copy link
Contributor Author

juanitorduz commented Dec 13, 2023

@ricardoV94 With the new suggestion the test fails locally with the

NotImplementedError: Univariate transform MultivariateIntervalTransform cannot be applied to multivariate lkjcorr_rv{1, (0, 0), floatX, False}

🤔

@ricardoV94
Copy link
Member

@ricardoV94 With the new suggestion the test fails with the

NotImplementedError: Univariate transform MultivariateIntervalTransform cannot be applied to multivariate lkjcorr_rv{1, (0, 0), floatX, False}

🤔

Where is that check coming from? We might need to add some meta-info to the Transform

@ricardoV94
Copy link
Member

The problem is the logp of the distribution is incorrectly implemented. It's returning a scalar instead of a vector of shape=(3,)

@juanitorduz
Copy link
Contributor Author

Yeah! It is failing locally. It's good that you caught up on this with the test! Do you think there is an "easy" fix?

@ricardoV94
Copy link
Member

ricardoV94 commented Dec 13, 2023

We should add a NotImplementedError in the logp for when value.ndim > 1, since it's not properly supported AFAICT: #5383

You can parametrize the test to have two cases shape=(2,) and shape=(3,2) the second of which is expected to fail. You can mark specific parametrizations with pytest.param like here:

pytest.param(
np.array([0.1441, 0.1363, 0.1385, 0.1348, 0.1521, 0.1500, 0.1442]),
4,
None,
np.array([1, 1, 1, 1, 0, 0, 0]),
marks=pytest.mark.xfail(
rises=AssertionError, reason="Known failure in mode approximation "
),

Also this test shouldn't be in the TestLKJCorr class, since that is just for the RandomVariable checks. You can make it a standalone test

@ricardoV94
Copy link
Member

ricardoV94 commented Dec 13, 2023

Yeah! It is failing locally. It's good that you caught up on this with the test! Do you think there is an "easy" fix?

It's not trivial, it requires thinking carefully about batch dimensions, like we did in this PR: #6897
But it would be a great addition. I think it should be done in a separate PR though

@ricardoV94
Copy link
Member

Also could you reintroduce the change from the other PR where we always run this check instead of being in the else branch?

else:
# Check there is no broadcasting between logp and jacobian
if logp.type.broadcastable != log_jac_det.type.broadcastable:
raise ValueError(
f"The logp of {rv_op} and log_jac_det of {transform} are not allowed to broadcast together. "
"There is a bug in the implementation of either one."

@juanitorduz juanitorduz marked this pull request as ready for review December 13, 2023 11:22
@juanitorduz
Copy link
Contributor Author

Added the suggested changes :)

@ricardoV94
Copy link
Member

We should add a check in the logp similar to this:

(value,) = values
if value.ndim > 1:
raise ValueError("_LKJCholeskyCov logp is only implemented for vector values (ndim=1)")

Should be a NotImplementedError though

@juanitorduz
Copy link
Contributor Author

We now have two tests failing because of the new NotImplementedError. Should I remove such cases (via xfail)?

tests/distributions/test_multivariate.py::TestMoments::test_lkjcorr_moment[3-1-1-expected2] FAILED [ 87%]
tests/distributions/test_multivariate.py::TestMoments::test_lkjcorr_moment[5-1-size3-expected3] FAILED [ 87%]

@ricardoV94
Copy link
Member

We now have two tests failing because of the new NotImplementedError. Should I remove such cases (via xfail)?

tests/distributions/test_multivariate.py::TestMoments::test_lkjcorr_moment[3-1-1-expected2] FAILED [ 87%]
tests/distributions/test_multivariate.py::TestMoments::test_lkjcorr_moment[5-1-size3-expected3] FAILED [ 87%]

Yup

@juanitorduz
Copy link
Contributor Author

@ricardoV94 we are back to 🟢 :)

@ricardoV94 ricardoV94 merged commit 851c991 into pymc-devs:main Dec 13, 2023
22 checks passed
@juanitorduz juanitorduz deleted the issue_7002 branch December 13, 2023 13:59
@juanitorduz
Copy link
Contributor Author

Thank you for all your help @ricardoV94 ❤️

@ricardoV94
Copy link
Member

Thanks @juanitorduz

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

BUG: LKJCorr default transform raises error
2 participants