Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

from_pretrained()'s load() blocks forever in subprocess #8649

Closed
1 of 2 tasks
levon003 opened this issue Nov 19, 2020 · 9 comments
Closed
1 of 2 tasks

from_pretrained()'s load() blocks forever in subprocess #8649

levon003 opened this issue Nov 19, 2020 · 9 comments

Comments

@levon003
Copy link

levon003 commented Nov 19, 2020

Environment info

  • transformers version: 3.5.1
  • Platform: Linux-5.4.58 x86_64-with-glibc2.10
  • Python version: 3.8.3
  • PyTorch version (GPU?): 1.6.0 (False)
  • Tensorflow version (GPU?): not installed (NA)
  • Using GPU in script?: No
  • Using distributed or parallel set-up in script?: Yes

Who can help

Anyone familiar with the from_pretrained() code path. Perhaps @sgugger? Thank you!

Information

from_pretrained()'s load() blocks forever loading roberta-base, due specifically to the call to nn.Module._load_from_state_dict that would load the "embeddings.word_embeddings". Occurs when loading the model in both the first process and a subprocess started via multiprocessing.

I observe the same behavior when loading via keyword vs loading local files cached via save_pretrained.

Model I am using (Bert, XLNet ...): roberta-base

The problem arises when using:

  • the official example scripts:
  • my own modified scripts: see sample script below.

To reproduce

Steps to reproduce the behavior:

import torch
import transformers
import multiprocessing as mp

def load_model_in_subprocess():
    print("Started subprocess.")
    model2 = transformers.RobertaModel.from_pretrained('roberta-base')
    print("Model loaded in subprocess.")

def main():
    model1 = transformers.RobertaModel.from_pretrained('roberta-base')
    print("Model loaded in main process.")

    p = mp.Process(target=load_model_in_subprocess, daemon=True)
    p.start()
    p.join()
    print("Main thread terminating.")
    
if __name__ == "__main__":
    main()

Output:

Model loaded in main process.
Started subprocess.
<never terminates>

Expected behavior

Model loads and is functional in both main process and subprocess.

@LysandreJik
Copy link
Member

If you've identified the issue to be coming from nn.Module._load_from_state_dict, then I guess this is more of a PyTorch issue than a transformers one? Do you have an idea what might cause this hang with that method?

@levon003
Copy link
Author

Well, I don't know enough about torch state_dict behavior to understand why transformers would be directly calling the underscored "internal use" method _load_from_state_dict in the first place, but it strikes me that transformers is making assumptions about the functioning of this internal method that may not hold in practice; I don't see anything obvious in _load_from_state_dict that would cause it to lock up under these (or any) conditions, but we may be violating a usage assumption (e.g. providing a bad pre-load hook).

@LysandreJik
Copy link
Member

Oh, I see. Looking at the torch.load_state_dict however, it doesn't seem to be doing something very differently to what we do. Have you managed to load several models using torch.load() with the same multiprocessing approach you have used?

@levon003
Copy link
Author

Well, a fair test would be to load the same (roBERTa-base) model, but I'm not sure how to write the code to do that... that's why I'm using transformers! But it's easy to verify that there's no problem with multi-process loading of PyTorch models:

import torch
import torch.nn as nn
import multiprocessing as mp

USE_STATE_DICT = True

class SimpleNet(nn.Module):
    def __init__(self):
        super(SimpleNet, self).__init__()
        self.fc1 = nn.Linear(768, 1)
        
    def forward(self, x):
        x = self.fc1(x)
        return x


def save_model():
    model = SimpleNet()
    torch.save(model, './full_model.pt')
    torch.save(model.state_dict(), './model_state_dict.pt')

def load_model_in_subprocess():
    print("Started subprocess.")
    if USE_STATE_DICT:
        model = SimpleNet()
        model.load_state_dict(torch.load('./model_state_dict.pt'))
    else:
        model = torch.load('./full_model.pt')
    print(f"Model loaded in subprocess: {model}")

def main():
    save_model()
    print("Saved model.")
    
    if USE_STATE_DICT:
        model = SimpleNet()
        model.load_state_dict(torch.load('./model_state_dict.pt'))
    else:
        model = torch.load('./full_model.pt')
    print(f"Model loaded in main process: {model}")

    p = mp.Process(target=load_model_in_subprocess, daemon=True)
    p.start()
    p.join()
    print("Main thread terminating.")
    
if __name__ == "__main__":
    main()

This script terminates fine when loading from state dict or a pickled model file.

@levon003
Copy link
Author

Adding some debug prints to transformers load in modeling_utils.py, I can confirm that it is the call to nn.Module._load_from_state_dict when:

prefix = roberta.embeddings.word_embeddings.
local_metadata = {'version': 1}
missing_keys = ['roberta.embeddings.position_ids']
unexpected_keys = []
error_msgs = []
strict = True

The keys in the state_dict are:

roberta.embeddings.word_embeddings.weight
roberta.embeddings.position_embeddings.weight
roberta.embeddings.token_type_embeddings.weight
roberta.embeddings.LayerNorm.weight
roberta.embeddings.LayerNorm.bias
<snipping all of the individual layer keys e.g. roberta.encoder.layer.0.attention.self.query.weight>
roberta.pooler.dense.weight
roberta.pooler.dense.bias
lm_head.bias
lm_head.dense.weight
lm_head.dense.bias
lm_head.layer_norm.weight
lm_head.layer_norm.bias
lm_head.decoder.weight

The shape of the roberta.embeddings.word_embeddings.weight tensor is [50265,768].

(note: same blocking behavior when loading bert-base-uncased)

@levon003
Copy link
Author

Okay, I did more investigation, and the problem is a blocking call to Tensor.copy_ that copies the Parameter in the state_dict into the Parameter in the Module (in this case, the Embedding(50265, 768, padding_idx=1) parameter in the roBERTa model).

The documentation indicates a non_blocking parameter that can be used when copying between CPU and GPU, but we are copying between CPU and CPU. I confirmed that non_blocking does nothing, and that the device of both Parameters is cpu.

That's where I'm going to stop pursuing this bug. I don't know the structure of the C++ code, but it seems likely that this is an issue with the PyTorch CPU copy implementation and the idiosyncrasies of the specific OS I'm using. If this problem can be reproduced on others systems it may be worth investigating further, but it does seem like the fault probably lies with PyTorch and not with transformers. Hopefully this affects only a small set of OSes.

@LysandreJik, you may want to close this issue?

@LysandreJik
Copy link
Member

Thank you very much for your deep investigation of this issue. Unfortunately I don't see how we could change that on our front to make it work, so we'll close this for now.

If we get other reports of this we'll investigate further.

@clbenoit
Copy link

clbenoit commented May 7, 2022

Got the same issue in this environment :

Platform: Linux clem-MacBookAir 5.13.0-40-generic #45~20.04.1-Ubuntu x86_64 x86_64 x86_64 GNU/Linux
Python version: 3.9.7
PyTorch version (GPU?): 1.11.0+cu102 (False)
Tensorflow version (GPU?): not installed (NA)
Using GPU in script?: No
Using distributed or parallel set-up in script?: Yes

@dukleryoni
Copy link

Experiencing the same issue as well with a torchvision.models which seems to be coming from nn.Module._load_from_state_dict running as subprocess on CPU, unsure why this has just started to happen.

moving the model to the GPU before loading works as a workaround

model.to(device)
model.load_from_state_dict(ckpt)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

4 participants