forked from bigscience-workshop/Megatron-DeepSpeed
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add generation server scripts using HF accelerate and DS-inference (b…
…igscience-workshop#328) * first step towards making libs * HF accelerate model * refactor accelerate * refactor DS inference * refactor DS ZeRO * make inference library * cli * server * request * remove MaxTokensError * fix batch size error with DS inference server * type fix * add latency * add latency * add min_length to default kwargs * str kwargs * str kwargs * fix comma * add old scripts back * move scripts * drop data * minor changes + add README * update README * drop nccl * fix * default values * resolve issues * handle keyboard interrupt * remove caching * use snapshot_download * make server class * fix snapshot download Co-authored-by: Mayank Mishra <[email protected]>
- Loading branch information
1 parent
e6daa19
commit 4fa35e9
Showing
21 changed files
with
1,436 additions
and
0 deletions.
There are no files selected for viewing
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,107 @@ | ||
## Inference solutions for BLOOM 176B | ||
We support HuggingFace accelerate and DeepSpeed Inference for generation. | ||
|
||
Install required packages: | ||
|
||
```shell | ||
pip install fastapi uvicorn accelerate huggingface_hub>=0.9.0 | ||
``` | ||
To install [DeepSpeed](https://github.com/microsoft/DeepSpeed): | ||
```shell | ||
git clone https://github.com/microsoft/DeepSpeed | ||
cd DeepSpeed | ||
CFLAGS="-I$CONDA_PREFIX/include/" LDFLAGS="-L$CONDA_PREFIX/lib/" TORCH_CUDA_ARCH_LIST="7.0" DS_BUILD_CPU_ADAM=1 DS_BUILD_AIO=1 DS_BUILD_UTILS=1 pip install -e . --global-option="build_ext" --global-option="-j8" --no-cache -v --disable-pip-version-check | ||
``` | ||
To install [DeepSpeed-MII](https://github.com/microsoft/DeepSpeed-MII): | ||
```shell | ||
git clone https://github.com/microsoft/DeepSpeed-MII | ||
cd DeepSpeed-MII | ||
pip install . | ||
``` | ||
|
||
All the provided scripts are tested on 8 A100 80GB GPUs for BLOOM 176B. These scripts might not work for other models or a different number of GPUs. | ||
DS inference only supports fp16 for cli and server application. However, for benchmarking, it supports both fp16 and bf16. bf16 support will be added once DeepSpeed adds suitable CUDA kernels for these. | ||
|
||
DS inference is deployed using the DeepSpeed MII library which requires the resharded checkpoints for 8 x Tensor Parallel. The HuggingFace checkpoints can be resharded and cached using the following command: | ||
```shell | ||
deepspeed --num_gpus 8 scripts/bloom-inference-server/cache_ds_checkpoints.py --model_name bigscience/bloom --dtype fp16 --save_mp_checkpoint_path <PATH TO DS CACHED MODEL> | ||
``` | ||
Note: Running the above script will consume ~350 GB of disk space and will take some time (~30 minutes), depending on both the speed of your GPUs and storage. | ||
|
||
Note: sometimes GPU memory is not freed when DS inference deployment is shutdown. You can free this memory by running: | ||
```python | ||
import mii | ||
mii.terminate("ds_inference_grpc_server") | ||
``` | ||
or alternatively, just doing a `killall python` in terminal. | ||
|
||
#### BLOOM inference via command-line | ||
This asks for generate_kwargs everytime. | ||
Example: generate_kwargs = | ||
```json | ||
{"min_length": 100, "max_new_tokens": 100, "do_sample": false} | ||
``` | ||
|
||
1. using HF accelerate | ||
```shell | ||
python scripts/bloom-inference-server/cli.py --model_name bigscience/bloom --dtype bf16 --deployment_framework hf_accelerate --generate_kwargs '{"min_length": 100, "max_new_tokens": 100, "do_sample": false}' | ||
``` | ||
|
||
2. using DS inference | ||
```shell | ||
python scripts/bloom-inference-server/cli.py --model_name bigscience/bloom --dtype fp16 --deployment_framework ds_inference --save_mp_checkpoint_path <PATH TO DS CACHED MODEL> --generate_kwargs '{"min_length": 100, "max_new_tokens": 100, "do_sample": false}' | ||
``` | ||
|
||
#### BLOOM server deployment | ||
1. using HF accelerate | ||
```shell | ||
python scripts/bloom-inference-server/server.py --model_name bigscience/bloom --dtype bf16 --deployment_framework hf_accelerate --host <HOST ADDRESS> --port <PORT> --allowed_max_new_tokens 100 | ||
``` | ||
|
||
2. using DS inference | ||
```shell | ||
python scripts/bloom-inference-server/server.py --model_name bigscience/bloom --dtype fp16 --deployment_framework ds_inference --save_mp_checkpoint_path <PATH TO DS CACHED MODEL> --host <HOST ADDRESS> --port <PORT> --allowed_max_new_tokens 100 | ||
``` | ||
|
||
We provide an example [script](examples/server_request.py) to query the BLOOM server is provided. To run this script: | ||
```shell | ||
python scripts/bloom-inference-server/examples/server_request.py --host <HOST ADDRESS> --port <PORT> | ||
``` | ||
|
||
#### Benchmark system for BLOOM inference | ||
1. using HF accelerate | ||
```shell | ||
python scripts/bloom-inference-server/benchmark.py --model_name bigscience/bloom --dtype bf16 --deployment_framework hf_accelerate --benchmark_cycles 5 | ||
``` | ||
|
||
2. using DS inference | ||
```shell | ||
deepspeed --num_gpus 8 scripts/bloom-inference-server/benchmark.py --model_name bigscience/bloom --dtype fp16 --deployment_framework ds_inference --save_mp_checkpoint_path <PATH TO DS CACHED MODEL> --benchmark_cycles 5 | ||
``` | ||
|
||
3. using DS ZeRO | ||
```shell | ||
deepspeed --num_gpus 8 scripts/bloom-inference-server/benchmark.py --model_name bigscience/bloom --dtype bf16 --deployment_framework ds_zero --benchmark_cycles 5 | ||
``` | ||
|
||
Alternatively, the following shell script will benchmark different batch sizes for the model. | ||
```shell | ||
mkdir -p logs | ||
|
||
for bs in {1,2,4,8,16,32,64,128} | ||
do | ||
python scripts/bloom-inference-server/benchmark.py --model_name bigscience/bloom --dtype bf16 --deployment_framework hf_accelerate --benchmark_cycles 5 --batch_size $bs 2>&1 | tee logs/hf-$bs.log | ||
|
||
deepspeed --num_gpus 8 scripts/bloom-inference-server/benchmark.py --model_name bigscience/bloom --dtype fp16 --deployment_framework ds_inference --save_mp_checkpoint_path <PATH TO DS CACHED MODEL> --benchmark_cycles 5 --batch_size $bs 2>&1 | tee logs/ds-$bs.log | ||
|
||
deepspeed --num_gpus 8 scripts/bloom-inference-server/benchmark.py --model_name bigscience/bloom --dtype bf16 --deployment_framework ds_zero --benchmark_cycles 5 --batch_size $bs 2>&1 | tee logs/ds-zero-$bs.log | ||
done | ||
``` | ||
|
||
The following will benchmark sequence length for batch size = 1 on DS inference. | ||
```shell | ||
for sq in {1,10,50,100,200,300,400,500,600,700,800,900,1000,1500,2000,2500,3000,3500,4000,4500,5000} | ||
do | ||
deepspeed --num_gpus 8 scripts/bloom-inference-server/benchmark.py --model_name bigscience/bloom --dtype fp16 --batch_size 1 --benchmark_cycles 5 --deployment_framework ds_inference --generate_kwargs '{"do_sample": false, "min_length": '$sq', "max_new_tokens": '$sq'}' 2>&1 | tee logs/ds_$sq.log | ||
done | ||
``` |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,159 @@ | ||
import argparse | ||
import gc | ||
import os | ||
|
||
import deepspeed | ||
import torch | ||
|
||
import utils | ||
from ds_inference import DSInferenceModel | ||
from ds_zero import DSZeROModel | ||
from hf_accelerate import HFAccelerateModel | ||
from utils import ( | ||
BENCHMARK, | ||
DS_INFERENCE, | ||
DS_ZERO, | ||
HF_ACCELERATE, | ||
GenerateRequest, | ||
Model, | ||
get_argument_parser, | ||
get_dummy_batch, | ||
parse_generate_kwargs, | ||
print_rank_n, | ||
run_and_log_time | ||
) | ||
|
||
|
||
def benchmark_generation(model: Model, | ||
request: GenerateRequest, | ||
cycles: int = 5): | ||
total_new_tokens_generated = 0 | ||
for _ in range(cycles): | ||
response = model.generate(request) | ||
total_new_tokens_generated += sum( | ||
new_tokens for new_tokens in response.num_generated_tokens) | ||
return total_new_tokens_generated | ||
|
||
|
||
def get_benchmark_results(benchmark_time: float, | ||
initialization_time: float, | ||
total_new_tokens_generated: int, | ||
batch_size: int, | ||
cycles: int) -> str: | ||
throughput = total_new_tokens_generated / benchmark_time | ||
latency = benchmark_time / cycles | ||
return f""" | ||
*** Performance stats: | ||
Throughput (including tokenization) = {throughput:.2f} tokens/sec | ||
Throughput (including tokenization) = {1000 / throughput:.2f} msecs/token | ||
Model loading time = {initialization_time:.2f} secs | ||
Total tokens generated = {total_new_tokens_generated} with batch size = {batch_size} | ||
Latency = {latency:.2f} secs | ||
Model loading time + generation time per batch = {initialization_time + latency:.2f} secs | ||
""" | ||
|
||
|
||
def benchmark_end_to_end(args: argparse.Namespace, | ||
model_class: Model, | ||
zero_activated: bool = False) -> None: | ||
model, initialization_time = run_and_log_time( | ||
(model_class, {"args": args}) | ||
) | ||
|
||
request = parse_generate_kwargs( | ||
get_dummy_batch(args.batch_size), | ||
args.generate_kwargs | ||
) | ||
|
||
print_rank_n(f"generate_kwargs = {args.generate_kwargs}") | ||
print_rank_n(f"batch_size = {args.batch_size}") | ||
|
||
# warmup is a must if measuring speed as it's when all the optimizations are performed | ||
# e.g. on 8x80 a100 the first pass of 100 tokens takes 23sec, and the next one is 4secs | ||
response = model.generate(request) | ||
|
||
for i, (o, _) in zip(request.text, zip(response.text, response.num_generated_tokens)): | ||
print_rank_n(f"{'-' * 60}\nin = {i}\nout = {o}\n") | ||
|
||
if (args.benchmark_cycles > 0): | ||
print_rank_n(f"*** Running benchmark") | ||
|
||
torch.cuda.empty_cache() | ||
gc.collect() | ||
|
||
# warm up | ||
model.generate(request) | ||
torch.cuda.synchronize() | ||
|
||
# benchmark | ||
total_new_tokens_generated, benchmark_time = run_and_log_time( | ||
( | ||
benchmark_generation, | ||
{ | ||
"model": model, | ||
"request": request, | ||
"cycles": args.benchmark_cycles | ||
} | ||
) | ||
) | ||
|
||
# with ZeRO every GPU is generating batch_size * sequence_length tokens | ||
if (zero_activated): | ||
world_size = int(os.getenv('WORLD_SIZE', '1')) | ||
total_new_tokens_generated *= world_size | ||
|
||
print_rank_n( | ||
get_benchmark_results( | ||
benchmark_time, | ||
initialization_time, | ||
total_new_tokens_generated, | ||
args.batch_size, | ||
args.benchmark_cycles | ||
) | ||
) | ||
|
||
|
||
def get_args() -> argparse.Namespace: | ||
parser = get_argument_parser() | ||
|
||
group = parser.add_argument_group(title="launch config") | ||
group.add_argument("--benchmark_cycles", type=int, | ||
default=0, help="additionally run benchmark") | ||
group.add_argument("--local_rank", required=False, | ||
type=int, help="used by dist launchers") | ||
group.add_argument("--batch_size", default=1, type=int, help="batch size") | ||
group.add_argument("--cpu_offload", action="store_true", | ||
help="whether to activate CPU offload for DS ZeRO") | ||
|
||
args = utils.get_args(parser, BENCHMARK) | ||
|
||
launched_with_deepspeed = args.deployment_framework in [ | ||
DS_INFERENCE, DS_ZERO] | ||
|
||
if (not launched_with_deepspeed): | ||
assert args.local_rank == None, "local_rank must be None if not launched with DeepSpeed" | ||
|
||
if (args.cpu_offload): | ||
assert args.deployment_framework == DS_ZERO, "cpu_offload only works with DS_ZeRO" | ||
|
||
return args | ||
|
||
|
||
def main() -> None: | ||
args = get_args() | ||
|
||
if (args.deployment_framework == HF_ACCELERATE): | ||
benchmark_end_to_end(args, HFAccelerateModel) | ||
elif (args.deployment_framework == DS_INFERENCE): | ||
deepspeed.init_distributed("nccl") | ||
benchmark_end_to_end(args, DSInferenceModel) | ||
elif (args.deployment_framework == DS_ZERO): | ||
deepspeed.init_distributed("nccl") | ||
benchmark_end_to_end(args, DSZeROModel, zero_activated=True) | ||
else: | ||
raise ValueError( | ||
f"Unknown deployment framework {args.deployment_framework}") | ||
|
||
|
||
if (__name__ == "__main__"): | ||
main() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,67 @@ | ||
import argparse | ||
import json | ||
import sys | ||
|
||
import utils | ||
from ds_inference import DSInferenceGRPCServer | ||
from hf_accelerate import HFAccelerateModel | ||
from utils import CLI, DS_INFERENCE, HF_ACCELERATE, get_argument_parser, parse_generate_kwargs, print_rank_n | ||
|
||
|
||
def get_args() -> argparse.Namespace: | ||
parser = get_argument_parser() | ||
|
||
group = parser.add_argument_group(title="launch config") | ||
group.add_argument("--shutdown_command", required=False, | ||
type=str, default="__shutdown__", help="This string will exit the script") | ||
|
||
args = utils.get_args(parser, CLI) | ||
|
||
return args | ||
|
||
|
||
def main() -> None: | ||
args = get_args() | ||
|
||
if (args.deployment_framework == HF_ACCELERATE): | ||
model = HFAccelerateModel(args) | ||
elif (args.deployment_framework == DS_INFERENCE): | ||
model = DSInferenceGRPCServer(args) | ||
else: | ||
raise ValueError( | ||
f"Unknown deployment framework {args.deployment_framework}") | ||
|
||
generate_kwargs = args.generate_kwargs | ||
|
||
while (True): | ||
try: | ||
input_text = input("Input text: ") | ||
|
||
if (input_text == args.shutdown_command): | ||
model.shutdown() | ||
|
||
if (input("change generate_kwargs? [y/n] ") == "y"): | ||
while (True): | ||
try: | ||
generate_kwargs = json.loads( | ||
input("Generate kwargs: ")) | ||
break | ||
except KeyboardInterrupt: | ||
model.shutdown() | ||
except Exception as e: | ||
e_type, e_message, _ = sys.exc_info() | ||
print("error =", e_type.__name__) | ||
print("message =", e_message) | ||
continue | ||
|
||
request = parse_generate_kwargs([input_text], generate_kwargs) | ||
response = model.generate(request) | ||
|
||
print_rank_n("Output text:", response.text[0]) | ||
print_rank_n("Generated tokens:", response.num_generated_tokens[0]) | ||
except KeyboardInterrupt: | ||
model.shutdown() | ||
|
||
|
||
if (__name__ == "__main__"): | ||
main() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,2 @@ | ||
from .grpc_server import DSInferenceGRPCServer | ||
from .model import DSInferenceModel |
Oops, something went wrong.