Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Trainer is using DataParallel on parallelized models #9577

Closed
jncasey opened this issue Jan 13, 2021 · 6 comments · Fixed by #9578
Closed

Trainer is using DataParallel on parallelized models #9577

jncasey opened this issue Jan 13, 2021 · 6 comments · Fixed by #9578

Comments

@jncasey
Copy link
Contributor

jncasey commented Jan 13, 2021

Environment info

  • transformers version: 4.2.0
  • Platform: Ubuntu 20.04
  • Python version: 3.8.5
  • PyTorch version (GPU?): 1.7.1 / CUDA 11.2
  • Using GPU in script?:
  • Using distributed or parallel set-up in script?:

Who can help

@sgugger @stas00

Information

I'm trying out the 4.2.0 release with a training script that had been working in 4.1.1.

I'm parallelizing my model over two GPUs, and I had been using the --model_parallel training arg in the previous version. Now that it's no longer used, I removed the arg from my training command, but I'm getting an error as though the DataParallel is being used and the model isn't being detected as parallelized:
RuntimeError: module must have its parameters and buffers on device cuda:0 (device_ids[0]) but found one of them on device: cuda:1

I did some debugging, and everything seems okay with my model (trainer. is_model_parallel returns True). But the trainer. args.n_gpu is still 2.

I admit that I don't totally understand what's happening in the trainer code, but it might be an error on line 289?
self.args._n_gpu = 1

Should that be self.args.n_gpu = 1, without the leading underscore?

To reproduce

Steps to reproduce the behavior:

  1. Parallelize a model
  2. Train on a machine with multiple GPUs
@sgugger
Copy link
Collaborator

sgugger commented Jan 13, 2021

The self.args._n_gpu = 1 is to avoid parallelizing the data so it has nothing to do with your problem (and it is right, we can't set self.args.n_gpu which is a property but that's a whole different story!)

How is your model parallelized? Without that piece of code we can't reproduce the bug and help you.

@jncasey
Copy link
Contributor Author

jncasey commented Jan 13, 2021

Thanks @sgugger.

In my test, I'm using some code originally derived from the run_clm.py example. I'm trying to fine-tune a GPT2 model I've trained from scratch. The model was parallelized with the following lines, and this exact fine-tuning script ran successfully yesterday in 4.1.1, using the --model_parallel training arg.

    device_map = {0: range(0, 15),
                  1: range(15, 32)}
    model.parallelize(device_map)

The error I'm getting now looks a lot like what would happen if I left out the --model_parallel flag in 4.1.1.

@stas00
Copy link
Contributor

stas00 commented Jan 13, 2021

RuntimeError: module must have its parameters and buffers on device cuda:0 (device_ids[0]) but found one of them on device: cuda:1

Please post the full trace.

I have only experimented with t5 and bart MP so far, but gpt2 is supposed to be very similar.

Most likely the outputs aren't being copied back to the 0th gpu on return, so this won't have anything to do with the trainer. Most likely the issue you encountered has to do with evaluation and not training.

I had to fix t5-MP to do that, but the PR with the fix hasn't been merged.

if self.model_parallel:
encoder_outputs, decoder_outputs = model_parallel_inputs_to_specific_device(
self.main_device, encoder_outputs, decoder_outputs
)

I won't be surprised if gpt2 is missing that too.

model_parallel_inputs_to_specific_device is a new function that isn't in master, but part of these 2 PRs: #9323 and #9384 - it relies on another function - the full new file is here: https://github.com/huggingface/transformers/blob/fe21c43745fcf3f7958c17c2ac461bd784094205/src/transformers/utils/model_parallel_utils.py

The current MP implementations are very limited and at the moment I highly recommend you look at DeepSpeed instead, see:
#8771 (comment) and
#8771 (comment)
You will need master for that as it was just merged 2 days ago.

We also removed --model_parallel in trainer master as it wasn't fully baked in first place.

@sgugger
Copy link
Collaborator

sgugger commented Jan 13, 2021

@stas00 This is linked to how TrainingArguments.n_gpu was computed. Could reproduce and test the fix in #9578 removes the bug.

@stas00
Copy link
Contributor

stas00 commented Jan 13, 2021

That's easy then. The error though very much reminded me of the issue I described in my comment above.

@jncasey
Copy link
Contributor Author

jncasey commented Jan 14, 2021

Thanks both!

@stas00 Definitely excited to check out DeepSpeed – that's the reason I started testing my code in 4.2.0

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging a pull request may close this issue.

3 participants