Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Provide JAX Ops from Optional tensorflow-probability dependency #403

Merged
merged 1 commit into from
Jul 27, 2023

Conversation

ricardoV94
Copy link
Member

Closes #43
Closes #91
Closes #256

@ricardoV94 ricardoV94 added enhancement New feature or request jax labels Jul 27, 2023
@ricardoV94 ricardoV94 changed the title Provide JAX Ops from Optional tfp dependency Provide JAX Ops from Optional tensorflow-probability dependency Jul 27, 2023
@ricardoV94 ricardoV94 force-pushed the jax_tfp_ops branch 2 times, most recently from a384d94 to 5b0eb87 Compare July 27, 2023 13:08
@codecov-commenter
Copy link

Codecov Report

Merging #403 (e4766a8) into main (8ac8342) will increase coverage by 0.00%.
The diff coverage is 90.90%.

Additional details and impacted files

Impacted file tree graph

@@           Coverage Diff           @@
##             main     #403   +/-   ##
=======================================
  Coverage   80.46%   80.47%           
=======================================
  Files         156      156           
  Lines       45558    45579   +21     
  Branches    11162    11167    +5     
=======================================
+ Hits        36660    36679   +19     
- Misses       6695     6697    +2     
  Partials     2203     2203           
Files Changed Coverage Δ
pytensor/link/jax/dispatch/scalar.py 97.27% <90.90%> (-1.14%) ⬇️

Copy link
Member

@jessegrabowski jessegrabowski left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, seems very straight-forward. Does anything else need to be done after this is merged to JAX compile models with distributions that rely on these functions?

@ricardoV94
Copy link
Member Author

Nope, users just have to install tfp, but they will get the informative message if that's the case

@ricardoV94 ricardoV94 merged commit 9ada945 into pymc-devs:main Jul 27, 2023
52 checks passed
@maresb
Copy link
Contributor

maresb commented Jul 27, 2023

Do you need a version pin in the PyTensor feedstock? You can add a run_constrained which says "if TFP is installed, then make sure it satisfies this version spec".

@ricardoV94
Copy link
Member Author

ricardoV94 commented Jul 27, 2023

I don't think so, we didn't need it for numpyro? We can wait and see if people have issues with it?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request jax
Projects
None yet
4 participants