Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Default to Scan in postprocessing of jax samplers #6922

Merged
merged 8 commits into from
Sep 25, 2023
Merged

Conversation

ferrine
Copy link
Member

@ferrine ferrine commented Sep 22, 2023

What is this PR about?
Fixing sansible defaults in jax post processing and updating to blackjax 1.0.0 api

Checklist

Major / Breaking Changes

  • ...

New features

  • ...

Bugfixes

  • ...

Documentation

  • ...

Maintenance

  • ...

📚 Documentation preview 📚: https://pymc--6922.org.readthedocs.build/en/6922/

@codecov
Copy link

codecov bot commented Sep 22, 2023

Codecov Report

Merging #6922 (0994c86) into main (9227827) will decrease coverage by 0.03%.
Report is 6 commits behind head on main.
The diff coverage is 79.16%.

Additional details and impacted files

Impacted file tree graph

@@            Coverage Diff             @@
##             main    #6922      +/-   ##
==========================================
- Coverage   92.16%   92.13%   -0.03%     
==========================================
  Files         100      100              
  Lines       16839    16853      +14     
==========================================
+ Hits        15519    15528       +9     
- Misses       1320     1325       +5     
Files Changed Coverage Δ
pymc/sampling/jax.py 96.31% <77.27%> (-1.97%) ⬇️
pymc/sampling_jax.py 100.00% <100.00%> (ø)

@ferrine
Copy link
Member Author

ferrine commented Sep 22, 2023

@junpenglao can you please have a look?

.github/workflows/tests.yml Outdated Show resolved Hide resolved
pymc/sampling/jax.py Show resolved Hide resolved
pymc/sampling/jax.py Outdated Show resolved Hide resolved
pymc/sampling/jax.py Outdated Show resolved Hide resolved
@ferrine
Copy link
Member Author

ferrine commented Sep 23, 2023

The failed tests seem to be unrelated, anything else I need to take care of? @junpenglao @ricardoV94

pymc/sampling/jax.py Outdated Show resolved Hide resolved
pymc/sampling/jax.py Outdated Show resolved Hide resolved
pymc/sampling/jax.py Outdated Show resolved Hide resolved
pymc/sampling/jax.py Show resolved Hide resolved
pymc/sampling/jax.py Outdated Show resolved Hide resolved
pymc/sampling/jax.py Outdated Show resolved Hide resolved
@ferrine
Copy link
Member Author

ferrine commented Sep 24, 2023

commited the suggestions by @ricardoV94

1 similar comment
@ferrine
Copy link
Member Author

ferrine commented Sep 24, 2023

commited the suggestions by @ricardoV94

pymc/sampling/jax.py Outdated Show resolved Hide resolved
pymc/sampling/jax.py Outdated Show resolved Hide resolved
@ferrine ferrine merged commit 15fbf0e into main Sep 25, 2023
22 checks passed
@ferrine ferrine deleted the fix-jax-sampling branch September 25, 2023 08:27
@ricardoV94 ricardoV94 changed the title Fix jax sampling Default to Scan in postprocessing of jax samplers Oct 4, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants