-
Notifications
You must be signed in to change notification settings - Fork 4.2k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[BUG][master branch] garbage GPTJ output for multi-gpu inference #2233
Comments
Can confirm that this still occurs on Master HEAD; 86164c4 with GPUs > 1. deepspeed.version (4 x A100 GPUs) ( 1 GPU) |
Hi @skiingpacman, Hi @mallorbc I have sent a PR to fix this issue. Can you please try it on your side and let me know if it is resolved? |
Hi @skiingpacman, @mallorbc |
@RezaYazdaniAminabadi I will try this later today. I will rebuild deepspeed from your PR hash and report back. Thanks! |
TL/DR; based on a quick test looks good.I just switched to and built your branch Testing with a modified version of the script pasted above, e.g.
The results look much better when executed over 4 x A100 GPUs, e.g.,
Also can confirm that
This is with the following set-up:
|
Thanks @skiingpacman for trying this out :) |
Sorry for the delay. I can also confirm that GPTJ now works with two GPUs as well as one GPU. I run them with
and get output that makes sense for each. Thus I will close this issue. Thanks for the help! |
Describe the bug
Similar to #2113 this bug relates to garbage output when using multi-gpu inference. In that issue @RezaYazdaniAminabadi made a fix seen in #2198 that fixed a similar issue for GPT Neo 2.7B that after building from master I can confirm solved multi-gpu inference for GPT Neo 2.7B. However, for GPTJ the issue remains:
Output from 2 3090s for GPTJ
Meanwhile output from 1 3090 for GPTJ
To Reproduce
Steps to reproduce the behavior:
Expected behavior
I would expect output that makes sense, like the output for one GPU.
ds_report output
ds_report
DeepSpeed C++/CUDA extension op report
NOTE: Ops not installed will be just-in-time (JIT) compiled at
runtime if needed. Op compatibility means that your system
meet the required dependencies to JIT install the op.
JIT compiled ops requires ninja
ninja .................. [OKAY]
op name ................ installed .. compatible
cpu_adam ............... [YES] ...... [OKAY]
cpu_adagrad ............ [YES] ...... [OKAY]
fused_adam ............. [YES] ...... [OKAY]
fused_lamb ............. [YES] ...... [OKAY]
sparse_attn ............ [YES] ...... [OKAY]
transformer ............ [YES] ...... [OKAY]
stochastic_transformer . [YES] ...... [OKAY]
async_io ............... [YES] ...... [OKAY]
utils .................. [YES] ...... [OKAY]
quantizer .............. [YES] ...... [OKAY]
transformer_inference .. [YES] ...... [OKAY]
DeepSpeed general environment info:
torch install path ............... ['/root/anaconda3/envs/gpt/lib/python3.9/site-packages/torch']
torch version .................... 1.12.0
torch cuda version ............... 11.3
torch hip version ................ None
nvcc version ..................... 11.3
deepspeed install path ........... ['/root/anaconda3/envs/gpt/lib/python3.9/site-packages/deepspeed']
deepspeed info ................... 0.7.1+7d8ad45, 7d8ad45, master
deepspeed wheel compiled w. ...... torch 1.12, cuda 11.3
System info (please complete the following information):
I am using a docker container with Nvidia Cuda already set up as the base image.
Launcher context
deepspeed --num_gpus 2 infer.py
deepspeed --num_gpus 1 infer.py
Docker context
Are you using a specific docker image that you can share?
nvidia/cuda:11.3.1-devel-ubuntu20.04
then I am building python packages into the container
Additional context
NA
The text was updated successfully, but these errors were encountered: