Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

BLOOM Inference via DeepSpeed-Inference, Accelerate and DeepSpeed-ZeRO #308

Merged
merged 58 commits into from
Aug 10, 2022
Merged
Show file tree
Hide file tree
Changes from 52 commits
Commits
Show all changes
58 commits
Select commit Hold shift + click to select a range
efa354d
hardcode the dtype depending on the model
stas00 Jul 10, 2022
cafc3f5
change the mp based on the world_size
Jul 10, 2022
daeb293
remove hardcoded world_size
stas00 Jul 10, 2022
7d5f7d4
add bigscience/bigscience-small-testing
stas00 Jul 10, 2022
2d3d271
Merge branch 'bloom-inference' of https://github.com/bigscience-works…
Jul 10, 2022
1ff0f69
fixes
stas00 Jul 10, 2022
56b24ed
add zero-inference script
stas00 Jul 10, 2022
67aab37
fixes
stas00 Jul 11, 2022
328ab0c
fix
stas00 Jul 11, 2022
f2628b0
working script
stas00 Jul 12, 2022
195288e
renames
stas00 Jul 12, 2022
3c7b2cb
fixes
stas00 Jul 12, 2022
6c5c23b
fix for offline use
stas00 Jul 13, 2022
6b19227
add benchmark
stas00 Jul 13, 2022
10cbb2d
add benchmark
stas00 Jul 13, 2022
494c212
update
stas00 Jul 13, 2022
2b67c0d
cleanup
stas00 Jul 13, 2022
3853724
update
stas00 Jul 13, 2022
1896739
msecs
stas00 Jul 13, 2022
7c9daaf
cleanup
stas00 Jul 13, 2022
dca2c8f
improve
stas00 Jul 13, 2022
85580c0
fix benchmark, add warmup
stas00 Jul 13, 2022
5ea3dee
update
stas00 Jul 13, 2022
737c681
fix; thanks Michael Wyatt
stas00 Jul 13, 2022
6be0cca
clarify
stas00 Jul 13, 2022
fea3902
Merge branch 'bloom-inference' of https://github.com/bigscience-works…
Jul 13, 2022
fc9b458
add bloom batch-inference script
Jul 13, 2022
7b0edef
removed the names :-)
Jul 13, 2022
2120dd2
fold the bs functionality from the other script
stas00 Jul 13, 2022
78bcbb7
fix
stas00 Jul 13, 2022
e7468cd
restore do_sample
stas00 Jul 13, 2022
68f5ca6
dump generate args
stas00 Jul 13, 2022
1eca7c5
fix
stas00 Jul 14, 2022
8815fc3
fix
stas00 Jul 14, 2022
034cc6f
support any batchsize
stas00 Jul 14, 2022
155c3c3
div by bs
stas00 Jul 14, 2022
73a8b7b
mul by bs
stas00 Jul 14, 2022
09d7408
add cpu_offload; sync scripts
stas00 Jul 14, 2022
695265d
wip
stas00 Jul 14, 2022
1a7e891
improvements
stas00 Jul 15, 2022
aba4055
fixes
stas00 Jul 15, 2022
5e92d55
fixes
stas00 Jul 15, 2022
3992112
add accelerate script
stas00 Jul 15, 2022
5a7057b
fix
stas00 Jul 15, 2022
4758531
wip
stas00 Jul 16, 2022
7550ee0
wip
stas00 Jul 16, 2022
5153c40
stats
stas00 Jul 18, 2022
cb50ea5
add OnDevice and remove zero-inference (#316)
jeffra Jul 19, 2022
a53fcaa
wip
stas00 Jul 19, 2022
7252879
rework generate + benchmark
stas00 Jul 19, 2022
2aa419d
figure out the memory map dynamically
stas00 Jul 19, 2022
4bd8ca5
bug fix
stas00 Jul 19, 2022
b76e516
fix ds-zero-inference wrt device
stas00 Jul 19, 2022
ecfd577
bug fix
stas00 Jul 20, 2022
fd26b9c
update
stas00 Jul 20, 2022
e2bfe91
update
stas00 Jul 22, 2022
b9a67ea
fix
stas00 Aug 9, 2022
3862ef0
Merge remote-tracking branch 'origin/main' into bloom-inference
stas00 Aug 9, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
195 changes: 195 additions & 0 deletions scripts/inference/README.md
Original file line number Diff line number Diff line change
@@ -1 +1,196 @@
# Inference scripts for BLOOM

## BLOOM Inference solutions

Here are some stats on JeanZay's 8x80GB A100 node w/ 512GB of CPU memory:

All benchmarks are doing greedy generation of 100 token outputs:
```
Generate args {'min_length': 100, 'max_length': 100, 'do_sample': False}
```
The inputs are just a few tokens.

Throughput in msecs:

| project \ bs | 1 | 8 | 16 | 32 | 64 | 128 |
| :----------- | :---- | :---- | :---- | :---- | :---- | :--- |
| accelerate | 230.38 | 31.78 | 17.84 | 10.89 | oom | omm |
| ds-inference | 40.57 | 5.23 | | | 2.77 | 0.66 |
| ds-zero | 283 | 34.88 | oom | oom | oom | oom |


Start to ready to generate in secs:

| project \ bs | 1 | 8 | 16 | 32 | 64 | 128 |
| :----------- | :--- | :--- | :--- | :--- | :--- | :--- |
| accelerate | 121 | 120 | 113 | 118 | | |
| ds-inference | 662 | 673 | | | 685 | 654 |
| ds-zero | 462 | 463 | | | | |
Comment on lines +22 to +28
Copy link
Collaborator

@mayank31398 mayank31398 Jul 21, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hey @stas00, I am closely monitoring this issue ❤️.
What is throughput? Is it the generation time per token?
Also, any idea why it decreases with BS?

I tried to run inference using HF (which uses accelerate).
But the generation times were way too high even using 8 GPUs (A100 80GBs).

https://huggingface.co/bigscience/bloom/discussions/59

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Never mind, I figured it out.
Turns out using a symmetric sharding of weights works worse than your benchmarks ❤️

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is not throughput, this is time-to-ready-to-generate. and it's not benchmarked - just one measurement, so it fluctuates. I probably should remove all but one column here

Copy link
Contributor Author

@stas00 stas00 Jul 21, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The Deepspeed team are working on making a special checkpoint that already has the TP weights pre-sharded, so it should reduce the load time to about 1-2min - on par with accelerate.

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@stas00 Do you know of any tracking issue or place to watch for those pre-sharded weights? I am seeing up to 15 mins of setup time.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes, 10-15min is the current load time - depending on the speed of your IO.

As I'm not the one doing that work I am not sure where to watch the progress but I will update this PR as soon as there is a working sharded version, so that you could play with. I doubt it'd take very long time.

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hey @stas00, do you know if the code to do TP weights pre-sharded from DeepSpeed Team is open sourced ?

| | | | | | | |


DS-Inference load time (start to ready to generate) will become much faster soon. Once we stop relying on ds-zero to instantiate the model on gpu. The plan is to pre-shard the weights TP-wise for 8x and 16x gpus and load them directly on each gpu. Will probably be under 1min.


## Deepspeed-Inference

Tensor-Parallelism and efficient fused CUDA kernels:
https://www.deepspeed.ai/tutorials/inference-tutorial/

### Setup

```
git clone https://github.com/microsoft/DeepSpeed
cd DeepSpeed
git checkout ds-inference/bloom-support
pip install .
```

### Run

```
deepspeed --num_gpus 8 scripts/inference/bloom-ds-inference.py --name bigscience/bloom
```

Performance on a single node of 8x80GB A100 w/ 512GB CPU RAM (JeanZay) - just a batch of 1 (would be more efficient to run a larger batch)

Adding `--benchmark` to activate the benchmarks


BS=1
```
$ deepspeed --num_gpus 8 scripts/inference/bloom-ds-inference.py --name bigscience/bloom --batch_size 1 --benchmark 2>&1 | tee bloom-ds-inference_bs=1.txt
[...]

```

While processing memory per process:

- GPU: ~50GB
- CPU: ~10GB


BS=8
```
$ deepspeed --num_gpus 8 scripts/inference/bloom-ds-inference.py --name bigscience/bloom --batch_size 8 --benchmark 2>&1 | tee bloom-ds-inference_bs=8.txt
[...]
*** Performance stats:
Throughput per token including tokenize: 5.23 msecs
Start to ready to generate: 683.397 secs
Tokenize and generate 800 (bs=8) tokens: 4.241 secs
Start to finish: 687.638 secs
```

BS=64

```
$ deepspeed --num_gpus 8 scripts/inference/bloom-ds-inference.py --name bigscience/bloom --batch_size 64 --benchmark 2>&1 | tee bloom-ds-inference_bs=64.txt




```

BS=128

```
$ deepspeed --num_gpus 8 scripts/inference/bloom-ds-inference.py --name bigscience/bloom --batch_size 128 --benchmark 2>&1 | tee bloom-ds-inference_bs=128.txt




```

## Deepspeed ZeRO-Inference

https://www.deepspeed.ai/tutorials/zero/

### Setup

```
pip install deepspeed
```


### Run

Note that the script currently runs the same inputs on all GPUs, but you can run a different stream on each GPU, and get `n_gpu` times faster throughput. You can't do that with Deepspeed-Inference.


BS=1

```
$ deepspeed --num_gpus 8 scripts/inference/bloom-ds-zero-inference.py --name bigscience/bloom --batch_size 1 --benchmark 2>&1 | tee bloom-ds-zero-inference_bs=1.txt
[...]
*** Performance stats:
Throughput per token including tokenize: 282.93 msecs
Start to ready to generate: 501.871 secs
Tokenize and generate 800 (bs=1) tokens: 226.188 secs
Start to finish: 728.060 secs
```


BS=8

```
$ deepspeed --num_gpus 8 scripts/inference/bloom-ds-zero-inference.py --name bigscience/bloom --batch_size 8 --benchmark 2>&1 | tee bloom-ds-zero-inference_bs=8.txt
[...]

*** Performance stats:
Throughput per token including tokenize: 34.57 msecs
Start to ready to generate: 482.132 secs
Tokenize and generate 6400 (bs=8) tokens: 221.236 secs
Start to finish: 703.368 secs
```

BS=16 and higher OOMs

```
$ deepspeed --num_gpus 8 scripts/inference/bloom-ds-zero-inference.py --name bigscience/bloom --batch_size 16 --benchmark 2>&1 | tee bloom-ds-zero-inference_bs=16.txt
[...]
OOM

```



## HF Accelerate

https://github.com/huggingface/accelerate

### Setup

```
pip install transformers
```



### Run




BS=1
```
$ python scripts/inference/bloom-accelerate-inference.py --name bigscience/bloom --batch_size 1 --benchmark 2>&1 | tee bloom-ds-zero-inference_bs=1.txt
[...]


```

BS=8
```
$ python scripts/inference/bloom-accelerate-inference.py --name bigscience/bloom --batch_size 8 --benchmark 2>&1 | tee bloom-ds-zero-inference_bs=8.txt
[...]


```

BS=16
```
$ python scripts/inference/bloom-accelerate-inference.py --name bigscience/bloom --batch_size 16 --benchmark 2>&1 | tee bloom-ds-zero-inference_bs=16.txt
[...]


```
184 changes: 184 additions & 0 deletions scripts/inference/bloom-accelerate-inference.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,184 @@
import argparse
import time
import os
import gc
import torch
import math
from transformers import AutoTokenizer, AutoConfig, AutoModelForCausalLM

def get_args():
parser = argparse.ArgumentParser()
parser.add_argument("--local_rank", required=False, type=int, help="used by dist launchers")
parser.add_argument("--name", type=str, help="Name path", required=True)
parser.add_argument("--batch_size", default=1, type=int, help="batch size")
parser.add_argument("--benchmark", action="store_true", help="additionally run benchmark")
parser.add_argument("--greedy", action="store_true")
parser.add_argument("--top-k", type=int, default=0)
parser.add_argument("--top-p", type=float, default=0.)

return parser.parse_args()

def get_max_memory_per_gpu_dict(dtype, model_name):
""" try to generate the memory map based on what we know about the model and the available hardware """

# figure out the memory map - the minimum per gpu required to load the model
n_gpus = torch.cuda.device_count()

if model_name == "bigscience/bloom" and n_gpus == 8 and torch.cuda.get_device_properties(0).total_memory > 79*2**30:
# hand crafted optimized memory map for 8x80 setup over BLOOM
# this works with bs=48
return {0: '0GIB', 1: '51GIB', 2: '51GIB', 3: '51GIB', 4: '51GIB', 5: '51GIB', 6: '51GIB', 7: '51GIB'}

try:
# model_params calculation, as we don't have a model yet to do:
#model_params = sum(dict((p.data_ptr(), p.numel()) for p in model.parameters()).values())

config = AutoConfig.from_pretrained(model_name)
h = config.n_embed
l = config.n_layer
v = config.vocab_size
# from https://github.com/bigscience-workshop/bigscience/tree/a3e451498ee8189d2a9dd47be19aa89b0e16cd89/math#model-sizing
model_params = l*(12*h**2 + 13*h) + v*h + 4*h
except:
print(f"The model {model_name} has a broken config file. Please notify the owner")
raise

bytes = torch.finfo(dtype).bits / 8
param_memory_total_in_bytes = model_params * bytes
# add 5% since weight sizes aren't the same and some GPU may need more memory
param_memory_per_gpu_in_bytes = int(param_memory_total_in_bytes / n_gpus * 1.05)
print(f"Estimating {param_memory_per_gpu_in_bytes/2**30:0.2f}GB per gpu for weights")

# check the real available memory
# load cuda kernels first and only measure the real free memory after loading (shorter by ~2GB)
torch.ones(1).cuda()
max_memory_per_gpu_in_bytes = torch.cuda.mem_get_info(0)[0]
if max_memory_per_gpu_in_bytes < param_memory_per_gpu_in_bytes:
raise ValueError(f"Unable to generate the memory map automatically as the needed estimated memory per gpu ({param_memory_per_gpu_in_bytes/2**30:0.2f}GB) is bigger than the available per gpu memory ({max_memory_per_gpu_in_bytes/2**30:0.2f}GB)")

return {i: param_memory_per_gpu_in_bytes for i in range(torch.cuda.device_count())}

t_start = time.time()

num_tokens = 100

args = get_args()

local_rank = int(os.getenv('LOCAL_RANK', '0'))
world_size = int(os.getenv('WORLD_SIZE', '1'))

rank = local_rank

model_name = args.name
if rank == 0:
print(f"Loading model {model_name}")


tokenizer = AutoTokenizer.from_pretrained(model_name)

# XXX: can't automatically derive dtype via config's `from_pretrained`
dtype = torch.bfloat16 if model_name in ["bigscience/bloom", "bigscience/bigscience-small-testing"] else torch.float16

#print(get_max_memory_per_gpu_dict())


model = AutoModelForCausalLM.from_pretrained(
model_name,
device_map="auto",
max_memory=get_max_memory_per_gpu_dict(dtype, model_name),
torch_dtype=dtype,
)


if args.benchmark:
t_ready = time.time()



### Generate

if rank == 0:
print(f"*** Starting to generate {num_tokens} tokens with bs={args.batch_size}")

input_sentences = [
"DeepSpeed is a machine learning framework",
"He is working on",
"He has a",
"He got all",
"Everyone is happy and I can",
"The new movie that got Oscar this year",
"In the far far distance from our galaxy,",
"Peace is the only way"
]

if args.batch_size > len(input_sentences):
# dynamically extend to support larger bs by repetition
input_sentences *= math.ceil(args.batch_size / len(input_sentences))

generate_kwargs = dict(max_new_tokens=num_tokens, do_sample=False)

if rank == 0:
print(f"Generate args {generate_kwargs}")
inputs = input_sentences[:args.batch_size]
def generate():
""" returns a list of zipped inputs, outputs and number of new tokens """

input_tokens = tokenizer.batch_encode_plus(inputs, return_tensors="pt", padding=True)
for t in input_tokens:
if torch.is_tensor(input_tokens[t]):
input_tokens[t] = input_tokens[t].to("cuda:0")

outputs = model.generate(**input_tokens, **generate_kwargs)

input_tokens_lengths = [x.shape[0] for x in input_tokens.input_ids]
output_tokens_lengths = [x.shape[0] for x in outputs]

total_new_tokens = [o-i for i,o in zip(input_tokens_lengths, output_tokens_lengths)]
outputs = tokenizer.batch_decode(outputs, skip_special_tokens=True)

return zip(inputs, outputs, total_new_tokens)

# warmup is a must if measuring speed as it's when all the optimizations are performed
# e.g. on 8x80 a100 the first pass of 100 tokens takes 23sec, and the next one is 4secs
_ = generate()

t_generate_start = time.time()
generated = generate()
t_generate_span = time.time() - t_generate_start
if rank == 0:
for i,o,_ in generated:
print(f"{'-'*60}\nin={i}\nout={o}\n")


if args.benchmark:
torch.cuda.empty_cache()
gc.collect()

### Benchmark

if args.benchmark:
if rank == 0:
print(f"*** Running benchmark")

# warm up
for i in range(1):
_ = generate()
torch.cuda.synchronize()

# benchmark
t0 = time.time()
cycles = 5
total_new_tokens_generated = 0
for i in range(cycles):
generated = generate()
total_new_tokens_generated += sum(new_tokens for _,_,new_tokens in generated)
torch.cuda.synchronize()
if rank == 0:
througput = (time.time() - t0)/(total_new_tokens_generated)
print(f"""
*** Performance stats:
Throughput per token including tokenize: {througput*1000:.2f} msecs
Start to ready to generate: {t_ready - t_start:.3f} secs
Tokenize and generate {total_new_tokens_generated} (bs={args.batch_size}) tokens: {t_generate_span:.3f} secs
Start to finish: {t_ready - t_start + t_generate_span:.3f} secs
""")
Loading