-
Notifications
You must be signed in to change notification settings - Fork 27k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
from_pretrained()'s load() blocks forever in subprocess #8649
Comments
If you've identified the issue to be coming from |
Well, I don't know enough about torch state_dict behavior to understand why |
Oh, I see. Looking at the |
Well, a fair test would be to load the same (roBERTa-base) model, but I'm not sure how to write the code to do that... that's why I'm using import torch
import torch.nn as nn
import multiprocessing as mp
USE_STATE_DICT = True
class SimpleNet(nn.Module):
def __init__(self):
super(SimpleNet, self).__init__()
self.fc1 = nn.Linear(768, 1)
def forward(self, x):
x = self.fc1(x)
return x
def save_model():
model = SimpleNet()
torch.save(model, './full_model.pt')
torch.save(model.state_dict(), './model_state_dict.pt')
def load_model_in_subprocess():
print("Started subprocess.")
if USE_STATE_DICT:
model = SimpleNet()
model.load_state_dict(torch.load('./model_state_dict.pt'))
else:
model = torch.load('./full_model.pt')
print(f"Model loaded in subprocess: {model}")
def main():
save_model()
print("Saved model.")
if USE_STATE_DICT:
model = SimpleNet()
model.load_state_dict(torch.load('./model_state_dict.pt'))
else:
model = torch.load('./full_model.pt')
print(f"Model loaded in main process: {model}")
p = mp.Process(target=load_model_in_subprocess, daemon=True)
p.start()
p.join()
print("Main thread terminating.")
if __name__ == "__main__":
main() This script terminates fine when loading from state dict or a pickled model file. |
Adding some debug prints to
The keys in the state_dict are:
The shape of the (note: same blocking behavior when loading bert-base-uncased) |
Okay, I did more investigation, and the problem is a blocking call to The documentation indicates a non_blocking parameter that can be used when copying between CPU and GPU, but we are copying between CPU and CPU. I confirmed that non_blocking does nothing, and that the That's where I'm going to stop pursuing this bug. I don't know the structure of the C++ code, but it seems likely that this is an issue with the PyTorch CPU copy implementation and the idiosyncrasies of the specific OS I'm using. If this problem can be reproduced on others systems it may be worth investigating further, but it does seem like the fault probably lies with @LysandreJik, you may want to close this issue? |
Thank you very much for your deep investigation of this issue. Unfortunately I don't see how we could change that on our front to make it work, so we'll close this for now. If we get other reports of this we'll investigate further. |
Got the same issue in this environment :
|
Experiencing the same issue as well with a moving the model to the GPU before loading works as a workaround
|
Environment info
transformers
version: 3.5.1Who can help
Anyone familiar with the from_pretrained() code path. Perhaps @sgugger? Thank you!
Information
from_pretrained()'s load() blocks forever loading roberta-base, due specifically to the call to
nn.Module._load_from_state_dict
that would load the "embeddings.word_embeddings". Occurs when loading the model in both the first process and a subprocess started viamultiprocessing
.I observe the same behavior when loading via keyword vs loading local files cached via
save_pretrained
.Model I am using (Bert, XLNet ...): roberta-base
The problem arises when using:
To reproduce
Steps to reproduce the behavior:
Output:
Expected behavior
Model loads and is functional in both main process and subprocess.
The text was updated successfully, but these errors were encountered: