Skip to content

Commit

Permalink
Add LLaMA end-to-end benchmarking (#19985)
Browse files Browse the repository at this point in the history
### Description

This PR adds a benchmarking script to measure end-to-end performance and
saves the results in a CSV file.

### Motivation and Context

With this PR, end-to-end performance can be easily measured for many
large-language models such as LLaMA-2. The performance numbers for
LLaMA-2 are located
[here](https://github.com/microsoft/onnxruntime-inference-examples/tree/main/python/models/llama).
  • Loading branch information
kunal-vaishnavi authored and rachguo committed Mar 25, 2024
1 parent 6469bb5 commit e6c3d56
Show file tree
Hide file tree
Showing 11 changed files with 957 additions and 23 deletions.
132 changes: 132 additions & 0 deletions onnxruntime/python/tools/transformers/models/llama/README.md
Original file line number Diff line number Diff line change
@@ -1,7 +1,14 @@
# Contents
- [LLaMA-2](#llama-2)
- [Prerequisites](#prerequisites)
- [Exporting LLaMA-2](#exporting-llama-2)
- [Examples of Exporting LLaMA-2](#examples-of-exporting-llama-2)
- [Parity Checking LLaMA-2](#parity-checking-llama-2)
- [Benchmarking LLaMA-2](#benchmark-llama-2)
- [Variants](#variants)
- [Benchmark All](#benchmark-all)
- [Benchmark E2E](#benchmark-e2e)
- [E2E Inference with LLaMA-2](#e2e-inference-with-llama-2)
- [Mistral](#mistral)
- [Exporting Mistral](#exporting-mistral)
- [Optimizing and Quantizing Mistral](#optimizing-and-quantizing-mistral)
Expand Down Expand Up @@ -229,6 +236,55 @@ $ ./build.sh --config Release --use_cuda --cuda_home /usr/local/cuda-12.2 --cudn
$ CUDA_VISIBLE_DEVICES=0,1,2,3 bash convert_70b_model.sh 4 -m meta-llama/Llama-2-70b-hf --output llama2-70b-distributed --precision fp16 --execution_provider cuda --use_gqa
```

## Parity Checking LLaMA-2

Here are some examples of how you can use the parity checker to verify your LLaMA-2 ONNX model.

1. Merged ONNX model, FP32 CPU
```
CUDA_VISIBLE_DEVICES=0 python3 -m models.llama.llama_parity \
--model_name meta-llama/Llama-2-7b-hf \
--onnx_model_path ./llama2-7b/rank_0_Llama-2-7b-hf_decoder_merged_model_fp32.onnx \
--merged \
--execution_provider cpu \
--precision fp32 \
--cache_dir ./model_cache \
```

2. Merged ONNX model, FP32 CUDA
```
CUDA_VISIBLE_DEVICES=0 python3 -m models.llama.llama_parity \
--model_name meta-llama/Llama-2-7b-hf \
--onnx_model_path ./llama2-7b/rank_0_Llama-2-7b-hf_decoder_merged_model_fp32.onnx \
--merged \
--execution_provider cuda \
--precision fp32 \
--cache_dir ./model_cache \
```

3. Merged ONNX model, FP16 CUDA
```
CUDA_VISIBLE_DEVICES=0 python3 -m models.llama.llama_parity \
--model_name meta-llama/Llama-2-7b-hf \
--onnx_model_path ./llama2-7b/rank_0_Llama-2-7b-hf_decoder_merged_model_fp32.onnx \
--merged \
--execution_provider cuda \
--precision fp16 \
--cache_dir ./model_cache \
```

4. Merged ONNX model, FP16 CUDA with GroupQueryAttention
```
CUDA_VISIBLE_DEVICES=0 python3 -m models.llama.llama_parity \
--model_name meta-llama/Llama-2-7b-hf \
--onnx_model_path ./llama2-7b/rank_0_Llama-2-7b-hf_decoder_merged_model_fp32.onnx \
--merged \
--use_gqa \
--execution_provider cuda \
--precision fp16 \
--cache_dir ./model_cache \
```

## Benchmark LLaMA-2

Here are some examples of how you can benchmark LLaMA-2.
Expand All @@ -240,6 +296,7 @@ Here are some examples of how you can benchmark LLaMA-2.
CUDA_VISIBLE_DEVICES=0 python3 -m models.llama.benchmark \
--benchmark-type hf-pt-eager \
--model-name meta-llama/Llama-2-7b-hf \
--cache-dir ./model_cache \
--precision fp32 \
--batch-sizes "1 2" \
--sequence-lengths "8 16" \
Expand All @@ -252,6 +309,7 @@ CUDA_VISIBLE_DEVICES=0 python3 -m models.llama.benchmark \
CUDA_VISIBLE_DEVICES=0 python3 -m models.llama.benchmark \
--benchmark-type hf-pt-compile \
--model-name meta-llama/Llama-2-7b-hf \
--cache-dir ./model_cache \
--precision fp16 \
--batch-sizes "1 2" \
--sequence-lengths "8 16" \
Expand All @@ -265,6 +323,7 @@ CUDA_VISIBLE_DEVICES=0 python3 -m models.llama.benchmark \
--benchmark-type hf-ort \
--hf-ort-dir-path ./Llama-2-7b-hf-onnx/ \
--model-name meta-llama/Llama-2-7b-hf \
--cache-dir ./model_cache \
--precision fp32 \
--batch-sizes "1 2" \
--sequence-lengths "8 16" \
Expand All @@ -278,6 +337,7 @@ CUDA_VISIBLE_DEVICES=0 python3 -m models.llama.benchmark \
--benchmark-type hf-ort \
--hf-ort-dir-path ./Llama-2-7b-hf-onnx/ \
--model-name meta-llama/Llama-2-7b-hf \
--cache-dir ./model_cache \
--precision fp16 \
--batch-sizes "1 2" \
--sequence-lengths "8 16" \
Expand All @@ -291,6 +351,7 @@ CUDA_VISIBLE_DEVICES=0 python3 -m models.llama.benchmark \
--benchmark-type ort-msft \
--ort-model-path ./llama-2-onnx/7B_float32/ONNX/LlamaV2_7B_float32.onnx \
--model-name meta-llama/Llama-2-7b-hf \
--cache-dir ./model_cache \
--precision fp32 \
--batch-sizes "1 2" \
--sequence-lengths "8 16" \
Expand All @@ -303,6 +364,7 @@ CUDA_VISIBLE_DEVICES=0 python3 -m models.llama.benchmark \
--benchmark-type ort-msft \
--ort-model-path ./llama-2-onnx/7B_float16/ONNX/LlamaV2_7B_float16.onnx \
--model-name meta-llama/Llama-2-7b-hf \
--cache-dir ./model_cache \
--precision fp16 \
--batch-sizes "1 2" \
--sequence-lengths "8 16" \
Expand All @@ -315,6 +377,7 @@ CUDA_VISIBLE_DEVICES=1 python3 -m models.llama.benchmark \
--benchmark-type ort-convert-to-onnx \
--ort-model-path ./llama2-7b/rank_0_Llama-2-7b-hf_decoder_merged_model_fp32.onnx \
--model-name meta-llama/Llama-2-7b-hf \
--cache-dir ./model_cache \
--precision fp32 \
--batch-sizes "1 2" \
--sequence-lengths "8 16" \
Expand All @@ -327,6 +390,7 @@ CUDA_VISIBLE_DEVICES=4 python3 -m models.llama.benchmark \
--benchmark-type ort-convert-to-onnx \
--ort-model-path ./llama2-7b/rank_0_Llama-2-7b-hf_decoder_merged_model_fp16.onnx \
--model-name meta-llama/Llama-2-7b-hf \
--cache-dir ./model_cache \
--precision fp16 \
--batch-sizes "1 2" \
--sequence-lengths "8 16" \
Expand All @@ -339,6 +403,7 @@ CUDA_VISIBLE_DEVICES=4,5,6,7 bash benchmark_70b_model.sh 4 \
--benchmark-type ort-convert-to-onnx \
--ort-model-path ./llama2-70b-dis/rank_{}_Llama-2-70b-hf_decoder_merged_model_fp16.onnx \
--model-name meta-llama/Llama-2-70b-hf \
--cache-dir ./model_cache \
--precision fp16 \
--device cuda \
--warmup-runs 5 \
Expand All @@ -357,6 +422,7 @@ CUDA_VISIBLE_DEVICES=0 python3 -m models.llama.benchmark_all \
--ort-convert-to-onnx-model-path ./llama2-7b-fp16/Llama-2-7b-hf_decoder_merged_model_fp16.onnx \
--ort-msft-model-path ./llama-2-onnx/7B_float16/ONNX/LlamaV2_7B_float16.onnx \
--model-name meta-llama/Llama-2-7b-hf \
--cache-dir ./model_cache \
--precision fp16 \
--batch-sizes "1 2" \
--sequence-lengths "8 16" \
Expand All @@ -366,6 +432,72 @@ CUDA_VISIBLE_DEVICES=0 python3 -m models.llama.benchmark_all \
--timeout 60 # number of minutes before moving to the next benchmark
```

### Benchmark E2E
You can use `benchmark_e2e.py` to benchmark the full end-to-end scenario and automatically store the results in a CSV file. This tool uses `argmax` for sampling to standardize the benchmarking process.

1. PyTorch without `torch.compile`, FP32
```
CUDA_VISIBLE_DEVICES=0 python3 -m models.llama.benchmark_e2e \
--benchmark-type pt-eager \
--model-name meta-llama/Llama-2-7b-hf \
--cache-dir ./model_cache \
--prompts-file ./models/llama/prompts.json \
--precision fp32 \
--batch-sizes "1 2" \
--prompt-lengths "16 64" \
--device cpu \
--auth
```

2. PyTorch with `torch.compile`, FP16
```
CUDA_VISIBLE_DEVICES=0 python3 -m models.llama.benchmark_e2e \
--benchmark-type pt-compile \
--model-name meta-llama/Llama-2-7b-hf \
--cache-dir ./model_cache \
--prompts-file ./models/llama/prompts.json \
--precision fp16 \
--batch-sizes "1 2" \
--prompt-lengths "16 64" \
--device cuda \
--auth
```

3. ONNX Runtime with `convert_to_onnx`, FP32
```
CUDA_VISIBLE_DEVICES=0 python3 -m models.llama.benchmark_e2e \
--benchmark-type ort \
--model-name meta-llama/Llama-2-7b-hf \
--cache-dir ./model_cache \
--onnx-model-path ./llama2-7b/rank_0_Llama-2-7b-hf_decoder_merged_model_fp32.onnx \
--prompts-file ./models/llama/prompts.json \
--precision fp32 \
--batch-sizes "1 2" \
--prompt-lengths "16 64" \
--device cpu \
--auth
```

4. ONNX Runtime with `convert_to_onnx`, FP16
```
CUDA_VISIBLE_DEVICES=0 python3 -m models.llama.benchmark_e2e \
--benchmark-type ort \
--model-name meta-llama/Llama-2-7b-hf \
--cache-dir ./model_cache \
--onnx-model-path ./llama2-7b/rank_0_Llama-2-7b-hf_decoder_merged_model_fp32.onnx \
--prompts-file ./models/llama/prompts.json \
--precision fp16 \
--batch-sizes "1 2" \
--prompt-lengths "16 64" \
--device cuda \
--use_buffer_share \
--auth
```

## E2E Inference with LLaMA-2

For end-to-end inference, please visit the [ONNX Runtime Inference Examples folder](https://github.com/microsoft/onnxruntime-inference-examples/tree/main/python/models/llama) for a step-by-step walkthrough, code examples, and performance metrics.

# Mistral

## Introduction
Expand Down
38 changes: 20 additions & 18 deletions onnxruntime/python/tools/transformers/models/llama/benchmark.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,8 @@
# -------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License. See License.txt in the project root for
# license information.
# --------------------------------------------------------------------------
import argparse
import datetime
import gc
Expand All @@ -14,11 +19,12 @@
from benchmark_helper import measure_memory, setup_logger
from dist_settings import get_rank, get_size
from llama_inputs import (
add_io_bindings,
add_io_bindings_as_ortvalues,
get_merged_sample_with_past_kv_inputs,
get_msft_sample_inputs,
get_sample_inputs,
get_sample_with_past_kv_inputs,
verify_ort_inputs,
)
from optimum.onnxruntime import ORTModelForCausalLM
from torch.profiler import ProfilerActivity, profile, record_function
Expand Down Expand Up @@ -203,6 +209,7 @@ def get_model(args: argparse.Namespace):
torch_dtype=torch.float16 if args.use_fp16 else torch.float32,
use_auth_token=args.auth,
use_cache=True,
cache_dir=args.cache_dir,
).to(args.target_device)
end_time = time.time()

Expand Down Expand Up @@ -444,24 +451,12 @@ def get_logits(inputs):

def run_ort_inference(args, init_inputs, iter_inputs, model):
def prepare_ort_inputs(inputs, kv_cache_ortvalues):
# Check that all model inputs will be provided
model_inputs = set(map(lambda model_input: model_input.name, model.get_inputs()))
user_inputs = set(inputs.keys())
missing_inputs = model_inputs - user_inputs
if len(missing_inputs):
logger.error(f"The following model inputs are missing: {missing_inputs}")
raise Exception("There are missing inputs to the model. Please add them and try again.")

# Remove unnecessary inputs from model inputs
unnecessary_inputs = user_inputs - model_inputs
if len(unnecessary_inputs):
for unnecessary_input in unnecessary_inputs:
logger.info(f"Removing unnecessary input '{unnecessary_input}' from user provided inputs")
del inputs[unnecessary_input]
# Verify model inputs
inputs = verify_ort_inputs(model, inputs)

# Add IO bindings for non-CPU execution providers
if args.device != "cpu":
io_binding, kv_cache_ortvalues = add_io_bindings(
io_binding, kv_cache_ortvalues = add_io_bindings_as_ortvalues(
model, inputs, args.device, int(args.rank), args.use_gqa, kv_cache_ortvalues
)
setattr(args, "io_binding", io_binding) # noqa: B010
Expand Down Expand Up @@ -612,6 +607,13 @@ def get_args(rank=0):
parser.add_argument("--pt-num-rows", type=int, default=1000, help="Number of rows for PyTorch profiler to display")
parser.add_argument("--verbose", default=False, action="store_true")
parser.add_argument("--log-folder", type=str, default=os.path.join("."), help="Folder to cache log files")
parser.add_argument(
"--cache-dir",
type=str,
required=True,
default="./model_cache",
help="Cache dir where Hugging Face files are stored",
)

args = parser.parse_args()

Expand Down Expand Up @@ -662,8 +664,8 @@ def main():

args.rank = rank
args.world_size = world_size
tokenizer = AutoTokenizer.from_pretrained(args.model_name)
config = AutoConfig.from_pretrained(args.model_name)
tokenizer = AutoTokenizer.from_pretrained(args.model_name, cache_dir=args.cache_dir)
config = AutoConfig.from_pretrained(args.model_name, cache_dir=args.cache_dir)
target_device = f"cuda:{args.rank}" if args.device != "cpu" else args.device
use_fp16 = args.precision == "fp16"

Expand Down
Original file line number Diff line number Diff line change
@@ -1,3 +1,8 @@
# -------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License. See License.txt in the project root for
# license information.
# --------------------------------------------------------------------------
import argparse
import datetime
import json
Expand Down Expand Up @@ -78,6 +83,13 @@ def get_args():
help="Path to ONNX model from convert_to_onnx",
)

parser.add_argument(
"--cache-dir",
type=str,
default="./model_cache",
help="Cache dir where Hugging Face files are stored",
)

parser.add_argument(
"--model-name",
type=str,
Expand Down Expand Up @@ -332,6 +344,8 @@ def main():
str(args.num_runs),
"--log-folder",
args.log_folder,
"--cache-dir",
args.cache_dir,
"--auth",
]
logger.info("Benchmark PyTorch without torch.compile")
Expand Down Expand Up @@ -362,6 +376,8 @@ def main():
str(args.num_runs),
"--log-folder",
args.log_folder,
"--cache-dir",
args.cache_dir,
"--auth",
]
logger.info("Benchmark PyTorch with torch.compile")
Expand Down Expand Up @@ -394,6 +410,8 @@ def main():
str(args.num_runs),
"--log-folder",
args.log_folder,
"--cache-dir",
args.cache_dir,
"--auth",
]
logger.info("Benchmark Optimum + ONNX Runtime")
Expand Down Expand Up @@ -426,6 +444,8 @@ def main():
str(args.num_runs),
"--log-folder",
args.log_folder,
"--cache-dir",
args.cache_dir,
]
logger.info("Benchmark Microsoft model in ONNX Runtime")
results = benchmark(args, benchmark_cmd, "ort-msft")
Expand Down Expand Up @@ -457,6 +477,8 @@ def main():
str(args.num_runs),
"--log-folder",
args.log_folder,
"--cache-dir",
args.cache_dir,
]
logger.info("Benchmark convert_to_onnx model in ONNX Runtime")
results = benchmark(args, benchmark_cmd, "onnxruntime")
Expand Down
Loading

0 comments on commit e6c3d56

Please sign in to comment.