-
Notifications
You must be signed in to change notification settings - Fork 17
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
jit
ing the DnCNN
prox on GPU
#104
Comments
A jax bug is possible, but I’d start by looking to see if the code (init or eval) has any side effects that are being missed by the jit. You could do something like |
The issue seems to be related to the fact that Jax is non deterministic on GPU as discussed in |
Branch |
The flags (or needing them) may change when they push an updated jaxlib as mentioned in jax-ml/jax#565 (comment) |
great catch folks. that is way worse than i'd expect from gpu nondeterminisim! (fwiw most things on a gpu are nondeterministic unless specifically instructed otherwise) |
When I run
pytest scico/test/test_functional.py -k DnCNN
with a GPU install of SCICO, I get a failure inTestDnCNN.test_prox
, which is comparing thejit
ed to unjit
ed version of the DnCNN prox. These results are disturbingly different:Might this be a JAX bug? One could try bumping the jax version and see if it is fixed.
The text was updated successfully, but these errors were encountered: