Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[trainer] deepspeed bug fixes and tests #10039

Merged
merged 2 commits into from
Feb 8, 2021
Merged

Conversation

stas00
Copy link
Contributor

@stas00 stas00 commented Feb 6, 2021

This PR:

  • fixes a bug with model.no_sync() which is not supported by DeepSpeed - we had the test but it's not running on CI
  • fixes a bug with when --train is not used - have to .to(device) in that case - reported here ([DeepSpeed] [success] trained t5-11b on 1x 40GB gpu #9996 (comment))
  • splits the deepspeed tests into its own dedicated file and will slowly start to build it up - but it's good enough for now - especially since we are going to switch over run_seq2seq.py, except it's not fully ready yet for adoption.
  • adds a new test which doesn't use --train

@sgugger

Copy link
Collaborator

@sgugger sgugger left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the PR. Concerning the new test_deepspeed.py file, the goal is to remove things from the seq2seq folder to make it less intimidating to new users, not add stuff in it ;-)

Maybe we should put all tests in examples/tests/, that would be easier.

src/transformers/trainer.py Outdated Show resolved Hide resolved
# In the future we probably can run deepspeed for inference too, but this will require some thinking about how to best run it - since while it works DeepSpeed wasn't designed for inference

# since we have to postpone model.to() till training for DeepSpeed, if there was no training, we must put the model on the right device
self.model = self.model.to(self.args.device)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Mmm, this means a future training might have it on the device already, now? Maybe we should just put on the device the model used (so model = self.model.to(self.args.device) but not stored in self.model.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is the case where one bypasses the training stage. Remember last PR here had to make a special case for deepspeed not to preload on device so that it could load a model in fp16?

Next I'm experimenting with DeepSpeed for inference only, so this will change again. But for now this is a bug fix.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes I understand that. But what if someone does:

trainer = Trainer(...)
trainer.evaluate()
trainer.train()

(agreed it would be weird but trying to have the bug fix be general)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is yet another combination I haven't thought of. Thank you for thinking of it, @sgugger

As I mentioned I'm already working on DeepSpeed for inference so this code will change again shortly. And if I manage to do it - this code will be replaced with deepspeed_init and no switching to device at all. So this area is a wip and this PR is a temporary patch.

So do let me know whether you prefer a more general fix or hopefully today/tomorrow I will have a new version if DeepSpeed supports that - I just started working on it and I think in the worst case if it doesn't let me init it for inference (i.e. w/o optimizer/scheduler) I'll just init DeepSpeed as I'd for training if it's not supporting that at the moment, so really it'd be the same as train. Down the road as DeepSpeed avails itself for inference it'll then improve again. That's the plan at the moment.

And yes, I need to test all these different variations you're pointing at.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok for the quick hotfix then! Just want to make sure the proper fix down the road supports all kinds of combination of train/eval.

Copy link
Contributor Author

@stas00 stas00 Feb 8, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes, let's merge it and I will work on the new tests location and then add new tests for all the different combinations.

@stas00
Copy link
Contributor Author

stas00 commented Feb 8, 2021

Thanks for the PR. Concerning the new test_deepspeed.py file, the goal is to remove things from the seq2seq folder to make it less intimidating to new users, not add stuff in it ;-)

Maybe we should put all tests in examples/tests/, that would be easier.

I'm open to suggestions, this is not really an addition but a split or a larger test file as it was becoming unnecessarily complicated.

This PR is a bug fix and the scene will evolve to be more user-friendly, but let's discuss this post merge to enable users to do their work.

I will start a new issue discussing your suggestion. #10076

@sgugger
Copy link
Collaborator

sgugger commented Feb 8, 2021

Well it's not just a bug fix since you split the test file with it ;-) But I agree we can do the regrouping in another PR.

@stas00
Copy link
Contributor Author

stas00 commented Feb 8, 2021

Because I had to add new tests and the previous situation would make things too complicated, so that bug fix required a new test which was the last straw that triggered a need for a dedicated test file. That is, I couldn't easily add a test to the way things were and the test was needed to go with the bug fix. So the split was done out of necessity.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants