Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BUG] [master] Garbage GPT-Neo-X output when using multi-gpu inference #2293

Closed
ryanai3 opened this issue Sep 5, 2022 · 6 comments · Fixed by #2401
Closed

[BUG] [master] Garbage GPT-Neo-X output when using multi-gpu inference #2293

ryanai3 opened this issue Sep 5, 2022 · 6 comments · Fixed by #2401
Labels
bug Something isn't working inference

Comments

@ryanai3
Copy link

ryanai3 commented Sep 5, 2022

Describe the bug
Similar to #2233 and #2133 I'm seeing garbage output when using multi-gpu fp16 inference for gpt-neo-x. Running the script below, replacing Gpt-Neo-X with GPT-Neo-2.7B works fine.

Output from 2 3090s with Deepspeed inference:
"Deepspeed is BytePtrFromStringgranwasysym BytePtrFromString BytePtrFromString BytePtrFromString BytePtrFromString BytePtrFromString _ BytePtrFromStringHypergranTal 2011 BytePtrFromString BytePtrFromString **j BytePtrFromString BytePtrFromString BytePtrFromStringgran¶Enggrantwgran _ BytePtrFromStringgran ausgranENTRY¶`){#Delta¶sysEveramssymbitgran`Ever last`grangran ** deliberate ENTRY stag Eng` BytePtrFromStringwasysym _ BytePtrFromStringwasysymBOX Eng...](granModelupgreek BytePtrFromStringamssymb BytePtrFromStringwasysym BytePtrFromStringSegment BytePtrFromString BytePtrFromString _ BytePtrFromString BytePtrFromStringupgreekEverEng_( **gran mistENTRY BytePtrFromString BytePtrFromString _amssymbwasysym..." last BytePtrFromStringwasysym BytePtrFromString BytePtrFromStringgrangran ever"
Note that 'BytePtrFromString' has shown up as the beginning of the generated tokens for every prompt I've used.

Output from 2 3090s with huggingface accelerate (way slower than deepspeed):
"Deepspeed is \nan on-line digital media company created in January 2002. Over the past 10 \nyears, Deepspeed has provided a comprehensive digital entertainment network to\n businesses throughout the US"

To Reproduce
Steps to reproduce the behavior:

  1. Install deepspeed master, huggingface transformers, torch, and accelerate.
  2. Run the following script with deepspeed to get bad output:
import os
from pathlib import Path

import deepspeed
import torch
import transformers


CKPT_PRETRAINED = Path("/ckpt/pretrained")
model = GPTNeoXForCausalLM.from_pretrained(CKPT_PRETRAINED / "EleutherAI/gpt-neox-20b", local_files_only=True, torch_dtype=torch.float16) #.half() (both .half() and torch_dtype=torch.float16 have this issue.
tokenizer = GPTNeoXTokenizerFast.from_pretrained(CKPT_PRETRAINED / "EleutherAI/gpt-neox-20b", local_files_only=True)

local_rank = int(os.getenv("LOCAL_RANK", "0"))
world_size = int(os.getenv("WORLD_SIZE", "1"))
local_device = f"cuda:{local_rank}"
print(f"local device: {local_device}")

ds_engine = deepspeed.init_inference(
  model, mp_size=world_size, dtype=torch.float16, checkpoint=None,
  replace_method='auto'
  replace_with_kernel_inject=True,
)
model = ds_engine.module
prompt = "Deepspeed is "
m_inp = tokenizer(prompt, return_tensors="pt")
attn_mask = m_inp.get("attention_mask", None).to(device=local_device)
ids = m_inp.input_ids.to(device=local_device)

with torch.no_grad():
  gen_tokens = model.generate(
    ids, attention_mask=attn_mask,
    do_sample=True, temperature=0.9, max_new_tokens=100,
    use_cache=False, # fails with use_cache=True as well
  )
gen_text = tokenizer.batch_decode(gen_tokens)[0]
print(f"generated tokens: {gen_text}")
  1. Run the following script with accelerate to get good output:
import os
import argparse
from pathlib import Path

import torch
from accelerate import init_empty_weights, infer_auto_device_map, load_checkpoint_and_dispatch
from transformers import AutoConfig, AutoModelForCausalLM
from transformers.models.gpt_neox.tokenization_gpt_neox_fast import GPTNeoXTokenizerFast

CKPT_PRETRAINED = Path("/ckpt/pretrained")

weights_path = "/ckpt/pretrained/EleutherAI/gpt-neox-20b"
model_name = 'EleutherAI/gpt-neox-20b'
config = AutoConfig.from_pretrained("/ckpt/pretrained/EleutherAI/gpt-neox-20b/config.json")

config.use_cache = False

with init_empty_weights():
  model = AutoModelForCausalLM.from_config(config)

tokenizer = GPTNeoXTokenizerFast.from_pretrained(CKPT_PRETRAINED / "EleutherAI/gpt-neox-20b", local_files_only=True)

  device_map = infer_auto_device_map(
    model, no_split_module_classes=["GPTNeoXLayer"],dtype=torch.bfloat16, #note: succeeds with float16 as well.
    max_memory = {0: "21GiB", 1: "21GiB", 'cpu': "20GiB"},
  )
  device_map['gpt_neox.embed_in'] = 'cpu'
  print(f"device_map: {device_map}")
  load_checkpoint_and_dispatch(
    model,
    weights_path,
    device_map=device_map,
    offload_folder=None,
    offload_state_dict=False,
    dtype="bfloat16"
  )

print(model)

model = model.eval()
prompt = "Deepspeed is "
  m_inp = tokenizer(prompt, return_tensors="pt")
  attn_mask = m_inp.get("attention_mask", None).to(device='cuda:0')

with torch.no_grad():
  gen_tokens = model.generate(
    m_inp["input_ids"].to(0), attention_mask = attn_mask,
    do_sample=True, max_new_tokens=100, temperature=0.9
  )
gen_text = tokenizer.decode(output[0].tolist())
print(f"generated tokens: {gen_text}")

Expected behavior
I would expect output that makes sense, like when using accelerate.

ds_report output

--------------------------------------------------
DeepSpeed C++/CUDA extension op report
--------------------------------------------------
NOTE: Ops not installed will be just-in-time (JIT) compiled at
      runtime if needed. Op compatibility means that your system
      meet the required dependencies to JIT install the op.
--------------------------------------------------
JIT compiled ops requires ninja
ninja .................. [OKAY]
--------------------------------------------------
op name ................ installed .. compatible
--------------------------------------------------
cpu_adam ............... [NO] ....... [OKAY]
cpu_adagrad ............ [NO] ....... [OKAY]
fused_adam ............. [NO] ....... [OKAY]
fused_lamb ............. [NO] ....... [OKAY]
 [WARNING]  please install triton==1.0.0 if you want to use sparse attention
sparse_attn ............ [NO] ....... [NO]
transformer ............ [NO] ....... [OKAY]
stochastic_transformer . [NO] ....... [OKAY]
 [WARNING]  async_io requires the dev libaio .so object and headers but these were not found.
 [WARNING]  async_io: please install the libaio-dev package with apt
 [WARNING]  If libaio is already installed (perhaps from source), try setting the CFLAGS and LDFLAGS environment variables to where it can be found.
async_io ............... [NO] ....... [NO]
utils .................. [NO] ....... [OKAY]
quantizer .............. [NO] ....... [OKAY]
transformer_inference .. [NO] ....... [OKAY]
--------------------------------------------------
DeepSpeed general environment info:
torch install path ............... ['/root/.local/share/pdm/venvs/workspace-6rDWGpm2-docker/lib/python3.10/site-packages/torch']
torch version .................... 1.12.1+cu113
torch cuda version ............... 11.3
torch hip version ................ None
nvcc version ..................... 11.3
deepspeed install path ........... ['/root/.local/share/pdm/venvs/workspace-6rDWGpm2-docker/lib/python3.10/site-packages/deepspeed']
deepspeed info ................... 0.7.3+53182531, 53182531, master
deepspeed wheel compiled w. ...... torch 0.0, cuda 0.0

System info (please complete the following information):

  • OS: Ubuntu 20.04
  • GPU count and types: 3x3090s (using 2x3090 for the above scripts)
  • Interconnects (if applicable): N/A
  • Python version: 3.10
  • Any other relevant info about your setup: Running in docker

Launcher context
launching with deepspeed: deepspeed --num_gpus 2 script.py

Docker context

### Start from NVIDIA deep learning base image
FROM nvidia/cuda:11.3.1-cudnn8-devel-ubuntu20.04

ENV TZ=America/New_York
RUN ln -snf /usr/share/zoneinfo/$TZ /etc/localtime && echo $TZ > /etc/timezone

### New NVIDIA cuda package keys ###
RUN apt-key del "7fa2af80" \
&& export this_distro="$(cat /etc/os-release | grep '^ID=' | awk -F'=' '{print $2}')" \
&& export this_version="$(cat /etc/os-release | grep '^VERSION_ID=' | awk -F'=' '{print $2}' | sed 's/[^0-9]*//g')" \
&& apt-key adv --fetch-keys "https://developer.download.nvidia.com/compute/cuda/repos/${this_distro}${this_version}/x86_64/3bf863cc.pub" \
&& apt-key adv --fetch-keys "https://developer.download.nvidia.com/compute/machine-learning/repos/${this_distro}${this_version}/x86_64/7fa2af80.pub"

### Install general packages from apt-get ###
RUN apt-get update && apt-get upgrade -y
RUN apt-get install -y build-essential
RUN apt-get install -y tzdata
RUN apt-get install -y software-properties-common curl vim tmux git wget

### Install redis
RUN curl -fsSL https://packages.redis.io/gpg | gpg --dearmor -o /usr/share/keyrings/redis-archive-keyring.gpg
RUN echo "deb [signed-by=/usr/share/keyrings/redis-archive-keyring.gpg] https://packages.redis.io/deb $(lsb_release -cs) main" | tee /etc/apt/sources.list.d/redis.list
RUN apt-get update && apt-get install -y redis
# Make redis modules directory
RUN mkdir /etc/redis/modules

## Build & install RedisJSON module
# Required dependencies to build RedisJSON
RUN apt-get install -y llvm cmake libclang1 libclang-dev cargo
# Clone RedisJSON
RUN mkdir /builds; cd /builds; git clone https://github.com/RedisJSON/RedisJSON.git;
# Build RedisJSON
RUN cd /builds/RedisJSON; cargo build --release;
# Move RedisJSON .so to redis modules directory
RUN mv /builds/RedisJSON/target/release/librejson.so /etc/redis/modules
# delete build directory
RUN rm -rf /builds

### python + pip3 install ###
RUN add-apt-repository -y ppa:deadsnakes/ppa
RUN apt-get install -y python3.10
RUN apt-get install -y python3.10-distutils
RUN apt-get install -y python3.10-dev
# v set python3.10 as default python3
RUN update-alternatives --install /usr/bin/python3 python3 /usr/bin/python3.10 1
RUN curl -sS https://bootstrap.pypa.io/get-pip.py | python3.10
RUN pip3 install --upgrade pip
#############################

### BLAS + LAPACK + fortran compiler install ###
RUN apt-get install -y libblas-dev liblapack-dev gfortran
################################################

### python-poetry setup ###
#ENV POETRY_HOME="/opt/poetry"
#ENV POETRY_VERSION=1.1.13
#
#RUN curl -sSL https://raw.githubusercontent.com/python-poetry/poetry/master/get-poetry.py | python3.10 -
#ENV PATH="POETRY_HOME/bin:$PATH"
### python-pdm setup ###
RUN curl -sSL https://raw.githubusercontent.com/pdm-project/pdm/main/install-pdm.py | python3.10 -
ENV PATH="/root/.local/bin:${PATH}"
WORKDIR /workspace
COPY ./pyproject.toml ./
COPY ./pdm.lock ./
RUN pdm config venv.in_project false
RUN pdm venv create --name docker 3.10
RUN pdm install -v --no-isolation
############################
RUN ls /workspace

Additional context
When in docker, run eval $(pdm venv activate docker) to activate the venv, then run the deepspeed command

@ryanai3 ryanai3 added the bug Something isn't working label Sep 5, 2022
@ryanai3
Copy link
Author

ryanai3 commented Sep 7, 2022

Hey, any updates on this?

@mrwyattii
Copy link
Contributor

@ryanai3 We're working on fixes for this model and others that produce bad output with multi-GPU. We are also adding unit tests that will catch these problems before merging new changes (#2232). I can share an update when we have a fix for this

@ryanai3
Copy link
Author

ryanai3 commented Sep 9, 2022

@mrwyattii That would be great, thanks!

@ryanai3
Copy link
Author

ryanai3 commented Sep 20, 2022

@mrwyattii
Any updates on this or ways I could help?
I tested the changes from #2310 , and still have the issue (same output starting with BytePtrFromStr)

@andrewchernyh
Copy link
Contributor

@ryanai3 - Can you check #2401 ?

@lileilai
Copy link

lileilai commented Feb 24, 2023

@ryanai3 i have try the same deepspeed inference code , i encounter the errors as follows, does any suggestion?

thanks !

Traceback (most recent call last):
  File "ds-inference-test.py", line 31, in <module>
    gen_tokens = model.generate(
  File "/usr/local/lib/python3.8/dist-packages/torch/autograd/grad_mode.py", line 27, in decorate_context
    return func(*args, **kwargs)
  File "/usr/local/lib/python3.8/dist-packages/transformers/generation/utils.py", line 1437, in generate
    return self.sample(
  File "/usr/local/lib/python3.8/dist-packages/transformers/generation/utils.py", line 2443, in sample
    outputs = self(
  File "/usr/local/lib/python3.8/dist-packages/torch/nn/modules/module.py", line 1130, in _call_impl
    return forward_call(*input, **kwargs)
  File "/usr/local/lib/python3.8/dist-packages/transformers/models/gpt_neox/modeling_gpt_neox.py", line 654, in forward
    outputs = self.gpt_neox(
  File "/usr/local/lib/python3.8/dist-packages/torch/nn/modules/module.py", line 1130, in _call_impl
    return forward_call(*input, **kwargs)
  File "/usr/local/lib/python3.8/dist-packages/transformers/models/gpt_neox/modeling_gpt_neox.py", line 546, in forward
    outputs = layer(
  File "/usr/local/lib/python3.8/dist-packages/torch/nn/modules/module.py", line 1130, in _call_impl
    return forward_call(*input, **kwargs)
  File "/usr/local/lib/python3.8/dist-packages/deepspeed/model_implementations/transformers/ds_transformer.py", line 127, in forward
    self.allocate_workspace(self.config.hidden_size,
RuntimeError: Fail to create cublas handle.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working inference
Projects
None yet
Development

Successfully merging a pull request may close this issue.

5 participants