Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

GPU device placement issues #610

Merged
merged 2 commits into from
Feb 1, 2022
Merged

GPU device placement issues #610

merged 2 commits into from
Feb 1, 2022

Conversation

LouisRouillard
Copy link
Contributor

I think most of the issues encountered by users result from device strings processing and mismatch. I'm trying a bunch of error-prone combinations and trying to see how can those mismatch be catched and corrected.

For now I only tried very small changes.

2 (known) issues remain:

  • prior device mismatch
  • embedding net with custom density estimator mismatch

But both items do not expose any device attribute for now so maybe more invasive changes could be needed

@LouisRouillard LouisRouillard changed the title Small changes to device st processing functions GPU device placement issues Jan 26, 2022
@janfb
Copy link
Contributor

janfb commented Jan 26, 2022

looks good, thanks!

Regarding the prior device: there is a function for checking the prior device, but it is not used currently:
https://github.com/mackelab/sbi/blob/bc4d43bf60ec714790ef8baa4b08cc6e5b821e45/sbi/utils/torchutils.py#L42-L51

We could use this function in sbi/inference/base.py to check the prior (if it is not None).

Regarding the embedding and density estimator: this happens outside of the SBI loop, no? E.g., the user would use get_nn_models function to build their density estimator with embedding net, right? Here one quick fix would be to check the embedding net for its device, and then
a) throw an error that it should be on the CPU for now, or
b) throw a warning and move it to CPU to then compose it with the density estimator on the CPU.

I would tend to b). And you?

@michaeldeistler @jan-matthis what do you think?

@codecov-commenter
Copy link

codecov-commenter commented Jan 26, 2022

Codecov Report

Merging #610 (9807de0) into main (8657c55) will increase coverage by 0.27%.
The diff coverage is 80.00%.

Impacted file tree graph

@@            Coverage Diff             @@
##             main     #610      +/-   ##
==========================================
+ Coverage   68.44%   68.72%   +0.27%     
==========================================
  Files          67       67              
  Lines        4443     4476      +33     
==========================================
+ Hits         3041     3076      +35     
+ Misses       1402     1400       -2     
Flag Coverage Δ
unittests 68.72% <80.00%> (+0.27%) ⬆️

Flags with carried forward coverage won't be shown. Click here to find out more.

Impacted Files Coverage Δ
sbi/utils/torchutils.py 63.97% <50.00%> (+3.67%) ⬆️
sbi/utils/user_input_checks.py 87.62% <73.33%> (-0.84%) ⬇️
sbi/inference/base.py 72.84% <100.00%> (+0.18%) ⬆️
sbi/neural_nets/classifier.py 100.00% <100.00%> (ø)
sbi/neural_nets/flow.py 87.20% <100.00%> (+1.13%) ⬆️
sbi/neural_nets/mdn.py 100.00% <100.00%> (ø)

Continue to review full report at Codecov.

Legend - Click here to learn more
Δ = absolute <relative> (impact), ø = not affected, ? = missing data
Powered by Codecov. Last update 8657c55...9807de0. Read the comment docs.

@michaeldeistler
Copy link
Contributor

michaeldeistler commented Jan 26, 2022

I'm having trouble understanding which cases are going wrong.

Is the issue that get_posterior_nn() (etc) can not handle the embedding_net being on the GPU?

@janfb
Copy link
Contributor

janfb commented Jan 26, 2022

Yes, that's one of the problems. We designed sbi such that the user would just say device="cuda" and we take care of moving everything to the device. However, it happens that users have their priors, data, or embeddings on the device already, and then posterior_nn() would compose a Flow with mixed devices and things break.

@LouisRouillard
Copy link
Contributor Author

Hey ! Sorry I had a bunch of stuff to work on, I'll work on the aforementioned issues some more tonight. @janfb I personally like APIs that are not error compliant but indicate clearly what went wrong -so option a-. But sbi is generally built around the "warn and correct" concept so I'll try to implement b

@LouisRouillard
Copy link
Contributor Author

@michaeldeistler I can show you the cases that go wrong in screenshare like I did with @janfb at some point if you want?

@LouisRouillard
Copy link
Contributor Author

I've done a couple functions that check what seems to be the most common error cases. I'm not sure I have propagated those functions everywhere needed though. For the embedding_net I managed to implement a "warn and correct" behavior, but not for the prior check because I feel like I would have needed to "wrap" a potentially custom prior into another prior that replicates the sampling and then moves the data accordingly. Tell me what you think ?

