-
Notifications
You must be signed in to change notification settings - Fork 221
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
BLOOM Inference via DeepSpeed-Inference, Accelerate and DeepSpeed-ZeRO #308
Merged
Merged
Changes from 52 commits
Commits
Show all changes
58 commits
Select commit
Hold shift + click to select a range
efa354d
hardcode the dtype depending on the model
stas00 cafc3f5
change the mp based on the world_size
daeb293
remove hardcoded world_size
stas00 7d5f7d4
add bigscience/bigscience-small-testing
stas00 2d3d271
Merge branch 'bloom-inference' of https://github.com/bigscience-works…
1ff0f69
fixes
stas00 56b24ed
add zero-inference script
stas00 67aab37
fixes
stas00 328ab0c
fix
stas00 f2628b0
working script
stas00 195288e
renames
stas00 3c7b2cb
fixes
stas00 6c5c23b
fix for offline use
stas00 6b19227
add benchmark
stas00 10cbb2d
add benchmark
stas00 494c212
update
stas00 2b67c0d
cleanup
stas00 3853724
update
stas00 1896739
msecs
stas00 7c9daaf
cleanup
stas00 dca2c8f
improve
stas00 85580c0
fix benchmark, add warmup
stas00 5ea3dee
update
stas00 737c681
fix; thanks Michael Wyatt
stas00 6be0cca
clarify
stas00 fea3902
Merge branch 'bloom-inference' of https://github.com/bigscience-works…
fc9b458
add bloom batch-inference script
7b0edef
removed the names :-)
2120dd2
fold the bs functionality from the other script
stas00 78bcbb7
fix
stas00 e7468cd
restore do_sample
stas00 68f5ca6
dump generate args
stas00 1eca7c5
fix
stas00 8815fc3
fix
stas00 034cc6f
support any batchsize
stas00 155c3c3
div by bs
stas00 73a8b7b
mul by bs
stas00 09d7408
add cpu_offload; sync scripts
stas00 695265d
wip
stas00 1a7e891
improvements
stas00 aba4055
fixes
stas00 5e92d55
fixes
stas00 3992112
add accelerate script
stas00 5a7057b
fix
stas00 4758531
wip
stas00 7550ee0
wip
stas00 5153c40
stats
stas00 cb50ea5
add OnDevice and remove zero-inference (#316)
jeffra a53fcaa
wip
stas00 7252879
rework generate + benchmark
stas00 2aa419d
figure out the memory map dynamically
stas00 4bd8ca5
bug fix
stas00 b76e516
fix ds-zero-inference wrt device
stas00 ecfd577
bug fix
stas00 fd26b9c
update
stas00 e2bfe91
update
stas00 b9a67ea
fix
stas00 3862ef0
Merge remote-tracking branch 'origin/main' into bloom-inference
stas00 File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1 +1,196 @@ | ||
# Inference scripts for BLOOM | ||
|
||
## BLOOM Inference solutions | ||
|
||
Here are some stats on JeanZay's 8x80GB A100 node w/ 512GB of CPU memory: | ||
|
||
All benchmarks are doing greedy generation of 100 token outputs: | ||
``` | ||
Generate args {'min_length': 100, 'max_length': 100, 'do_sample': False} | ||
``` | ||
The inputs are just a few tokens. | ||
|
||
Throughput in msecs: | ||
|
||
| project \ bs | 1 | 8 | 16 | 32 | 64 | 128 | | ||
| :----------- | :---- | :---- | :---- | :---- | :---- | :--- | | ||
| accelerate | 230.38 | 31.78 | 17.84 | 10.89 | oom | omm | | ||
| ds-inference | 40.57 | 5.23 | | | 2.77 | 0.66 | | ||
| ds-zero | 283 | 34.88 | oom | oom | oom | oom | | ||
|
||
|
||
Start to ready to generate in secs: | ||
|
||
| project \ bs | 1 | 8 | 16 | 32 | 64 | 128 | | ||
| :----------- | :--- | :--- | :--- | :--- | :--- | :--- | | ||
| accelerate | 121 | 120 | 113 | 118 | | | | ||
| ds-inference | 662 | 673 | | | 685 | 654 | | ||
| ds-zero | 462 | 463 | | | | | | ||
| | | | | | | | | ||
|
||
|
||
DS-Inference load time (start to ready to generate) will become much faster soon. Once we stop relying on ds-zero to instantiate the model on gpu. The plan is to pre-shard the weights TP-wise for 8x and 16x gpus and load them directly on each gpu. Will probably be under 1min. | ||
|
||
|
||
## Deepspeed-Inference | ||
|
||
Tensor-Parallelism and efficient fused CUDA kernels: | ||
https://www.deepspeed.ai/tutorials/inference-tutorial/ | ||
|
||
### Setup | ||
|
||
``` | ||
git clone https://github.com/microsoft/DeepSpeed | ||
cd DeepSpeed | ||
git checkout ds-inference/bloom-support | ||
pip install . | ||
``` | ||
|
||
### Run | ||
|
||
``` | ||
deepspeed --num_gpus 8 scripts/inference/bloom-ds-inference.py --name bigscience/bloom | ||
``` | ||
|
||
Performance on a single node of 8x80GB A100 w/ 512GB CPU RAM (JeanZay) - just a batch of 1 (would be more efficient to run a larger batch) | ||
|
||
Adding `--benchmark` to activate the benchmarks | ||
|
||
|
||
BS=1 | ||
``` | ||
$ deepspeed --num_gpus 8 scripts/inference/bloom-ds-inference.py --name bigscience/bloom --batch_size 1 --benchmark 2>&1 | tee bloom-ds-inference_bs=1.txt | ||
[...] | ||
|
||
``` | ||
|
||
While processing memory per process: | ||
|
||
- GPU: ~50GB | ||
- CPU: ~10GB | ||
|
||
|
||
BS=8 | ||
``` | ||
$ deepspeed --num_gpus 8 scripts/inference/bloom-ds-inference.py --name bigscience/bloom --batch_size 8 --benchmark 2>&1 | tee bloom-ds-inference_bs=8.txt | ||
[...] | ||
*** Performance stats: | ||
Throughput per token including tokenize: 5.23 msecs | ||
Start to ready to generate: 683.397 secs | ||
Tokenize and generate 800 (bs=8) tokens: 4.241 secs | ||
Start to finish: 687.638 secs | ||
``` | ||
|
||
BS=64 | ||
|
||
``` | ||
$ deepspeed --num_gpus 8 scripts/inference/bloom-ds-inference.py --name bigscience/bloom --batch_size 64 --benchmark 2>&1 | tee bloom-ds-inference_bs=64.txt | ||
|
||
|
||
|
||
|
||
``` | ||
|
||
BS=128 | ||
|
||
``` | ||
$ deepspeed --num_gpus 8 scripts/inference/bloom-ds-inference.py --name bigscience/bloom --batch_size 128 --benchmark 2>&1 | tee bloom-ds-inference_bs=128.txt | ||
|
||
|
||
|
||
|
||
``` | ||
|
||
## Deepspeed ZeRO-Inference | ||
|
||
https://www.deepspeed.ai/tutorials/zero/ | ||
|
||
### Setup | ||
|
||
``` | ||
pip install deepspeed | ||
``` | ||
|
||
|
||
### Run | ||
|
||
Note that the script currently runs the same inputs on all GPUs, but you can run a different stream on each GPU, and get `n_gpu` times faster throughput. You can't do that with Deepspeed-Inference. | ||
|
||
|
||
BS=1 | ||
|
||
``` | ||
$ deepspeed --num_gpus 8 scripts/inference/bloom-ds-zero-inference.py --name bigscience/bloom --batch_size 1 --benchmark 2>&1 | tee bloom-ds-zero-inference_bs=1.txt | ||
[...] | ||
*** Performance stats: | ||
Throughput per token including tokenize: 282.93 msecs | ||
Start to ready to generate: 501.871 secs | ||
Tokenize and generate 800 (bs=1) tokens: 226.188 secs | ||
Start to finish: 728.060 secs | ||
``` | ||
|
||
|
||
BS=8 | ||
|
||
``` | ||
$ deepspeed --num_gpus 8 scripts/inference/bloom-ds-zero-inference.py --name bigscience/bloom --batch_size 8 --benchmark 2>&1 | tee bloom-ds-zero-inference_bs=8.txt | ||
[...] | ||
|
||
*** Performance stats: | ||
Throughput per token including tokenize: 34.57 msecs | ||
Start to ready to generate: 482.132 secs | ||
Tokenize and generate 6400 (bs=8) tokens: 221.236 secs | ||
Start to finish: 703.368 secs | ||
``` | ||
|
||
BS=16 and higher OOMs | ||
|
||
``` | ||
$ deepspeed --num_gpus 8 scripts/inference/bloom-ds-zero-inference.py --name bigscience/bloom --batch_size 16 --benchmark 2>&1 | tee bloom-ds-zero-inference_bs=16.txt | ||
[...] | ||
OOM | ||
|
||
``` | ||
|
||
|
||
|
||
## HF Accelerate | ||
|
||
https://github.com/huggingface/accelerate | ||
|
||
### Setup | ||
|
||
``` | ||
pip install transformers | ||
``` | ||
|
||
|
||
|
||
### Run | ||
|
||
|
||
|
||
|
||
BS=1 | ||
``` | ||
$ python scripts/inference/bloom-accelerate-inference.py --name bigscience/bloom --batch_size 1 --benchmark 2>&1 | tee bloom-ds-zero-inference_bs=1.txt | ||
[...] | ||
|
||
|
||
``` | ||
|
||
BS=8 | ||
``` | ||
$ python scripts/inference/bloom-accelerate-inference.py --name bigscience/bloom --batch_size 8 --benchmark 2>&1 | tee bloom-ds-zero-inference_bs=8.txt | ||
[...] | ||
|
||
|
||
``` | ||
|
||
BS=16 | ||
``` | ||
$ python scripts/inference/bloom-accelerate-inference.py --name bigscience/bloom --batch_size 16 --benchmark 2>&1 | tee bloom-ds-zero-inference_bs=16.txt | ||
[...] | ||
|
||
|
||
``` |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,184 @@ | ||
import argparse | ||
import time | ||
import os | ||
import gc | ||
import torch | ||
import math | ||
from transformers import AutoTokenizer, AutoConfig, AutoModelForCausalLM | ||
|
||
def get_args(): | ||
parser = argparse.ArgumentParser() | ||
parser.add_argument("--local_rank", required=False, type=int, help="used by dist launchers") | ||
parser.add_argument("--name", type=str, help="Name path", required=True) | ||
parser.add_argument("--batch_size", default=1, type=int, help="batch size") | ||
parser.add_argument("--benchmark", action="store_true", help="additionally run benchmark") | ||
parser.add_argument("--greedy", action="store_true") | ||
parser.add_argument("--top-k", type=int, default=0) | ||
parser.add_argument("--top-p", type=float, default=0.) | ||
|
||
return parser.parse_args() | ||
|
||
def get_max_memory_per_gpu_dict(dtype, model_name): | ||
""" try to generate the memory map based on what we know about the model and the available hardware """ | ||
|
||
# figure out the memory map - the minimum per gpu required to load the model | ||
n_gpus = torch.cuda.device_count() | ||
|
||
if model_name == "bigscience/bloom" and n_gpus == 8 and torch.cuda.get_device_properties(0).total_memory > 79*2**30: | ||
# hand crafted optimized memory map for 8x80 setup over BLOOM | ||
# this works with bs=48 | ||
return {0: '0GIB', 1: '51GIB', 2: '51GIB', 3: '51GIB', 4: '51GIB', 5: '51GIB', 6: '51GIB', 7: '51GIB'} | ||
|
||
try: | ||
# model_params calculation, as we don't have a model yet to do: | ||
#model_params = sum(dict((p.data_ptr(), p.numel()) for p in model.parameters()).values()) | ||
|
||
config = AutoConfig.from_pretrained(model_name) | ||
h = config.n_embed | ||
l = config.n_layer | ||
v = config.vocab_size | ||
# from https://github.com/bigscience-workshop/bigscience/tree/a3e451498ee8189d2a9dd47be19aa89b0e16cd89/math#model-sizing | ||
model_params = l*(12*h**2 + 13*h) + v*h + 4*h | ||
except: | ||
print(f"The model {model_name} has a broken config file. Please notify the owner") | ||
raise | ||
|
||
bytes = torch.finfo(dtype).bits / 8 | ||
param_memory_total_in_bytes = model_params * bytes | ||
# add 5% since weight sizes aren't the same and some GPU may need more memory | ||
param_memory_per_gpu_in_bytes = int(param_memory_total_in_bytes / n_gpus * 1.05) | ||
print(f"Estimating {param_memory_per_gpu_in_bytes/2**30:0.2f}GB per gpu for weights") | ||
|
||
# check the real available memory | ||
# load cuda kernels first and only measure the real free memory after loading (shorter by ~2GB) | ||
torch.ones(1).cuda() | ||
max_memory_per_gpu_in_bytes = torch.cuda.mem_get_info(0)[0] | ||
if max_memory_per_gpu_in_bytes < param_memory_per_gpu_in_bytes: | ||
raise ValueError(f"Unable to generate the memory map automatically as the needed estimated memory per gpu ({param_memory_per_gpu_in_bytes/2**30:0.2f}GB) is bigger than the available per gpu memory ({max_memory_per_gpu_in_bytes/2**30:0.2f}GB)") | ||
|
||
return {i: param_memory_per_gpu_in_bytes for i in range(torch.cuda.device_count())} | ||
|
||
t_start = time.time() | ||
|
||
num_tokens = 100 | ||
|
||
args = get_args() | ||
|
||
local_rank = int(os.getenv('LOCAL_RANK', '0')) | ||
world_size = int(os.getenv('WORLD_SIZE', '1')) | ||
|
||
rank = local_rank | ||
|
||
model_name = args.name | ||
if rank == 0: | ||
print(f"Loading model {model_name}") | ||
|
||
|
||
tokenizer = AutoTokenizer.from_pretrained(model_name) | ||
|
||
# XXX: can't automatically derive dtype via config's `from_pretrained` | ||
dtype = torch.bfloat16 if model_name in ["bigscience/bloom", "bigscience/bigscience-small-testing"] else torch.float16 | ||
|
||
#print(get_max_memory_per_gpu_dict()) | ||
|
||
|
||
model = AutoModelForCausalLM.from_pretrained( | ||
model_name, | ||
device_map="auto", | ||
max_memory=get_max_memory_per_gpu_dict(dtype, model_name), | ||
torch_dtype=dtype, | ||
) | ||
|
||
|
||
if args.benchmark: | ||
t_ready = time.time() | ||
|
||
|
||
|
||
### Generate | ||
|
||
if rank == 0: | ||
print(f"*** Starting to generate {num_tokens} tokens with bs={args.batch_size}") | ||
|
||
input_sentences = [ | ||
"DeepSpeed is a machine learning framework", | ||
"He is working on", | ||
"He has a", | ||
"He got all", | ||
"Everyone is happy and I can", | ||
"The new movie that got Oscar this year", | ||
"In the far far distance from our galaxy,", | ||
"Peace is the only way" | ||
] | ||
|
||
if args.batch_size > len(input_sentences): | ||
# dynamically extend to support larger bs by repetition | ||
input_sentences *= math.ceil(args.batch_size / len(input_sentences)) | ||
|
||
generate_kwargs = dict(max_new_tokens=num_tokens, do_sample=False) | ||
|
||
if rank == 0: | ||
print(f"Generate args {generate_kwargs}") | ||
inputs = input_sentences[:args.batch_size] | ||
def generate(): | ||
""" returns a list of zipped inputs, outputs and number of new tokens """ | ||
|
||
input_tokens = tokenizer.batch_encode_plus(inputs, return_tensors="pt", padding=True) | ||
for t in input_tokens: | ||
if torch.is_tensor(input_tokens[t]): | ||
input_tokens[t] = input_tokens[t].to("cuda:0") | ||
|
||
outputs = model.generate(**input_tokens, **generate_kwargs) | ||
|
||
input_tokens_lengths = [x.shape[0] for x in input_tokens.input_ids] | ||
output_tokens_lengths = [x.shape[0] for x in outputs] | ||
|
||
total_new_tokens = [o-i for i,o in zip(input_tokens_lengths, output_tokens_lengths)] | ||
outputs = tokenizer.batch_decode(outputs, skip_special_tokens=True) | ||
|
||
return zip(inputs, outputs, total_new_tokens) | ||
|
||
# warmup is a must if measuring speed as it's when all the optimizations are performed | ||
# e.g. on 8x80 a100 the first pass of 100 tokens takes 23sec, and the next one is 4secs | ||
_ = generate() | ||
|
||
t_generate_start = time.time() | ||
generated = generate() | ||
t_generate_span = time.time() - t_generate_start | ||
if rank == 0: | ||
for i,o,_ in generated: | ||
print(f"{'-'*60}\nin={i}\nout={o}\n") | ||
|
||
|
||
if args.benchmark: | ||
torch.cuda.empty_cache() | ||
gc.collect() | ||
|
||
### Benchmark | ||
|
||
if args.benchmark: | ||
if rank == 0: | ||
print(f"*** Running benchmark") | ||
|
||
# warm up | ||
for i in range(1): | ||
_ = generate() | ||
torch.cuda.synchronize() | ||
|
||
# benchmark | ||
t0 = time.time() | ||
cycles = 5 | ||
total_new_tokens_generated = 0 | ||
for i in range(cycles): | ||
generated = generate() | ||
total_new_tokens_generated += sum(new_tokens for _,_,new_tokens in generated) | ||
torch.cuda.synchronize() | ||
if rank == 0: | ||
througput = (time.time() - t0)/(total_new_tokens_generated) | ||
print(f""" | ||
*** Performance stats: | ||
Throughput per token including tokenize: {througput*1000:.2f} msecs | ||
Start to ready to generate: {t_ready - t_start:.3f} secs | ||
Tokenize and generate {total_new_tokens_generated} (bs={args.batch_size}) tokens: {t_generate_span:.3f} secs | ||
Start to finish: {t_ready - t_start + t_generate_span:.3f} secs | ||
""") |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hey @stas00, I am closely monitoring this issue ❤️.
What is throughput? Is it the generation time per token?
Also, any idea why it decreases with BS?
I tried to run inference using HF (which uses accelerate).
But the generation times were way too high even using 8 GPUs (A100 80GBs).
https://huggingface.co/bigscience/bloom/discussions/59
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Never mind, I figured it out.
Turns out using a symmetric sharding of weights works worse than your benchmarks ❤️
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this is not throughput, this is time-to-ready-to-generate. and it's not benchmarked - just one measurement, so it fluctuates. I probably should remove all but one column here
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The Deepspeed team are working on making a special checkpoint that already has the TP weights pre-sharded, so it should reduce the load time to about 1-2min - on par with accelerate.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@stas00 Do you know of any tracking issue or place to watch for those pre-sharded weights? I am seeing up to 15 mins of setup time.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yes, 10-15min is the current load time - depending on the speed of your IO.
As I'm not the one doing that work I am not sure where to watch the progress but I will update this PR as soon as there is a working sharded version, so that you could play with. I doubt it'd take very long time.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hey @stas00, do you know if the code to do
TP weights pre-sharded
from DeepSpeed Team is open sourced ?