-
Notifications
You must be signed in to change notification settings - Fork 150
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
GPU device placement issues #610
Conversation
looks good, thanks! Regarding the prior device: there is a function for checking the prior device, but it is not used currently: We could use this function in Regarding the embedding and density estimator: this happens outside of the SBI loop, no? E.g., the user would use I would tend to b). And you? @michaeldeistler @jan-matthis what do you think? |
Codecov Report
@@ Coverage Diff @@
## main #610 +/- ##
==========================================
+ Coverage 68.44% 68.72% +0.27%
==========================================
Files 67 67
Lines 4443 4476 +33
==========================================
+ Hits 3041 3076 +35
+ Misses 1402 1400 -2
Flags with carried forward coverage won't be shown. Click here to find out more.
Continue to review full report at Codecov.
|
I'm having trouble understanding which cases are going wrong. Is the issue that |
Yes, that's one of the problems. We designed |
Hey ! Sorry I had a bunch of stuff to work on, I'll work on the aforementioned issues some more tonight. @janfb I personally like APIs that are not error compliant but indicate clearly what went wrong -so option a-. But |
@michaeldeistler I can show you the cases that go wrong in screenshare like I did with @janfb at some point if you want? |
I've done a couple functions that check what seems to be the most common error cases. I'm not sure I have propagated those functions everywhere needed though. For the |
Arf I realize I redefined a prior checking function. I'll remove it. |
@LouisRouillard I think quickest for now would be to just do a single commit with the black formatting changes. Then I'll quickly have a look :) |
sbi/utils/user_input_checks.py
Outdated
@@ -372,6 +372,23 @@ def check_prior_support(prior): | |||
) | |||
|
|||
|
|||
def check_embedding_net_device(embedding_net: nn.Module, batch_y: torch.Tensor) -> None: | |||
batch_y_device = batch_y.device |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
what if the user passes data on GPU? Then this check would pass and the embedding would remain on the GPU and cause issues with the density estimator that is constructed on CPU, no?
If we assume that we always build the density estimator on the CPU, then we could just always move the embedding to the GPU, no? (and warn, or throw an error depending on whether we want to warn and correct
or not. )
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm not sure this situation will arise in the sense that the mismatch between prior
and estimator
is checked before that step? So before this is ran we know that batch_y
(generated from the prior
) is the same device as estimator
, and all that remains to do is to check that the embedder
is also in that case? I guess the question is: is there use case -unknown to me- where the build
functions are used outside of the classic
"pipeline" ?
Ok so I advanced on the tests. I think I have good coverage over the few functions I implemented. I expanded the integration test |
I sadly have conflicts, I'll try to rebase over main and pray it's not too bad ^^' |
Watch out for |
Great! Is this ready for review? |
Ah just saw discord...I'll review it now |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is great! Thanks a lot for looking into this and understanding and fixing the issue. Good to go from my side!
All test in |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
great code, great tests, thanks a lot!
I added small comments on types and some refactoring. This is good to go in once they are addressed.
sbi/neural_nets/mdn.py
Outdated
assert batch_x.device == batch_y.device, ( | ||
"Mismatch in fed data's device: " | ||
f"batch_x has device '{batch_x.device}' whereas " | ||
f"batch_x has device '{batch_x.device}'. Please " |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
typo: y
vs x
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
given we have >5 of those asserts you could write a function that checks two tensors for their device and prints a (slightly more general) error message?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Will do !
sbi/utils/torchutils.py
Outdated
device = "cpu" | ||
|
||
return device | ||
return device | ||
|
||
|
||
def check_if_prior_on_device(device, prior: Optional[Any] = None): | ||
if prior is not None: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
general suggestion: we could have a
if prior is None:
pass
else:
...
to improve readability?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Will change it thanks :)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks @janfb for the comments !
sbi/neural_nets/mdn.py
Outdated
assert batch_x.device == batch_y.device, ( | ||
"Mismatch in fed data's device: " | ||
f"batch_x has device '{batch_x.device}' whereas " | ||
f"batch_x has device '{batch_x.device}'. Please " |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Will do !
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Great, thanks a lot! Go to go in once CI is passing
I think most of the issues encountered by users result from
device
strings processing and mismatch. I'm trying a bunch of error-prone combinations and trying to see how can those mismatch be catched and corrected.For now I only tried very small changes.
2 (known) issues remain:
But both items do not expose any device attribute for now so maybe more invasive changes could be needed