Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BUG] GPT-J + init_inference + replace_with_kernel_inject returns copy error with multiple GPUs #1719

Closed
TiesdeKok opened this issue Jan 23, 2022 · 12 comments · Fixed by #1724
Labels
bug Something isn't working

Comments

@TiesdeKok
Copy link

Describe the bug

Using the replace_with_kernel_inject option in init_inference returns an error when using multiple GPUs (with a GPT-J model).

To Reproduce
Steps to reproduce the behavior:

  1. Create an inference script using HF Transformers and GPT-J
  2. Run the deepspeed command with multiple GPUs
import os
local_rank = int(os.getenv('LOCAL_RANK', '0'))
world_size = int(os.getenv('WORLD_SIZE', '1'))

import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
import deepspeed
from transformers import pipeline as t_pipeline

model = AutoModelForCausalLM.from_pretrained("EleutherAI/gpt-j-6B", torch_dtype=torch.float16)
tokenizer = AutoTokenizer.from_pretrained("EleutherAI/gpt-j-6B")
generator = t_pipeline('text-generation', model=model, tokenizer=tokenizer, eos_token_id=50256,  device=local_rank)

generator.model = deepspeed.init_inference(generator.model,
                                            mp_size=world_size,
                                            dtype=torch.float16,
                                            replace_method= 'auto',
                                            replace_with_kernel_inject= True
                                        )

input_list = ["This is the input "]

res_ds = generator(input_list, do_sample=True, max_length = 1000, eos_token_id=50256, temperature=0.25, pad_token_id=50257)

Expected behavior
No error.

ds_report output
Unavailable, not currently in the compute node.

Screenshots
image

System info (please complete the following information):

  • OS: Linux - Ubuntu
  • One machine with 8x A100 40gb PCIE
  • Python 3.8
  • Using the following docker image: pytorch/pytorch:1.9.1-cuda11.1-cudnn8-devel

Launcher context
Deepspeed command line

Docker context
Base image is: pytorch/pytorch:1.9.1-cuda11.1-cudnn8-devel

Additional context

@TiesdeKok TiesdeKok added the bug Something isn't working label Jan 23, 2022
@TiesdeKok TiesdeKok changed the title [BUG] init_inference + replace_with_kernel_inject returns copy error with multiple GPUs [BUG] GPT-J + init_inference + replace_with_kernel_inject returns copy error with multiple GPUs Jan 23, 2022
@TiesdeKok
Copy link
Author

In addition, I was able to replicate the issue on a different box with Fedora and 8x a6000 GPUs.

@RezaYazdaniAminabadi
Copy link
Contributor

Hi @TiesdeKok,

I will take a look at this.

Thanks,
Reza

@TiesdeKok
Copy link
Author

Hi @TiesdeKok,

I will take a look at this.

Thanks,
Reza

Thanks a lot!

@RezaYazdaniAminabadi
Copy link
Contributor

Hi @TiesdeKok
Can you please try this PR and see if this is fixed?

Thanks.

@TiesdeKok
Copy link
Author

Appreciate the quick turnaround here @RezaYazdaniAminabadi!

The copy error is gone and the inference starts now, so that appears resolved. 🥳

However, I am running into another problem where everything works great with one GPU, however, with multiple GPUs, the inference will hang indefinitely. I can make a separate issue if you prefer, but let me describe what I am observing:

  • "EleutherAI/gpt-j-6B" with float16 with one GPU without kernel inject --> works
  • "EleutherAI/gpt-j-6B" with float16 with one GPU with kernel inject --> works
  • "EleutherAI/gpt-j-6B" with float16 with 2+ GPU without kernel inject --> hangs indefinitely
  • "EleutherAI/gpt-j-6B" with float16 with 2+ GPU with kernel inject --> hangs indefinitely

No errors are shown, it just pins the GPUs at 100% and nothing happens. I have tried this on two different machines and the behavior is the same. I noticed the same issue already yesterday without the kernel inject and letting it run for hours (on one prompt) clearly indicates that things are stuck.

To dig into this further, I have also tried using the distilgpt2 model, the same issue pops up:

  • "distilgpt2" with float 16 with one GPU --> works
  • "distilgpt2" with float 16 with 2+ GPU --> hangs indefinitely
  • "distilgpt2" with float 32 with one GPU --> works
  • "distilgpt2" with float 32 with 2+ GPU --> hangs indefinitely

I am a little lost here, the code I am running is essentially the same as:
https://github.com/microsoft/DeepSpeedExamples/blob/fix-inferen-test/inference/huggingface/gpt-neo.py

Which I run with deepspeed --num_gpus X gpt-neo.py (pseudo-code).

I tried looking for a verbose option to see if I could get better logging once things are on the GPUs, however, I could not find it. Any ideas on what might be happening here? 😕

@RezaYazdaniAminabadi
Copy link
Contributor

@TiesdeKok,

This is a known issue that we have with the integration of DeepSpeed Inference and HF. This is happening since one GPU is finished with the generation while the other one is waiting to continue for the next token-generation. Would you mind setting the min_length and max_length to the same number and see if this issue is resolved?
Thanks

@TiesdeKok
Copy link
Author

After reading your description it immediately hit me that the hanging issue is caused by a random.shuffle() line in my code, that created a different input for every machine and caused it all to hang. 🤦🏻‍♂️

With that out of the way, I am now seeing weird behavior with the kernel inject:

  • multiple GPUs without kernel inject works great.
  • multiple GPus with kernel inject and do_sample=False completes without errors but it generates garbage output. The output looks like this (\n####\n is where my prompt ends):

\n####\n!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!"

  • multiple GPus with kernel inject and do_sample=True throws an error:

image

I added a quick print statement right before that torch.multinomial() step and it shows:

print(probs)
#tensor([[nan, nan, nan,  ..., nan, nan, nan]], device='cuda:1')
#tensor([[nan, nan, nan,  ..., nan, nan, nan]], device='cuda:2')
#tensor([[nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0')
#tensor([[nan, nan, nan,  ..., nan, nan, nan]], device='cuda:3')

The above issue and error also occur when settings max_lenght and min_length to the same value.

Any thoughts on what might be the issue here? Thanks again for your help!

Ps. my torch version is 1.10.0 and my transformers version is 4.16.0.dev0

@RezaYazdaniAminabadi
Copy link
Contributor

I did test this on the same versions as you mentioned. Just that I am using PyTorch1.9. The code snippet I am using is as follows:

import os
import torch
import transformers

from deepspeed import module_inject
from transformers import pipeline

local_rank = int(os.getenv('LOCAL_RANK', '0'))
world_size = int(os.getenv('WORLD_SIZE', '1'))

generator = pipeline('text-generation',
                     model='EleutherAI/gpt-j-6B',
                     device=local_rank,
                     )
generator.model = deepspeed.init_inference(generator.model,
                                           mp_size=world_size,
                                           dtype=torch.half,
                                           replace_method='auto',
                                           replace_with_kernel_inject=True)
string = generator("DeepSpeed is ", do_sample=True, min_length=50)
print(string)

@RezaYazdaniAminabadi
Copy link
Contributor

Here is some part of the result I am seeing for this example:

image

@TiesdeKok
Copy link
Author

That little code snippet was very helpful to debug what is happening here, my observations:

I was using the float16 revision so I had to download the float32 version and I figured that might be it, but that didn't change anything. I got the same error as before when running the exact code you provided (I only fixed the deepspeed import):

image

When turning off sampling I also saw the same weird behavior with the exclamation marks:

image

However, given that it worked for you there had to be something about my setup that was causing it, so I started changing dials:

  • Changing to transformers==4.15 --> no change
  • Changing to 2 GPUs --> no change

But then I tried deepspeed==0.5.10 and it all works again! Both your code snippet started working as well as my code. This suggests to me that something else got introduced that causes things to break.

image

@tomerip
Copy link

tomerip commented Feb 27, 2022

Hi @TiesdeKok,
I think taking a look on this issue I opened might be relevant to your use case:
#1797
I think it at least explains why you got the exclamation marks outputs and also probably raise your attention regarding the outputs you're getting in case you pad some of your inputs.

@lanking520
Copy link

Hi @TiesdeKok I am also facing the garbage output issue. Not sure if it is related to the issue you were having previously: #2113

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

Successfully merging a pull request may close this issue.

4 participants