forked from bigscience-workshop/Megatron-DeepSpeed
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
BLOOM Inference via DeepSpeed-Inference, Accelerate and DeepSpeed-ZeRO (
bigscience-workshop#308) * hardcode the dtype depending on the model * change the mp based on the world_size * remove hardcoded world_size * add bigscience/bigscience-small-testing * fixes * add zero-inference script * fixes * fix * working script * renames * fixes * fix for offline use * add benchmark * add benchmark * update * cleanup * update * msecs * cleanup * improve * fix benchmark, add warmup * update * fix; thanks Michael Wyatt * clarify * add bloom batch-inference script * removed the names :-) * fold the bs functionality from the other script * fix * restore do_sample * dump generate args * fix * fix * support any batchsize * div by bs * mul by bs * add cpu_offload; sync scripts * wip * improvements * fixes * fixes * add accelerate script * fix * wip * wip * stats * add OnDevice and remove zero-inference (bigscience-workshop#316) * wip * rework generate + benchmark * figure out the memory map dynamically * bug fix * fix ds-zero-inference wrt device * bug fix * update * update * fix Co-authored-by: Reza Yazdani <[email protected]> Co-authored-by: Jeff Rasley <[email protected]>
- Loading branch information
1 parent
c9f196e
commit e52d34c
Showing
5 changed files
with
890 additions
and
153 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1 +1,195 @@ | ||
# Inference scripts for BLOOM | ||
|
||
## BLOOM Inference solutions | ||
|
||
Here are some stats on JeanZay's 8x80GB A100 node w/ 512GB of CPU memory: | ||
|
||
All benchmarks are doing greedy generation of 100 token outputs: | ||
``` | ||
Generate args {'min_length': 100, 'max_length': 100, 'do_sample': False} | ||
``` | ||
The inputs are just a few tokens. | ||
|
||
Throughput in msecs: | ||
|
||
| project \ bs | 1 | 8 | 16 | 32 | 64 | 128 | | ||
| :----------- | :---- | :---- | :---- | :---- | :---- | :--- | | ||
| accelerate | 230.38 | 31.78 | 17.84 | 10.89 | oom | omm | | ||
| ds-inference | 40.57 | 5.23 | | | 2.77 | 0.66 | | ||
| ds-zero | 283 | 34.88 | oom | oom | oom | oom | | ||
|
||
|
||
Start to ready to generate in secs: | ||
|
||
| project \ bs | 1 | 8 | 16 | 32 | 64 | 128 | | ||
| :----------- | :--- | :--- | :--- | :--- | :--- | :--- | | ||
| accelerate | 121 | 120 | 113 | 118 | | | | ||
| ds-inference | 662 | 673 | | | 685 | 654 | | ||
| ds-zero | 462 | 463 | | | | | | ||
| | | | | | | | | ||
|
||
|
||
DS-Inference load time (start to ready to generate) will become much faster soon. Once we stop relying on ds-zero to instantiate the model on gpu. The plan is to pre-shard the weights TP-wise for 8x and 16x gpus and load them directly on each gpu. Will probably be under 1min. | ||
|
||
|
||
## Deepspeed-Inference | ||
|
||
Tensor-Parallelism and efficient fused CUDA kernels: | ||
https://www.deepspeed.ai/tutorials/inference-tutorial/ | ||
|
||
### Setup | ||
|
||
``` | ||
git clone https://github.com/microsoft/DeepSpeed | ||
cd DeepSpeed | ||
pip install . | ||
``` | ||
|
||
### Run | ||
|
||
``` | ||
deepspeed --num_gpus 8 scripts/inference/bloom-ds-inference.py --name bigscience/bloom | ||
``` | ||
|
||
Performance on a single node of 8x80GB A100 w/ 512GB CPU RAM (JeanZay) - just a batch of 1 (would be more efficient to run a larger batch) | ||
|
||
Adding `--benchmark` to activate the benchmarks | ||
|
||
|
||
BS=1 | ||
``` | ||
$ deepspeed --num_gpus 8 scripts/inference/bloom-ds-inference.py --name bigscience/bloom --batch_size 1 --benchmark 2>&1 | tee bloom-ds-inference_bs=1.txt | ||
[...] | ||
``` | ||
|
||
While processing memory per process: | ||
|
||
- GPU: ~50GB | ||
- CPU: ~10GB | ||
|
||
|
||
BS=8 | ||
``` | ||
$ deepspeed --num_gpus 8 scripts/inference/bloom-ds-inference.py --name bigscience/bloom --batch_size 8 --benchmark 2>&1 | tee bloom-ds-inference_bs=8.txt | ||
[...] | ||
*** Performance stats: | ||
Throughput per token including tokenize: 5.23 msecs | ||
Start to ready to generate: 683.397 secs | ||
Tokenize and generate 800 (bs=8) tokens: 4.241 secs | ||
Start to finish: 687.638 secs | ||
``` | ||
|
||
BS=64 | ||
|
||
``` | ||
$ deepspeed --num_gpus 8 scripts/inference/bloom-ds-inference.py --name bigscience/bloom --batch_size 64 --benchmark 2>&1 | tee bloom-ds-inference_bs=64.txt | ||
``` | ||
|
||
BS=128 | ||
|
||
``` | ||
$ deepspeed --num_gpus 8 scripts/inference/bloom-ds-inference.py --name bigscience/bloom --batch_size 128 --benchmark 2>&1 | tee bloom-ds-inference_bs=128.txt | ||
``` | ||
|
||
## Deepspeed ZeRO-Inference | ||
|
||
https://www.deepspeed.ai/tutorials/zero/ | ||
|
||
### Setup | ||
|
||
``` | ||
pip install deepspeed | ||
``` | ||
|
||
|
||
### Run | ||
|
||
Note that the script currently runs the same inputs on all GPUs, but you can run a different stream on each GPU, and get `n_gpu` times faster throughput. You can't do that with Deepspeed-Inference. | ||
|
||
|
||
BS=1 | ||
|
||
``` | ||
$ deepspeed --num_gpus 8 scripts/inference/bloom-ds-zero-inference.py --name bigscience/bloom --batch_size 1 --benchmark 2>&1 | tee bloom-ds-zero-inference_bs=1.txt | ||
[...] | ||
*** Performance stats: | ||
Throughput per token including tokenize: 282.93 msecs | ||
Start to ready to generate: 501.871 secs | ||
Tokenize and generate 800 (bs=1) tokens: 226.188 secs | ||
Start to finish: 728.060 secs | ||
``` | ||
|
||
|
||
BS=8 | ||
|
||
``` | ||
$ deepspeed --num_gpus 8 scripts/inference/bloom-ds-zero-inference.py --name bigscience/bloom --batch_size 8 --benchmark 2>&1 | tee bloom-ds-zero-inference_bs=8.txt | ||
[...] | ||
*** Performance stats: | ||
Throughput per token including tokenize: 34.57 msecs | ||
Start to ready to generate: 482.132 secs | ||
Tokenize and generate 6400 (bs=8) tokens: 221.236 secs | ||
Start to finish: 703.368 secs | ||
``` | ||
|
||
BS=16 and higher OOMs | ||
|
||
``` | ||
$ deepspeed --num_gpus 8 scripts/inference/bloom-ds-zero-inference.py --name bigscience/bloom --batch_size 16 --benchmark 2>&1 | tee bloom-ds-zero-inference_bs=16.txt | ||
[...] | ||
OOM | ||
``` | ||
|
||
|
||
|
||
## HF Accelerate | ||
|
||
https://github.com/huggingface/accelerate | ||
|
||
### Setup | ||
|
||
``` | ||
pip install transformers | ||
``` | ||
|
||
|
||
|
||
### Run | ||
|
||
|
||
|
||
|
||
BS=1 | ||
``` | ||
$ python scripts/inference/bloom-accelerate-inference.py --name bigscience/bloom --batch_size 1 --benchmark 2>&1 | tee bloom-ds-zero-inference_bs=1.txt | ||
[...] | ||
``` | ||
|
||
BS=8 | ||
``` | ||
$ python scripts/inference/bloom-accelerate-inference.py --name bigscience/bloom --batch_size 8 --benchmark 2>&1 | tee bloom-ds-zero-inference_bs=8.txt | ||
[...] | ||
``` | ||
|
||
BS=16 | ||
``` | ||
$ python scripts/inference/bloom-accelerate-inference.py --name bigscience/bloom --batch_size 16 --benchmark 2>&1 | tee bloom-ds-zero-inference_bs=16.txt | ||
[...] | ||
``` |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,186 @@ | ||
import argparse | ||
import time | ||
import os | ||
import gc | ||
import torch | ||
import math | ||
from transformers import AutoTokenizer, AutoConfig, AutoModelForCausalLM | ||
|
||
def get_args(): | ||
parser = argparse.ArgumentParser() | ||
parser.add_argument("--local_rank", required=False, type=int, help="used by dist launchers") | ||
parser.add_argument("--name", type=str, help="Name path", required=True) | ||
parser.add_argument("--batch_size", default=1, type=int, help="batch size") | ||
parser.add_argument("--benchmark", action="store_true", help="additionally run benchmark") | ||
parser.add_argument("--greedy", action="store_true") | ||
parser.add_argument("--top-k", type=int, default=0) | ||
parser.add_argument("--top-p", type=float, default=0.) | ||
|
||
return parser.parse_args() | ||
|
||
def get_max_memory_per_gpu_dict(dtype, model_name): | ||
""" try to generate the memory map based on what we know about the model and the available hardware """ | ||
|
||
# figure out the memory map - the minimum per gpu required to load the model | ||
n_gpus = torch.cuda.device_count() | ||
|
||
if model_name == "bigscience/bloom" and n_gpus == 8 and torch.cuda.get_device_properties(0).total_memory > 79*2**30: | ||
# hand crafted optimized memory map for 8x80 setup over BLOOM | ||
# this works with bs=40 | ||
return {0: '0GIB', 1: '51GIB', 2: '51GIB', 3: '51GIB', 4: '51GIB', 5: '51GIB', 6: '51GIB', 7: '51GIB'} | ||
|
||
try: | ||
# model_params calculation, as we don't have a model yet to do: | ||
#model_params = sum(dict((p.data_ptr(), p.numel()) for p in model.parameters()).values()) | ||
|
||
config = AutoConfig.from_pretrained(model_name) | ||
h = config.n_embed | ||
l = config.n_layer | ||
v = config.vocab_size | ||
# from https://github.com/bigscience-workshop/bigscience/tree/6917a3b5fefcf439d3485ca184b4d9f6ab605150/math#model-sizing | ||
model_params = l*(12*h**2 + 13*h) + v*h + 4*h | ||
except: | ||
print(f"The model {model_name} has a broken config file. Please notify the owner") | ||
raise | ||
|
||
bytes = torch.finfo(dtype).bits / 8 | ||
param_memory_total_in_bytes = model_params * bytes | ||
# add 5% since weight sizes aren't the same and some GPU may need more memory | ||
param_memory_per_gpu_in_bytes = int(param_memory_total_in_bytes / n_gpus * 1.05) | ||
print(f"Estimating {param_memory_per_gpu_in_bytes/2**30:0.2f}GB per gpu for weights") | ||
|
||
# check the real available memory | ||
# load cuda kernels first and only measure the real free memory after loading (shorter by ~2GB) | ||
torch.ones(1).cuda() | ||
max_memory_per_gpu_in_bytes = torch.cuda.mem_get_info(0)[0] | ||
if max_memory_per_gpu_in_bytes < param_memory_per_gpu_in_bytes: | ||
raise ValueError(f"Unable to generate the memory map automatically as the needed estimated memory per gpu ({param_memory_per_gpu_in_bytes/2**30:0.2f}GB) is bigger than the available per gpu memory ({max_memory_per_gpu_in_bytes/2**30:0.2f}GB)") | ||
|
||
return {i: param_memory_per_gpu_in_bytes for i in range(torch.cuda.device_count())} | ||
|
||
t_start = time.time() | ||
|
||
num_tokens = 100 | ||
|
||
args = get_args() | ||
|
||
local_rank = int(os.getenv('LOCAL_RANK', '0')) | ||
world_size = int(os.getenv('WORLD_SIZE', '1')) | ||
|
||
rank = local_rank | ||
|
||
model_name = args.name | ||
if rank == 0: | ||
print(f"Loading model {model_name}") | ||
|
||
|
||
tokenizer = AutoTokenizer.from_pretrained(model_name) | ||
|
||
# XXX: can't automatically derive dtype via config's `from_pretrained` | ||
dtype = torch.bfloat16 if model_name in ["bigscience/bloom", "bigscience/bigscience-small-testing"] else torch.float16 | ||
|
||
#print(get_max_memory_per_gpu_dict()) | ||
|
||
|
||
model = AutoModelForCausalLM.from_pretrained( | ||
model_name, | ||
device_map="auto", | ||
max_memory=get_max_memory_per_gpu_dict(dtype, model_name), | ||
torch_dtype=dtype, | ||
) | ||
|
||
|
||
if args.benchmark: | ||
t_ready = time.time() | ||
|
||
|
||
|
||
### Generate | ||
|
||
if rank == 0: | ||
print(f"*** Starting to generate {num_tokens} tokens with bs={args.batch_size}") | ||
|
||
input_sentences = [ | ||
"DeepSpeed is a machine learning framework", | ||
"He is working on", | ||
"He has a", | ||
"He got all", | ||
"Everyone is happy and I can", | ||
"The new movie that got Oscar this year", | ||
"In the far far distance from our galaxy,", | ||
"Peace is the only way" | ||
] | ||
|
||
if args.batch_size > len(input_sentences): | ||
# dynamically extend to support larger bs by repetition | ||
input_sentences *= math.ceil(args.batch_size / len(input_sentences)) | ||
|
||
generate_kwargs = dict(max_new_tokens=num_tokens, do_sample=False) | ||
#generate_kwargs = dict(max_new_tokens=num_tokens, use_cache=False, do_sample=False) | ||
#generate_kwargs = dict(min_length=num_tokens, max_length=num_tokens, do_sample=False) | ||
|
||
if rank == 0: | ||
print(f"Generate args {generate_kwargs}") | ||
inputs = input_sentences[:args.batch_size] | ||
def generate(): | ||
""" returns a list of zipped inputs, outputs and number of new tokens """ | ||
|
||
input_tokens = tokenizer.batch_encode_plus(inputs, return_tensors="pt", padding=True) | ||
for t in input_tokens: | ||
if torch.is_tensor(input_tokens[t]): | ||
input_tokens[t] = input_tokens[t].to("cuda:0") | ||
|
||
outputs = model.generate(**input_tokens, **generate_kwargs) | ||
|
||
input_tokens_lengths = [x.shape[0] for x in input_tokens.input_ids] | ||
output_tokens_lengths = [x.shape[0] for x in outputs] | ||
|
||
total_new_tokens = [o-i for i,o in zip(input_tokens_lengths, output_tokens_lengths)] | ||
outputs = tokenizer.batch_decode(outputs, skip_special_tokens=True) | ||
|
||
return zip(inputs, outputs, total_new_tokens) | ||
|
||
# warmup is a must if measuring speed as it's when all the optimizations are performed | ||
# e.g. on 8x80 a100 the first pass of 100 tokens takes 23sec, and the next one is 4secs | ||
_ = generate() | ||
|
||
t_generate_start = time.time() | ||
generated = generate() | ||
t_generate_span = time.time() - t_generate_start | ||
if rank == 0: | ||
for i,o,_ in generated: | ||
print(f"{'-'*60}\nin={i}\nout={o}\n") | ||
|
||
|
||
if args.benchmark: | ||
torch.cuda.empty_cache() | ||
gc.collect() | ||
|
||
### Benchmark | ||
|
||
if args.benchmark: | ||
if rank == 0: | ||
print(f"*** Running benchmark") | ||
|
||
# warm up | ||
for i in range(1): | ||
_ = generate() | ||
torch.cuda.synchronize() | ||
|
||
# benchmark | ||
t0 = time.time() | ||
cycles = 5 | ||
total_new_tokens_generated = 0 | ||
for i in range(cycles): | ||
generated = generate() | ||
total_new_tokens_generated += sum(new_tokens for _,_,new_tokens in generated) | ||
torch.cuda.synchronize() | ||
if rank == 0: | ||
througput = (time.time() - t0)/(total_new_tokens_generated) | ||
print(f""" | ||
*** Performance stats: | ||
Throughput per token including tokenize: {througput*1000:.2f} msecs | ||
Start to ready to generate: {t_ready - t_start:.3f} secs | ||
Tokenize and generate {total_new_tokens_generated} (bs={args.batch_size}) tokens: {t_generate_span:.3f} secs | ||
Start to finish: {t_ready - t_start + t_generate_span:.3f} secs | ||
""") |
Oops, something went wrong.