Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

BLOOM Inference via DeepSpeed-Inference, Accelerate and DeepSpeed-ZeRO #308

Merged
merged 58 commits into from
Aug 10, 2022

Conversation

stas00
Copy link
Contributor

@stas00 stas00 commented Jul 10, 2022

update: I expanded the PR to include accelerate and deepspeed ZeRO - please see README for full details


This PR is sorting out the inference script for BLOOM via DeepSpeed-Inference microsoft/DeepSpeed#2083

I pushed the main script into main already, so this is just the fixes of that script.

setup transformers

make sure you are on the latest transformers@main

setup DeepSpeed

Get the DS master branch

git clone https://github.com/microsoft/DeepSpeed
cd DeepSpeed
pip install -e .

setup Meg-DS

git clone https://github.com/bigscience-workshop/Megatron-DeepSpeed
cd Megatron-DeepSpeed
git checkout bloom-inference

run the script:

deepspeed --num_gpus 8 scripts/inference/bloom-ds-inference.py --name bigscience/bloom --benchmark --batch_size 8

adapt to number of wanted gpus, use the larger models if needed.

p.s. also added zero3-inference script

deepspeed --num_gpus 8 scripts/inference/bloom-ds-zero-inference.py --name bigscience/bloom

but you must edit the nvme path and this one is super-slow - but it works and requires only 1x 24GB GPU! and not 8x80GB :)

On JZ you must do:

srun --pty --account=six@a100 --constraint=a100 --reservation=hug --partition=gpu_p5 --gres=gpu:8 --nodes=1 --cpus-per-task=64 --time 4:00:00 --tasks-per-node=1 bash
cd $six_ALL_CCFRWORK/code/inference/Megatron-DeepSpeed

export TRANSFORMERS_CACHE=$six_ALL_CCFRWORK/models
export HF_DATASETS_CACHE=$six_ALL_CCFRWORK/datasets
export HF_MODULES_CACHE=$six_ALL_CCFRWORK/modules
export HF_METRICS_CACHE=$six_ALL_CCFRWORK/metrics
export HF_DATASETS_OFFLINE=1
export TRANSFORMERS_OFFLINE=1

deepspeed --num_gpus 8 scripts/inference/bloom-ds-inference.py --name bigscience/bloom

(I already pre-cached bigscience/bloom-350m and bigscience/bloom so it should work from offline mode.

@RezaYazdaniAminabadi, @jeffra

@stas00 stas00 changed the title hardcode the dtype depending on the model BLOOM Inference via DeepSpeed-Inference Jul 10, 2022
@stas00
Copy link
Contributor Author

stas00 commented Jul 10, 2022

with a single A100 it fails with:

$ deepspeed --num_gpus 1 bloom-inference.py --name bigscience/bloom-350m
[...]
DeepSpeed Transformer Inference config is  {'layer_id': 23, 'hidden_size': 1024, 'intermediate_size': 4096, 'heads': 16, 'num_hidden_layers': -1, 'fp16': True, 'pre_layer_norm': True, 'local_rank': -1, 'stochastic_mode': False, 'epsilon': 1e-05, 'mp_size': 1, 'q_int8': False, 'scale_attention': True, 'triangular_masking': True, 'local_attention': False, 'window_size': 1, 'rotary_dim': -1, 'rotate_half': False, 'rotate_every_two': True, 'return_tuple': True, 'mlp_after_attn': True, 'specialized_mode': False, 'training_mp_size': 1, 'bigscience_bloom': True}
[2022-07-09 22:58:55,924] [INFO] [engine.py:143:__init__] Place model to device: 0
!!!! kernel execution error. (m: 1024, n: 3, k: 4096, error: 13) 
!!!! kernel execution error. (m: 3072, n: 3, k: 1024, error: 13) 
Traceback (most recent call last):
  File "bloom-inference.py", line 154, in <module>
    min_length=50,
  File "/home/stas/anaconda3/envs/py38-pt112/lib/python3.8/site-packages/torch/autograd/grad_mode.py", line 27, in decorate_context
    return func(*args, **kwargs)
  File "/mnt/nvme0/code/huggingface/transformers-master/src/transformers/generation_utils.py", line 1288, in generate
    return self.greedy_search(
  File "/mnt/nvme0/code/huggingface/transformers-master/src/transformers/generation_utils.py", line 1683, in greedy_search
    outputs = self(
  File "/home/stas/anaconda3/envs/py38-pt112/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1148, in _call_impl
    result = forward_call(*input, **kwargs)
  File "/mnt/nvme0/code/huggingface/transformers-master/src/transformers/models/bloom/modeling_bloom.py", line 919, in forward
    transformer_outputs = self.transformer(
  File "/home/stas/anaconda3/envs/py38-pt112/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1148, in _call_impl
    result = forward_call(*input, **kwargs)
  File "/mnt/nvme0/code/huggingface/transformers-master/src/transformers/models/bloom/modeling_bloom.py", line 806, in forward
    outputs = block(
  File "/home/stas/anaconda3/envs/py38-pt112/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1130, in _call_impl
    return forward_call(*input, **kwargs)
  File "/mnt/nvme0/code/github/00optimize/deepspeed/deepspeed/ops/transformer/inference/transformer_inference.py", line 839, in forward
    self.attention(input,
  File "/home/stas/anaconda3/envs/py38-pt112/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1130, in _call_impl
    return forward_call(*input, **kwargs)
  File "/mnt/nvme0/code/github/00optimize/deepspeed/deepspeed/ops/transformer/inference/transformer_inference.py", line 553, in forward
    output = DeepSpeedSelfAttentionFunction.apply(
  File "/mnt/nvme0/code/github/00optimize/deepspeed/deepspeed/ops/transformer/inference/transformer_inference.py", line 474, in forward
    output, key_layer, value_layer, context_layer, inp_norm = selfAttention_fp()
  File "/mnt/nvme0/code/github/00optimize/deepspeed/deepspeed/ops/transformer/inference/transformer_inference.py", line 438, in selfAttention_fp
    context_layer, key_layer, value_layer = compute_attention(qkv_out[0] if isinstance(qkv_out, list) else qkv_out, input_mask)
  File "/mnt/nvme0/code/github/00optimize/deepspeed/deepspeed/ops/transformer/inference/transformer_inference.py", line 390, in compute_attention
    context_layer, presents = backup_attention(qkv_out, layer_past, alibi, input_mask, norm_factor)
  File "/mnt/nvme0/code/github/00optimize/deepspeed/deepspeed/ops/transformer/inference/transformer_inference.py", line 204, in backup_attention
    value_layer) = split_tensor_along_last_dim(mixed_x_layer,
  File "/mnt/nvme0/code/github/00optimize/deepspeed/deepspeed/ops/transformer/inference/transformer_inference.py", line 189, in split_tensor_along_last_dim
    return tuple(chunk.contiguous() for chunk in tensor_list)
  File "/mnt/nvme0/code/github/00optimize/deepspeed/deepspeed/ops/transformer/inference/transformer_inference.py", line 189, in <genexpr>
    return tuple(chunk.contiguous() for chunk in tensor_list)
RuntimeError: CUDA error: an illegal memory access was encountered
CUDA kernel errors might be asynchronously reported at some other API call,so the stacktrace below might be incorrect.
For debugging consider passing CUDA_LAUNCH_BLOCKING=1.
[2022-07-09 22:58:59,156] [INFO] [launch.py:178:sigkill_handler] Killing subprocess 693904

@RezaYazdaniAminabadi
Copy link
Collaborator

It is very weird, let me try it on my side and I will push a fix soon.

@RezaYazdaniAminabadi
Copy link
Collaborator

I used my own script and could run this fine, I am gonna try it with yours and see if it works.

@RezaYazdaniAminabadi
Copy link
Collaborator

I also produced good results with your script on my side:
in=DeepSpeed is
out=DeepSpeed is a function of the number of bits in the data stream, and the number of bits in the data stream is a function of the number of bits in the data stream. The number of bits in the data stream is a function of the

@RezaYazdaniAminabadi
Copy link
Collaborator

RezaYazdaniAminabadi commented Jul 10, 2022

I found an issue with your script when running with mlti-GPU, that it results in illegal memory access. I push a fix now.

@RezaYazdaniAminabadi
Copy link
Collaborator

I found an issue with your script when running with mlti-GPU, that it results in illegal memory access. I push a fix now.

So, I realize I cannot push here. But the change is simple and you can do it on your side. Just change the mp_size from 1 to world_size when passing to init_inference:

model = deepspeed.init_inference(model,
                                 mp_size=world_size,
                                 dtype=torch.half,
                                 checkpoint=checkpoints_json,
                                 #injection_policy={BloomBlock: ('self_attention.dense', 'mlp.dense_4h_to_h')}
                                 replace_with_kernel_inject=True
                                 )
model = model.module

@RezaYazdaniAminabadi
Copy link
Collaborator

Btw, is there any chance you may have running this on 2 GPUs. Can you please retry this with using --num_nodes 1?

@stas00
Copy link
Contributor Author

stas00 commented Jul 10, 2022

So, I realize I cannot push here.

I sent you an invite to this repo yesterday, please check your emails from github

The problem was that I didn't built a kernel for this card, I was able to see that through adding CUDA_LAUNCH_BLOCKING=1 and then got

RuntimeError: CUDA error: no kernel image is available for execution on the device

Fixed that now, passed that point (and yes, mp_size should be fixed as well, but was fine for 1 gpu) - thank you!

next it's now crashing here:

DeepSpeed Transformer Inference config is  {'layer_id': 23, 'hidden_size': 1024, 'intermediate_size': 4096, 'heads': 16, 'num_hidden_layers': -1, 'fp16': True, 'pre_layer_norm': True, 'local_rank': -1, 'stochastic_mode': False, 'epsilon': 1e-05, 'mp_size': 1, 'q_int8': False, 'scale_attention': True, 'triangular_masking': True, 'local_attention': False, 'window_size': 1, 'rotary_dim': -1, 'rotate_half': False, 'rotate_every_two': True, 'return_tuple': True, 'mlp_after_attn': True, 'specialized_mode': False, 'training_mp_size': 1, 'bigscience_bloom': True}
[2022-07-10 07:02:11,724] [INFO] [engine.py:143:__init__] Place model to device: 0
Traceback (most recent call last):
  File "bloom-inference.py", line 153, in <module>
    gen_tokens = model.generate(
  File "/home/stas/anaconda3/envs/py38-pt112/lib/python3.8/site-packages/torch/autograd/grad_mode.py", line 27, in decorate_context
    return func(*args, **kwargs)
  File "/mnt/nvme0/code/huggingface/transformers-master/src/transformers/generation_utils.py", line 1288, in generate
    return self.greedy_search(
  File "/mnt/nvme0/code/huggingface/transformers-master/src/transformers/generation_utils.py", line 1683, in greedy_search
    outputs = self(
  File "/home/stas/anaconda3/envs/py38-pt112/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1148, in _call_impl
    result = forward_call(*input, **kwargs)
  File "/mnt/nvme0/code/huggingface/transformers-master/src/transformers/models/bloom/modeling_bloom.py", line 919, in forward
    transformer_outputs = self.transformer(
  File "/home/stas/anaconda3/envs/py38-pt112/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1148, in _call_impl
    result = forward_call(*input, **kwargs)
  File "/mnt/nvme0/code/huggingface/transformers-master/src/transformers/models/bloom/modeling_bloom.py", line 806, in forward
    outputs = block(
  File "/home/stas/anaconda3/envs/py38-pt112/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1130, in _call_impl
    return forward_call(*input, **kwargs)
  File "/mnt/nvme0/code/github/00optimize/deepspeed/deepspeed/ops/transformer/inference/transformer_inference.py", line 839, in forward
    self.attention(input,
  File "/home/stas/anaconda3/envs/py38-pt112/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1130, in _call_impl
    return forward_call(*input, **kwargs)
  File "/mnt/nvme0/code/github/00optimize/deepspeed/deepspeed/ops/transformer/inference/transformer_inference.py", line 553, in forward
    output = DeepSpeedSelfAttentionFunction.apply(
  File "/mnt/nvme0/code/github/00optimize/deepspeed/deepspeed/ops/transformer/inference/transformer_inference.py", line 474, in forward
    output, key_layer, value_layer, context_layer, inp_norm = selfAttention_fp()
  File "/mnt/nvme0/code/github/00optimize/deepspeed/deepspeed/ops/transformer/inference/transformer_inference.py", line 438, in selfAttention_fp
    context_layer, key_layer, value_layer = compute_attention(qkv_out[0] if isinstance(qkv_out, list) else qkv_out, input_mask)
  File "/mnt/nvme0/code/github/00optimize/deepspeed/deepspeed/ops/transformer/inference/transformer_inference.py", line 390, in compute_attention
    context_layer, presents = backup_attention(qkv_out, layer_past, alibi, input_mask, norm_factor)
  File "/mnt/nvme0/code/github/00optimize/deepspeed/deepspeed/ops/transformer/inference/transformer_inference.py", line 280, in backup_attention
    context_layer = torch.bmm(attention_probs_reshaped,
RuntimeError: CUDA error: CUBLAS_STATUS_EXECUTION_FAILED when calling `cublasGemmStridedBatchedExFix( handle, opa, opb, m, n, k, (void*)(&falpha), a, CUDA_R_16F, lda, stridea, b, CUDA_R_16F, ldb, strideb, (void*)(&fbeta), c, CUDA_R_16F, ldc, stridec, num_batches, CUDA_R_32F, CUBLAS_GEMM_DEFAULT_TENSOR_OP)`
[2022-07-10 07:02:14,292] [INFO] [launch.py:178:sigkill_handler] Killing subprocess 729431

@RezaYazdaniAminabadi
Copy link
Collaborator

Hi @stas00, Cany you please see if you can run this without kernel injection? Just, remove the replace_with_kernel_inject from the init_inference and pass the injection_policy.
Thanks

@stas00
Copy link
Contributor Author

stas00 commented Jul 10, 2022

Reza figured it out - it was a bug in transformers's bloom model. Alibi wasn't placed on the correct device. will send a fix soon.

diff --git a/src/transformers/models/bloom/modeling_bloom.py b/src/transformers/models/bloom/modeling_bloom.py
index ba8edde14..e1723dd68 100644
--- a/src/transformers/models/bloom/modeling_bloom.py
+++ b/src/transformers/models/bloom/modeling_bloom.py
@@ -774,6 +774,7 @@ class BloomModel(BloomPreTrainedModel):
         if past_key_values[0] is not None:
             current_sequence_length += past_key_values[0][0].shape[1]
         alibi = build_alibi_tensor(current_sequence_length, self.n_head, hidden_states.dtype)
+        alibi = alibi.to(device=hidden_states.device)

         for i, (block, layer_past) in enumerate(zip(self.h, past_key_values)):

@stas00
Copy link
Contributor Author

stas00 commented Jul 13, 2022

Status update - after many hard working days and nights everything works fast and great! Reza++!

Let's generate some text:

model.generate(**tokens, min_length=100, max_length=100, do_sample=False)

in=DeepSpeed is a machine learning framework

out=DeepSpeed is a machine learning framework that is designed to be used by researchers and developers who are interested in applying deep learning to their own problems. It is a Python library that provides a set of tools for training and evaluating deep neural networks. It is designed to be easy to use and to provide a flexible environment for experimentation. DeepSpeed is built on top of the Caffe deep learning framework, and it provides a set of tools for training and evaluating deep neural networks. It is designed to

@stas00
Copy link
Contributor Author

stas00 commented Jul 13, 2022

OK, instrumented the script to run various benchmarks, so 8x80 a100:

deepspeed --num_gpus 8 scripts/inference/bloom-ds-inference.py \
--name bigscience/bloom --benchmark

*** Performance stats:
Throughput per token: 40.73 msecs
Start to ready to generate: 673.429 secs
Tokenize and generate 100 tokens: 4.089 secs
Start to finish: 677.518 secs

While processing memory per process:

  • GPU: ~50GB
  • CPU: ~10GB

Radical!

@RezaYazdaniAminabadi run the same on 2x8x40GB A100 and it was 50msec per token. The slowness is due to internode communication. So the slowdown will depends on the internode connectivity - faster network will lead to faster throughput.

@mayank31398
Copy link
Collaborator

mayank31398 commented Jul 25, 2022

@RezaYazdaniAminabadi
Am I supposed to use some deepspeed branch?
Full trace

[2022-07-25 08:20:58,333] [WARNING] [runner.py:159:fetch_hostfile] Unable to find hostfile, will proceed with training with local resources only.
[2022-07-25 08:21:01,366] [INFO] [runner.py:457:main] cmd = /net/llm-shared-nfs/nfs/yelkurdi/conda/miniconda3/envs/llmpt/bin/python -u -m deepspeed.launcher.launch --world_info=eyJsb2NhbGhvc3QiOiBbMCwgMSwgMiwgMywgNCwgNSwgNiwgN119 --master_addr=127.0.0.1 --master_port=29500 scripts/inference/bloom-ds-inference.py --name bigscience/bloom --benchmark
[2022-07-25 08:21:02,279] [INFO] [launch.py:103:main] WORLD INFO DICT: {'localhost': [0, 1, 2, 3, 4, 5, 6, 7]}
[2022-07-25 08:21:02,279] [INFO] [launch.py:109:main] nnodes=1, num_local_procs=8, node_rank=0
[2022-07-25 08:21:02,279] [INFO] [launch.py:122:main] global_rank_mapping=defaultdict(<class 'list'>, {'localhost': [0, 1, 2, 3, 4, 5, 6, 7]})
[2022-07-25 08:21:02,279] [INFO] [launch.py:123:main] dist_world_size=8
[2022-07-25 08:21:02,279] [INFO] [launch.py:125:main] Setting CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7
[2022-07-25 08:21:03,698] [INFO] [comm.py:423:init_distributed] Initializing TorchBackend in DeepSpeed with backend nccl
*** Loading the model bigscience/bloom
[2022-07-25 08:21:12,069] [INFO] [utils.py:827:see_memory_usage] pre-from-pretrained
[2022-07-25 08:21:12,070] [INFO] [utils.py:828:see_memory_usage] MA 0.0 GB         Max_MA 0.0 GB         CA 0.0 GB         Max_CA 0 GB
[2022-07-25 08:21:12,070] [INFO] [utils.py:836:see_memory_usage] CPU Virtual Memory:  used = 11.16 GB, percent = 0.9%
[2022-07-25 08:21:12,219] [INFO] [utils.py:827:see_memory_usage] post-from-pretrained
[2022-07-25 08:21:12,220] [INFO] [utils.py:828:see_memory_usage] MA 0.0 GB         Max_MA 0.0 GB         CA 0.0 GB         Max_CA 0 GB
[2022-07-25 08:21:12,220] [INFO] [utils.py:836:see_memory_usage] CPU Virtual Memory:  used = 11.17 GB, percent = 0.9%
[2022-07-25 08:21:12,266] [INFO] [utils.py:827:see_memory_usage] post-init-ds-zero-init
[2022-07-25 08:21:12,266] [INFO] [utils.py:828:see_memory_usage] MA 0.0 GB         Max_MA 0.0 GB         CA 0.0 GB         Max_CA 0 GB
[2022-07-25 08:21:12,267] [INFO] [utils.py:836:see_memory_usage] CPU Virtual Memory:  used = 11.17 GB, percent = 0.9%
[2022-07-25 08:21:21,406] [INFO] [utils.py:827:see_memory_usage] pre-ds-inference-init
[2022-07-25 08:21:21,407] [INFO] [utils.py:828:see_memory_usage] MA 0.0 GB         Max_MA 0.0 GB         CA 0.0 GB         Max_CA 0 GB
[2022-07-25 08:21:21,407] [INFO] [utils.py:836:see_memory_usage] CPU Virtual Memory:  used = 32.56 GB, percent = 2.6%
[2022-07-25 08:21:21,407] [INFO] [logging.py:69:log_dist] [Rank 0] DeepSpeed info: version=0.7.0+8413b7f8, git-hash=8413b7f8, git-branch=master
[2022-07-25 08:21:21,407] [INFO] [logging.py:69:log_dist] [Rank 0] quantize_bits = 8 mlp_extra_grouping = False, quantize_groups = 1
Installed CUDA version 11.6 does not match the version torch was compiled with 11.3 but since the APIs are compatible, accepting this combination
Installed CUDA version 11.6 does not match the version torch was compiled with 11.3 but since the APIs are compatible, accepting this combination
Installed CUDA version 11.6 does not match the version torch was compiled with 11.3 but since the APIs are compatible, accepting this combination
Installed CUDA version 11.6 does not match the version torch was compiled with 11.3 but since the APIs are compatible, accepting this combination
Installed CUDA version 11.6 does not match the version torch was compiled with 11.3 but since the APIs are compatible, accepting this combination
Installed CUDA version 11.6 does not match the version torch was compiled with 11.3 but since the APIs are compatible, accepting this combination
Installed CUDA version 11.6 does not match the version torch was compiled with 11.3 but since the APIs are compatible, accepting this combination
Installed CUDA version 11.6 does not match the version torch was compiled with 11.3 but since the APIs are compatible, accepting this combination
Using /home/gpttest/.cache/torch_extensions/py38_cu113 as PyTorch extensions root...
Using /home/gpttest/.cache/torch_extensions/py38_cu113 as PyTorch extensions root...
Using /home/gpttest/.cache/torch_extensions/py38_cu113 as PyTorch extensions root...Using /home/gpttest/.cache/torch_extensions/py38_cu113 as PyTorch extensions root...Using /home/gpttest/.cache/torch_extensions/py38_cu113 as PyTorch extensions root...


Using /home/gpttest/.cache/torch_extensions/py38_cu113 as PyTorch extensions root...
Using /home/gpttest/.cache/torch_extensions/py38_cu113 as PyTorch extensions root...
Using /home/gpttest/.cache/torch_extensions/py38_cu113 as PyTorch extensions root...
/net/llm-shared-nfs/nfs/yelkurdi/conda/miniconda3/envs/llmpt/lib/python3.8/site-packages/torch/utils/cpp_extension.py:295: UserWarning:

                               !! WARNING !!

!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
Your compiler (c++) is not compatible with the compiler Pytorch was
built with for this platform, which is g++ on linux. Please
use g++ to to compile your extension. Alternatively, you may
compile PyTorch from source using c++, and then you can also use
c++ to compile your extension.

See https://github.com/pytorch/pytorch/blob/master/CONTRIBUTING.md for help
with compiling PyTorch from source.
!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!

                              !! WARNING !!

  warnings.warn(WRONG_COMPILER_WARNING.format(
Detected CUDA files, patching ldflags
Emitting ninja build file /home/gpttest/.cache/torch_extensions/py38_cu113/transformer_inference/build.ninja...
Building extension module transformer_inference...
Allowing ninja to set a default number of workers... (overridable by setting the environment variable MAX_JOBS=N)
ninja: no work to do.
Loading extension module transformer_inference...
Time to load transformer_inference op: 0.30841994285583496 seconds
[2022-07-25 08:21:22,096] [INFO] [logging.py:69:log_dist] [Rank 0] DeepSpeed-Inference config: {'layer_id': 0, 'hidden_size': 14336, 'intermediate_size': 57344, 'heads': 112, 'num_hidden_layers': -1, 'fp16': True, 'pre_layer_norm': True, 'local_rank': -1, 'stochastic_mode': False, 'epsilon': 1e-05, 'mp_size': 8, 'q_int8': False, 'scale_attention': True, 'triangular_masking': True, 'local_attention': False, 'window_size': 1, 'rotary_dim': -1, 'rotate_half': False, 'rotate_every_two': True, 'return_tuple': True, 'mlp_after_attn': True, 'specialized_mode': False, 'training_mp_size': 1, 'bigscience_bloom': True}
Loading extension module transformer_inference...
Time to load transformer_inference op: 0.35624146461486816 seconds
Loading extension module transformer_inference...
Loading extension module transformer_inference...
Loading extension module transformer_inference...
Loading extension module transformer_inference...
Time to load transformer_inference op: 0.35947346687316895 seconds
Time to load transformer_inference op: 0.35944700241088867 seconds
Time to load transformer_inference op: 0.3590846061706543 seconds
Time to load transformer_inference op: 0.35912346839904785 seconds
Loading extension module transformer_inference...
Loading extension module transformer_inference...
Time to load transformer_inference op: 0.3587977886199951 seconds
Time to load transformer_inference op: 0.35920000076293945 seconds
Loading 72 checkpoint shards:   0%|          | 0/72 [13:06<?, ?it/s]1.39s/it]
[2022-07-25 08:34:29,134] [INFO] [engine.py:145:__init__] Place model to device: 7
Loading 72 checkpoint shards:   0%|          | 0/72 [13:11<?, ?it/s]
[2022-07-25 08:34:34,207] [INFO] [engine.py:145:__init__] Place model to device: 6
Loading 72 checkpoint shards:   0%|          | 0/72 [13:13<?, ?it/s]
[2022-07-25 08:34:36,378] [INFO] [engine.py:145:__init__] Place model to device: 3
Loading 72 checkpoint shards: 100%|██████████| 72/72 [13:14<00:00, 11.04s/it]
[2022-07-25 08:34:37,256] [INFO] [engine.py:145:__init__] Place model to device: 0
[2022-07-25 08:34:37,391] [INFO] [utils.py:827:see_memory_usage] post-ds-inference-init
[2022-07-25 08:34:37,392] [INFO] [utils.py:828:see_memory_usage] MA 47.04 GB         Max_MA 47.24 GB         CA 47.04 GB         Max_CA 47 GB
[2022-07-25 08:34:37,392] [INFO] [utils.py:836:see_memory_usage] CPU Virtual Memory:  used = 55.12 GB, percent = 4.4%
*** Starting to generate 100 tokens with bs=1
Generate args {'max_new_tokens': 100, 'do_sample': False}
Loading 72 checkpoint shards:   0%|          | 0/72 [13:19<?, ?it/s]
[2022-07-25 08:34:42,599] [INFO] [engine.py:145:__init__] Place model to device: 1
Loading 72 checkpoint shards:   0%|          | 0/72 [13:25<?, ?it/s]
[2022-07-25 08:34:48,165] [INFO] [engine.py:145:__init__] Place model to device: 5
Loading 72 checkpoint shards:   0%|          | 0/72 [13:28<?, ?it/s]
[2022-07-25 08:34:51,199] [INFO] [engine.py:145:__init__] Place model to device: 2
Loading 72 checkpoint shards:   0%|          | 0/72 [13:34<?, ?it/s]
[2022-07-25 08:34:57,001] [INFO] [engine.py:145:__init__] Place model to device: 4
Traceback (most recent call last):
  File "scripts/inference/bloom-ds-inference.py", line 257, in <module>
    _ = generate()
  File "scripts/inference/bloom-ds-inference.py", line 244, in generate
    outputs = model.generate(**input_tokens, **generate_kwargs)
  File "/net/llm-shared-nfs/nfs/yelkurdi/conda/miniconda3/envs/llmpt/lib/python3.8/site-packages/torch/autograd/grad_mode.py", line 28, in decorate_context
    return func(*args, **kwargs)
  File "/net/llm-shared-nfs/nfs/yelkurdi/conda/miniconda3/envs/llmpt/lib/python3.8/site-packages/transformers/generation_utils.py", line 1288, in generate
    return self.greedy_search(
  File "/net/llm-shared-nfs/nfs/yelkurdi/conda/miniconda3/envs/llmpt/lib/python3.8/site-packages/transformers/generation_utils.py", line 1683, in greedy_search
    outputs = self(
  File "/net/llm-shared-nfs/nfs/yelkurdi/conda/miniconda3/envs/llmpt/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1102, in _call_impl
    return forward_call(*input, **kwargs)
  File "/net/llm-shared-nfs/nfs/yelkurdi/conda/miniconda3/envs/llmpt/lib/python3.8/site-packages/deepspeed/inference/engine.py", line 508, in forward
    outputs = self.model_orig_fwd(*inputs, **kwargs)
  File "/net/llm-shared-nfs/nfs/yelkurdi/conda/miniconda3/envs/llmpt/lib/python3.8/site-packages/transformers/models/bloom/modeling_bloom.py", line 919, in forward
    transformer_outputs = self.transformer(
  File "/net/llm-shared-nfs/nfs/yelkurdi/conda/miniconda3/envs/llmpt/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1102, in _call_impl
    return forward_call(*input, **kwargs)
  File "/net/llm-shared-nfs/nfs/yelkurdi/conda/miniconda3/envs/llmpt/lib/python3.8/site-packages/transformers/models/bloom/modeling_bloom.py", line 806, in forward
    outputs = block(
  File "/net/llm-shared-nfs/nfs/yelkurdi/conda/miniconda3/envs/llmpt/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1102, in _call_impl
    return forward_call(*input, **kwargs)
  File "/net/llm-shared-nfs/nfs/yelkurdi/conda/miniconda3/envs/llmpt/lib/python3.8/site-packages/deepspeed/ops/transformer/inference/transformer_inference.py", line 831, in forward
    self.attention(input,
  File "/net/llm-shared-nfs/nfs/yelkurdi/conda/miniconda3/envs/llmpt/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1102, in _call_impl
    return forward_call(*input, **kwargs)
  File "/net/llm-shared-nfs/nfs/yelkurdi/conda/miniconda3/envs/llmpt/lib/python3.8/site-packages/deepspeed/ops/transformer/inference/transformer_inference.py", line 543, in forward
    output = DeepSpeedSelfAttentionFunction.apply(
  File "/net/llm-shared-nfs/nfs/yelkurdi/conda/miniconda3/envs/llmpt/lib/python3.8/site-packages/deepspeed/ops/transformer/inference/transformer_inference.py", line 466, in forward
    dist.all_reduce(output, group=mp_group)
  File "/net/llm-shared-nfs/nfs/yelkurdi/conda/miniconda3/envs/llmpt/lib/python3.8/site-packages/deepspeed/comm/comm.py", line 312, in all_reduce
    return cdb.all_reduce(tensor, op, group, async_op)
  File "/net/llm-shared-nfs/nfs/yelkurdi/conda/miniconda3/envs/llmpt/lib/python3.8/site-packages/deepspeed/comm/torch.py", line 49, in all_reduce
    return torch.distributed.all_reduce(tensor=tensor,
  File "/net/llm-shared-nfs/nfs/yelkurdi/conda/miniconda3/envs/llmpt/lib/python3.8/site-packages/torch/distributed/distributed_c10d.py", line 1287, in all_reduce
    work = group.allreduce([tensor], opts)
RuntimeError: NCCL error in: /opt/conda/conda-bld/pytorch_1639180588308/work/torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp:957, unhandled cuda error, NCCL version 21.0.3
ncclUnhandledCudaError: Call to CUDA function failed.
Traceback (most recent call last):
  File "scripts/inference/bloom-ds-inference.py", line 257, in <module>
    _ = generate()
  File "scripts/inference/bloom-ds-inference.py", line 244, in generate
    outputs = model.generate(**input_tokens, **generate_kwargs)
  File "/net/llm-shared-nfs/nfs/yelkurdi/conda/miniconda3/envs/llmpt/lib/python3.8/site-packages/torch/autograd/grad_mode.py", line 28, in decorate_context
    return func(*args, **kwargs)
  File "/net/llm-shared-nfs/nfs/yelkurdi/conda/miniconda3/envs/llmpt/lib/python3.8/site-packages/transformers/generation_utils.py", line 1288, in generate
    return self.greedy_search(
  File "/net/llm-shared-nfs/nfs/yelkurdi/conda/miniconda3/envs/llmpt/lib/python3.8/site-packages/transformers/generation_utils.py", line 1683, in greedy_search
    outputs = self(
  File "/net/llm-shared-nfs/nfs/yelkurdi/conda/miniconda3/envs/llmpt/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1102, in _call_impl
    return forward_call(*input, **kwargs)
  File "/net/llm-shared-nfs/nfs/yelkurdi/conda/miniconda3/envs/llmpt/lib/python3.8/site-packages/deepspeed/inference/engine.py", line 508, in forward
    outputs = self.model_orig_fwd(*inputs, **kwargs)
  File "/net/llm-shared-nfs/nfs/yelkurdi/conda/miniconda3/envs/llmpt/lib/python3.8/site-packages/transformers/models/bloom/modeling_bloom.py", line 919, in forward
    transformer_outputs = self.transformer(
  File "/net/llm-shared-nfs/nfs/yelkurdi/conda/miniconda3/envs/llmpt/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1102, in _call_impl
    return forward_call(*input, **kwargs)
  File "/net/llm-shared-nfs/nfs/yelkurdi/conda/miniconda3/envs/llmpt/lib/python3.8/site-packages/transformers/models/bloom/modeling_bloom.py", line 806, in forward
    outputs = block(
  File "/net/llm-shared-nfs/nfs/yelkurdi/conda/miniconda3/envs/llmpt/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1102, in _call_impl
    return forward_call(*input, **kwargs)
  File "/net/llm-shared-nfs/nfs/yelkurdi/conda/miniconda3/envs/llmpt/lib/python3.8/site-packages/deepspeed/ops/transformer/inference/transformer_inference.py", line 831, in forward
    self.attention(input,
  File "/net/llm-shared-nfs/nfs/yelkurdi/conda/miniconda3/envs/llmpt/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1102, in _call_impl
    return forward_call(*input, **kwargs)
  File "/net/llm-shared-nfs/nfs/yelkurdi/conda/miniconda3/envs/llmpt/lib/python3.8/site-packages/deepspeed/ops/transformer/inference/transformer_inference.py", line 543, in forward
    output = DeepSpeedSelfAttentionFunction.apply(
  File "/net/llm-shared-nfs/nfs/yelkurdi/conda/miniconda3/envs/llmpt/lib/python3.8/site-packages/deepspeed/ops/transformer/inference/transformer_inference.py", line 466, in forward
    dist.all_reduce(output, group=mp_group)
  File "/net/llm-shared-nfs/nfs/yelkurdi/conda/miniconda3/envs/llmpt/lib/python3.8/site-packages/deepspeed/comm/comm.py", line 312, in all_reduce
    return cdb.all_reduce(tensor, op, group, async_op)
  File "/net/llm-shared-nfs/nfs/yelkurdi/conda/miniconda3/envs/llmpt/lib/python3.8/site-packages/deepspeed/comm/torch.py", line 49, in all_reduce
    return torch.distributed.all_reduce(tensor=tensor,
  File "/net/llm-shared-nfs/nfs/yelkurdi/conda/miniconda3/envs/llmpt/lib/python3.8/site-packages/torch/distributed/distributed_c10d.py", line 1287, in all_reduce
    work = group.allreduce([tensor], opts)
RuntimeError: NCCL error in: /opt/conda/conda-bld/pytorch_1639180588308/work/torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp:957, unhandled cuda error, NCCL version 21.0.3
ncclUnhandledCudaError: Call to CUDA function failed.
Traceback (most recent call last):
  File "scripts/inference/bloom-ds-inference.py", line 257, in <module>
    _ = generate()
  File "scripts/inference/bloom-ds-inference.py", line 244, in generate
    outputs = model.generate(**input_tokens, **generate_kwargs)
  File "/net/llm-shared-nfs/nfs/yelkurdi/conda/miniconda3/envs/llmpt/lib/python3.8/site-packages/torch/autograd/grad_mode.py", line 28, in decorate_context
    return func(*args, **kwargs)
  File "/net/llm-shared-nfs/nfs/yelkurdi/conda/miniconda3/envs/llmpt/lib/python3.8/site-packages/transformers/generation_utils.py", line 1288, in generate
    return self.greedy_search(
  File "/net/llm-shared-nfs/nfs/yelkurdi/conda/miniconda3/envs/llmpt/lib/python3.8/site-packages/transformers/generation_utils.py", line 1683, in greedy_search
    outputs = self(
  File "/net/llm-shared-nfs/nfs/yelkurdi/conda/miniconda3/envs/llmpt/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1102, in _call_impl
    return forward_call(*input, **kwargs)
  File "/net/llm-shared-nfs/nfs/yelkurdi/conda/miniconda3/envs/llmpt/lib/python3.8/site-packages/deepspeed/inference/engine.py", line 508, in forward
    outputs = self.model_orig_fwd(*inputs, **kwargs)
  File "/net/llm-shared-nfs/nfs/yelkurdi/conda/miniconda3/envs/llmpt/lib/python3.8/site-packages/transformers/models/bloom/modeling_bloom.py", line 919, in forward
    transformer_outputs = self.transformer(
  File "/net/llm-shared-nfs/nfs/yelkurdi/conda/miniconda3/envs/llmpt/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1102, in _call_impl
    return forward_call(*input, **kwargs)
  File "/net/llm-shared-nfs/nfs/yelkurdi/conda/miniconda3/envs/llmpt/lib/python3.8/site-packages/transformers/models/bloom/modeling_bloom.py", line 806, in forward
    outputs = block(
  File "/net/llm-shared-nfs/nfs/yelkurdi/conda/miniconda3/envs/llmpt/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1102, in _call_impl
    return forward_call(*input, **kwargs)
  File "/net/llm-shared-nfs/nfs/yelkurdi/conda/miniconda3/envs/llmpt/lib/python3.8/site-packages/deepspeed/ops/transformer/inference/transformer_inference.py", line 831, in forward
    self.attention(input,
  File "/net/llm-shared-nfs/nfs/yelkurdi/conda/miniconda3/envs/llmpt/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1102, in _call_impl
    return forward_call(*input, **kwargs)
  File "/net/llm-shared-nfs/nfs/yelkurdi/conda/miniconda3/envs/llmpt/lib/python3.8/site-packages/deepspeed/ops/transformer/inference/transformer_inference.py", line 543, in forward
    output = DeepSpeedSelfAttentionFunction.apply(
  File "/net/llm-shared-nfs/nfs/yelkurdi/conda/miniconda3/envs/llmpt/lib/python3.8/site-packages/deepspeed/ops/transformer/inference/transformer_inference.py", line 466, in forward
    dist.all_reduce(output, group=mp_group)
  File "/net/llm-shared-nfs/nfs/yelkurdi/conda/miniconda3/envs/llmpt/lib/python3.8/site-packages/deepspeed/comm/comm.py", line 312, in all_reduce
    return cdb.all_reduce(tensor, op, group, async_op)
  File "/net/llm-shared-nfs/nfs/yelkurdi/conda/miniconda3/envs/llmpt/lib/python3.8/site-packages/deepspeed/comm/torch.py", line 49, in all_reduce
    return torch.distributed.all_reduce(tensor=tensor,
  File "/net/llm-shared-nfs/nfs/yelkurdi/conda/miniconda3/envs/llmpt/lib/python3.8/site-packages/torch/distributed/distributed_c10d.py", line 1287, in all_reduce
    work = group.allreduce([tensor], opts)
RuntimeError: NCCL error in: /opt/conda/conda-bld/pytorch_1639180588308/work/torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp:957, unhandled cuda error, NCCL version 21.0.3
ncclUnhandledCudaError: Call to CUDA function failed.
Traceback (most recent call last):
  File "scripts/inference/bloom-ds-inference.py", line 257, in <module>
    _ = generate()
  File "scripts/inference/bloom-ds-inference.py", line 244, in generate
    outputs = model.generate(**input_tokens, **generate_kwargs)
  File "/net/llm-shared-nfs/nfs/yelkurdi/conda/miniconda3/envs/llmpt/lib/python3.8/site-packages/torch/autograd/grad_mode.py", line 28, in decorate_context
    return func(*args, **kwargs)
  File "/net/llm-shared-nfs/nfs/yelkurdi/conda/miniconda3/envs/llmpt/lib/python3.8/site-packages/transformers/generation_utils.py", line 1288, in generate
    return self.greedy_search(
  File "/net/llm-shared-nfs/nfs/yelkurdi/conda/miniconda3/envs/llmpt/lib/python3.8/site-packages/transformers/generation_utils.py", line 1683, in greedy_search
    outputs = self(
  File "/net/llm-shared-nfs/nfs/yelkurdi/conda/miniconda3/envs/llmpt/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1102, in _call_impl
    return forward_call(*input, **kwargs)
  File "/net/llm-shared-nfs/nfs/yelkurdi/conda/miniconda3/envs/llmpt/lib/python3.8/site-packages/deepspeed/inference/engine.py", line 508, in forward
    outputs = self.model_orig_fwd(*inputs, **kwargs)
  File "/net/llm-shared-nfs/nfs/yelkurdi/conda/miniconda3/envs/llmpt/lib/python3.8/site-packages/transformers/models/bloom/modeling_bloom.py", line 919, in forward
    transformer_outputs = self.transformer(
  File "/net/llm-shared-nfs/nfs/yelkurdi/conda/miniconda3/envs/llmpt/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1102, in _call_impl
    return forward_call(*input, **kwargs)
  File "/net/llm-shared-nfs/nfs/yelkurdi/conda/miniconda3/envs/llmpt/lib/python3.8/site-packages/transformers/models/bloom/modeling_bloom.py", line 806, in forward
    outputs = block(
  File "/net/llm-shared-nfs/nfs/yelkurdi/conda/miniconda3/envs/llmpt/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1102, in _call_impl
    return forward_call(*input, **kwargs)
  File "/net/llm-shared-nfs/nfs/yelkurdi/conda/miniconda3/envs/llmpt/lib/python3.8/site-packages/deepspeed/ops/transformer/inference/transformer_inference.py", line 831, in forward
    self.attention(input,
  File "/net/llm-shared-nfs/nfs/yelkurdi/conda/miniconda3/envs/llmpt/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1102, in _call_impl
    return forward_call(*input, **kwargs)
  File "/net/llm-shared-nfs/nfs/yelkurdi/conda/miniconda3/envs/llmpt/lib/python3.8/site-packages/deepspeed/ops/transformer/inference/transformer_inference.py", line 543, in forward
    output = DeepSpeedSelfAttentionFunction.apply(
  File "/net/llm-shared-nfs/nfs/yelkurdi/conda/miniconda3/envs/llmpt/lib/python3.8/site-packages/deepspeed/ops/transformer/inference/transformer_inference.py", line 466, in forward
    dist.all_reduce(output, group=mp_group)
  File "/net/llm-shared-nfs/nfs/yelkurdi/conda/miniconda3/envs/llmpt/lib/python3.8/site-packages/deepspeed/comm/comm.py", line 312, in all_reduce
    return cdb.all_reduce(tensor, op, group, async_op)
  File "/net/llm-shared-nfs/nfs/yelkurdi/conda/miniconda3/envs/llmpt/lib/python3.8/site-packages/deepspeed/comm/torch.py", line 49, in all_reduce
    return torch.distributed.all_reduce(tensor=tensor,
  File "/net/llm-shared-nfs/nfs/yelkurdi/conda/miniconda3/envs/llmpt/lib/python3.8/site-packages/torch/distributed/distributed_c10d.py", line 1287, in all_reduce
    work = group.allreduce([tensor], opts)
RuntimeError: NCCL error in: /opt/conda/conda-bld/pytorch_1639180588308/work/torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp:957, unhandled cuda error, NCCL version 21.0.3
ncclUnhandledCudaError: Call to CUDA function failed.
Traceback (most recent call last):
  File "scripts/inference/bloom-ds-inference.py", line 257, in <module>
    _ = generate()
  File "scripts/inference/bloom-ds-inference.py", line 244, in generate
    outputs = model.generate(**input_tokens, **generate_kwargs)
  File "/net/llm-shared-nfs/nfs/yelkurdi/conda/miniconda3/envs/llmpt/lib/python3.8/site-packages/torch/autograd/grad_mode.py", line 28, in decorate_context
    return func(*args, **kwargs)
  File "/net/llm-shared-nfs/nfs/yelkurdi/conda/miniconda3/envs/llmpt/lib/python3.8/site-packages/transformers/generation_utils.py", line 1288, in generate
    return self.greedy_search(
  File "/net/llm-shared-nfs/nfs/yelkurdi/conda/miniconda3/envs/llmpt/lib/python3.8/site-packages/transformers/generation_utils.py", line 1683, in greedy_search
    outputs = self(
  File "/net/llm-shared-nfs/nfs/yelkurdi/conda/miniconda3/envs/llmpt/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1102, in _call_impl
    return forward_call(*input, **kwargs)
  File "/net/llm-shared-nfs/nfs/yelkurdi/conda/miniconda3/envs/llmpt/lib/python3.8/site-packages/deepspeed/inference/engine.py", line 508, in forward
    outputs = self.model_orig_fwd(*inputs, **kwargs)
  File "/net/llm-shared-nfs/nfs/yelkurdi/conda/miniconda3/envs/llmpt/lib/python3.8/site-packages/transformers/models/bloom/modeling_bloom.py", line 919, in forward
    transformer_outputs = self.transformer(
  File "/net/llm-shared-nfs/nfs/yelkurdi/conda/miniconda3/envs/llmpt/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1102, in _call_impl
    return forward_call(*input, **kwargs)
  File "/net/llm-shared-nfs/nfs/yelkurdi/conda/miniconda3/envs/llmpt/lib/python3.8/site-packages/transformers/models/bloom/modeling_bloom.py", line 806, in forward
    outputs = block(
  File "/net/llm-shared-nfs/nfs/yelkurdi/conda/miniconda3/envs/llmpt/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1102, in _call_impl
    return forward_call(*input, **kwargs)
  File "/net/llm-shared-nfs/nfs/yelkurdi/conda/miniconda3/envs/llmpt/lib/python3.8/site-packages/deepspeed/ops/transformer/inference/transformer_inference.py", line 831, in forward
    self.attention(input,
  File "/net/llm-shared-nfs/nfs/yelkurdi/conda/miniconda3/envs/llmpt/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1102, in _call_impl
    return forward_call(*input, **kwargs)
  File "/net/llm-shared-nfs/nfs/yelkurdi/conda/miniconda3/envs/llmpt/lib/python3.8/site-packages/deepspeed/ops/transformer/inference/transformer_inference.py", line 543, in forward
    output = DeepSpeedSelfAttentionFunction.apply(
  File "/net/llm-shared-nfs/nfs/yelkurdi/conda/miniconda3/envs/llmpt/lib/python3.8/site-packages/deepspeed/ops/transformer/inference/transformer_inference.py", line 466, in forward
    dist.all_reduce(output, group=mp_group)
  File "/net/llm-shared-nfs/nfs/yelkurdi/conda/miniconda3/envs/llmpt/lib/python3.8/site-packages/deepspeed/comm/comm.py", line 312, in all_reduce
    return cdb.all_reduce(tensor, op, group, async_op)
  File "/net/llm-shared-nfs/nfs/yelkurdi/conda/miniconda3/envs/llmpt/lib/python3.8/site-packages/deepspeed/comm/torch.py", line 49, in all_reduce
    return torch.distributed.all_reduce(tensor=tensor,
  File "/net/llm-shared-nfs/nfs/yelkurdi/conda/miniconda3/envs/llmpt/lib/python3.8/site-packages/torch/distributed/distributed_c10d.py", line 1287, in all_reduce
    work = group.allreduce([tensor], opts)
RuntimeError: NCCL error in: /opt/conda/conda-bld/pytorch_1639180588308/work/torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp:957, unhandled cuda error, NCCL version 21.0.3
ncclUnhandledCudaError: Call to CUDA function failed.
Traceback (most recent call last):
  File "scripts/inference/bloom-ds-inference.py", line 257, in <module>
    _ = generate()
  File "scripts/inference/bloom-ds-inference.py", line 244, in generate
    outputs = model.generate(**input_tokens, **generate_kwargs)
  File "/net/llm-shared-nfs/nfs/yelkurdi/conda/miniconda3/envs/llmpt/lib/python3.8/site-packages/torch/autograd/grad_mode.py", line 28, in decorate_context
    return func(*args, **kwargs)
  File "/net/llm-shared-nfs/nfs/yelkurdi/conda/miniconda3/envs/llmpt/lib/python3.8/site-packages/transformers/generation_utils.py", line 1288, in generate
    return self.greedy_search(
  File "/net/llm-shared-nfs/nfs/yelkurdi/conda/miniconda3/envs/llmpt/lib/python3.8/site-packages/transformers/generation_utils.py", line 1683, in greedy_search
    outputs = self(
  File "/net/llm-shared-nfs/nfs/yelkurdi/conda/miniconda3/envs/llmpt/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1102, in _call_impl
    return forward_call(*input, **kwargs)
  File "/net/llm-shared-nfs/nfs/yelkurdi/conda/miniconda3/envs/llmpt/lib/python3.8/site-packages/deepspeed/inference/engine.py", line 508, in forward
    outputs = self.model_orig_fwd(*inputs, **kwargs)
  File "/net/llm-shared-nfs/nfs/yelkurdi/conda/miniconda3/envs/llmpt/lib/python3.8/site-packages/transformers/models/bloom/modeling_bloom.py", line 919, in forward
    transformer_outputs = self.transformer(
  File "/net/llm-shared-nfs/nfs/yelkurdi/conda/miniconda3/envs/llmpt/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1102, in _call_impl
    return forward_call(*input, **kwargs)
  File "/net/llm-shared-nfs/nfs/yelkurdi/conda/miniconda3/envs/llmpt/lib/python3.8/site-packages/transformers/models/bloom/modeling_bloom.py", line 806, in forward
    outputs = block(
  File "/net/llm-shared-nfs/nfs/yelkurdi/conda/miniconda3/envs/llmpt/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1102, in _call_impl
    return forward_call(*input, **kwargs)
  File "/net/llm-shared-nfs/nfs/yelkurdi/conda/miniconda3/envs/llmpt/lib/python3.8/site-packages/deepspeed/ops/transformer/inference/transformer_inference.py", line 831, in forward
    self.attention(input,
  File "/net/llm-shared-nfs/nfs/yelkurdi/conda/miniconda3/envs/llmpt/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1102, in _call_impl
    return forward_call(*input, **kwargs)
  File "/net/llm-shared-nfs/nfs/yelkurdi/conda/miniconda3/envs/llmpt/lib/python3.8/site-packages/deepspeed/ops/transformer/inference/transformer_inference.py", line 543, in forward
    output = DeepSpeedSelfAttentionFunction.apply(
  File "/net/llm-shared-nfs/nfs/yelkurdi/conda/miniconda3/envs/llmpt/lib/python3.8/site-packages/deepspeed/ops/transformer/inference/transformer_inference.py", line 466, in forward
    dist.all_reduce(output, group=mp_group)
  File "/net/llm-shared-nfs/nfs/yelkurdi/conda/miniconda3/envs/llmpt/lib/python3.8/site-packages/deepspeed/comm/comm.py", line 312, in all_reduce
    return cdb.all_reduce(tensor, op, group, async_op)
  File "/net/llm-shared-nfs/nfs/yelkurdi/conda/miniconda3/envs/llmpt/lib/python3.8/site-packages/deepspeed/comm/torch.py", line 49, in all_reduce
    return torch.distributed.all_reduce(tensor=tensor,
  File "/net/llm-shared-nfs/nfs/yelkurdi/conda/miniconda3/envs/llmpt/lib/python3.8/site-packages/torch/distributed/distributed_c10d.py", line 1287, in all_reduce
    work = group.allreduce([tensor], opts)
RuntimeError: NCCL error in: /opt/conda/conda-bld/pytorch_1639180588308/work/torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp:957, unhandled cuda error, NCCL version 21.0.3
ncclUnhandledCudaError: Call to CUDA function failed.
Traceback (most recent call last):
  File "scripts/inference/bloom-ds-inference.py", line 257, in <module>
    _ = generate()
  File "scripts/inference/bloom-ds-inference.py", line 244, in generate
    outputs = model.generate(**input_tokens, **generate_kwargs)
  File "/net/llm-shared-nfs/nfs/yelkurdi/conda/miniconda3/envs/llmpt/lib/python3.8/site-packages/torch/autograd/grad_mode.py", line 28, in decorate_context
    return func(*args, **kwargs)
  File "/net/llm-shared-nfs/nfs/yelkurdi/conda/miniconda3/envs/llmpt/lib/python3.8/site-packages/transformers/generation_utils.py", line 1288, in generate
    return self.greedy_search(
  File "/net/llm-shared-nfs/nfs/yelkurdi/conda/miniconda3/envs/llmpt/lib/python3.8/site-packages/transformers/generation_utils.py", line 1683, in greedy_search
    outputs = self(
  File "/net/llm-shared-nfs/nfs/yelkurdi/conda/miniconda3/envs/llmpt/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1102, in _call_impl
    return forward_call(*input, **kwargs)
  File "/net/llm-shared-nfs/nfs/yelkurdi/conda/miniconda3/envs/llmpt/lib/python3.8/site-packages/deepspeed/inference/engine.py", line 508, in forward
    outputs = self.model_orig_fwd(*inputs, **kwargs)
  File "/net/llm-shared-nfs/nfs/yelkurdi/conda/miniconda3/envs/llmpt/lib/python3.8/site-packages/transformers/models/bloom/modeling_bloom.py", line 919, in forward
    transformer_outputs = self.transformer(
  File "/net/llm-shared-nfs/nfs/yelkurdi/conda/miniconda3/envs/llmpt/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1102, in _call_impl
    return forward_call(*input, **kwargs)
  File "/net/llm-shared-nfs/nfs/yelkurdi/conda/miniconda3/envs/llmpt/lib/python3.8/site-packages/transformers/models/bloom/modeling_bloom.py", line 806, in forward
    outputs = block(
  File "/net/llm-shared-nfs/nfs/yelkurdi/conda/miniconda3/envs/llmpt/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1102, in _call_impl
    return forward_call(*input, **kwargs)
  File "/net/llm-shared-nfs/nfs/yelkurdi/conda/miniconda3/envs/llmpt/lib/python3.8/site-packages/deepspeed/ops/transformer/inference/transformer_inference.py", line 831, in forward
    self.attention(input,
  File "/net/llm-shared-nfs/nfs/yelkurdi/conda/miniconda3/envs/llmpt/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1102, in _call_impl
    return forward_call(*input, **kwargs)
  File "/net/llm-shared-nfs/nfs/yelkurdi/conda/miniconda3/envs/llmpt/lib/python3.8/site-packages/deepspeed/ops/transformer/inference/transformer_inference.py", line 543, in forward
    output = DeepSpeedSelfAttentionFunction.apply(
  File "/net/llm-shared-nfs/nfs/yelkurdi/conda/miniconda3/envs/llmpt/lib/python3.8/site-packages/deepspeed/ops/transformer/inference/transformer_inference.py", line 466, in forward
    dist.all_reduce(output, group=mp_group)
  File "/net/llm-shared-nfs/nfs/yelkurdi/conda/miniconda3/envs/llmpt/lib/python3.8/site-packages/deepspeed/comm/comm.py", line 312, in all_reduce
    return cdb.all_reduce(tensor, op, group, async_op)
  File "/net/llm-shared-nfs/nfs/yelkurdi/conda/miniconda3/envs/llmpt/lib/python3.8/site-packages/deepspeed/comm/torch.py", line 49, in all_reduce
    return torch.distributed.all_reduce(tensor=tensor,
  File "/net/llm-shared-nfs/nfs/yelkurdi/conda/miniconda3/envs/llmpt/lib/python3.8/site-packages/torch/distributed/distributed_c10d.py", line 1287, in all_reduce
    work = group.allreduce([tensor], opts)
RuntimeError: NCCL error in: /opt/conda/conda-bld/pytorch_1639180588308/work/torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp:957, unhandled cuda error, NCCL version 21.0.3
ncclUnhandledCudaError: Call to CUDA function failed.
Traceback (most recent call last):
  File "scripts/inference/bloom-ds-inference.py", line 257, in <module>
    _ = generate()
  File "scripts/inference/bloom-ds-inference.py", line 244, in generate
    outputs = model.generate(**input_tokens, **generate_kwargs)
  File "/net/llm-shared-nfs/nfs/yelkurdi/conda/miniconda3/envs/llmpt/lib/python3.8/site-packages/torch/autograd/grad_mode.py", line 28, in decorate_context
    return func(*args, **kwargs)
  File "/net/llm-shared-nfs/nfs/yelkurdi/conda/miniconda3/envs/llmpt/lib/python3.8/site-packages/transformers/generation_utils.py", line 1288, in generate
    return self.greedy_search(
  File "/net/llm-shared-nfs/nfs/yelkurdi/conda/miniconda3/envs/llmpt/lib/python3.8/site-packages/transformers/generation_utils.py", line 1683, in greedy_search
    outputs = self(
  File "/net/llm-shared-nfs/nfs/yelkurdi/conda/miniconda3/envs/llmpt/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1102, in _call_impl
    return forward_call(*input, **kwargs)
  File "/net/llm-shared-nfs/nfs/yelkurdi/conda/miniconda3/envs/llmpt/lib/python3.8/site-packages/deepspeed/inference/engine.py", line 508, in forward
    outputs = self.model_orig_fwd(*inputs, **kwargs)
  File "/net/llm-shared-nfs/nfs/yelkurdi/conda/miniconda3/envs/llmpt/lib/python3.8/site-packages/transformers/models/bloom/modeling_bloom.py", line 919, in forward
    transformer_outputs = self.transformer(
  File "/net/llm-shared-nfs/nfs/yelkurdi/conda/miniconda3/envs/llmpt/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1102, in _call_impl
    return forward_call(*input, **kwargs)
  File "/net/llm-shared-nfs/nfs/yelkurdi/conda/miniconda3/envs/llmpt/lib/python3.8/site-packages/transformers/models/bloom/modeling_bloom.py", line 806, in forward
    outputs = block(
  File "/net/llm-shared-nfs/nfs/yelkurdi/conda/miniconda3/envs/llmpt/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1102, in _call_impl
    return forward_call(*input, **kwargs)
  File "/net/llm-shared-nfs/nfs/yelkurdi/conda/miniconda3/envs/llmpt/lib/python3.8/site-packages/deepspeed/ops/transformer/inference/transformer_inference.py", line 831, in forward
    self.attention(input,
  File "/net/llm-shared-nfs/nfs/yelkurdi/conda/miniconda3/envs/llmpt/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1102, in _call_impl
    return forward_call(*input, **kwargs)
  File "/net/llm-shared-nfs/nfs/yelkurdi/conda/miniconda3/envs/llmpt/lib/python3.8/site-packages/deepspeed/ops/transformer/inference/transformer_inference.py", line 543, in forward
    output = DeepSpeedSelfAttentionFunction.apply(
  File "/net/llm-shared-nfs/nfs/yelkurdi/conda/miniconda3/envs/llmpt/lib/python3.8/site-packages/deepspeed/ops/transformer/inference/transformer_inference.py", line 466, in forward
    dist.all_reduce(output, group=mp_group)
  File "/net/llm-shared-nfs/nfs/yelkurdi/conda/miniconda3/envs/llmpt/lib/python3.8/site-packages/deepspeed/comm/comm.py", line 312, in all_reduce
    return cdb.all_reduce(tensor, op, group, async_op)
  File "/net/llm-shared-nfs/nfs/yelkurdi/conda/miniconda3/envs/llmpt/lib/python3.8/site-packages/deepspeed/comm/torch.py", line 49, in all_reduce
    return torch.distributed.all_reduce(tensor=tensor,
  File "/net/llm-shared-nfs/nfs/yelkurdi/conda/miniconda3/envs/llmpt/lib/python3.8/site-packages/torch/distributed/distributed_c10d.py", line 1287, in all_reduce
    work = group.allreduce([tensor], opts)
RuntimeError: NCCL error in: /opt/conda/conda-bld/pytorch_1639180588308/work/torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp:957, unhandled cuda error, NCCL version 21.0.3
ncclUnhandledCudaError: Call to CUDA function failed.
[2022-07-25 08:35:50,274] [INFO] [launch.py:178:sigkill_handler] Killing subprocess 915820
[2022-07-25 08:35:50,274] [INFO] [launch.py:178:sigkill_handler] Killing subprocess 915821
[2022-07-25 08:35:50,275] [INFO] [launch.py:178:sigkill_handler] Killing subprocess 915822
[2022-07-25 08:35:50,275] [INFO] [launch.py:178:sigkill_handler] Killing subprocess 915823
[2022-07-25 08:35:50,275] [INFO] [launch.py:178:sigkill_handler] Killing subprocess 915824
[2022-07-25 08:35:50,275] [INFO] [launch.py:178:sigkill_handler] Killing subprocess 915825
[2022-07-25 08:35:50,275] [INFO] [launch.py:178:sigkill_handler] Killing subprocess 915826
[2022-07-25 08:35:50,275] [INFO] [launch.py:178:sigkill_handler] Killing subprocess 915827
[2022-07-25 08:35:50,275] [ERROR] [launch.py:184:sigkill_handler] ['/net/llm-shared-nfs/nfs/yelkurdi/conda/miniconda3/envs/llmpt/bin/python', '-u', 'scripts/inference/bloom-ds-inference.py', '--local_rank=7', '--name', 'bigscience/bloom', '--benchmark'] exits with return code = -6

@stas00
Copy link
Contributor Author

stas00 commented Jul 25, 2022

Thank you for the full traceback, @xuyifanbupt - as we have 3 different unrelated implementations in this PR could you please create a new issue for the one you reported - as it's with Accelerate and most of this thread is debugging the ds-inference one. and tag @stas00 and @sgugger on it.

Please repaste your script and the full traceback and the rest of the notes you shared. Thank you!

@RezaYazdaniAminabadi
Copy link
Collaborator

Hi @mayank31398

You should be using the master branch on both DeepSpeed and HuggingFace. Just note that with pip install, you may not get the latest versions.
I have seen the same nccl error that I am debugging with other scenario that batch size is very large, but I am very surprised that you get it with batch 1!

Thanks,
Reza

@mayank31398
Copy link
Collaborator

mayank31398 commented Jul 26, 2022

Hi @mayank31398

You should be using the master branch on both DeepSpeed and HuggingFace. Just note that with pip install, you may not get the latest versions. I have seen the same nccl error that I am debugging with other scenario that batch size is very large, but I am very surprised that you get it with batch 1!

Thanks, Reza

I am still getting this issue in when I have installed DeepSpeed and HF on master branch.
I am not sure what is going wrong here

I tried with NCCL_DEBUG=INFO

llm-test-cluster-9:1281342:1283501 [1] include/alloc.h:50 NCCL WARN Cuda failure 'an illegal memory access was encountered'
llm-test-cluster-9:1281342:1283501 [1] NCCL INFO channel.cc:20 -> 1
llm-test-cluster-9:1281342:1283501 [1] NCCL INFO init.cc:373 -> 1
llm-test-cluster-9:1281342:1283501 [1] NCCL INFO init.cc:774 -> 1
llm-test-cluster-9:1281342:1283501 [1] NCCL INFO init.cc:904 -> 1
llm-test-cluster-9:1281342:1283501 [1] NCCL INFO group.cc:72 -> 1 [Async thread]

I see this just before the error traceback

@RezaYazdaniAminabadi yes, this is with batch size 1
Is this something to be fixed on my end? I am not sure, but this seems like a bug in deepspeed.

@mayank31398
Copy link
Collaborator

Also @RezaYazdaniAminabadi can you point out the branch which might contain the fix for this?
Since, inferencing with HF is quite slow and DS is not working correctly.
Me and my team is working on some RL based approaches using the output of LLMs, which needs inferencing to be quick.
So, I want to monitor this issue closely if possible.

@stas00
Copy link
Contributor Author

stas00 commented Jul 26, 2022

microsoft/DeepSpeed#2132

Copy link
Collaborator

@mayank31398 mayank31398 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The bloom-fix branch in DeepSpeed has been merged into master.
And the bloom-ds-inference.py is working fine now.
Thanks for this PR.

@mayank31398
Copy link
Collaborator

Can we merge this?
❤️

@stas00
Copy link
Contributor Author

stas00 commented Aug 3, 2022

I don't think this is going to go here, since it has nothing to do with Meg-DS. I was just using it as a dev PR so that the Deepspeed team could push directly into it.

Once things are stable and there are still some issues to resolve on the server side, we will merge this into transformers where this really belongs.

@mayank31398
Copy link
Collaborator

I don't think this is going to go here, since it has nothing to do with Meg-DS. I was just using it as a dev PR so that the Deepspeed team could push directly into it.

Once things are stable and there are still some issues to resolve on the server side, we will merge this into transformers where this really belongs.

There is already an inferencing script in main branch under the same directory. Not sure if it works
So, thought it would be better to have this branch merged into main branch.

@mayank31398
Copy link
Collaborator

mayank31398 commented Aug 4, 2022

@stas00 , I have currently deployed BLOOM in server mode using accelerate with batch size = 1
I see slowdown of the application over time. And GPU memory fills up.
Why does your code call torch.cuda.empty_cache() during benchmarking?
Is it necessary during inference?

The scripts can be found here: #325

This was when I initially deployed:
Screen Shot 2022-08-04 at 10 51 48 PM

and this is after a long period of usage:
Screen Shot 2022-08-04 at 10 51 19 PM

@stas00
Copy link
Contributor Author

stas00 commented Aug 5, 2022

I see slowdown of the application over time. And GPU memory fills up.

Perhaps there is a memory leak somewhere? Let's ask @sgugger, as he developed accelerate

Why does your code call torch.cuda.empty_cache() during benchmarking?
Is it necessary during inference?

Not at all, you don't want to do that in production most of the time, unless there is a special situation where you want to control memory freeing. I was just using it to see how much real memory was used - since pytorch tends to cache memory. so nvidia-smi doesn't tell you the real story w/o it.

@stas00
Copy link
Contributor Author

stas00 commented Aug 5, 2022

There is already an inferencing script in main branch under the same directory. Not sure if it works
So, thought it would be better to have this branch merged into main branch.

It was my bad, I pushed the initial version into master as I thought it was done, but then opened this PR, so that initial version is very outdated.

@mayank31398
Copy link
Collaborator

mayank31398 commented Aug 7, 2022

I would like to add generation server scripts from #325 to this branch

@mayank31398
Copy link
Collaborator

mayank31398 commented Aug 8, 2022

I am still seeing an illegal memory access error for batch size = 2 @RezaYazdaniAminabadi @stas00
and
UnhandledCudaError: Call to CUDA function failed. with batch size=4

@mayank31398
Copy link
Collaborator

mayank31398 commented Aug 8, 2022

@jeffra
This might be the error. This is not in DS-MII though. It is in the script bloom-ds-inference here. but looks similar

!!!! kernel execution error. (m: 5376, n: 2, k: 14336, error: 13)
Traceback (most recent call last):
  File "scripts/inference/bloom-ds-inference.py", line 257, in <module>
    _ = generate()
  File "scripts/inference/bloom-ds-inference.py", line 244, in generate
    outputs = model.generate(**input_tokens, **generate_kwargs)
  File "/net/llm-shared-nfs/nfs/yelkurdi/conda/miniconda3/envs/bloom/lib/python3.8/site-packages/torch/autograd/grad_mode.py", line 27, in decorate_context
    return func(*args, **kwargs)
  File "/net/llm-shared-nfs/nfs/yelkurdi/conda/miniconda3/envs/bloom/lib/python3.8/site-packages/transformers/generation_utils.py", line 1288, in generate
    return self.greedy_search(
  File "/net/llm-shared-nfs/nfs/yelkurdi/conda/miniconda3/envs/bloom/lib/python3.8/site-packages/transformers/generation_utils.py", line 1683, in greedy_search
    outputs = self(
  File "/net/llm-shared-nfs/nfs/yelkurdi/conda/miniconda3/envs/bloom/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1110, in _call_impl
    return forward_call(*input, **kwargs)
  File "/net/llm-shared-nfs/nfs/mayank/DeepSpeed/deepspeed/inference/engine.py", line 521, in forward
    outputs = self.model_orig_fwd(*inputs, **kwargs)
  File "/net/llm-shared-nfs/nfs/yelkurdi/conda/miniconda3/envs/bloom/lib/python3.8/site-packages/transformers/models/bloom/modeling_bloom.py", line 919, in forward
    transformer_outputs = self.transformer(
  File "/net/llm-shared-nfs/nfs/yelkurdi/conda/miniconda3/envs/bloom/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1110, in _call_impl
    return forward_call(*input, **kwargs)
  File "/net/llm-shared-nfs/nfs/yelkurdi/conda/miniconda3/envs/bloom/lib/python3.8/site-packages/transformers/models/bloom/modeling_bloom.py", line 806, in forward
    outputs = block(
  File "/net/llm-shared-nfs/nfs/yelkurdi/conda/miniconda3/envs/bloom/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1110, in _call_impl
    return forward_call(*input, **kwargs)
  File "/net/llm-shared-nfs/nfs/mayank/DeepSpeed/deepspeed/ops/transformer/inference/transformer_inference.py", line 829, in forward
    self.attention(input,
  File "/net/llm-shared-nfs/nfs/yelkurdi/conda/miniconda3/envs/bloom/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1110, in _call_impl
    return forward_call(*input, **kwargs)
  File "/net/llm-shared-nfs/nfs/mayank/DeepSpeed/deepspeed/ops/transformer/inference/transformer_inference.py", line 541, in forward
    output = DeepSpeedSelfAttentionFunction.apply(
  File "/net/llm-shared-nfs/nfs/mayank/DeepSpeed/deepspeed/ops/transformer/inference/transformer_inference.py", line 461, in forward
    output, key_layer, value_layer, context_layer, inp_norm = selfAttention_fp()
  File "/net/llm-shared-nfs/nfs/mayank/DeepSpeed/deepspeed/ops/transformer/inference/transformer_inference.py", line 425, in selfAttention_fp
    context_layer, key_layer, value_layer = compute_attention(qkv_out[0] if isinstance(qkv_out, list) else qkv_out, input_mask)
  File "/net/llm-shared-nfs/nfs/mayank/DeepSpeed/deepspeed/ops/transformer/inference/transformer_inference.py", line 373, in compute_attention
    context_layer, presents = backup_attention(qkv_out, layer_past, alibi, input_mask, norm_factor)
  File "/net/llm-shared-nfs/nfs/mayank/DeepSpeed/deepspeed/ops/transformer/inference/transformer_inference.py", line 194, in backup_attention
    alibi = alibi.to(torch.cuda.current_device())
RuntimeError: CUDA error: an illegal memory access was encountered
CUDA kernel errors might be asynchronously reported at some other API call,so the stacktrace below might be incorrect.
For debugging consider passing CUDA_LAUNCH_BLOCKING=1.

@mayank31398
Copy link
Collaborator

@stas00 , I see the memory configuration is [0, 51, 51, 51, 51, 51, 51, 51].
I tried using [45, 45, 45, 45, 45, 45, 45, 45] (symmetricall config) and I didn't see any change in the generation throughput.
Can you confirm?

@mayank31398
Copy link
Collaborator

mayank31398 commented Aug 10, 2022

Also, @stas00 #325 use a lot of code that is written in your scripts. If #325 can be merged into this branch, that would be helpful.
Let me know if you feel that is right or not.

I have also done a bit of code refactoring so these duplicate methods can be reused across scripts.

@stas00
Copy link
Contributor Author

stas00 commented Aug 10, 2022

@stas00 , I see the memory configuration is [0, 51, 51, 51, 51, 51, 51, 51]. I tried using [45, 45, 45, 45, 45, 45, 45, 45] (symmetricall config) and I didn't see any change in the generation throughput. Can you confirm?

The only reason to keep the first gpu unallocated with model weights is to allow for a much higher BS.

If you don't need the higher BS, you don't need to do that.

@stas00
Copy link
Contributor Author

stas00 commented Aug 10, 2022

wrt merging - let me merge this PR for now.

then rebase your PR and refactor, then we will see how to proceed.

then later we will move the whole thing into transformers.

I gave you write access to the repo, so it'll be easier for you to contribute. Just please don't push directly w/o a PR and wait for at least one approval from another members before merging.

@stas00 stas00 merged commit 3932c74 into main Aug 10, 2022
@stas00 stas00 deleted the bloom-inference branch August 10, 2022 19:28
Muennighoff added a commit that referenced this pull request Aug 17, 2022
* Reshape deepspeed checkpoint (#239)

* Reshape deepspeed checkpoint

* add checkpoint tests

* Validate input folder

* Tests for tp/pp reshape

* remove debug folders

* fix test_checkpoint_reshaping_empty_dir

* Fix unit tests

* Remove deepspeed checkpoint utils

* Use DS 3D reshaping utils

* convert to bf16

* wip universal chkpt

* rename

* rename

* wip on fragments dealing

* cleanup

* Loading universal checkpoint with reshaping

* all gpu1<->2 reshapes work

* param attrs

* make the tests adaptable to the number of available gpus

* WIP

* WIP

* WIP

* WIP

* Debug functions

* args should be required, don't create another latest file

* Parallelize shard extraction

* close+join pool; add tqdm; comment out noise

* rename

* parameterize

* Parallel slice merging

* Cleanup

* allow inspection on a machine w/o gpus

* test against the right DS branch

* DS size was merged

Co-authored-by: Stas Bekman <[email protected]>

* BLOOM Inference via DeepSpeed-Inference, Accelerate and DeepSpeed-ZeRO (#308)

* hardcode the dtype depending on the model

* change the mp based on the world_size

* remove hardcoded world_size

* add bigscience/bigscience-small-testing

* fixes

* add zero-inference script

* fixes

* fix

* working script

* renames

* fixes

* fix for offline use

* add benchmark

* add benchmark

* update

* cleanup

* update

* msecs

* cleanup

* improve

* fix benchmark, add warmup

* update

* fix; thanks Michael Wyatt

* clarify

* add bloom batch-inference script

* removed the names :-)

* fold the bs functionality from the other script

* fix

* restore do_sample

* dump generate args

* fix

* fix

* support any batchsize

* div by bs

* mul by bs

* add cpu_offload; sync scripts

* wip

* improvements

* fixes

* fixes

* add accelerate script

* fix

* wip

* wip

* stats

* add OnDevice and remove zero-inference (#316)

* wip

* rework generate + benchmark

* figure out the memory map dynamically

* bug fix

* fix ds-zero-inference wrt device

* bug fix

* update

* update

* fix

Co-authored-by: Reza Yazdani <[email protected]>
Co-authored-by: Jeff Rasley <[email protected]>

Co-authored-by: Olatunji Ruwase <[email protected]>
Co-authored-by: Stas Bekman <[email protected]>
Co-authored-by: Stas Bekman <[email protected]>
Co-authored-by: Reza Yazdani <[email protected]>
Co-authored-by: Jeff Rasley <[email protected]>
younesbelkada pushed a commit to younesbelkada/Megatron-DeepSpeed that referenced this pull request Sep 28, 2022
bigscience-workshop#308)

* hardcode the dtype depending on the model

* change the mp based on the world_size

* remove hardcoded world_size

* add bigscience/bigscience-small-testing

* fixes

* add zero-inference script

* fixes

* fix

* working script

* renames

* fixes

* fix for offline use

* add benchmark

* add benchmark

* update

* cleanup

* update

* msecs

* cleanup

* improve

* fix benchmark, add warmup

* update

* fix; thanks Michael Wyatt

* clarify

* add bloom batch-inference script

* removed the names :-)

* fold the bs functionality from the other script

* fix

* restore do_sample

* dump generate args

* fix

* fix

* support any batchsize

* div by bs

* mul by bs

* add cpu_offload; sync scripts

* wip

* improvements

* fixes

* fixes

* add accelerate script

* fix

* wip

* wip

* stats

* add OnDevice and remove zero-inference (bigscience-workshop#316)

* wip

* rework generate + benchmark

* figure out the memory map dynamically

* bug fix

* fix ds-zero-inference wrt device

* bug fix

* update

* update

* fix

Co-authored-by: Reza Yazdani <[email protected]>
Co-authored-by: Jeff Rasley <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

10 participants