Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

jiting the DnCNN prox on GPU #104

Closed
Michael-T-McCann opened this issue Nov 17, 2021 · 6 comments · Fixed by #105
Closed

jiting the DnCNN prox on GPU #104

Michael-T-McCann opened this issue Nov 17, 2021 · 6 comments · Fixed by #105
Assignees
Labels
bug Something isn't working
Milestone

Comments

@Michael-T-McCann
Copy link
Contributor

When I run pytest scico/test/test_functional.py -k DnCNN with a GPU install of SCICO, I get a failure in TestDnCNN.test_prox, which is comparing the jited to unjited version of the DnCNN prox. These results are disturbingly different:

E       Mismatched elements: 1020 / 1024 (99.6%)
E       Max absolute difference: 0.70825666
E       Max relative difference: 83.99731
E        x: array([[-1.311581e+00, -4.477247e-01, -2.993065e-01, ...,  1.231516e-01,
E                2.924881e-01,  1.054722e-01],
E              [ 2.548240e-01,  7.504759e-01, -3.427111e-01, ...,  1.717497e-01,...
E        y: array([[-1.290338, -0.550369, -0.359681, ...,  0.004449,  0.33103 ,
E                0.088125],
E              [ 0.211299,  0.899408, -0.37546 , ...,  0.460677, -0.817596,...

Might this be a JAX bug? One could try bumping the jax version and see if it is fixed.

@bwohlberg bwohlberg added the bug Something isn't working label Nov 17, 2021
@bwohlberg bwohlberg added this to the Release 0.1.0 milestone Nov 17, 2021
@lukepfister
Copy link
Contributor

A jax bug is possible, but I’d start by looking to see if the code (init or eval) has any side effects that are being missed by the jit.

You could do something like
jax.jit(lambda x: f.prox(x)) and see if it matches f.prox(x). That is, jit the fully instantiated denoiser object.

@FernandoDavis
Copy link
Contributor

FernandoDavis commented Nov 17, 2021 via email

@crstngc
Copy link
Contributor

crstngc commented Nov 17, 2021

The issue seems to be related to the fact that Jax is non deterministic on GPU as discussed in
google/flax#33 (comment)
jax-ml/jax#565 and jax-ml/jax#4823

@crstngc
Copy link
Contributor

crstngc commented Nov 17, 2021

Branch cristina/jitDnCNNprox includes a fix to the DnCNN tests in test_functional.py as suggested in jax-ml/jax#4823 (comment). @Michael-T-McCann could you please check it?

@crstngc
Copy link
Contributor

crstngc commented Nov 17, 2021

The flags (or needing them) may change when they push an updated jaxlib as mentioned in jax-ml/jax#565 (comment)

@lukepfister
Copy link
Contributor

great catch folks. that is way worse than i'd expect from gpu nondeterminisim! (fwiw most things on a gpu are nondeterministic unless specifically instructed otherwise)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

Successfully merging a pull request may close this issue.

5 participants