@LouisRouillard
Copy link
Contributor Author

Arf I realize I redefined a prior checking function. I'll remove it.

@michaeldeistler
Copy link
Contributor

@LouisRouillard I think quickest for now would be to just do a single commit with the black formatting changes. Then I'll quickly have a look :)

@@ -372,6 +372,23 @@ def check_prior_support(prior):
)


def check_embedding_net_device(embedding_net: nn.Module, batch_y: torch.Tensor) -> None:
batch_y_device = batch_y.device
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what if the user passes data on GPU? Then this check would pass and the embedding would remain on the GPU and cause issues with the density estimator that is constructed on CPU, no?

If we assume that we always build the density estimator on the CPU, then we could just always move the embedding to the GPU, no? (and warn, or throw an error depending on whether we want to warn and correct or not. )

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure this situation will arise in the sense that the mismatch between prior and estimator is checked before that step? So before this is ran we know that batch_y (generated from the prior) is the same device as estimator, and all that remains to do is to check that the embedder is also in that case? I guess the question is: is there use case -unknown to me- where the build functions are used outside of the classic "pipeline" ?

@LouisRouillard
Copy link
Contributor Author

LouisRouillard commented Jan 28, 2022

Ok so I advanced on the tests. I think I have good coverage over the few functions I implemented. I expanded the integration test test_train_with_different_data_and_training_device and added a new integration test test_embedding_nets_integration_training_device that tackles the use case that was failing for me. It checks that everything runs even with wrong device assignments and that the users is warned when data or modules are moved automatically.

@LouisRouillard
Copy link
Contributor Author

I sadly have conflicts, I'll try to rebase over main and pray it's not too bad ^^'

@michaeldeistler
Copy link
Contributor

Watch out for sbi/utils/plot.py -- the file got deleted along the way

@michaeldeistler
Copy link
Contributor

Great! Is this ready for review?

@michaeldeistler
Copy link
Contributor

Ah just saw discord...I'll review it now

Copy link
Contributor

@michaeldeistler michaeldeistler left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is great! Thanks a lot for looking into this and understanding and fixing the issue. Good to go from my side!

sbi/neural_nets/classifier.py Show resolved Hide resolved
sbi/utils/torchutils.py Show resolved Hide resolved
@LouisRouillard
Copy link
Contributor Author

All test in tests/inference_on_device_test.py and pytest tests/user_input_checks_test.py including slow and gpu run locally !

Copy link
Contributor

@janfb janfb left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

great code, great tests, thanks a lot!
I added small comments on types and some refactoring. This is good to go in once they are addressed.

sbi/neural_nets/classifier.py Outdated Show resolved Hide resolved
sbi/neural_nets/classifier.py Outdated Show resolved Hide resolved
sbi/neural_nets/classifier.py Outdated Show resolved Hide resolved
sbi/neural_nets/flow.py Outdated Show resolved Hide resolved
sbi/neural_nets/flow.py Outdated Show resolved Hide resolved
sbi/neural_nets/flow.py Outdated Show resolved Hide resolved
assert batch_x.device == batch_y.device, (
"Mismatch in fed data's device: "
f"batch_x has device '{batch_x.device}' whereas "
f"batch_x has device '{batch_x.device}'. Please "
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

typo: y vs x

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

given we have >5 of those asserts you could write a function that checks two tensors for their device and prints a (slightly more general) error message?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Will do !

device = "cpu"

return device
return device


def check_if_prior_on_device(device, prior: Optional[Any] = None):
if prior is not None:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

general suggestion: we could have a

if prior is None: 
    pass
else: 
    ...

to improve readability?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Will change it thanks :)

sbi/utils/user_input_checks.py Show resolved Hide resolved
sbi/utils/user_input_checks.py Show resolved Hide resolved
Copy link
Contributor Author

@LouisRouillard LouisRouillard left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @janfb for the comments !

sbi/neural_nets/classifier.py Outdated Show resolved Hide resolved
assert batch_x.device == batch_y.device, (
"Mismatch in fed data's device: "
f"batch_x has device '{batch_x.device}' whereas "
f"batch_x has device '{batch_x.device}'. Please "
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Will do !

sbi/utils/user_input_checks.py Show resolved Hide resolved
Copy link
Contributor

@janfb janfb left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great, thanks a lot! Go to go in once CI is passing

@janfb janfb merged commit d373ff6 into sbi-dev:main Feb 1, 2022
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants