This repository has been archived by the owner on Nov 21, 2022. It is now read-only.
Deepspeed sharding and load from checkpoint with custom lightning module - setup() not called during checkpoint loading #290
Labels
question
Further information is requested
❓ Questions and Help
Before asking:
What is your question?
Hi, I'm doing training from scratch using deepspeed, pytorch lightning, and transformers in a multi node setting, and wanted to know how to setup the code to handle loading from a pytorch checkpoint.
Going off of the docs here, I see that the model is intended to be defined in setup(). However, this doesn't work when loading from a state dict since setup is not called. What's the right way to structure the code here? Does enable_transformers_pretrained_deepspeed_sharding need to be called in setup or can it be called in the constructor?
This has been my potential workaround in the constructor, because it does seem to fail on certain ranks
As opposed to:
Code
What have you tried?
What's your environment?
Linux, conda/pip,
deepspeed==0.7.3
pytorch-lightning==1.6.5
lighting-transformers==0.2.1
Thanks in advance for the help!
The text was updated successfully, but these errors were encountered: