Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BUG][master branch] garbage GPTJ output for multi-gpu inference #2233

Closed
mallorbc opened this issue Aug 18, 2022 · 8 comments
Closed

[BUG][master branch] garbage GPTJ output for multi-gpu inference #2233

mallorbc opened this issue Aug 18, 2022 · 8 comments
Labels
bug Something isn't working inference

Comments

@mallorbc
Copy link

Describe the bug

Similar to #2113 this bug relates to garbage output when using multi-gpu inference. In that issue @RezaYazdaniAminabadi made a fix seen in #2198 that fixed a similar issue for GPT Neo 2.7B that after building from master I can confirm solved multi-gpu inference for GPT Neo 2.7B. However, for GPTJ the issue remains:

Output from 2 3090s for GPTJ

[{'generated_text': 'DeepSpeed is,: to,,/ &.. by and.. a\n.. and- and.. the,,\n of\n [.,.\n:, &-. and a- the,\n\n). the'}]

Meanwhile output from 1 3090 for GPTJ

[{'generated_text': 'DeepSpeed is a leading deep learning framework designed for distributed training and inference on heterogeneous accelerators and CPUs. Our paper (https://arxiv.org/abs/1811.11540) describes an optimized deep architecture and inference engine and'}]

To Reproduce
Steps to reproduce the behavior:

  1. Install DeepSpeed from source on master
  2. pip install transformers
  3. Run with 2 GPUs to get bad output
  4. Run with 1 GPU to get good output
import os
import deepspeed
import torch
from transformers import pipeline
from transformers import AutoTokenizer, AutoModelForCausalLM
model_name = 'EleutherAI/gpt-j-6B'
# model_name = "EleutherAI/gpt-neo-2.7B"
local_rank = int(os.getenv('LOCAL_RANK', '0'))
world_size = int(os.getenv('WORLD_SIZE', '1'))
model = AutoModelForCausalLM.from_pretrained(model_name,torch_dtype=torch.float16)
tokenizer = AutoTokenizer.from_pretrained(model_name)
generator = pipeline('text-generation', model=model_name, device=local_rank,torch_dtype=torch.float16)


generator.model = deepspeed.init_inference(generator.model,
                                           mp_size=world_size,
                                           dtype=torch.half,
                                           replace_method='auto',
					   replace_with_kernel_inject=True)

string = generator("DeepSpeed is", do_sample=True, min_length=50)
if not torch.distributed.is_initialized() or torch.distributed.get_rank() == 0:
    print(string)

Expected behavior

I would expect output that makes sense, like the output for one GPU.

ds_report output

ds_report

DeepSpeed C++/CUDA extension op report

NOTE: Ops not installed will be just-in-time (JIT) compiled at
runtime if needed. Op compatibility means that your system
meet the required dependencies to JIT install the op.

JIT compiled ops requires ninja
ninja .................. [OKAY]

op name ................ installed .. compatible

cpu_adam ............... [YES] ...... [OKAY]
cpu_adagrad ............ [YES] ...... [OKAY]
fused_adam ............. [YES] ...... [OKAY]
fused_lamb ............. [YES] ...... [OKAY]
sparse_attn ............ [YES] ...... [OKAY]
transformer ............ [YES] ...... [OKAY]
stochastic_transformer . [YES] ...... [OKAY]
async_io ............... [YES] ...... [OKAY]
utils .................. [YES] ...... [OKAY]
quantizer .............. [YES] ...... [OKAY]
transformer_inference .. [YES] ...... [OKAY]

DeepSpeed general environment info:
torch install path ............... ['/root/anaconda3/envs/gpt/lib/python3.9/site-packages/torch']
torch version .................... 1.12.0
torch cuda version ............... 11.3
torch hip version ................ None
nvcc version ..................... 11.3
deepspeed install path ........... ['/root/anaconda3/envs/gpt/lib/python3.9/site-packages/deepspeed']
deepspeed info ................... 0.7.1+7d8ad45, 7d8ad45, master
deepspeed wheel compiled w. ...... torch 1.12, cuda 11.3

System info (please complete the following information):

  • OS: Ubuntu 20.04
  • GPU count and types: 2 3090s
  • Interconnects: 1 system, 2 3090s
  • Python version: 3.9.13

I am using a docker container with Nvidia Cuda already set up as the base image.

Launcher context

deepspeed --num_gpus 2 infer.py
deepspeed --num_gpus 1 infer.py

Docker context

Are you using a specific docker image that you can share?
nvidia/cuda:11.3.1-devel-ubuntu20.04
then I am building python packages into the container

Additional context

NA

@mallorbc mallorbc added the bug Something isn't working label Aug 18, 2022
@skiingpacman
Copy link

Can confirm that this still occurs on Master HEAD; 86164c4 with GPUs > 1.

deepspeed.version
'0.7.1+86164c48'
transformers.version
'4.21.1'

(4 x A100 GPUs)
deepspeed gpt-j-6b-generation.py
[{'generated_text': 'DeepSpeed is and a to the\n and-.\n to the-- the,., the- and- and in the\n and the-,---- to,- a,-, as and-\n,, '}]

( 1 GPU)
deepspeed --num_gpus 1 gpt-j-6b-generation.py
[{'generated_text': 'DeepSpeed is a new project of the Google AI Language team, and is the first release of that team’s product in more than 3 years. DeepSpeed aims to dramatically increase the speed of machine translation, from a cost of years and years'}]

@skiingpacman
Copy link

code used is this from the tutorial example with the model name changed to EleutherAI/gpt-j-6B

image

@RezaYazdaniAminabadi
Copy link
Contributor

Hi @skiingpacman, Hi @mallorbc

I have sent a PR to fix this issue. Can you please try it on your side and let me know if it is resolved?
Thanks,
Reza

@RezaYazdaniAminabadi
Copy link
Contributor

Hi @skiingpacman, @mallorbc
Can I ask if you get a chance to try this? I want to merge this PR asap if this works fine and fixes the issue.
Thanks,
Reza

@mallorbc
Copy link
Author

@RezaYazdaniAminabadi I will try this later today. I will rebuild deepspeed from your PR hash and report back. Thanks!

@skiingpacman
Copy link

skiingpacman commented Aug 29, 2022

TL/DR; based on a quick test looks good.

Hi @RezaYazdaniAminabadi,

I just switched to and built your branch ds-inference/fix-mp2 which built deepspeed version 0.7.3+9eea4ee4

Testing with a modified version of the script pasted above, e.g.

< generator = pipeline('text-generation', model='EleutherAI/gpt-neo-2.7B',
---
> generator = pipeline('text-generation', model='EleutherAI/gpt-j-6B',

The results look much better when executed over 4 x A100 GPUs, e.g.,

[{'generated_text': 'DeepSpeed is a free, open source program and service. As such, the following describes how the DeepSpeed program is configured. This overview is specific to the DeepSpeed program as it is configured for the DeepSpeed.com.com web server ('}]

Also can confirm that deepspeed --num_gpus 1 gpt-j-6b-generation.py with 1 x A100 GPU still works:

[{'generated_text': 'DeepSpeed is committed to delivering high quality, high reliability and easy of use and to providing a long term relationship with our clients.\n\nWe offer a comprehensive portfolio of financial, technology and infrastructure services including IT Consultancy, Systems Integration, Consulting,'}]

This is with the following set-up:

>>> deepspeed.__version__
'0.7.3+9eea4ee4'
>>> transformers.__version__
'4.21.1'

@RezaYazdaniAminabadi
Copy link
Contributor

Thanks @skiingpacman for trying this out :)

@mallorbc
Copy link
Author

Sorry for the delay. I can also confirm that GPTJ now works with two GPUs as well as one GPU. I run them with

deepspeed --num_gpus 2 infer.py
deepspeed --num_gpus 1 infer.py

and get output that makes sense for each. Thus I will close this issue. Thanks for the help!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working inference
Projects
None yet
Development

No branches or pull requests

4 participants