Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix crash when using torch.nn.DataParallel for LORA inference #805

Merged
merged 1 commit into from
Aug 8, 2023

Conversation

sywangyi
Copy link
Contributor

@sywangyi sywangyi commented Aug 8, 2023

Load the Lora model
inference_model = PeftModel.from_pretrained(inference_model, peft_model_id)
inference_model = torch.nn.DataParallel(inference_model)
inference_model.to(device)
inference_model.eval()

just the use the inference mode to inference. and the issue could be reproduce.
calltrace like
Original Traceback (most recent call last):
File "/mnt/disk1/wangyi/miniconda3/envs/alpaca/lib/python3.9/site-packages/torch/nn/parallel/parallel_apply.py", line 64, in _worker
output = module(*input, **kwargs)
File "/mnt/disk1/wangyi/miniconda3/envs/alpaca/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
return forward_call(*args, **kwargs)
File "/mnt/disk3/wangyi/peft/src/peft/peft_model.py", line 720, in forward
return self.base_model(
File "/mnt/disk1/wangyi/miniconda3/envs/alpaca/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
return forward_call(*args, **kwargs)
File "/mnt/disk3/wangyi/transformers/src/transformers/models/roberta/modeling_roberta.py", line 1196, in forward
outputs = self.roberta(
File "/mnt/disk1/wangyi/miniconda3/envs/alpaca/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
return forward_call(*args, **kwargs)
File "/mnt/disk3/wangyi/transformers/src/transformers/models/roberta/modeling_roberta.py", line 837, in forward
embedding_output = self.embeddings(
File "/mnt/disk1/wangyi/miniconda3/envs/alpaca/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
return forward_call(*args, **kwargs)
File "/mnt/disk3/wangyi/transformers/src/transformers/models/roberta/modeling_roberta.py", line 125, in forward
inputs_embeds = self.word_embeddings(input_ids)
File "/mnt/disk1/wangyi/miniconda3/envs/alpaca/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
return forward_call(*args, **kwargs)
File "/mnt/disk1/wangyi/miniconda3/envs/alpaca/lib/python3.9/site-packages/torch/nn/modules/sparse.py", line 162, in forward
return F.embedding(
File "/mnt/disk1/wangyi/miniconda3/envs/alpaca/lib/python3.9/site-packages/torch/nn/functional.py", line 2210, in embedding
return torch.embedding(weight, input, padding_idx, scale_grad_by_freq, sparse)
RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cuda:1! (when checking argument for argument index in method wrapper_CUDA__index_select)

@HuggingFaceDocBuilderDev
Copy link

HuggingFaceDocBuilderDev commented Aug 8, 2023

The documentation is not available anymore as the PR was closed or merged.

@sywangyi
Copy link
Contributor Author

sywangyi commented Aug 8, 2023

I use the example https://github.com/huggingface/peft/blob/main/examples/sequence_classification/LoRA.ipynb, just add inference_model = torch.nn.DataParallel(inference_model) to let multi-gpu run

@sywangyi
Copy link
Contributor Author

sywangyi commented Aug 8, 2023

@pacman100 @younesbelkada please help review the issue

Copy link
Contributor

@pacman100 pacman100 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hello @sywangyi, thank you for the fix, LGTM!

Copy link
Member

@BenjaminBossan BenjaminBossan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@sywangyi I couldn't reproduce the error locally, did you run into it while using the notebook or did you use a script?

Regardless, I'm okay with the suggested change. The only downside I see is that we lose any possible docstring from the original forward method (but see #784).

@sywangyi
Copy link
Contributor Author

sywangyi commented Aug 8, 2023

@BenjaminBossan , I use script, but code are copied from the notebook. then, just add inference_model = torch.nn.DataParallel(inference_model) before inference

@BenjaminBossan
Copy link
Member

I use script, but code are copied from the notebook. then, just add inference_model = torch.nn.DataParallel(inference_model) before inference

Okay, probably requires actually running on multiple GPUs to trigger (only tested it on one).

@BenjaminBossan BenjaminBossan merged commit 7d44026 into huggingface:main Aug 8, 2023
11 checks passed
@sywangyi sywangyi deleted the lora_dp_fix branch August 8, 2023 14:00
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants