diff --git a/applications/DeepSpeed-Chat/README.md b/applications/DeepSpeed-Chat/README.md index 91e4854fb..fa8fa9d38 100644 --- a/applications/DeepSpeed-Chat/README.md +++ b/applications/DeepSpeed-Chat/README.md @@ -251,9 +251,9 @@ bash training_scripts/opt/single_gpu/run_1.3b.sh ### 🐼 Adding and using your own datasets in DeepSpeed-Chat -In addition to the datasets used in our example scripts, you can also add and use your own datasets. To do so, first you need to add a new Class in [training/utils/data/raw_datasets.py](https://github.com/microsoft/DeepSpeedExamples/blob/master/applications/DeepSpeed-Chat/training/utils/data/raw_datasets.py) to define the format when using your data. You need to make sure to follow the APIs and format defined in the PromptRawDataset class to ensure a consistent data format that DeepSpeed-Chat relies on. You can look at the existing classes to learn how to do so. +In addition to the datasets used in our example scripts, you can also add and use your own datasets. To do so, first you need to add a new Class in [dschat/utils/data/raw_datasets.py](https://github.com/microsoft/DeepSpeedExamples/blob/master/applications/DeepSpeed-Chat/dschat/utils/data/raw_datasets.py) to define the format when using your data. You need to make sure to follow the APIs and format defined in the PromptRawDataset class to ensure a consistent data format that DeepSpeed-Chat relies on. You can look at the existing classes to learn how to do so. -Second, you need to add an if condition in function get_raw_dataset in [training/utils/data/data_utils.py](https://github.com/microsoft/DeepSpeedExamples/blob/master/applications/DeepSpeed-Chat/training/utils/data/data_utils.py) corresponding to your new dataset. The dataset_name string in the if condition should be the dataset name you will provide as a arg for the training scripts. Last, you need to add your new dataset's dataset_name into your "--data_path" arg in your training scripts. +Second, you need to add an if condition in function get_raw_dataset in [dschat/utils/data/data_utils.py](https://github.com/microsoft/DeepSpeedExamples/blob/master/applications/DeepSpeed-Chat/dschat/utils/data/data_utils.py) corresponding to your new dataset. The dataset_name string in the if condition should be the dataset name you will provide as a arg for the training scripts. Last, you need to add your new dataset's dataset_name into your "--data_path" arg in your training scripts. If you have downloaded huggingface datasets manually, you can add your local path into "--data_path", such as "--data_path ./relative/Dahoas/rm-static" and "--data_path /absolute/Dahoas/rm-static". Remember you should not make `data/` in your local path, it may cause an exception to `load_dataset`. One thing to note is that some datasets may only have one response instead of two responses. For those datasets, you can only use them in step 1. And in such case, you should add the dataset_name as part of the "--sft_only_data_path" arg instead of the "--data_path" arg. One thing to note is that: If you plan to only do step 1 SFT, adding more single-response datasets is definitely beneficial. However, if you do plan to do steps 2 and 3, then adding too many single-response datasets during SFT could backfire: these data could be different from the data used for steps 2/3, generating different distributions which could cause training instability/worse model quality during step 2/3. That is part of the reason why we focused on trying the datasets with two responses and the preference, and always split a dataset into all 3 steps. diff --git a/applications/DeepSpeed-Chat/dschat/rlhf/rlhf_engine.py b/applications/DeepSpeed-Chat/dschat/rlhf/rlhf_engine.py index 5b6778cc2..0e67efcf9 100755 --- a/applications/DeepSpeed-Chat/dschat/rlhf/rlhf_engine.py +++ b/applications/DeepSpeed-Chat/dschat/rlhf/rlhf_engine.py @@ -268,23 +268,14 @@ def _init_reward(self, critic_model_name_or_path): # If critic is ZeRO-3 then we use it for everything, otherwise assume we have enough memory zero_stage = 0 - ds_config = get_eval_ds_config(offload=self.args.offload, + ds_config = get_eval_ds_config(offload=self.args.offload_reward_model, dtype=self.args.dtype, stage=zero_stage) - ds_config[ - 'train_micro_batch_size_per_gpu'] = self.args.per_device_training_batch_size - ds_config[ - 'train_batch_size'] = self.args.per_device_training_batch_size * torch.distributed.get_world_size( - ) * self.args.gradient_accumulation_steps - - ds_eval_config = get_eval_ds_config(offload=False, - dtype=self.args.dtype, - stage=zero_stage) # We need to set train batch size and micro batch size here to pass the sanity check of DeepSpeed engine. - ds_eval_config[ + ds_config[ 'train_micro_batch_size_per_gpu'] = self.args.per_device_training_batch_size - ds_eval_config[ + ds_config[ 'train_batch_size'] = self.args.per_device_training_batch_size * torch.distributed.get_world_size( ) * self.args.gradient_accumulation_steps @@ -292,7 +283,7 @@ def _init_reward(self, critic_model_name_or_path): reward_model = create_critic_model( model_name_or_path=critic_model_name_or_path, tokenizer=self.tokenizer, - ds_config=ds_eval_config, + ds_config=ds_config, num_padding_at_beginning=self.args.num_padding_at_beginning, rlhf_training=True, dropout=self.args.critic_dropout, diff --git a/applications/DeepSpeed-Chat/dschat/utils/data/data_utils.py b/applications/DeepSpeed-Chat/dschat/utils/data/data_utils.py index 62c153644..53479a1eb 100644 --- a/applications/DeepSpeed-Chat/dschat/utils/data/data_utils.py +++ b/applications/DeepSpeed-Chat/dschat/utils/data/data_utils.py @@ -149,9 +149,13 @@ def __len__(self): def __getitem__(self, idx): if self.train_phase == 1: return { - "input_ids": self.chosen_dataset[idx]["input_ids"], - "attention_mask": self.chosen_dataset[idx]["attention_mask"], - "labels": self.chosen_dataset[idx]["input_ids"] + "input_ids": + self.chosen_dataset[idx]["input_ids"], + "attention_mask": + self.chosen_dataset[idx]["attention_mask"], + "labels": + torch.where(self.chosen_dataset[idx]["attention_mask"].bool(), + self.chosen_dataset[idx]["input_ids"], -100) } elif self.train_phase == 2: return self.chosen_dataset[idx]["input_ids"], self.chosen_dataset[idx]["attention_mask"], \ diff --git a/applications/DeepSpeed-Chat/dschat/utils/ds_utils.py b/applications/DeepSpeed-Chat/dschat/utils/ds_utils.py index 9c15e5143..0cf1c28ab 100644 --- a/applications/DeepSpeed-Chat/dschat/utils/ds_utils.py +++ b/applications/DeepSpeed-Chat/dschat/utils/ds_utils.py @@ -33,6 +33,7 @@ def get_train_ds_config(offload, dtype_config = {"enabled": True} zero_opt_dict = { "stage": stage, + "overlap_comm": True, "offload_param": { "device": device }, diff --git a/applications/DeepSpeed-Chat/dschat/utils/model/model_utils.py b/applications/DeepSpeed-Chat/dschat/utils/model/model_utils.py index 97d3bff15..050819a22 100644 --- a/applications/DeepSpeed-Chat/dschat/utils/model/model_utils.py +++ b/applications/DeepSpeed-Chat/dschat/utils/model/model_utils.py @@ -10,7 +10,7 @@ AutoModel, ) from huggingface_hub import snapshot_download -from transformers.deepspeed import HfDeepSpeedConfig +from transformers.integrations.deepspeed import HfDeepSpeedConfig from dschat.utils.model.reward_model import RewardModel from dschat.utils.utils import load_state_dict_into_model, print_rank_0 diff --git a/applications/DeepSpeed-Chat/training/step1_supervised_finetuning/main.py b/applications/DeepSpeed-Chat/training/step1_supervised_finetuning/main.py index c37d1f4cd..d9527af54 100755 --- a/applications/DeepSpeed-Chat/training/step1_supervised_finetuning/main.py +++ b/applications/DeepSpeed-Chat/training/step1_supervised_finetuning/main.py @@ -270,6 +270,7 @@ def main(): args.seed, tokenizer, args.max_seq_len, + end_of_conversation_token=tokenizer.eos_token, sft_only_data_path=args.sft_only_data_path) # DataLoaders creation: if args.local_rank == -1: diff --git a/applications/DeepSpeed-Chat/training/step2_reward_model_finetuning/README.md b/applications/DeepSpeed-Chat/training/step2_reward_model_finetuning/README.md index ede072a79..3c62b9f82 100644 --- a/applications/DeepSpeed-Chat/training/step2_reward_model_finetuning/README.md +++ b/applications/DeepSpeed-Chat/training/step2_reward_model_finetuning/README.md @@ -6,7 +6,7 @@ Finetuning the Reward Model (RM) is more or less similar to Step-1 Supervised F For SFT finetuning, the data is the concatenation of a query and an answer. However, for RM finetuning, each batch of data consists of two query-answer pairs, i.e., the same query with a high-score answer and a low-score answer. This also leads to the second difference as describe below. -👉**The training objective difference** +👉 **The training objective difference** For RW, the training objective is the pairwise ranking score, i.e., for the two query-answer pairs, RM is supposed to give a higher score to the better answer. There are multiple ways to achieve this. In our implementation, we use either the end token of the sequence or the first padding token as the aggregated score and compare them. Others may also use the average score for the entire answer as an alternative. diff --git a/applications/DeepSpeed-Chat/training/step2_reward_model_finetuning/main.py b/applications/DeepSpeed-Chat/training/step2_reward_model_finetuning/main.py index 6a8211988..c247b53e8 100644 --- a/applications/DeepSpeed-Chat/training/step2_reward_model_finetuning/main.py +++ b/applications/DeepSpeed-Chat/training/step2_reward_model_finetuning/main.py @@ -258,8 +258,17 @@ def main(): # the LN that precedes it. force_optimize_params = [] if "bigscience/bloom-" in args.model_name_or_path: - torch.nn.init.ones_(rm_model.rwtransformer.ln_f.weight) - torch.nn.init.zeros_(rm_model.rwtransformer.ln_f.bias) + zero_init_enabled = (args.zero_stage == 3) + params = [ + rm_model.rwtranrsformer.ln_f.weight, + rm_model.rwtranrsformer.ln_f.bias + ] + with deepspeed.zero.GatheredParameters(params, + modifier_rank=0, + enabled=zero_init_enabled): + if deepspeed.comm.get_rank() == 0 or not zero_init_enabled: + torch.nn.init.ones_(rm_model.rwtransformer.ln_f.weight) + torch.nn.init.zeros_(rm_model.rwtransformer.ln_f.bias) force_optimize_params.extend( ['rwtransformer.ln_f.weight', 'rwtransformer.ln_f.bias']) diff --git a/applications/DeepSpeed-Chat/training/step3_rlhf_finetuning/main.py b/applications/DeepSpeed-Chat/training/step3_rlhf_finetuning/main.py index a5be5671b..1378dc4e6 100644 --- a/applications/DeepSpeed-Chat/training/step3_rlhf_finetuning/main.py +++ b/applications/DeepSpeed-Chat/training/step3_rlhf_finetuning/main.py @@ -246,6 +246,9 @@ def parse_args(): '--offload_reference_model', action='store_true', help='Enable ZeRO Offload techniques for reference model') + parser.add_argument('--offload_reward_model', + action='store_true', + help='Enable ZeRO Offload techniques for reward model') parser.add_argument( '--actor_zero_stage', type=int, diff --git a/applications/DeepSpeed-VisualChat/utils/model/modeling_dsvl.py b/applications/DeepSpeed-VisualChat/utils/model/modeling_dsvl.py index eb9db9428..1407c1dfc 100755 --- a/applications/DeepSpeed-VisualChat/utils/model/modeling_dsvl.py +++ b/applications/DeepSpeed-VisualChat/utils/model/modeling_dsvl.py @@ -15,7 +15,7 @@ os.path.abspath(os.path.join(os.path.dirname(__file__), os.path.pardir))) import data.DST as DST # default special tokens from torch.utils.data import DataLoader -from transformers.deepspeed import HfDeepSpeedConfig +from transformers.integrations.deepspeed import HfDeepSpeedConfig import numpy as np from .vis_proj import VisProjection_vit, VisProjection_perceiver diff --git a/benchmarks/communication/README.md b/benchmarks/communication/README.md index 535b5d308..15ce1995b 100644 --- a/benchmarks/communication/README.md +++ b/benchmarks/communication/README.md @@ -1,6 +1,6 @@ # The DeepSpeed Communication Benchmarking Suite -The intent of these benchmarks is to measure communication latency/bw of deepspeed and/or pytorch distributed communication operations at the Python layer. These benchmarks are complementary to C-level comms benchmarks like [OSU Micro-Benchmarks](https://mvapich.cse.ohio-state.edu/benchmarks/) and [NCCL Tests](https://github.com/NVIDIA/nccl-tests) in that users can: +The intent of these benchmarks is to measure communication latency/bw of deepspeed and/or pytorch distributed communication operations at the Python layer. These benchmarks are complementary to C-level comms benchmarks like [OSU Micro-Benchmarks](https://mvapich.cse.ohio-state.edu/benchmarks/) , [NCCL Tests](https://github.com/NVIDIA/nccl-tests) and [oneCCL Benchmark](https://oneapi-src.github.io/oneCCL/benchmark.html) in that users can: - Easily debug which layer of the communication software stack hangs or performance degradations originate from. - Measure the expected communication performance of either DeepSpeed comms or pure PyTorch distributed @@ -77,6 +77,14 @@ Finally, users can choose specific communication operations to run in `run_all.p deepspeed run_all.py --scan --all-reduce --all-to-all --broadcast +## CPU Support +Those benchmarks could also support other devices like Intel CPU via oneCCL. +Users just need to append one more argument "--device cpu" for all python scripts to run on Intel CPU. +For example, run with a single large message size on Intel CPU: +
+deepspeed all_reduce.py --device cpu
+
+ # Adding Communication Benchmarks diff --git a/benchmarks/communication/all_gather.py b/benchmarks/communication/all_gather.py index 8aa33581d..76c4f3b1e 100644 --- a/benchmarks/communication/all_gather.py +++ b/benchmarks/communication/all_gather.py @@ -17,6 +17,9 @@ # Run all_gather and print metrics def timed_all_gather(input, output, start_event, end_event, args): + if args.device == "cpu": + print_rank_0(f"No Event support on CPU to measure time for now") + return if args.dist == 'torch': import torch.distributed as dist @@ -64,8 +67,15 @@ def run_all_gather(local_rank, args): global_rank = dist.get_rank() world_size = dist.get_world_size() - start_event = torch.cuda.Event(enable_timing=True) - end_event = torch.cuda.Event(enable_timing=True) + if args.device == "xpu": + start_event = torch.xpu.Event(enable_timing=True) + end_event = torch.xpu.Event(enable_timing=True) + elif args.device == "cpu": + start_event = torch.cpu.Event() + end_event = torch.cpu.Event() + else: + start_event = torch.cuda.Event(enable_timing=True) + end_event = torch.cuda.Event(enable_timing=True) if args.scan: # Create list of message sizes diff --git a/benchmarks/communication/all_reduce.py b/benchmarks/communication/all_reduce.py index b9decfd98..41c3116ee 100644 --- a/benchmarks/communication/all_reduce.py +++ b/benchmarks/communication/all_reduce.py @@ -15,6 +15,9 @@ def timed_all_reduce(input, start_event, end_event, args): + if args.device == "cpu": + print_rank_0(f"No Event support on CPU to measure time for now") + return if args.dist == 'torch': import torch.distributed as dist elif args.dist == 'deepspeed': @@ -60,8 +63,15 @@ def run_all_reduce(local_rank, args): world_size = dist.get_world_size() global_rank = dist.get_rank() - start_event = torch.cuda.Event(enable_timing=True) - end_event = torch.cuda.Event(enable_timing=True) + if args.device == "xpu": + start_event = torch.xpu.Event(enable_timing=True) + end_event = torch.xpu.Event(enable_timing=True) + elif args.device == "cpu": + start_event = torch.cpu.Event() + end_event = torch.cpu.Event() + else: + start_event = torch.cuda.Event(enable_timing=True) + end_event = torch.cuda.Event(enable_timing=True) if args.scan: M_LIST = [] diff --git a/benchmarks/communication/all_to_all.py b/benchmarks/communication/all_to_all.py index 7eccfa824..dc10b9ec9 100644 --- a/benchmarks/communication/all_to_all.py +++ b/benchmarks/communication/all_to_all.py @@ -15,6 +15,9 @@ def timed_all_to_all(input, output, start_event, end_event, args): + if args.device == "cpu": + print_rank_0(f"No Event support on CPU to measure time for now") + return if args.dist == 'torch': import torch.distributed as dist elif args.dist == 'deepspeed': @@ -59,8 +62,15 @@ def run_all_to_all(local_rank, args): # Prepare benchmark header print_header(args, 'all_to_all') - start_event = torch.cuda.Event(enable_timing=True) - end_event = torch.cuda.Event(enable_timing=True) + if args.device == "xpu": + start_event = torch.xpu.Event(enable_timing=True) + end_event = torch.xpu.Event(enable_timing=True) + elif args.device == "cpu": + start_event = torch.cpu.Event() + end_event = torch.cpu.Event() + else: + start_event = torch.cuda.Event(enable_timing=True) + end_event = torch.cuda.Event(enable_timing=True) if args.scan: M_LIST = [] diff --git a/benchmarks/communication/broadcast.py b/benchmarks/communication/broadcast.py index 860c9555b..d05303be1 100644 --- a/benchmarks/communication/broadcast.py +++ b/benchmarks/communication/broadcast.py @@ -15,6 +15,9 @@ def timed_broadcast(input, start_event, end_event, args): + if args.device == "cpu": + print_rank_0(f"No Event support on CPU to measure time for now") + return if args.dist == 'torch': import torch.distributed as dist elif args.dist == 'deepspeed': @@ -60,8 +63,15 @@ def run_broadcast(local_rank, args): world_size = dist.get_world_size() global_rank = dist.get_rank() - start_event = torch.cuda.Event(enable_timing=True) - end_event = torch.cuda.Event(enable_timing=True) + if args.device == "xpu": + start_event = torch.xpu.Event(enable_timing=True) + end_event = torch.xpu.Event(enable_timing=True) + elif args.device == "cpu": + start_event = torch.cpu.Event() + end_event = torch.cpu.Event() + else: + start_event = torch.cuda.Event(enable_timing=True) + end_event = torch.cuda.Event(enable_timing=True) if args.scan: M_LIST = [] diff --git a/benchmarks/communication/constants.py b/benchmarks/communication/constants.py index ae9fa261b..60df98ed2 100644 --- a/benchmarks/communication/constants.py +++ b/benchmarks/communication/constants.py @@ -12,4 +12,5 @@ DEFAULT_UNIT = 'Gbps' DEFAULT_DIST = 'deepspeed' DEFAULT_MAXSIZE = 24 +DEFAULT_DEVICE = 'cuda' TORCH_DISTRIBUTED_DEFAULT_PORT = 29500 diff --git a/benchmarks/communication/pt2pt.py b/benchmarks/communication/pt2pt.py index 57eab9a66..ec3252eb8 100644 --- a/benchmarks/communication/pt2pt.py +++ b/benchmarks/communication/pt2pt.py @@ -15,6 +15,9 @@ def timed_pt2pt(input, start_event, end_event, args): + if args.device == "cpu": + print_rank_0(f"No Event support on CPU to measure time for now") + return if args.dist == 'torch': import torch.distributed as dist elif args.dist == 'deepspeed': @@ -78,8 +81,15 @@ def run_pt2pt(local_rank, args): global_rank = dist.get_rank() world_size = dist.get_world_size() - start_event = torch.cuda.Event(enable_timing=True) - end_event = torch.cuda.Event(enable_timing=True) + if args.device == "xpu": + start_event = torch.xpu.Event(enable_timing=True) + end_event = torch.xpu.Event(enable_timing=True) + elif args.device == "cpu": + start_event = torch.cpu.Event() + end_event = torch.cpu.Event() + else: + start_event = torch.cuda.Event(enable_timing=True) + end_event = torch.cuda.Event(enable_timing=True) if args.scan: # Create list of message sizes diff --git a/benchmarks/communication/utils.py b/benchmarks/communication/utils.py index a74d24e28..6f6dd83a1 100644 --- a/benchmarks/communication/utils.py +++ b/benchmarks/communication/utils.py @@ -108,6 +108,11 @@ def get_bw(comm_op, size, duration, args): n = dist.get_world_size() tput = 0 busbw = 0 + + if duration == 0: + print_rank_0("Error. Duration is 0.") + return tput, busbw + if comm_op == "all_to_all": tput = (size / duration) busbw = (size / duration) * ((n - 1) / n) @@ -235,4 +240,5 @@ def benchmark_parser(): default=.3, help='Proportion of max available GPU memory to use for single-size evals') parser.add_argument("--debug", action="store_true", help='Enables all_to_all debug prints') + parser.add_argument("--device", type=str, default=DEFAULT_DEVICE, help='target device') return parser diff --git a/benchmarks/inference/deepspeedometer/README.md b/benchmarks/inference/deepspeedometer/README.md new file mode 100644 index 000000000..b327916c5 --- /dev/null +++ b/benchmarks/inference/deepspeedometer/README.md @@ -0,0 +1,88 @@ +# DeepSpeedometer + +NOTE: This is an experimental tool and is not currently being supported since it's not fully functional. Please use the MII benchmark which can be found here: +https://github.com/microsoft/DeepSpeedExamples/tree/master/benchmarks/inference/mii + +This benchmark is designed to measure performance of LLM serving solutions. Using a different number of parallel clients sending requests to an inference server, we gather data to plot throughput-latency curves and find the saturation point of an inference server that demonstrates the maximum performance. + +## Installation + +To install the benchmark, clone this repository and install using `pip`: +```shell +git clone https://github.com/Microsoft/DeepSpeedExamples +cd ./DeepSpeedExamples/benchmarks/deepspeedometer +pip install . +``` + +## Usage + +To quickly test the benchmark code without creating an inference server, run the following: +``` +python3 -m deepspeedometer.benchmark_runner --model facebook/opt-125m --api dummy +``` + +### Supports APIs + +The benchmark supports different APIs, each with their own client type. Depending on the client, you may need to run the benchmark against a locally hosted inference server or a remote inference server. Adding support for new serving solutions can be achieved by creating a new client class that defines a few basic methods. See the section below on adding new clients for more information. + +The clients (i.e., APIs) curently supported (and configuration options for each) are listed below. You can see more information about the configuration options by looking at the `*ClientConfig` classes located in `clients/*.py`: + +1. `fastgen`: Runs a local model inference server with DeepSpeed's FastGen. Config options include: + - `model`: Which model to use for serving (required) + - `deployment_name`: Name of the deployment server + - `tp_size`: Tensor parallel size for each model replicas + - `num_replicas`: Number of model replicas + - `max_ragged_batch_size`: Max number of requests running per model replicas + - `quantization_mode`: Type of quantization to use +2. `vllm`: Runs a local model inference server with vLLM. + - `model`: Which model to use for serving (required) + - `tp_size`: Tensor parallel size for model + - `port`: Which port to use for REST API +3. `azureml`: Interfaces with remote AzureML online endpoint/deployment. + - `api_url`: AzureML endpoint API URL (required) + - `api_key`: AzureML token key for connecting to endpoint (required) + - `deployment_name`: Name of deployment hosted in given endpoint (required) + +### Benchmark Configuration + +The Benchmark has many options for tailoring performance measurements to a specific use-cases. For additional information and default values, see the `BenchmarkConfig` class defined in `benchmark_runner.py`. + +- `api`: Which API to use +- `warmup_requests`: Number of warm-up requests to run before measuring performance +- `result_dir`: Directory where results will be written out (as JSON files) +- `use_threading`: Whether to use threading for the benchmark clients. Default is to use multi-processing +- `config_file`: One or more config YAML files that contain values for any of the Prompt configuration options (see below section on prompt configuration) +- `num_clients`: One or more integer values for the number of parallel clients to run +- `num_requests_per_client`: Number of requests that will be run by each of the parallel clients +- `min_requests`: Minimum number of requests to be sent during duration of benchmark. Useful when there is a low number of clients to ensure good measurement. +- `prompt_text_source`: Text file or string that will be sampled to generate request prompts +- `early_stop_latency`: When running multiple values for `num_clients`, if the average latency per request exceeds this value (in seconds) the benchmark will not test a larger number of parallel clients +- `force`: Force the overwrite of result files. By default, if a result file exists, the benchmark is skipped + +### Prompt Configuration + +These options allow users to modify the prompt input and generation behavior of the served models. Note that you can run multiple prompt configurations in a single command by using the `config_file` input as described in the Benchmark Configuration section. + +- `model`: Which model to use for tokenizing prompts (required) +- `prompt_generator_seed`: Seed value for random number generation +- `max_prompt_length`: The maximum prompt length allowed +- `prompt_length`: Target mean prompt length +- `prompt_lenght_var`: Variance of generated prompt lengths +- `max_new_tokens`: Target mean number of generated tokens +- `max_new_tokens_var`: Variance of generated tokens +- `streaming`: Whether to enabled streaming output for generated tokens + +#### About Prompt Generation + +To mimic real-world serving scenarios, this benchmark samples prompt length and generated token length values from a normal distribution. This distribution can be manipulated with the `prompt_length*` and `max_new_tokens*` values in the prompt configuration. To get all prompt lengths and generation lengths to match exactly, set the `*_var` values to 0. + +## Adding New Client APIs + +The DeepSpeedometer benchmark was designed to allow easily adding support for new inference server solutions. To do so: + +1. Create a new `*_client.py` file in the `clients/` directory. +2. Define a `*Client` class that inherits from the `BaseClient` class in `clients/base.py`. This class should define 5 methods: `start_service`, `stop_service`, `prepare_request`, `send_request`, and `process_response`. Take a look at the type hints for these methods in the `BaseClient` class to understand the expected inputs and outputs for each method. +3. Define a `*ClientConfig` class that inherits from the `BaseConfigModel` class. Place any configuration options (i.e., user-passed command line arguments) necessary for your defined `*Client` class in here. +4. Import the newly added `*Client` and `*ClientConfig` into `clients/__init__.py` and add them to the `client_config_classes` and `client_classes` dictionaries. + +For the simplest example of adding a new client, take a look at the `clients/dummy_client.py` file where we have defined a client that does not stand up a server and only returns a sample of the input prompt after a short sleep cycle. We use this as a light-weight class for unit testing. diff --git a/benchmarks/inference/deepspeedometer/configs/128k-120.yaml b/benchmarks/inference/deepspeedometer/configs/128k-120.yaml new file mode 100644 index 000000000..574e8e05e --- /dev/null +++ b/benchmarks/inference/deepspeedometer/configs/128k-120.yaml @@ -0,0 +1,5 @@ +prompt_length: 128000 +prompt_length_var: 0.1 +max_prompt_length: 131072 +max_new_tokens: 120 +max_new_tokens_var: 0.3 diff --git a/benchmarks/inference/deepspeedometer/configs/1300-120.yaml b/benchmarks/inference/deepspeedometer/configs/1300-120.yaml new file mode 100644 index 000000000..874a24c27 --- /dev/null +++ b/benchmarks/inference/deepspeedometer/configs/1300-120.yaml @@ -0,0 +1,4 @@ +prompt_length: 1300 +prompt_lenght_var: 0.3 +max_new_tokens: 120 +max_new_tokens_var: 0.3 diff --git a/benchmarks/inference/deepspeedometer/configs/2600-60.yaml b/benchmarks/inference/deepspeedometer/configs/2600-60.yaml new file mode 100644 index 000000000..f7674f819 --- /dev/null +++ b/benchmarks/inference/deepspeedometer/configs/2600-60.yaml @@ -0,0 +1,4 @@ +prompt_length: 2600 +prompt_lenght_var: 0.3 +max_new_tokens: 60 +max_new_tokens_var: 0.3 diff --git a/benchmarks/inference/deepspeedometer/configs/500-500.yaml b/benchmarks/inference/deepspeedometer/configs/500-500.yaml new file mode 100644 index 000000000..72389b37d --- /dev/null +++ b/benchmarks/inference/deepspeedometer/configs/500-500.yaml @@ -0,0 +1,4 @@ +prompt_length: 500 +prompt_lenght_var: 0.3 +max_new_tokens: 500 +max_new_tokens_var: 0.3 diff --git a/benchmarks/inference/deepspeedometer/pyproject.toml b/benchmarks/inference/deepspeedometer/pyproject.toml new file mode 100644 index 000000000..c15a27035 --- /dev/null +++ b/benchmarks/inference/deepspeedometer/pyproject.toml @@ -0,0 +1,32 @@ +[build-system] +requires = ["setuptools>=61.0"] +build-backend = "setuptools.build_meta" +[project] +name = "deepspeedometer" +version = "0.0.1" +authors = [ + { name="Ammar Ahmad Awan", email="ammar.awan@microsoft.com" }, + { name="Arash Bakhitiari", email="abakhtiari@microsoft.com" }, + { name="Connor Holmes"}, + { name="Lev Kurilenko", email="lev.kurilenko@microsoft.com" }, + { name="Heyang Qin", email="heyangqin@microsoft.com" }, + { name="Masahiro Tanaka", email="mtanaka@microsoft.com" }, + { name="Michael Wyatt", email="michaelwyatt@microsoft.com" }, +] +description = "LLM benchmarking tool" +readme = "README.md" +requires-python = ">=3.8" +classifiers = [ + "Programming Language :: Python :: 3", +] +dependencies = [ + "loguru", + "pydantic>=2.0.0", + "torch", + "tqdm", + "transformers", +] + +[project.urls] +Homepage = "https://github.com/Microsoft/DeepSpeedExamples/tree/master/benchmarks/inference/deepspeedometer" +Issues = "https://github.com/Microsoft/DeepSpeedExamples/issues" diff --git a/benchmarks/inference/deepspeedometer/run_example.sh b/benchmarks/inference/deepspeedometer/run_example.sh new file mode 100644 index 000000000..42fef231d --- /dev/null +++ b/benchmarks/inference/deepspeedometer/run_example.sh @@ -0,0 +1 @@ +python -m src.deepspeedometer.benchmark_runner --model "facebook/opt-125m" --api dummy --config_file ./configs/1300-120.yaml diff --git a/benchmarks/inference/deepspeedometer/src/deepspeedometer/__init__.py b/benchmarks/inference/deepspeedometer/src/deepspeedometer/__init__.py new file mode 100644 index 000000000..32cb0a0f9 --- /dev/null +++ b/benchmarks/inference/deepspeedometer/src/deepspeedometer/__init__.py @@ -0,0 +1,2 @@ +from .arg_parsing import parse_args_to_configs +from .benchmark_runner import BenchmarkRunner diff --git a/benchmarks/inference/deepspeedometer/src/deepspeedometer/arg_parsing.py b/benchmarks/inference/deepspeedometer/src/deepspeedometer/arg_parsing.py new file mode 100644 index 000000000..8be6d0d42 --- /dev/null +++ b/benchmarks/inference/deepspeedometer/src/deepspeedometer/arg_parsing.py @@ -0,0 +1,51 @@ +import argparse +from typing import List, Tuple + +from .benchmark_runner import BenchmarkConfig +from .clients import client_config_classes +from .config import BaseConfigModel + + +def parse_args_to_configs(args: List[str]) -> Tuple[BenchmarkConfig, BaseConfigModel]: + def add_model(parser: argparse.ArgumentParser, model: BaseConfigModel): + """Adds fields from pydantic model to the parser.""" + for name, field in model.model_fields.items(): + field_type = field.annotation + + # Get information about number of arguments expected + nargs = None + if getattr(field.annotation, "_name", "") == "List": + nargs = "+" + field_type = field.annotation.__args__[0] + + # Add field to parser + parser.add_argument( + f"--{name}", + dest=name, + nargs=nargs, + type=field_type, + required=getattr(field, "required", False), + default=getattr(field, "default", None), + help=getattr(field, "description", ""), + ) + + # Parse benchmark config fields + parser = argparse.ArgumentParser(allow_abbrev=False) + add_model(parser, BenchmarkConfig) + benchmark_args, remaining_args = parser.parse_known_args(args) + benchmark_config = BenchmarkConfig(**vars(benchmark_args)) + unused_args = set(remaining_args) + + # Parse client config fields + client_config_class = client_config_classes[benchmark_config.api] + parser = argparse.ArgumentParser(allow_abbrev=False) + add_model(parser, client_config_class) + client_args, remaining_args = parser.parse_known_args(args) + client_config = client_config_class(**vars(client_args)) + + # Check for unused arguments + unused_args = unused_args.intersection(remaining_args) + if unused_args: + raise ValueError(f"Unused arguments: {unused_args}") + + return benchmark_config, client_config diff --git a/benchmarks/inference/deepspeedometer/src/deepspeedometer/benchmark_runner.py b/benchmarks/inference/deepspeedometer/src/deepspeedometer/benchmark_runner.py new file mode 100644 index 000000000..96dd3a0da --- /dev/null +++ b/benchmarks/inference/deepspeedometer/src/deepspeedometer/benchmark_runner.py @@ -0,0 +1,390 @@ +import itertools +import json +import multiprocessing +import os +import queue +import sys +import threading +import time +import yaml +from pathlib import Path +from typing import List, Iterable, Tuple + +from loguru import logger +from tqdm import tqdm + +from .clients import client_classes, BaseClient +from .config import BaseConfigModel +from .prompt import Prompt, PromptConfig, PromptGenerator +from .response import Response +from .sample_input import sample_input_text + + +class BenchmarkConfig(PromptConfig): + api: str = "azure_ml" + """ Which API to use for benchmarking. New APIs can be added by creating a new client class in the `clients` directory. """ + + warmup_requests: int = 1 + """ Number of requests to run (per client) as a warm-up before starting the benchmark. """ + + result_dir: Path = Path("./results") + """ Top directory where results will be saved. """ + + use_threading: bool = False + """ Whether to use threading or multiprocessing for parallel client requests. Default is multiprocessing. """ + + config_file: List[Path] = [] + """ Path to YAML file(s) containing benchmark configuration settings. """ + + num_clients: List[int] = [1, 2, 4, 6, 8, 12, 16, 20, 24, 28, 32] + """ Number of clients to run in parallel. """ + + num_requests_per_client: int = 16 + """ Number of requests to run per client. """ + + min_requests: int = 128 + """ Minimum number of request to create (regardless of num_requests_per_client). """ + + prompt_text_source: str = sample_input_text + """ Text file or string to use for generated prompts. """ + + early_stop_latency: float = 10.0 + """ Maximum mean latency (in seconds) to allow before stopping the benchmark early. """ + + force: bool = False + """ Whether to overwrite existing result files. """ + + +class ClientLauncher: + def __init__( + self, + client_class: BaseClient, + client_config: BaseConfigModel, + warmup_requests: int, + use_threading: bool, + prompt_generator: PromptGenerator, + ): + self.client_class = client_class + self.client_config = client_config + self.client_obj = client_class(client_config) + self.warmup_requests = warmup_requests + self.prompt_generator = prompt_generator + + if use_threading: + self.runnable_cls = threading.Thread + self.barrier_cls = threading.Barrier + self.queue_cls = queue.Queue + else: + self.runnable_cls = multiprocessing.Process + self.barrier_cls = multiprocessing.Barrier + self.queue_cls = multiprocessing.Queue + + def run_parallel_clients(self, num_clients: int) -> None: + logger.info(f"Launching {num_clients} client(s)") + + total_requests = self.request_queue.qsize() + + self.barrier = self.barrier_cls(num_clients + 1) + processes = [ + self.runnable_cls( + target=self._run_client, + args=( + i, + self.barrier, + self.request_queue, + self.response_queue, + self.client_class, + self.client_config, + self.warmup_requests, + ), + ) + for i in range(num_clients) + ] + for p in processes: + p.start() + + self.barrier.wait() # Barrier 1 for master process + + self._progress_bar(total_requests - self.warmup_requests * num_clients) + + self.barrier.wait() # Barrier 2 for master process + + def _progress_bar(self, total_requests: int) -> None: + pbar = tqdm(total=total_requests) + num_responses = 0 + while num_responses != total_requests: + num_responses = self.response_queue.qsize() + pbar.update(num_responses - pbar.n) + time.sleep(1) + pbar.close() + + @staticmethod + def _run_client( + client_id: int, + barrier: multiprocessing.Barrier, + request_queue: multiprocessing.Queue, + response_queue: multiprocessing.Queue, + client_class: BaseClient, + client_config: BaseConfigModel, + warmup_requests: int, + ): + client = client_class(client_config) + + for _ in range(warmup_requests): + prompt = request_queue.get(timeout=1.0) + _ = client.send_request(prompt.request_kwargs) + + barrier.wait() # Barrier 1 for client process + try: + while True: + prompt = request_queue.get(timeout=1.0) + start_time = time.time() + raw_response = client.send_request(prompt.request_kwargs) + end_time = time.time() + request_time = end_time - start_time + response = Response( + prompt_text=prompt.text, + prompt_tokens=prompt.num_tokens, + raw_response=raw_response, + request_time=request_time, + client_id=client_id, + ) + response_queue.put_nowait(response) + except queue.Empty: + pass + + barrier.wait() # Barrier 2 for client process + + def add_request(self, prompt: Prompt) -> None: + request_kwargs = self.client_obj.prepare_request(prompt) + prompt.request_kwargs = request_kwargs + self.request_queue.put(prompt) + + def get_response(self) -> Response: + response = self.response_queue.get(timeout=1.0) + processed_response = self.client_obj.process_response(response.raw_response) + response.generated_output = processed_response + response.generated_tokens = self.prompt_generator.count_tokens( + processed_response + ) + return response + + def clear_queues(self) -> None: + self.request_queue = self.queue_cls() + self.response_queue = self.queue_cls() + + def start_service(self) -> None: + self.client_obj.start_service() + + def stop_service(self) -> None: + self.client_obj.stop_service() + + +class BenchmarkRunner: + def __init__( + self, benchmark_config: BaseConfigModel, client_config: BaseConfigModel + ) -> None: + logger.info("Initializing Benchmark Runner") + self.config = benchmark_config + self.client_config = client_config + self.client_class = client_classes[self.config.api] + self.prompt_generator = PromptGenerator( + self.config.model, self.config.prompt_text_source + ) + self.client_launcher = ClientLauncher( + client_class=self.client_class, + client_config=self.client_config, + warmup_requests=self.config.warmup_requests, + use_threading=self.config.use_threading, + prompt_generator=self.prompt_generator, + ) + self.all_responses = [] + + def _benchmark_settings(self) -> Iterable[Tuple[List[int], PromptConfig]]: + prompt_config_keys = list(PromptConfig.model_fields.keys()) + + configs_list = [] + for f in self.config.config_file: + logger.info(f"Generating benchmark run settings from config file: {f}") + with open(f, "r") as fh: + file_config = yaml.safe_load(fh) + + # Get any prompt config values stored in config files + for key in prompt_config_keys + ["num_clients"]: + if key not in file_config: + file_config[key] = getattr(self.config, key) + configs_list.append(file_config) + + if not configs_list: + logger.info(f"Generating benchmark run settings from command line args") + configs_list.append( + { + key: getattr(self.config, key) + for key in prompt_config_keys + ["num_clients"] + } + ) + + all_config_product = [] + for config in configs_list: + # Ensure all config values are iterable types (i.e., list or tuple) + for k, v in config.items(): + if not isinstance(v, list) or isinstance(v, tuple): + config[k] = [v] + + # We treat num_clients differently to enable early stopping + num_clients = config.pop("num_clients") + + # Generate all possible combinations of prompt config values + for vals in itertools.product(*[config[k] for k in prompt_config_keys]): + config_product = {k: v for k, v in zip(prompt_config_keys, vals)} + config_product["num_clients"] = num_clients + all_config_product.append(config_product) + + logger.info(f"Generated {len(all_config_product)} benchmark run setting(s)") + + for config in all_config_product: + num_clients = config.pop("num_clients") + prompt_config = PromptConfig(**config) + yield num_clients, prompt_config + + def _generate_requests(self, prompt_config: PromptConfig, num_clients: int) -> None: + logger.info("Generating Prompts") + + warmup_prompts = self.config.warmup_requests * num_clients + workload_prompts = max( + self.config.min_requests, self.config.num_requests_per_client * num_clients + ) + for prompt in self.prompt_generator( + config=prompt_config, num_prompts=warmup_prompts + workload_prompts + ): + self.client_launcher.add_request(prompt) + + logger.info( + f"Generated {warmup_prompts} warmup and {workload_prompts} workload prompts." + ) + + def _get_output_dir(self) -> Path: + return self.config.result_dir / self.config.api / self.config.model + + def _get_output_path(self, prompt_config: PromptConfig, num_clients: int) -> Path: + output_dir = self._get_output_dir() + output_file = f"prompt{prompt_config.prompt_length}_gen{prompt_config.max_new_tokens}_clients{num_clients}.json" + return output_dir / output_file + + def _process_responses( + self, prompt_config: PromptConfig, num_clients: int + ) -> List[Response]: + output_path = self._get_output_path( + prompt_config=prompt_config, num_clients=num_clients + ) + + logger.info(f"Saving results to {output_path}") + + all_responses = [] + while True: + try: + all_responses.append(self.client_launcher.get_response()) + except queue.Empty: + break + + os.makedirs(output_path.parent, exist_ok=True) + with open(output_path, "w") as fh: + json.dump([r.to_dict() for r in all_responses], fh, indent=2) + + logger.info(f"Saved {len(all_responses)} responses to {output_path}") + + return all_responses + + def _print_result_summary( + self, all_responses: List[Response], num_clients: int + ) -> None: + num_responses = int(len(all_responses)) + mean_latency = sum([r.request_time for r in all_responses]) / num_responses + query_throughput = num_clients / mean_latency + mean_prompt_length = int( + sum([r.prompt_tokens for r in all_responses]) / num_responses + ) + mean_gen_length = int( + sum([r.generated_tokens for r in all_responses]) / num_responses + ) + logger.info( + f"Result summary - # Requests: {num_responses:d}, Mean Prompt Length: {mean_prompt_length:d} tokens, Mean Generation Length: {mean_gen_length:d} tokens, Mean Latency: {mean_latency:.2f} s, Throughput: {query_throughput:.2f} queries/s," + ) + + def _check_early_stop(self, all_responses: List[Response]) -> bool: + if not all_responses: + return False + mean_latency = sum([r.request_time for r in all_responses]) / len(all_responses) + if mean_latency >= self.config.early_stop_latency: + logger.info( + f"Mean latency of {mean_latency:.2f} exceeds early stopping threshold of {self.config.early_stop_latency}. Stopping early." + ) + return True + return False + + def _skip_existing_result( + self, prompt_config: PromptConfig, num_clients: int + ) -> bool: + output_path = self._get_output_path( + prompt_config=prompt_config, num_clients=num_clients + ) + if output_path.exists(): + if self.config.force: + logger.info( + f"Result already exists, but force flag is set. Overwriting benchmark with {num_clients} client(s) and prompt config: {prompt_config}" + ) + return False + else: + logger.info( + f"Result already exists, skipping benchmark with {num_clients} client(s) and prompt config: {prompt_config}" + ) + return True + return False + + def run(self) -> None: + # Start the client service + self.client_launcher.start_service() + + # Generate all benchmark settings from user config(s) + for num_clients_list, prompt_config in self._benchmark_settings(): + all_responses = [] + for num_clients in sorted(num_clients_list): + if self._skip_existing_result( + prompt_config=prompt_config, num_clients=num_clients + ): + continue + + if self._check_early_stop(all_responses): + break + + logger.info( + f"Running benchmark with {num_clients} client(s) and prompt config: {prompt_config}" + ) + # Clear out queues and generate request prompts + self.client_launcher.clear_queues() + self._generate_requests( + prompt_config=prompt_config, num_clients=num_clients + ) + + # Launch the clients and process requests + self.client_launcher.run_parallel_clients(num_clients=num_clients) + + # Process raw responses and save results to file + all_responses = self._process_responses( + prompt_config=prompt_config, num_clients=num_clients + ) + + self._print_result_summary( + all_responses=all_responses, num_clients=num_clients + ) + + # Stop the client service + self.client_launcher.stop_service() + + +if __name__ == "__main__": + from .arg_parsing import parse_args_to_configs + + benchmark_config, client_config = parse_args_to_configs(sys.argv[1:]) + benchmark_runner = BenchmarkRunner(benchmark_config, client_config) + benchmark_runner.run() diff --git a/benchmarks/inference/deepspeedometer/src/deepspeedometer/clients/__init__.py b/benchmarks/inference/deepspeedometer/src/deepspeedometer/clients/__init__.py new file mode 100644 index 000000000..ac1891112 --- /dev/null +++ b/benchmarks/inference/deepspeedometer/src/deepspeedometer/clients/__init__.py @@ -0,0 +1,22 @@ +from .base import BaseClient + +from .azure_ml_client import AzureMLClientConfig, AzureMLClient +from .dummy_client import DummyClientConfig, DummyClient +from .fastgen_client import FastGenClientConfig, FastGenClient +from .vllm_client import vLLMClientConfig, vLLMClient +from .openai_client import openaiClientConfig, openaiClient + +client_config_classes = { + "dummy": DummyClientConfig, + "azure_ml": AzureMLClientConfig, + "fastgen": FastGenClientConfig, + "vllm": vLLMClientConfig, + "openai": openaiClientConfig +} +client_classes = { + "dummy": DummyClient, + "azure_ml": AzureMLClient, + "fastgen": FastGenClient, + "vllm": vLLMClient, + "openai": openaiClient, +} diff --git a/benchmarks/inference/deepspeedometer/src/deepspeedometer/clients/azure_ml_client.py b/benchmarks/inference/deepspeedometer/src/deepspeedometer/clients/azure_ml_client.py new file mode 100644 index 000000000..5bedff692 --- /dev/null +++ b/benchmarks/inference/deepspeedometer/src/deepspeedometer/clients/azure_ml_client.py @@ -0,0 +1,79 @@ +import json +import requests +from typing import Any, Dict + +from loguru import logger + +from .base import BaseClient +from ..config import BaseConfigModel +from ..prompt import Prompt + + +class AzureMLClientConfig(BaseConfigModel): + api_url: str = "" + """ URL for the AzureML REST API. """ + + api_key: str = "" + """ REST API key for the AzureML deployment. """ + + deployment_name: str = "" + """ Name of the AzureML deployment. """ + + +class AzureMLClient(BaseClient): + def __init__(self, config: AzureMLClientConfig) -> None: + super().__init__(config) + self.api_url = config.api_url + self.api_key = config.api_key + self.deployment_name = config.deployment_name + + def start_service(self) -> None: + # Verify that the server exists, this could be extended to actually + # start an AML deployment. However currently we assume one exists. + test_prompt = Prompt("hello world", num_tokens=5, max_new_tokens=16) + _ = self.process_response(self.send_request(self.prepare_request(test_prompt))) + + def stop_service(self) -> None: + pass + + def prepare_request(self, prompt: Prompt) -> Dict[str, Any]: + # TODO: add support for OpenAI chat completion template + if prompt.streaming: + raise ValueError("AzureMLClient does not support streaming prompts.") + + headers = { + "Content-Type": "application/json", + "Authorization": ("Bearer " + self.api_key), + "azureml-model-deployment": self.deployment_name, + } + pload = { + "input_data": { + "input_string": [ + prompt.text, + ], + "parameters": { + "max_tokens": prompt.max_new_tokens, + "return_full_text": prompt.return_full_text, + }, + } + } + return {"url": self.api_url, "headers": headers, "json": pload, "timeout": 180} + + def send_request(self, request_kwargs: Dict[str, Any]) -> Any: + while True: + try: # Sometimes the AML endpoint will return an error, so we send the request again + response = requests.post(**request_kwargs) + output = json.loads(response.content) + assert ( + response.status_code == 200 + ), f"Status code: {response.status_code}" + assert output[0]["0"], f"Empty response" + break + except Exception as e: + logger.debug(f"Connection failed with {e}. Retrying AML request") + + return output + + def process_response(self, raw_response: Any) -> str: + response_text = raw_response[0]["0"] + return response_text diff --git a/benchmarks/inference/deepspeedometer/src/deepspeedometer/clients/base.py b/benchmarks/inference/deepspeedometer/src/deepspeedometer/clients/base.py new file mode 100644 index 000000000..40a38e057 --- /dev/null +++ b/benchmarks/inference/deepspeedometer/src/deepspeedometer/clients/base.py @@ -0,0 +1,30 @@ +from abc import ABC, abstractmethod +from typing import Any, Dict + +from ..config import BaseConfigModel +from ..prompt import Prompt + + +class BaseClient(ABC): + def __init__(self, config: BaseConfigModel) -> None: + self.config = config + + @abstractmethod + def start_service(self) -> None: + pass + + @abstractmethod + def stop_service(self) -> None: + pass + + @abstractmethod + def prepare_request(self, prompt: Prompt) -> Dict[str, Any]: + pass + + @abstractmethod + def send_request(self, request_kwargs: Dict[str, Any]) -> Any: + pass + + @abstractmethod + def process_response(self, raw_response: Any) -> str: + pass diff --git a/benchmarks/inference/deepspeedometer/src/deepspeedometer/clients/dummy_client.py b/benchmarks/inference/deepspeedometer/src/deepspeedometer/clients/dummy_client.py new file mode 100644 index 000000000..f10b1e94e --- /dev/null +++ b/benchmarks/inference/deepspeedometer/src/deepspeedometer/clients/dummy_client.py @@ -0,0 +1,45 @@ +import time +import random +from typing import Any, Dict + +from transformers import AutoTokenizer + +from .base import BaseClient +from ..config import BaseConfigModel +from ..prompt import Prompt + + +class DummyClientConfig(BaseConfigModel): + model: str + dummy_client_latency_time: float = 0.1 + + +class DummyClient(BaseClient): + def __init__(self, config: DummyClientConfig) -> None: + super().__init__(config) + self.tokenizer = AutoTokenizer.from_pretrained(self.config.model) + self.latency_time = config.dummy_client_latency_time + + def start_service(self) -> None: + pass + + def stop_service(self) -> None: + pass + + def prepare_request(self, prompt: Prompt) -> Dict[str, Any]: + return {"input_text": prompt.text, "max_new_tokens": prompt.max_new_tokens} + + def send_request(self, request_kwargs: Dict[str, Any]) -> Any: + time.sleep( + abs(random.uniform(self.latency_time - 0.1, self.latency_time + 0.2)) + ) + output_text = self.tokenizer.decode( + random.choices( + self.tokenizer.encode(request_kwargs["input_text"]), + k=request_kwargs["max_new_tokens"], + ) + ) + return output_text + + def process_response(self, raw_response: Any) -> str: + return raw_response diff --git a/benchmarks/inference/deepspeedometer/src/deepspeedometer/clients/fastgen_client.py b/benchmarks/inference/deepspeedometer/src/deepspeedometer/clients/fastgen_client.py new file mode 100644 index 000000000..c3f3a086f --- /dev/null +++ b/benchmarks/inference/deepspeedometer/src/deepspeedometer/clients/fastgen_client.py @@ -0,0 +1,91 @@ +import time +from typing import Any, Dict, Optional + +from loguru import logger +from pydantic import Field + +from .base import BaseClient +from ..config import BaseConfigModel +from ..prompt import Prompt + + +class FastGenClientConfig(BaseConfigModel): + model: str = Field(..., description="HuggingFace.co model name") + deployment_name: str = "fastgen-benchmark-deployment" + tp_size: int = 1 + num_replicas: int = 1 + max_ragged_batch_size: int = 768 + quantization_mode: Optional[str] = None + + +class FastGenClient(BaseClient): + def __init__(self, config: FastGenClientConfig): + super().__init__(config) + try: + import mii + except ImportError as e: + logger.error( + "Please install the `deepspeed-mii` package to use this client." + ) + raise e + + self.mii_client = mii.client(config.deployment_name) + self.streaming = config.streaming + + def start_service(self) -> None: + import mii + from deepspeed.inference import RaggedInferenceEngineConfig, DeepSpeedTPConfig + from deepspeed.inference.v2.ragged import DSStateManagerConfig + + tp_config = DeepSpeedTPConfig(tp_size=self.config.tp_size) + mgr_config = DSStateManagerConfig( + max_ragged_batch_size=self.config.max_ragged_batch_size, + max_ragged_sequence_count=self.config.max_ragged_batch_size, + ) + inference_config = RaggedInferenceEngineConfig( + tensor_parallel=tp_config, state_manager=mgr_config + ) + mii.serve( + self.config.model, + deployment_name=self.config.deployment_name, + tensor_parallel=self.config.tp_size, + inference_engine_config=inference_config, + replica_num=self.config.num_replicas, + quantization_mode=self.config.quantization_mode, + ) + + def stop_service(self) -> None: + import mii + + mii.client(self.config.deployment_name).terminate_server() + + def _streaming_callback(self, raw_response) -> None: + self.streaming_response_tokens.append(raw_response[0].generated_text) + time_now = time.time() + self.streaming_token_gen_time.append(time_now - time_last_token) + time_last_token = time_now + + def prepare_request(self, prompt: Prompt) -> Dict[str, Any]: + request_kwargs = { + "prompts": prompt.text, + "max_new_tokens": prompt.max_new_tokens, + } + if self.streaming: + self.streaming_response_tokens = [] + self.streaming_token_gen_time = [] + self.streaming_time_last_token = None + request_kwargs["streaming_fn"] = self._streaming_callback + return request_kwargs + + def send_request(self, request_kwargs: Dict[str, Any]) -> Any: + if self.streaming: + self.streaming_time_last_token = time.time() + response = self.mii_client(**request_kwargs) + if self.streaming: + response = self.streaming_response_tokens + + return response + + def process_response(self, raw_response: Any) -> str: + if not self.streaming: + return raw_response[0].generated_text diff --git a/benchmarks/inference/deepspeedometer/src/deepspeedometer/clients/openai_client.py b/benchmarks/inference/deepspeedometer/src/deepspeedometer/clients/openai_client.py new file mode 100644 index 000000000..76eadfc5c --- /dev/null +++ b/benchmarks/inference/deepspeedometer/src/deepspeedometer/clients/openai_client.py @@ -0,0 +1,57 @@ +import os +import json +import requests +import subprocess +import time +from typing import Any, Dict + +from loguru import logger +from pydantic import Field + +from .base import BaseClient +from ..config import BaseConfigModel +from ..prompt import Prompt + + +# client to test any openai API +class openaiClientConfig(BaseConfigModel): + model: str = Field(..., description="HuggingFace.co model name") + url: str = "http://127.0.0.1:26500/v1/completions" + + +class openaiClient(BaseClient): + def __init__(self, config: openaiClientConfig): + super().__init__(config) + + def start_service(self) -> None: + pass + + def stop_service(self) -> None: + pass + + def prepare_request(self, prompt: Prompt) -> Dict[str, Any]: + api_url = self.config.url + headers = { + "User-Agent": "Benchmark Client", + "Content-Type": "application/json", + "Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}" + } + pload = { + "prompt": prompt.text, + "model": self.config.model, + "n": 1, + "use_beam_search": False, + "temperature": 1.0, + "top_p": 0.9, + "max_tokens": prompt.max_new_tokens, + "ignore_eos": False, + } + return {"url": api_url, "headers": headers, "json": pload, "timeout": 180} + + def send_request(self, request_kwargs: Dict[str, Any]) -> Any: + response = requests.post(**request_kwargs) + output = json.loads(response.content) + return output + + def process_response(self, raw_response: Any) -> str: + return raw_response["choices"][0]["text"] diff --git a/benchmarks/inference/deepspeedometer/src/deepspeedometer/clients/vllm_client.py b/benchmarks/inference/deepspeedometer/src/deepspeedometer/clients/vllm_client.py new file mode 100644 index 000000000..563c66e9d --- /dev/null +++ b/benchmarks/inference/deepspeedometer/src/deepspeedometer/clients/vllm_client.py @@ -0,0 +1,88 @@ +import json +import requests +import subprocess +import time +from typing import Any, Dict + +from loguru import logger +from pydantic import Field + +from .base import BaseClient +from ..config import BaseConfigModel +from ..prompt import Prompt + + +class vLLMClientConfig(BaseConfigModel): + model: str = Field(..., description="HuggingFace.co model name") + tp_size: int = 1 + port: int = 26500 + + +class vLLMClient(BaseClient): + def __init__(self, config: vLLMClientConfig): + super().__init__(config) + try: + import vllm + except ImportError as e: + logger.error("Please install the `vllm` package to use this client.") + raise e + + def start_service(self) -> None: + vllm_cmd = ( + "python", + "-m", + "vllm.entrypoints.api_server", + "--host", + "127.0.0.1", + "--port", + str(self.config.port), + "--tensor-parallel-size", + str(self.config.tp_size), + "--model", + self.config.model, + ) + p = subprocess.Popen( + vllm_cmd, stdout=subprocess.DEVNULL, stderr=subprocess.PIPE, close_fds=True + ) + start_time = time.time() + timeout_after = 60 * 5 # 5 minutes + while True: + line = p.stderr.readline().decode("utf-8") + if "Application startup complete" in line: + break + if "error" in line.lower(): + p.terminate() + # self.stop_service(config) + raise RuntimeError(f"Error starting VLLM server: {line}") + if time.time() - start_time > timeout_after: + p.terminate() + # self.stop_service(config) + raise TimeoutError("Timed out waiting for VLLM server to start") + time.sleep(0.01) + + def stop_service(self) -> None: + vllm_cmd = ("pkill", "-f", "vllm.entrypoints.api_server") + p = subprocess.Popen(vllm_cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE) + p.wait() + + def prepare_request(self, prompt: Prompt) -> Dict[str, Any]: + api_url = "http://localhost:26500/generate" + headers = {"User-Agent": "Benchmark Client"} + pload = { + "prompt": prompt.text, + "n": 1, + "use_beam_search": False, + "temperature": 1.0, + "top_p": 0.9, + "max_tokens": prompt.max_new_tokens, + "ignore_eos": False, + } + return {"url": api_url, "headers": headers, "json": pload, "timeout": 180} + + def send_request(self, request_kwargs: Dict[str, Any]) -> Any: + response = requests.post(**request_kwargs) + output = json.loads(response.content) + return output + + def process_response(self, raw_response: Any) -> str: + return raw_response["text"] diff --git a/benchmarks/inference/deepspeedometer/src/deepspeedometer/config.py b/benchmarks/inference/deepspeedometer/src/deepspeedometer/config.py new file mode 100644 index 000000000..d524eb2cf --- /dev/null +++ b/benchmarks/inference/deepspeedometer/src/deepspeedometer/config.py @@ -0,0 +1,13 @@ +from pydantic import BaseModel, ConfigDict + + +class BaseConfigModel(BaseModel): + model_config = ConfigDict( + validate_default=True, + validate_assignment=False, + use_enum_values=True, + populate_by_name=True, + extra="forbid", + arbitrary_types_allowed=True, + protected_namespaces=(), + ) diff --git a/benchmarks/inference/deepspeedometer/src/deepspeedometer/prompt.py b/benchmarks/inference/deepspeedometer/src/deepspeedometer/prompt.py new file mode 100644 index 000000000..58bd82d0a --- /dev/null +++ b/benchmarks/inference/deepspeedometer/src/deepspeedometer/prompt.py @@ -0,0 +1,117 @@ +import os +from dataclasses import dataclass +from typing import Iterable, Optional +from typing_extensions import Self + +import numpy as np +import torch +from loguru import logger +from pydantic import model_validator +from transformers import AutoTokenizer + +from .config import BaseConfigModel + +# Avoids a warning from transformers +os.environ["TOKENIZERS_PARALLELISM"] = "false" + + +@dataclass +class Prompt: + text: str + num_tokens: int + max_new_tokens: int + streaming: bool = False + return_full_text: bool = False + request_kwargs: dict = None + + +class PromptConfig(BaseConfigModel): + model: str + """ Names of the model used to benchmark. Used to load the model/tokenizer from HuggingFace.co. """ + + prompt_generator_seed: Optional[int] = None + """ Seed value for prompt generator. """ + + max_prompt_length: int = 4000 + """ Maximum prompt length for any request. """ + + prompt_length: int = 2600 + """ Mean prompt length for requests. """ + + prompt_length_var: float = 0.3 + """ Variance of prompt length. """ + + max_new_tokens: int = 60 + """ Mean number of new tokens to generate in each request. """ + + max_new_tokens_var: float = 0.3 + """ Variance of new tokens to generate. """ + + streaming: bool = False + """ Whether to enable streaming mode for the client. """ + + @model_validator(mode="after") + def set_max_prompt_length(self) -> Self: + if self.prompt_length > self.max_prompt_length: + logger.warning( + f"Prompt length {self.prompt_length} is greater than max prompt length {self.max_prompt_length}. Setting max prompt length to {self.prompt_length}." + ) + self.max_prompt_length = max(self.max_prompt_length, self.prompt_length) + return self + + +class PromptGenerator: + def __init__(self, model: str, prompt_text_source: str) -> None: + self.tokenizer = AutoTokenizer.from_pretrained(model) + if os.path.isfile(prompt_text_source): + with open(prompt_text_source, "r") as f: + prompt_text_source = f.read() + self.input_text = prompt_text_source + self.tokenized_input = self.tokenizer.encode( + self.input_text, return_tensors="pt", padding=False + )[0] + + def count_tokens(self, text: str) -> int: + return len(self.tokenizer.encode(text)) + + def __call__(self, config: PromptConfig, num_prompts: int) -> Iterable[Prompt]: + tokenized_input = self.tokenized_input + if len(tokenized_input) < config.max_prompt_length: + tokenized_input = torch.cat( + [ + tokenized_input + for _ in range(config.max_prompt_length // len(tokenized_input) + 1) + ] + ).flatten() + + if config.prompt_generator_seed is not None: + np.random.seed(config.prompt_generator_seed) + + for _ in range(num_prompts): + # Take the absolute value here because sometimes the normal + # distribution will return a negative value. This is technically + # wrong, but works out OK for most scenarios. + prompt_length = min( + abs( + int( + np.random.normal( + config.prompt_length, + config.prompt_length * config.prompt_length_var, + ) + ) + ), + config.max_prompt_length, + ) + max_new_tokens = abs( + int( + np.random.normal( + config.max_new_tokens, + config.max_new_tokens * config.max_new_tokens_var, + ) + ) + ) + yield Prompt( + text=self.tokenizer.decode(tokenized_input[:prompt_length]), + num_tokens=prompt_length, + max_new_tokens=max_new_tokens, + ) diff --git a/benchmarks/inference/deepspeedometer/src/deepspeedometer/response.py b/benchmarks/inference/deepspeedometer/src/deepspeedometer/response.py new file mode 100644 index 000000000..3842ce5d7 --- /dev/null +++ b/benchmarks/inference/deepspeedometer/src/deepspeedometer/response.py @@ -0,0 +1,16 @@ +from dataclasses import asdict, dataclass +from typing import Any + + +@dataclass +class Response: + prompt_text: str = "" + prompt_tokens: int = 0 + generated_output: str = "" + generated_tokens: int = 0 + request_time: float = 0 + raw_response: Any = None + client_id: int = 0 + + def to_dict(self) -> dict: + return asdict(self) diff --git a/benchmarks/inference/deepspeedometer/src/deepspeedometer/sample_input.py b/benchmarks/inference/deepspeedometer/src/deepspeedometer/sample_input.py new file mode 100644 index 000000000..0754da724 --- /dev/null +++ b/benchmarks/inference/deepspeedometer/src/deepspeedometer/sample_input.py @@ -0,0 +1,225 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +# This is a sample input consisting of: +# Code & Text + +sample_input_text = """Deep learning involves the use of neural networks, which are computational models inspired by the structure and functioning of the human brain. These networks consist of interconnected nodes called neurons. Each neuron takes input, performs a computation, and produces an output. + During training, the neural network learns to make accurate predictions by adjusting its internal parameters. This adjustment is done using an optimization algorithm called gradient descent. Gradient descent calculates the gradients of a loss function, which measures the discrepancy between the predicted output of the network and the desired output. These gradients indicate the direction and magnitude of parameter updates that will minimize the loss. + The learning rate is an important hyperparameter in gradient descent. It determines the step size taken during parameter updates. A higher learning rate can lead to faster convergence, but it risks overshooting the optimal solution. On the other hand, a lower learning rate may converge more slowly, but it can result in more precise updates. + Activation functions are applied to the output of each neuron in a neural network. They introduce non-linearities, enabling the network to learn complex patterns and relationships in the data. Popular activation functions include the rectified linear unit (ReLU), sigmoid, and hyperbolic tangent (tanh). + By adjusting the parameters of the neural network during training, deep learning models learn to represent and generalize from complex data patterns. They have achieved remarkable success in various tasks, including image recognition, speech recognition, and natural language processing. + Here are the key fundamentals of deep learning for training large language models: + Neural Networks: At the heart of deep learning are artificial neural networks, which are inspired by the structure and functioning of biological neurons in the human brain. These networks consist of interconnected layers of artificial neurons called nodes or units. The nodes receive input, perform computations, and pass the results to the next layer. + Representation Learning: Deep learning models excel at learning meaningful representations of data. In the context of language, the models can automatically learn hierarchical representations of text, capturing complex relationships and semantic structures. + Feedforward and Backpropagation: Deep learning models typically use feedforward neural networks, where information flows from the input layer through intermediate hidden layers to the output layer. The network makes predictions based on the input data, and the prediction error is then backpropagated through the network. Backpropagation calculates gradients that indicate how each parameter in the network should be adjusted to minimize the error. + Activation Functions: Activation functions introduce non-linearities to neural networks, enabling them to learn complex patterns. Common activation functions include the rectified linear unit (ReLU), sigmoid, and hyperbolic tangent (tanh). These functions determine the output of each neuron based on its weighted inputs. + Loss Functions: During training, a loss function is used to measure the discrepancy between the predicted output of the neural network and the desired output. In language modeling tasks, common loss functions include cross-entropy loss, which quantifies the difference in probability distributions. + Optimization Algorithms: Optimization algorithms determine how the network's parameters are updated based on the calculated gradients during backpropagation. Stochastic Gradient Descent (SGD) is a widely used algorithm that iteratively updates the parameters in the direction that minimizes the loss. Variants of SGD, such as Adam or RMSprop, adaptively adjust the learning rate to accelerate convergence. + Regularization Techniques: Deep learning models are prone to overfitting, where they memorize the training data but fail to generalize well to unseen examples. Regularization techniques such as dropout and weight decay are commonly used to prevent overfitting and improve generalization by adding constraints to the model's parameters. + Training on Large-Scale Datasets: Deep learning models, including large language models, require substantial amounts of labeled training data to learn effectively. Large-scale datasets are crucial to expose the model to diverse language patterns and ensure it captures a broad understanding of language. + Parallel Computing: Training large language models is computationally demanding. To accelerate the training process, parallel computing techniques, such as using multiple GPUs or distributed computing systems, are employed. These techniques allow for efficient processing of large datasets and speeding up the training iterations. + Transfer Learning and Fine-tuning: Transfer learning is a technique where a pre-trained model, trained on a large-scale dataset, is used as a starting point for a new task or dataset. Fine-tuning involves adjusting the pre-trained model's parameters on the new dataset to adapt it to the specific task at hand. This approach significantly reduces the training time and data requirements for new models. + The training process of a large language model typically involves the following steps: + Data Collection: A diverse and comprehensive dataset is collected, which typically consists of a vast range of text from sources like books, websites, articles, and other textual resources. The quality and variety of the dataset are crucial to ensure the model learns a broad understanding of language. + Preprocessing: The collected text data is preprocessed to clean and normalize it. This step involves removing irrelevant characters or symbols, converting the text to a consistent format, and organizing it into smaller units such as sentences or paragraphs. + Tokenization: The preprocessed text is divided into individual tokens, which can be as small as words or even subword units. Tokenization helps in representing and processing the text efficiently during training. + Architecture Design: The model architecture, often based on the transformer architecture, is defined. Transformers are neural network models that excel in capturing long-range dependencies in sequential data, making them well-suited for language modeling tasks. + Model Initialization: The model parameters are randomly initialized to start the training process. These parameters will be adjusted iteratively during training to optimize the model's performance. + Training Loop: The model is trained using a large-scale computational infrastructure. The training loop typically involves several iterations over the dataset, known as epochs. During each epoch, the model processes the input data, generates predictions, and compares them with the expected output. The discrepancy between the predicted and expected output is used to compute a loss, which quantifies the model's performance. + Backpropagation and Optimization: Backpropagation is employed to calculate the gradients of the model's parameters with respect to the loss. These gradients indicate the direction and magnitude of the parameter updates needed to minimize the loss. Optimization algorithms, such as stochastic gradient descent (SGD) or its variants, are then used to update the model's parameters based on the computed gradients. + Iterative Refinement: Steps 6 and 7 are repeated for multiple epochs, gradually refining the model's performance. The model's ability to generate coherent and contextually relevant responses improves as it learns from the dataset. + Evaluation: The trained model is evaluated on a separate dataset to assess its performance and identify areas for improvement. Various metrics, such as perplexity or accuracy, can be used to evaluate the model's language generation capabilities. + Fine-tuning and Iteration: Based on the evaluation results, the model may undergo fine-tuning or further iterations of training to enhance its performance. This process helps in addressing specific limitations or biases and aligning the model's output more closely with desired expectations. + It's important to note that training a large language model from scratch is a computationally intensive process that requires substantial computational resources, including powerful hardware like GPUs or specialized hardware accelerators, and large-scale distributed systems to handle the massive amount of data and model parameters involved. + Here are ten highly recommended books that can help you learn deep learning: + "Deep Learning" by Ian Goodfellow, Yoshua Bengio, and Aaron Courville: + This comprehensive book covers the fundamental concepts of deep learning, including neural networks, optimization algorithms, and regularization techniques. It also explores advanced topics like generative models and deep reinforcement learning. + "Deep Learning with Python" by François Chollet: + Written by the creator of the Keras deep learning library, this book provides a practical introduction to deep learning with Python. It covers essential concepts, tools, and techniques, and includes hands-on examples and case studies. + "Hands-On Machine Learning with Scikit-Learn, Keras, and TensorFlow" by Aurélien Géron: + This book offers a hands-on approach to learning machine learning and deep learning using popular Python libraries such as Scikit-Learn, Keras, and TensorFlow. It covers various algorithms and provides practical examples and exercises. + "Deep Learning for Computer Vision" by Rajalingappaa Shanmugamani: + Focusing on deep learning techniques for computer vision tasks, this book explores topics such as convolutional neural networks (CNNs), image classification, object detection, and image generation. It includes code examples using Python and popular deep learning frameworks. + "Deep Learning: A Practitioner's Approach" by Josh Patterson and Adam Gibson: + This book offers a practical guide to implementing deep learning solutions using the Deeplearning4j library. It covers key concepts, architectures, and techniques, and includes code examples and case studies. + "Grokking Deep Learning" by Andrew Trask: + Geared towards beginners, this book provides an intuitive and accessible introduction to deep learning concepts. It covers neural networks, backpropagation, gradient descent, and other fundamental topics with clear explanations and visualizations. + "Deep Learning for Natural Language Processing" by Palash Goyal, Sumit Pandey, and Karan Jain: + Focusing on deep learning techniques for natural language processing (NLP), this book explores topics like word embeddings, recurrent neural networks (RNNs), and sequence-to-sequence models. It includes code examples using Python and popular NLP libraries. + "Deep Reinforcement Learning" by Pieter Abbeel and John Schulman: + This book provides an in-depth exploration of deep reinforcement learning, a subfield that combines deep learning with reinforcement learning. It covers topics like Q-learning, policy gradients, and deep Q-networks (DQNs) and provides practical examples. + "Deep Learning for Time Series Forecasting" by N.D. Lewis: + Focusing on deep learning techniques for time series data, this book covers topics such as recurrent neural networks (RNNs), long short-term memory (LSTM) networks, and attention models. It includes code examples using Python and popular deep learning frameworks. + "Interpretable Deep Learning" by Christoph Molnar: + This book delves into the challenges and techniques for interpreting and understanding deep learning models. It covers model visualization, feature importance, and other methods for explaining and interpreting deep learning predictions. + These books cover a range of deep learning topics and provide valuable insights and practical guidance for learning and applying deep learning techniques. Choose the ones that align with your interests and learning style to enhance your understanding of deep learning. + Here are 10 popular GitHub projects that can be useful for building large language models (LLMs) or working with natural language processing (NLP) tasks: + TensorFlow: An open-source deep learning framework that provides tools and resources for building and training LLMs. It offers extensive support for various neural network architectures and has a large community. + PyTorch: Another popular deep learning framework that provides a dynamic computational graph and a wide range of tools for building LLMs. It is known for its user-friendly interface and flexibility. + Hugging Face Transformers: A library that provides pre-trained models and a high-level API for natural language understanding (NLU) tasks, including LLMs. It supports popular models like GPT, BERT, and RoBERTa. + Fairseq: A library developed by Facebook AI Research that focuses on sequence modeling tasks, including LLMs. It offers pre-trained models and tools for training and evaluating models using sequence-to-sequence architectures. + AllenNLP: A powerful NLP research library that simplifies the process of building and evaluating deep learning models. It offers pre-built components for common NLP tasks and supports LLMs with various architectures. + OpenAI GPT-3: Although not available on GitHub, OpenAI's GPT-3 language model is widely recognized and can be accessed via the OpenAI API. It offers state-of-the-art language generation capabilities and can be used for various NLP tasks. + BERT: A pre-trained language model developed by Google Research that has achieved exceptional results on various NLP benchmarks. The official implementation is available on GitHub and can be fine-tuned for specific tasks. + spaCy: A popular Python library for NLP tasks that provides efficient and scalable tools for tokenization, named entity recognition, part-of-speech tagging, and more. It integrates well with deep learning frameworks. + FastText: A library developed by Facebook Research that provides efficient tools for text classification and word representation learning. It offers pre-trained word embeddings and supports training LLMs for classification tasks. + NLTK (Natural Language Toolkit): A comprehensive library for NLP tasks in Python. It provides various modules for tokenization, stemming, tagging, parsing, and more. Although it doesn't focus explicitly on LLMs, it is widely used for preprocessing text data in NLP pipelines. + These projects offer a range of resources, pre-trained models, and tools that can assist you in building and working with large language models. Make sure to review the documentation and examples provided by each project to understand their capabilities and how they can be integrated into your workflow. + Here are some popular backend libraries that are commonly used for deep learning: + TensorFlow: Developed by Google's Brain Team, TensorFlow is one of the most widely used deep learning frameworks. It provides a flexible and comprehensive ecosystem for building and deploying machine learning models. TensorFlow offers high-level APIs for easy model construction, as well as lower-level APIs for fine-grained control. It supports distributed computing and has extensive community support. + PyTorch: Developed by Facebook's AI Research lab, PyTorch is known for its simplicity and dynamic computational graph. It allows for intuitive model construction and debugging. PyTorch is widely used in both research and industry due to its flexibility, support for dynamic networks, and strong GPU acceleration capabilities. + Keras: Initially developed as a user-friendly deep learning library, Keras is now integrated as the official high-level API in TensorFlow. It provides a user-friendly and modular interface for building neural networks. Keras abstracts away many complexities and allows users to build models with just a few lines of code. It supports multiple backends, including TensorFlow and Theano. + Theano: Although its development has been discontinued, Theano was one of the first widely-used deep learning libraries. It allows for efficient mathematical operations on multi-dimensional arrays and supports GPU acceleration. Theano was influential in shaping the deep learning landscape and served as a precursor to subsequent frameworks. + Caffe: Developed by the Berkeley Vision and Learning Center (BVLC), Caffe is a popular deep learning framework known for its efficiency and simplicity. It is particularly suitable for convolutional neural networks (CNNs) and image-related tasks. Caffe has a clean and expressive architecture description language that makes it easy to define and train deep models. + MXNet: MXNet is an open-source deep learning framework developed by Apache. It offers a flexible and efficient interface for building and deploying neural networks. MXNet provides a hybrid frontend that allows users to seamlessly switch between symbolic and imperative programming. It is known for its scalability and supports multiple programming languages. + Chainer: Chainer is a flexible deep learning framework that focuses on dynamic neural networks. It allows for intuitive model construction using imperative programming, making it easy to define complex architectures and manipulate data within the network. Chainer is known for its "define-by-run" approach, which facilitates dynamic computations. + Microsoft Cognitive Toolkit (CNTK): CNTK is a deep learning framework developed by Microsoft. It provides a highly efficient and scalable implementation of deep neural networks. CNTK supports both declarative and imperative programming models, making it suitable for both research and production-level deployments. + Deeplearning4j: Deeplearning4j is an open-source deep learning library that focuses on scalability and performance. It is designed to integrate with the Java ecosystem and supports distributed computing. Deeplearning4j provides tools for building various types of neural networks and offers integration with other popular libraries like Hadoop and Spark. + PaddlePaddle: PaddlePaddle (PArallel Distributed Deep LEarning) is a deep learning framework developed by Baidu. It emphasizes scalability and supports large-scale distributed training. PaddlePaddle provides a rich set of built-in models and algorithms, making it accessible to both beginners and advanced users. + Each of these backend libraries offers unique features, performance characteristics, and levels of abstraction. The choice of a backend library depends on factors such as your programming language preferences, the complexity of your models, the availability of community support, and the specific requirements of your deep learning project. + Here's an example code snippet that demonstrates how to create a GPT-Neox20B model using the Hugging Face Transformers library and start fine-tuning it with sample data from the '/tmp/wikitext' directory: + + import torch + from transformers import GPTNeoForCausalLM, GPT2Tokenizer, TextDataset, DataCollatorForLanguageModeling, Trainer, TrainingArguments + + # Load the GPT-Neo model and tokenizer + model = GPTNeoForCausalLM.from_pretrained("EleutherAI/gpt-neo-2.7B") + tokenizer = GPT2Tokenizer.from_pretrained("EleutherAI/gpt-neo-2.7B") + + # Set the path to the training data + data_path = "/tmp/wikitext" + + # Define the dataset and data collator + dataset = TextDataset(tokenizer=tokenizer, file_path=data_path, block_size=128) + data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False) + + # Define the training arguments + training_args = TrainingArguments( + output_dir="./output_dir", # Directory to save the fine-tuned model and logs + overwrite_output_dir=True, + num_train_epochs=3, # Number of training epochs + per_device_train_batch_size=8, + save_steps=500, # Number of updates steps before saving the model + save_total_limit=2, # Limit the total number of saved models + prediction_loss_only=True, + ) + + # Initialize the Trainer + trainer = Trainer( + model=model, + args=training_args, + data_collator=data_collator, + train_dataset=dataset, + ) + + # Start fine-tuning + trainer.train() + + # Save the fine-tuned model + trainer.save_model("./fine-tuned-model") + + To deploy the GPT-Neo model on an HTTP server, you can use a framework like Flask or FastAPI. Here's a basic example using Flask: + + Install Flask: + + pip install flask + + Create a Python file, let's call it app.py, and import the necessary libraries: + + from flask import Flask, request, jsonify + from transformers import GPTNeoForCausalLM, GPT2Tokenizer + + # Initialize Flask app + app = Flask(__name__) + + # Load the pretrained GPT-Neo model and tokenizer + model = GPTNeoForCausalLM.from_pretrained("EleutherAI/gpt-neo-2.7B") + tokenizer = GPT2Tokenizer.from_pretrained("EleutherAI/gpt-neo-2.7B") + Define an endpoint for model inference: + + python + Copy code + @app.route("/generate", methods=["POST"]) + def generate_text(): + data = request.get_json() + input_text = data["input_text"] + + # Tokenize the input text + input_ids = tokenizer.encode(input_text, return_tensors="pt") + + # Generate text with the model + output = model.generate(input_ids, max_length=100, num_return_sequences=1) + generated_text = tokenizer.decode(output[0], skip_special_tokens=True) + + return jsonify({"generated_text": generated_text}) + Run the Flask app: + + python + Copy code + if __name__ == "__main__": + app.run(host="0.0.0.0", port=5000) + The code that includes the text generation sampling functions and provides a commented example: + + import requests + import numpy as np + + class TextGeneratorAPI: + def __init__(self, server_url): + self.server_url = server_url + + def generate_text(self, input_text, sampling_algorithm="greedy", temperature=0.7): + url = f"{self.server_url}/generate" + payload = { + "input_text": input_text, + "sampling_algorithm": sampling_algorithm, + "temperature": temperature + } + response = requests.post(url, json=payload) + generated_text = response.json()["generated_text"] + return generated_text + + def greedy_sampling(self, logits): + return np.argmax(logits) + + def random_sampling(self, logits): + probabilities = np.exp(logits / temperature) + probabilities = probabilities / np.sum(probabilities) + return np.random.choice(len(logits), p=probabilities) + + def top_k_sampling(self, logits, k=10): + indices = np.argsort(logits)[-k:] + probabilities = np.exp(logits[indices] / temperature) + probabilities = probabilities / np.sum(probabilities) + return np.random.choice(indices, p=probabilities) + + def top_p_sampling(self, logits, p=0.9): + sorted_logits = np.sort(logits)[::-1] + cumulative_probs = np.cumsum(np.exp(sorted_logits) / temperature) + indices = np.arange(len(sorted_logits)) + selected_indices = indices[cumulative_probs <= p] + probabilities = np.exp(logits[selected_indices] / temperature) + probabilities = probabilities / np.sum(probabilities) + return np.random.choice(selected_indices, p=probabilities) + In this updated code, the TextGeneratorAPI class includes the additional sampling functions: greedy_sampling, random_sampling, top_k_sampling, and top_p_sampling. These functions take logits (output of the model) as input and return the index of the selected token based on the respective sampling algorithm. + The greedy_sampling function selects the token with the highest probability (argmax) as the next token. The random_sampling function applies a temperature scaling to the logits and then samples from the resulting probability distribution. The top_k_sampling function selects from the top-k tokens with the highest probabilities. The top_p_sampling function selects from the tokens with cumulative probabilities below a certain threshold (top-p). + You can now use the updated TextGeneratorAPI class with the sampling functions. Here's an example: + + api = TextGeneratorAPI(server_url="http://localhost:5000") + + input_text = "Once upon a time" + + # Generate text using different sampling algorithms and temperatures + greedy_text = api.generate_text(input_text, sampling_algorithm="greedy") + random_text = api.generate_text(input_text, sampling_algorithm="random") + top_k_text = api.generate_text(input_text, sampling_algorithm="top_k", temperature=0.8) + top_p_text = api.generate_text(input_text, sampling_algorithm="top_p", temperature=0.9) + + print("Greedy Sampling:", greedy_text) + print("Random Sampling:", random_text) + print("Top-k Sampling:", top_k_text) + print("Top-p Sampling:", top_p_text) + Make sure to adjust the server_url with the appropriate URL of your HTTP server, and ensure that the server is running and accessible before making requests through the API. + """ diff --git a/benchmarks/inference/deepspeedometer/tests/README.md b/benchmarks/inference/deepspeedometer/tests/README.md new file mode 100644 index 000000000..15a5f49f9 --- /dev/null +++ b/benchmarks/inference/deepspeedometer/tests/README.md @@ -0,0 +1,3 @@ +To run the unit tests: + +`python3 -m pytest .` \ No newline at end of file diff --git a/benchmarks/inference/deepspeedometer/tests/__init__.py b/benchmarks/inference/deepspeedometer/tests/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/benchmarks/inference/deepspeedometer/tests/conftest.py b/benchmarks/inference/deepspeedometer/tests/conftest.py new file mode 100644 index 000000000..e2f779c44 --- /dev/null +++ b/benchmarks/inference/deepspeedometer/tests/conftest.py @@ -0,0 +1,95 @@ +import pytest + + +@pytest.fixture(scope="function", params=["facebook/opt-125m"]) +def model(request): + return request.param + + +@pytest.fixture(scope="function", params=["dummy"]) +def api(request): + return request.param + + +@pytest.fixture(scope="function", params=[""]) +def result_dir(request, tmpdir): + if request.param: + return str(request.param) + return str(tmpdir) + + +@pytest.fixture(scope="function", params=[5]) +def num_requests_per_client(request): + return str(request.param) + + +@pytest.fixture(scope="function", params=[16]) +def min_requests(request): + return str(request.param) + + +@pytest.fixture(scope="function", params=[(1, 2)]) +def num_clients(request): + if isinstance(request.param, tuple) or isinstance(request.param, list): + return [str(num) for num in request.param] + else: + return [str(request.param)] + + +@pytest.fixture(scope="function", params=[0]) +def num_config_files(request): + return request.param + + +@pytest.fixture(scope="function") +def config_files(num_config_files, tmp_path): + config_files = [] + for i in range(num_config_files): + config_file = tmp_path / f"config_{i}.yaml" + config_file.touch() + config_files.append(str(config_file)) + return config_files + + +@pytest.fixture(scope="function", params=[""]) +def prompt_length_var(request): + return str(request.param) + + +@pytest.fixture(scope="function", params=[""]) +def max_new_tokens_var(request): + return str(request.param) + + +@pytest.fixture(scope="function") +def benchmark_args( + model, + api, + result_dir, + num_requests_per_client, + min_requests, + num_clients, + config_files, + prompt_length_var, + max_new_tokens_var, +): + args = [] + if model: + args.extend(["--model", model]) + if api: + args.extend(["--api", api]) + if result_dir: + args.extend(["--result_dir", result_dir]) + if num_requests_per_client: + args.extend(["--num_requests_per_client", num_requests_per_client]) + if min_requests: + args.extend(["--min_requests", min_requests]) + if num_clients: + args.extend(["--num_clients"] + num_clients) + if config_files: + args.extend(["--config_file"] + config_files) + if prompt_length_var: + args.extend(["--prompt_length_var", prompt_length_var]) + if max_new_tokens_var: + args.extend(["--max_new_tokens_var", max_new_tokens_var]) + return args diff --git a/benchmarks/inference/deepspeedometer/tests/test_benchmark.py b/benchmarks/inference/deepspeedometer/tests/test_benchmark.py new file mode 100644 index 000000000..2b067d39e --- /dev/null +++ b/benchmarks/inference/deepspeedometer/tests/test_benchmark.py @@ -0,0 +1,17 @@ +import pytest + +from deepspeedometer import parse_args_to_configs, BenchmarkRunner + + +def test_benchmark_runner(benchmark_args, num_clients): + benchmark_config, client_config = parse_args_to_configs(benchmark_args) + benchmark_runner = BenchmarkRunner(benchmark_config, client_config) + benchmark_runner.run() + + expected_results = sum(1 for _ in benchmark_runner._benchmark_settings()) * len( + num_clients + ) + actual_results = len(list(benchmark_runner._get_output_dir().glob("*.json"))) + assert ( + expected_results == actual_results + ), f"Number of result files ({actual_results}) does not match expected number ({expected_results})." diff --git a/benchmarks/inference/deepspeedometer/tests/test_config.py b/benchmarks/inference/deepspeedometer/tests/test_config.py new file mode 100644 index 000000000..d20e0981a --- /dev/null +++ b/benchmarks/inference/deepspeedometer/tests/test_config.py @@ -0,0 +1,32 @@ +import pytest + +import yaml + +import pydantic + +from deepspeedometer import BenchmarkRunner, parse_args_to_configs + + +def test_config(benchmark_args): + benchmark_config, client_config = parse_args_to_configs(benchmark_args) + + +@pytest.mark.parametrize("model", [""]) +def test_config_required_fail(benchmark_args): + with pytest.raises(pydantic.ValidationError): + benchmark_config, client_config = parse_args_to_configs(benchmark_args) + + +@pytest.mark.parametrize("num_config_files", [1]) +def test_config_file(benchmark_args, config_files, num_clients): + # Create a config that would generate 6 benchmark settings + config = {"max_prompt_length": [500, 1300, 2600], "num_clients": [1, 2]} + with open(config_files[0], "w") as f: + yaml.dump(config, f) + + benchmark_config, client_config = parse_args_to_configs(benchmark_args) + benchmark_runner = BenchmarkRunner(benchmark_config, client_config) + benchmark_settings = sum(1 for _ in benchmark_runner._benchmark_settings()) * len( + num_clients + ) + assert benchmark_settings == 6 diff --git a/benchmarks/inference/deepspeedometer/tests/test_early_stop.py b/benchmarks/inference/deepspeedometer/tests/test_early_stop.py new file mode 100644 index 000000000..2a63ba206 --- /dev/null +++ b/benchmarks/inference/deepspeedometer/tests/test_early_stop.py @@ -0,0 +1,23 @@ +import pytest + +from deepspeedometer import parse_args_to_configs, BenchmarkRunner + + +@pytest.mark.parametrize("num_clients", [(1, 2, 4)], indirect=True) +def test_early_stop(benchmark_args): + benchmark_args += [ + "--early_stop_latency", + "1", + "--dummy_client_latency_time", + "2.0", + ] + print(benchmark_args) + benchmark_config, client_config = parse_args_to_configs(benchmark_args) + benchmark_runner = BenchmarkRunner(benchmark_config, client_config) + benchmark_runner.run() + + expected_results = 1 + actual_results = len(list(benchmark_runner._get_output_dir().glob("*.json"))) + assert ( + expected_results == actual_results + ), f"Number of result files ({actual_results}) does not match expected number ({expected_results})." diff --git a/benchmarks/inference/deepspeedometer/tests/test_prompt.py b/benchmarks/inference/deepspeedometer/tests/test_prompt.py new file mode 100644 index 000000000..997a82dd5 --- /dev/null +++ b/benchmarks/inference/deepspeedometer/tests/test_prompt.py @@ -0,0 +1,15 @@ +import pytest + +from deepspeedometer import BenchmarkRunner, parse_args_to_configs + + +@pytest.mark.parametrize("prompt_length_var, max_new_tokens_var", [(0, 0)]) +def test_prompt_length(benchmark_args): + benchmark_config, client_config = parse_args_to_configs(benchmark_args) + benchmark_runner = BenchmarkRunner(benchmark_config, client_config) + num_clients, prompt_config = next(benchmark_runner._benchmark_settings()) + + for prompt in benchmark_runner.prompt_generator(prompt_config, num_prompts=10): + prompt_length = benchmark_runner.prompt_generator.count_tokens(prompt.text) + # Using pytest.approx here because often we will have 1-off errors due to tokenization special tokens + assert prompt_length == pytest.approx(benchmark_runner.config.prompt_length, 1) diff --git a/benchmarks/inference/mii/README.md b/benchmarks/inference/mii/README.md index d9e475cdb..726cad462 100644 --- a/benchmarks/inference/mii/README.md +++ b/benchmarks/inference/mii/README.md @@ -1,39 +1,104 @@ -# Benchmarking Scripts for DeepSpeed-FastGen +# Inference Benchmarking Scripts for vLLM, DeepSpeed-FastGen, and Azure ML endpoints ## Run the Benchmark -The benchmarking scripts use DeepSpeed-FastGen in the persistent mode. -You can start the server with the command below: +The benchmarking scripts use DeepSpeed-FastGen in the persistent mode. You can +run the benchmark using `run_benchmark.py`. This script will run several +combinations of inference servers and clients with different tensor parallel +size, number of model replicas (MII only), number of clients, prompt length, and +max new tokens values. By default, the benchmark will run with the `meta-llama/Llama-2-7b-hf` model. ```bash -python server.py [options] start +python run_benchmark.py ``` -Use the -h option to view all available options. To stop the server, use this command: +Use the -h option to view all available options. Several models have pre-defined +default values, including `meta-llama/Llama-2-{7|13|70}b-hf`, +`tiiuae/falcon-{40|180}B`, `microsoft/phi-2`, and `mistralai/Mixtral-8x7B-v0.1`. +These defaults can be overridden if provided to the `run_benchmark.py` script. +For example, to run `meta-llama/Llama-13b-hf` with a tensor parallel size of `1` +and `2` (instead of the default `1`, `2`, and `4`): -```bash -python server.py stop +```bash +python run_benchmark.py --tp_size 1 2 +``` + +By default the benchmark runs with DeepSpeed-MII as the backend inference +server. The benchmark also supports vLLM and Azure endpoints. To change the +backend to vLLM, provide the `--backend vllm` arg: + +```bash +python run_benchmark.py --backend vllm ``` -Once the server is up and running, initiate the client using the command below. The -h option will display all the possible options. +To benchmark against an Azure endpoint, provide the `--backend aml` as well as +the following values: +- `--aml_api_url`: API URL that points to an AML endpoint +- `--aml_api_key`: API key for the given AML endpoint +- `--deployment_name`: The name of the AML endpoint deployment you want to test against +- `--model`: The name of the HuggingFace-hosted model deployed on the AML endpoint. This is used to load a tokenizer and correctly calculate the number of tokens in the prompts and responses. ```bash -python run_benchmark_client.py [options] +python run_benchmark.py --backend aml --model mistralai/Mixtral-8x7B-v0.1 --deployment_name mistralai-mixtral-8x7b-v01-4 --aml_api_url --aml_api_key ``` -The run_all.sh script performs benchmarks across various model sizes and client numbers. For VLLM benchmarks, use the run_all_vllm.sh script. Results are logged in a directory named logs.[BENCHMARK_PARAMETERS]. +The run_all.sh script performs benchmarks across various models, client numbers, +tensor parallel sizes, etc. This script is intended to be run on a system with +8xA100 (80GB) GPUs available. It will run all the benchmarks (including vLLM) +and collect the data used in our [DeepSpeed-Fastgen +blogs](https://github.com/microsoft/DeepSpeed/tree/master/blogs/deepspeed-fastgen). +Results are collected in `./results/`. ## Analyze the Benchmark Results -The scripts mentioned below were used for generating the plots featured in our blog. Specify the root directory for log files using --log_dir. +The scripts mentioned below were used for generating the plots featured in our +blog. Specify the root directory for log files using `--log_dir` and the backends you wish to run for, e.g. `--backend vllm fastgen aml`. The generated +figures will be saved to `./plots/` + +- `src/plot_th_lat.py`: This script generates charts for throughput and latency across different model sizes and client counts. +- `src/plot_effective_throughput.py`: Use this to chart effective throughput. +- `src/plot_latency_percentile.py`: This script will plot the 50th, 90th, and 95th percentile latencies. +- `src/plot_repl_scale.py`: This script will plot the throughput and number of replicas for a fixed clients/replica per plot. +- `src/plot_tp_sizes.py`: This script will plot latency and TFLOPs per GPU across different tensor parallelism sizes. + +## Throughput Latency Plot Generation Script +The `plot_th_lat.py` throughput-latency plot generation script is generalized for any result output directory, irrespective of where it was run. + +The script uses an **_optional_** `plot_config.yaml` that resides within each result directory and allows for overrides in the plot formatting. An example config file may look like this: +```yaml +label: "vLLM" +color: "purple" +marker: "o" +linestyle: "--" +polyfit_degree: 0 +x_max : 30 +y_max : 10 +``` + +Each of the config parameters is optional, allowing for overriding of only the specific plot aspects required, however, all parameters may also be provided. + +A few nuances for the `polyfit_degree` and `x/y_max` parameters: +- `polyfit_degree`: Specifies the polynomial degree for the 'best fit line'. Specifying `0` removes the best fit line and simply connects the scatter plot points. +- `x/y_max`: Clips the x or y axis data using the specified value as the upper bound. + +An example command executing the script may look something like this: +```bash +DeepSpeedExamples/benchmarks/inference/mii$ python3 src/plot_th_lat.py --data_dirs ./results/results-* --model_name +``` -- `plot_th_lat.py`: This script generates charts for throughput and latency across different model sizes and client counts. -- `plot_effective_throughput.py`: Use this to chart effective throughput. -- `plot_latency_percentile.py`: This script will plot the 50th, 90th, and 95th percentile latencies. +Or each result directory can be enumerated explicitly: +```bash +DeepSpeedExamples/benchmarks/inference/mii$ python3 src/plot_th_lat.py --data_dirs ./results/results-1 ./results/results-2 ./results/results-3 --model_name +``` ## Running an End-to-End Example -To quickly experience the end-to-end process of running our benchmark and getting results, you can use the `run_example.sh`. This script is designed to execute the benchmark with a specific configuration. The plots below will be generated in the charts directory. These plots show the performance as depicted in figure 8 of our blog [post.](https://github.com/microsoft/DeepSpeed/tree/master/blogs/deepspeed-fastgen#f-other-hardware-platforms) +To quickly experience the end-to-end process of running our benchmark and +getting results, you can use the `run_example.sh`. This script is designed to +execute the benchmark with a specific configuration. The plots below will be +generated in the `./plots/` directory. These plots show the performance as +depicted in figure 8 of our blog +[post.](https://github.com/microsoft/DeepSpeed/tree/master/blogs/deepspeed-fastgen#f-other-hardware-platforms) ```bash bash run_example.sh @@ -43,4 +108,4 @@ bash run_example.sh
*Figure 1: Throughput-latency curve and effective throughput of Llama 2 7b using A6000. Runs the client with 60 generation steps and input prompt length of 2600.*
- \ No newline at end of file + diff --git a/benchmarks/inference/mii/plot_config.yaml b/benchmarks/inference/mii/plot_config.yaml new file mode 100644 index 000000000..48a5a3171 --- /dev/null +++ b/benchmarks/inference/mii/plot_config.yaml @@ -0,0 +1,7 @@ +label: "vLLM" +color: "purple" +marker: "o" +linestyle: "--" +polyfit_degree: 0 +x_max : 30 +y_max : 10 diff --git a/benchmarks/inference/mii/plot_effective_throughput.py b/benchmarks/inference/mii/plot_effective_throughput.py deleted file mode 100644 index 350c269c3..000000000 --- a/benchmarks/inference/mii/plot_effective_throughput.py +++ /dev/null @@ -1,176 +0,0 @@ -import argparse -from pathlib import Path -import glob -import matplotlib.pyplot as plt -import numpy as np -import pandas as pd - -from postprocess_results import read_json, get_tokenizer - -RAGGED_BATCH_SIZE = 768 -SLA_PROMPT_TOKENS_PER_SEC = 512 -SLA_GEN_TOKENS_PER_SEC = [1, 2, 3, 4, 6, 8] -EMA_SPAN = 16 - -tp_sizes_all = { - "7b": [1], - "70b": [4, 8] -} - -tp_sizes_test = { - "7b": [1] -} - -prompt_gen_pairs_all = [ - (1200, 60), - (1200, 128), - (2600, 60), - (2600, 128), -] - -prompt_gen_pairs_test = [ - (2600, 60) -] - -def get_args(): - parser = argparse.ArgumentParser() - parser.add_argument("--test", action="store_true") - parser.add_argument("--no_vllm", action="store_true") - parser.add_argument("--log_dir", type=Path, default=".") - parser.add_argument("--out_dir", type=Path, default="charts/goodtput") - args = parser.parse_args() - return args - - -def check_token_latency_step(response_details, token_index): - P50_token_latency = np.percentile([r.token_gen_time[token_index] for r in response_details if len(r.token_gen_time) > token_index], 50) - P90_token_latency = np.percentile([r.token_gen_time[token_index] for r in response_details if len(r.token_gen_time) > token_index], 90) - P99_token_latency = np.percentile([r.token_gen_time[token_index] for r in response_details if len(r.token_gen_time) > token_index], 99) - - return P50_token_latency, P90_token_latency, P99_token_latency - - -def validate_token_cum_latency_SLA(response_detail, sla_token_gen): - cumsum_latencies = np.cumsum(np.array(response_detail.token_gen_time[1:])) - return all([cumsum_latencies[i] <= (1 / sla_token_gen) * (i + 1) for i in range(len(cumsum_latencies))]) - - -def validate_token_ema_latency_SLA(response_detail, sla_token_gen, ema_span): - ema_latency = pd.Series(response_detail.token_gen_time[1:]).ewm(span=ema_span).mean().values.tolist() - return all([t < 1. / sla_token_gen for t in ema_latency]) - - -def validate_prompt_latency_SLA(response_detail, sla_token_gen, f): - tokenizer = get_tokenizer() - prompt_length = len(tokenizer.tokenize(response_detail.prompt)) - prompt_latency_SLA = prompt_length / SLA_PROMPT_TOKENS_PER_SEC - if prompt_latency_SLA < response_detail.token_gen_time[0]: - return False - - if len(response_detail.token_gen_time) == 1: - return True - - return f[0](response_detail, sla_token_gen, *f[1]) - - -def calc_throughput(response_details): - start_time = min([r.start_time for r in response_details]) - end_time = max([r.end_time for r in response_details]) - return len(response_details) / (end_time - start_time) - - -def extract_values(file_pattern, sla_token_gen, validate_func): - files = glob.glob(file_pattern) - print(f"Found {len(files)} files") - goodputs = {} - good_ratios = {} - for f in files: - prof_args, response_details = read_json(f) - client_num = prof_args["client_num"] - num_req_ok = len([r for r in response_details if validate_prompt_latency_SLA(r, sla_token_gen, validate_func)]) - goodputs[client_num] = calc_throughput(response_details) * (num_req_ok / len(response_details)) - good_ratios[client_num] = num_req_ok / len(response_details) - - return goodputs, good_ratios - - -def display_results(model_size, tp, bs, sla_token_gen, prompt, gen, log_dir, out_dir): - if not log_dir.exists(): - print(f"Log directory {log_dir} does not exist") - return - - if not out_dir.exists(): - out_dir.mkdir(parents=True, exist_ok=True) - - print(f"model: {model_size} Prompt: {prompt}, Generation: {gen}, TP: {tp} sla_token_gen: {sla_token_gen}") - - mii_file_pattern = f"{log_dir}/logs.llama2-{model_size}-tp{tp}-b{bs}/llama2-{model_size}-tp{tp}-b{bs}_c*_p{prompt}_g{gen}.json" - if not args.no_vllm: - vllm_file_pattern = f"{log_dir}/logs.vllm-llama2-{model_size}-tp{tp}/vllm-llama2-{model_size}-tp{tp}_c*_p{prompt}_g{gen}.json" - - validate_funcs = [ - (validate_token_cum_latency_SLA, (), "cum"), - (validate_token_ema_latency_SLA, (EMA_SPAN, ), f"ema{EMA_SPAN}"), - ] - - for f in validate_funcs: - - mii_goodputs, mii_good_ratios = extract_values(mii_file_pattern, sla_token_gen, f) - client_num_list = sorted(list(mii_goodputs.keys())) - mii_goodputs_list = [mii_goodputs[client_num] for client_num in client_num_list] - - if not args.no_vllm: - vllm_goodputs, vllm_good_ratios = extract_values(vllm_file_pattern, sla_token_gen, f) - vllm_goodputs_list = [vllm_goodputs[client_num] for client_num in client_num_list] - - # print(f"MII {mii_goodputs_list} ratio={mii_good_ratios}") - # print(f"vLLM {vllm_goodputs_list} ratio={vllm_good_ratios}") - - # Plotting the scatter plot - plt.figure(figsize=(7, 4)) - plt.scatter(client_num_list, mii_goodputs_list, label=f"DeepSpeed-FastGen", marker="o", color="blue") - if not args.no_vllm: - plt.scatter(client_num_list, vllm_goodputs_list, label=f"vLLM", marker="x", color="orange") - - fit_x_list = np.arange(min(client_num_list), max(client_num_list), 0.1) - mii_fit_model = np.polyfit(client_num_list, mii_goodputs_list, 4) - mii_model_fn = np.poly1d(mii_fit_model) - plt.plot(fit_x_list, mii_model_fn(fit_x_list), color="blue", alpha=0.5, linestyle="--") - - if not args.no_vllm: - vllm_fit_model = np.polyfit(client_num_list, vllm_goodputs_list, 4) - vllm_model_fn = np.poly1d(vllm_fit_model) - plt.plot(fit_x_list, vllm_model_fn(fit_x_list), color="orange", alpha=0.5, linestyle="--") - - title = f"Effective throughput (SLA prompt: {SLA_PROMPT_TOKENS_PER_SEC} tokens/s, generation: {sla_token_gen} tokens/s)\n" \ - + f'Llama 2 {model_size.upper()} Prompt: {prompt}, Generation: {gen}, TP: {tp}' - plt.title(title, fontsize=10) - plt.xlabel('Number of clients', fontsize=10) - plt.ylabel('Effective throughput (queries/s)', fontsize=10) - # plt.rcParams['figure.subplot.bottom'] = 0.30 - plt.ylim(bottom=-0.05) - plt.legend() - plt.grid(True) - # plt.show() - out_file = out_dir / f"goodput_llama{model_size}_SLAp{SLA_PROMPT_TOKENS_PER_SEC}g{sla_token_gen}_tp{tp}_b{bs}_p{prompt}g{gen}_{f[2]}.png" - plt.savefig(out_file) - plt.clf() - print(f"Saved {out_file}") - - -if __name__ == "__main__": - args = get_args() - - if args.test: - tp_sizes = tp_sizes_test - prompt_gen_pairs = prompt_gen_pairs_test - else: - tp_sizes = tp_sizes_all - prompt_gen_pairs = prompt_gen_pairs_all - - for model_size, tps in tp_sizes.items(): - for tp in tps: - for prompt, gen in prompt_gen_pairs: - for sla_token_gen in SLA_GEN_TOKENS_PER_SEC: - display_results(model_size, tp, RAGGED_BATCH_SIZE, sla_token_gen, prompt, gen, args.log_dir, args.out_dir) - diff --git a/benchmarks/inference/mii/plot_latency_percentile.py b/benchmarks/inference/mii/plot_latency_percentile.py deleted file mode 100644 index c91c78bf1..000000000 --- a/benchmarks/inference/mii/plot_latency_percentile.py +++ /dev/null @@ -1,110 +0,0 @@ -import argparse -import glob -from pathlib import Path -import matplotlib.pyplot as plt -import numpy as np -import itertools - -from postprocess_results import read_json, get_token_latency - -bs = 768 -SKIP_HEAD_TOKEN_NUM = 2 -SKIP_REQUEST_NUM = 100 - -tp_sizes = { - "70b": [4], -} - -prompt_gen_pairs = [ - (2600, 128), -] - - -def get_args(): - parser = argparse.ArgumentParser() - parser.add_argument("--log_dir", type=Path, default=".") - parser.add_argument("--out_dir", type=Path, default="charts/percentile_token_latency") - args = parser.parse_args() - return args - - -def extract_values(file_pattern): - files = glob.glob(file_pattern) - - latencies = {} - for f in files: - prof_args, response_details = read_json(f) - client_num = prof_args["client_num"] - - response_details.sort(key=lambda r: r.start_time) - response_details = response_details[SKIP_REQUEST_NUM:-SKIP_REQUEST_NUM] - token_latencies = [r.token_gen_time[SKIP_HEAD_TOKEN_NUM:-1] for r in response_details] - - flat_latency_list = list(itertools.chain(*token_latencies)) - latencies[client_num] = flat_latency_list - return latencies - - -def output_charts(model_size, tp, bs, prompt, gen, log_dir, out_dir): - if not log_dir.exists(): - print(f"Log directory {log_dir} does not exist") - return - - if not out_dir.exists(): - out_dir.mkdir(parents=True, exist_ok=True) - - mii_file_pattern = f"{log_dir}/logs.llama2-{model_size}-tp{tp}-b{bs}/llama2-{model_size}-tp{tp}-b{bs}_c*_p{prompt}_g{gen}.json" - vllm_file_pattern = f"{log_dir}/logs.vllm-llama2-{model_size}-tp{tp}/vllm-llama2-{model_size}-tp{tp}_c*_p{prompt}_g{gen}.json" - - mii_latencies = extract_values(mii_file_pattern) - vllm_latencies = extract_values(vllm_file_pattern) - client_num_list = sorted(list(mii_latencies.keys())) - - for client_num in client_num_list: - plt.figure(figsize=(6, 4)) - - percentile = 95 - - P50_vllm_val = np.percentile(vllm_latencies[client_num], 50) - P50_mii_val = np.percentile(mii_latencies[client_num], 50) - P90_vllm_val = np.percentile(vllm_latencies[client_num], 90) - P90_mii_val = np.percentile(mii_latencies[client_num], 90) - P95_vllm_val = np.percentile(vllm_latencies[client_num], 95) - P95_mii_val = np.percentile(mii_latencies[client_num], 95) - - # print(f"P50_vllm_val={P50_vllm_val}") - # print(f"P50_mii_val={P50_mii_val}") - # print(f"P90_vllm_val={P90_vllm_val}") - # print(f"P90_mii_val={P90_mii_val}") - # print(f"P95_vllm_val={P95_vllm_val}") - # print(f"P95_mii_val={P95_mii_val}") - - out_file = out_dir / f"p{percentile}_token_latency_llama{model_size}_c{client_num}_tp{tp}_p{prompt}g{gen}.png" - - x1 = [1, 2, 3] - y1 = [P50_vllm_val, P90_vllm_val, P95_vllm_val] - - x2 = [1.3, 2.3, 3.3] - y2 = [P50_mii_val, P90_mii_val, P95_mii_val] - - label_x = ['P50', 'P90', 'P95'] - - plt.bar(x1, y1, width=0.3, label='vLLM', align="center", color="orange") - plt.bar(x2, y2, width=0.3, label="DeepSpeed-FastGen", align="center", color="blue") - plt.ylabel('Latency', fontsize=14) - plt.legend(loc=2) - - plt.xticks([1.15, 2.15, 3.15], label_x) - - plt.savefig(out_file) - print(f"Saved {out_file}") - - -if __name__ == "__main__": - args = get_args() - - for model_size, tps in tp_sizes.items(): - for tp in tps: - for prompt, gen in prompt_gen_pairs: - output_charts(model_size, tp, bs, prompt, gen, args.log_dir, args.out_dir) - diff --git a/benchmarks/inference/mii/plot_repl_scale.py b/benchmarks/inference/mii/plot_repl_scale.py deleted file mode 100644 index 394c54588..000000000 --- a/benchmarks/inference/mii/plot_repl_scale.py +++ /dev/null @@ -1,95 +0,0 @@ -import glob -import matplotlib.pyplot as plt -import argparse -from pathlib import Path -import numpy as np - -from postprocess_results import read_json, get_summary - -bs = 768 - -REPLICA_NUMS = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16] - -tp_sizes = { - "70b": [4], -} - -prompt_gen_pairs = [ - (2600, 60), -] - -def get_args(): - parser = argparse.ArgumentParser() - parser.add_argument("--log_dir", type=Path, default=".") - parser.add_argument("--out_dir", type=Path, default="charts/repl_scale") - args = parser.parse_args() - return args - - -def extract_values(file_pattern): - files = glob.glob(file_pattern) - - clients = [] - throughputs = [] - latencies = [] - for f in files: - prof_args, response_details = read_json(f) - summary = get_summary(prof_args, response_details) - clients.append(prof_args["client_num"]) - throughputs.append(summary.throughput) - latencies.append(summary.latency) - - return clients, throughputs, latencies - - -def output_charts(model_size, tp, bs, prompt, gen, log_dir, out_dir): - if not log_dir.exists(): - print(f"Log directory {log_dir} does not exist") - return - - if not out_dir.exists(): - out_dir.mkdir(parents=True, exist_ok=True) - - throughputs = {} - for repl in REPLICA_NUMS: - mii_file_pattern = f"{log_dir}/logs.llama2-{model_size}-tp{tp}-b{bs}_repl{repl}/llama2-{model_size}-tp{tp}-b{bs}_repl{repl}_c*_p{prompt}_g{gen}.json" - print(f"Looking for {mii_file_pattern}") - clients, mii_throughputs, mii_latencies = extract_values(mii_file_pattern) - - for c, th in zip(clients, mii_throughputs): - client_per_repl = c // repl - if client_per_repl not in throughputs: - throughputs[client_per_repl] = [] - print(f"Throughput for {client_per_repl} clients: {th}") - throughputs[client_per_repl].append(th) - - for c in throughputs: - - # Plotting the scatter plot - plt.figure(figsize=(6, 4)) - - plt.bar(REPLICA_NUMS, throughputs[c], color="blue", alpha=0.9) - - fit_x_list = np.arange(min(REPLICA_NUMS), max(REPLICA_NUMS), 0.1) - mii_fit_model = np.polyfit(REPLICA_NUMS, throughputs[c], 1) - mii_model_fn = np.poly1d(mii_fit_model) - plt.plot(fit_x_list, mii_model_fn(fit_x_list), color="blue", linestyle="--") - - plt.title(f'Model Llama 2 {model_size.upper()}, Prompt: {prompt}, Generation: {gen}, TP: {tp}') - plt.xlabel('Number of replicas', fontsize=14) - plt.ylabel('Throughput (queries/s)', fontsize=14) - plt.grid(True) - plt.tight_layout() - # plt.show() - out_file = out_dir / f"repl_scale_llama{model_size}_tp{tp}_p{prompt}g{gen}.png" - plt.savefig(out_file) - - -if __name__ == "__main__": - args = get_args() - - for model_size, tps in tp_sizes.items(): - for tp in tps: - for prompt, gen in prompt_gen_pairs: - output_charts(model_size, tp, bs, prompt, gen, args.log_dir, args.out_dir) - diff --git a/benchmarks/inference/mii/plot_th_lat.py b/benchmarks/inference/mii/plot_th_lat.py deleted file mode 100644 index e99dc5a3e..000000000 --- a/benchmarks/inference/mii/plot_th_lat.py +++ /dev/null @@ -1,116 +0,0 @@ -import glob -import matplotlib.pyplot as plt -import argparse -from pathlib import Path -import numpy as np -import pdb -from postprocess_results import read_json, get_summary - -bs = 768 - -tp_sizes_test = { - "7b": [1] -} - -tp_sizes_all = { - "7b": [1], - "70b": [4, 8], -} - -prompt_gen_pairs_test = [ - (2600, 60) -] - -prompt_gen_pairs_all = [ - (1200, 60), - (1200, 128), - (2600, 60), - (2600, 128), -] - -def get_args(): - parser = argparse.ArgumentParser() - parser.add_argument("--test", action="store_true") - parser.add_argument("--no_vllm", action="store_true") - parser.add_argument("--log_dir", type=Path, default=".") - parser.add_argument("--out_dir", type=Path, default="charts/throughput_latency") - args = parser.parse_args() - return args - - -def extract_values(file_pattern): - files = glob.glob(file_pattern) - - print(f"Found {len(files)}") - print('\n'.join(files)) - - clients = [] - throughputs = [] - latencies = [] - for f in files: - prof_args, response_details = read_json(f) - summary = get_summary(prof_args, response_details) - clients.append(prof_args["client_num"]) - throughputs.append(summary.throughput) - latencies.append(summary.latency) - - return clients, throughputs, latencies - - -def output_charts(model_size, tp, bs, prompt, gen, log_dir, out_dir): - if not log_dir.exists(): - print(f"Log directory {log_dir} does not exist") - return - - if not out_dir.exists(): - out_dir.mkdir(parents=True, exist_ok=True) - - mii_file_pattern = f"{log_dir}/logs.llama2-{model_size}-tp{tp}-b{bs}/llama2-{model_size}-tp{tp}-b{bs}_c*_p{prompt}_g{gen}.json" - if not args.no_vllm: - vllm_file_pattern = f"{log_dir}/logs.vllm-llama2-{model_size}-tp{tp}/vllm-llama2-{model_size}-tp{tp}_c*_p{prompt}_g{gen}.json" - - _, mii_throughputs, mii_latencies = extract_values(mii_file_pattern) - if not args.no_vllm: - _, vllm_throughputs, vllm_latencies = extract_values(vllm_file_pattern) - - # Plotting the scatter plot - plt.figure(figsize=(6, 4)) - - if not args.no_vllm: - plt.scatter(vllm_throughputs, vllm_latencies, label=f"vLLM", marker="x", color="orange") - fit_vllm_x_list = np.arange(min(vllm_throughputs), max(vllm_throughputs), 0.01) - vllm_vllm_model = np.polyfit(vllm_throughputs, vllm_latencies, 3) - vllm_model_fn = np.poly1d(vllm_vllm_model) - plt.plot(fit_vllm_x_list, vllm_model_fn(fit_vllm_x_list), color="orange", alpha=0.5, linestyle="--") - - plt.scatter(mii_throughputs, mii_latencies, label=f"DeepSpeed FastGen", marker="o", color="blue") - fit_mii_x_list = np.arange(min(mii_throughputs), max(mii_throughputs), 0.01) - mii_fit_model = np.polyfit(mii_throughputs, mii_latencies, 3) - mii_model_fn = np.poly1d(mii_fit_model) - plt.plot(fit_mii_x_list, mii_model_fn(fit_mii_x_list), color="blue", alpha=0.5, linestyle="--") - - plt.title(f'Model Llama 2 {model_size.upper()}, Prompt: {prompt}, Generation: {gen}, TP: {tp}') - plt.xlabel('Throughput (queries/s)', fontsize=14) - plt.ylabel('Latency', fontsize=14) - plt.legend() - plt.grid(True) - plt.tight_layout() - out_file = out_dir / f"th_lat_curve_llama{model_size}_tp{tp}_p{prompt}g{gen}.png" - print(f"Saving {out_file}") - plt.savefig(out_file) - - -if __name__ == "__main__": - args = get_args() - if args.test: - tp_sizes = tp_sizes_test - prompt_gen_pairs = prompt_gen_pairs_test - else: - tp_sizes = tp_sizes_all - prompt_gen_pairs = prompt_gen_pairs_test_all - - for model_size, tps in tp_sizes.items(): - for tp in tps: - for prompt, gen in prompt_gen_pairs: - output_charts(model_size, tp, bs, prompt, gen, args.log_dir, args.out_dir) - diff --git a/benchmarks/inference/mii/plot_tp_sizes.py b/benchmarks/inference/mii/plot_tp_sizes.py deleted file mode 100644 index 546310258..000000000 --- a/benchmarks/inference/mii/plot_tp_sizes.py +++ /dev/null @@ -1,98 +0,0 @@ -import glob -import matplotlib.pyplot as plt -import argparse -from pathlib import Path -import numpy as np - -from postprocess_results import read_json, get_summary - -bs = 768 - -tp_sizes = { - # "7b": [1], - "13b": [1, 2, 4], - # "70b": [4, 8], -} - -prompt_gen_pairs = [ - (1200, 60), - (1200, 128), - (2600, 60), - (2600, 128), - (2600, 256), -] - -def get_args(): - parser = argparse.ArgumentParser() - parser.add_argument("--log_dir", type=Path, default="logs.release") - parser.add_argument("--out_dir", type=Path, default="charts/tp_sizes") - args = parser.parse_args() - return args - - -def extract_values(file_pattern): - files = glob.glob(file_pattern) - - print(f"Found {len(files)}") - print('\n'.join(files)) - - clients = [] - throughputs = [] - latencies = [] - for f in files: - prof_args, response_details = read_json(f) - summary = get_summary(prof_args, response_details) - clients.append(prof_args["client_num"]) - throughputs.append(summary.throughput) - latencies.append(summary.latency) - - return clients, throughputs, latencies - - -def output_charts(model_size, tps, bs, prompt, gen, log_dir, out_dir): - if not log_dir.exists(): - print(f"Log directory {log_dir} does not exist") - return - - if not out_dir.exists(): - out_dir.mkdir(parents=True, exist_ok=True) - - # Plotting the scatter plot - plt.figure(figsize=(6, 4)) - - colors = ["orange", "green", "brown"] - - for tp, color in zip(tps, colors): - mii_file_pattern = f"{log_dir}/logs.llama2-{model_size}-tp{tp}-b{bs}/llama2-{model_size}-tp{tp}-b{bs}_c*_p{prompt}_g{gen}.json" - _, mii_throughputs, mii_latencies = extract_values(mii_file_pattern) - - if len(mii_throughputs) == 0: - continue - - n_params = int(model_size[:-1]) - tflops_per_query = n_params * (prompt + gen) * 2 * 1e-3 - mii_tflops = [th * tflops_per_query / tp for th in mii_throughputs] - - plt.scatter(mii_tflops, mii_latencies, label=f"TP={tp}", marker="o", color=color) - fit_mii_x_list = np.arange(min(mii_tflops), max(mii_tflops), 0.01) - mii_fit_model = np.polyfit(mii_tflops, mii_latencies, 3) - mii_model_fn = np.poly1d(mii_fit_model) - plt.plot(fit_mii_x_list, mii_model_fn(fit_mii_x_list), color=color, alpha=0.5, linestyle="--") - - plt.title(f'Model Llama 2 {model_size.upper()}, Prompt: {prompt}, Generation: {gen}, TP: {tps}') - plt.xlabel('TFLOPs (per GPU)', fontsize=14) - plt.ylabel('Latency', fontsize=14) - plt.legend() - plt.grid(True) - # plt.show() - out_file = out_dir / f"tp_sizes_llama{model_size}_tp{'_'.join([str(tp) for tp in tps])}_p{prompt}g{gen}.png" - plt.savefig(out_file) - - -if __name__ == "__main__": - args = get_args() - - for model_size, tps in tp_sizes.items(): - for prompt, gen in prompt_gen_pairs: - output_charts(model_size, tps, bs, prompt, gen, args.log_dir, args.out_dir) - diff --git a/benchmarks/inference/mii/postprocess_results.py b/benchmarks/inference/mii/postprocess_results.py deleted file mode 100644 index cb2000d5f..000000000 --- a/benchmarks/inference/mii/postprocess_results.py +++ /dev/null @@ -1,112 +0,0 @@ -import argparse -from pathlib import Path -import json -import numpy as np -from statistics import mean -from functools import reduce -from dataclasses import dataclass -from typing import List - -from transformers import AutoTokenizer - - -tokenizer = None - - -@dataclass -class ResponseDetails: - generated_tokens: List[str] - prompt: str - start_time: float - end_time: float - model_time: float - token_gen_time: List[float] - - -@dataclass -class ProfilingSummary: - throughput: float - latency: float - token_gen_latency: float - first_token_latency: float - tokens_per_sec: float - - -def parse_args(): - parser = argparse.ArgumentParser(description="Postprocess results") - parser.add_argument('-i', '--input_path', type=Path, default="results.json") - - args = parser.parse_args() - return args - - -def get_tokenizer(): - global tokenizer - if tokenizer is None: - tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-hf") - return tokenizer - - -def read_json(file_path): - with open(file_path, 'r') as f: - data = json.load(f) - - args = data["args"] - - response_details = [] - for response in data["response_details"]: - response_details.append(ResponseDetails(**response)) - - return args, response_details - - -def get_summary(args, response_details): - client_num = args["client_num"] - - # Calculate latency and throughput using P95 latency - latency = mean([r.end_time - r.start_time for r in response_details]) - throughput = client_num / latency - - tokens_per_sec = mean([(len(get_tokenizer().tokenize(r.prompt)) + len(r.generated_tokens)) / (r.end_time - r.start_time) for r in response_details]) - first_token_latency = mean([r.token_gen_time[0] for r in response_details]) - - token_gen_latency_flat = reduce(list.__add__, [r.token_gen_time[1:-1] for r in response_details if len(r.token_gen_time) > 2]) - token_gen_latency = mean([t for t in token_gen_latency_flat]) - - return ProfilingSummary(throughput, latency, token_gen_latency, first_token_latency, tokens_per_sec) - - -def get_token_latency(response_details, percentile=None, variance=False, cumulative=False): - req_latencies = [r.token_gen_time for r in response_details] - if cumulative: - req_latencies = [np.cumsum(np.array(r.token_gen_time)).tolist() for r in response_details] - max_gen_length = max([len(r.generated_tokens) for r in response_details]) - latency = [] - for i in range(max_gen_length): - if variance: - token_latency_step = np.var([latency[i] for latency in req_latencies if len(latency) > i]) - if percentile is None: - token_latency_step = [latency[i] for latency in req_latencies if len(latency) > i] - else: - token_latency_step = np.percentile([latency[i] for latency in req_latencies if len(latency) > i], percentile) - - latency.append(token_latency_step) - - return latency - - -def get_token_acc_latency(response_details, percentile=99): - return get_token_latency(response_details, percentile, cumulative=True) - - -if __name__ == "__main__": - args = parse_args() - prof_args, response_details = read_json(args.input_path) - - ps = get_summary(prof_args, response_details) - print(f"Deployment: {prof_args['deployment_name']} Clients: {prof_args['client_num']}, " - + f"Query throughput: {ps.throughput:.3f} queries/s, " - + f"Token throughput (total): {ps.tokens_per_sec:.3f} tokens/s, " - + f"Query latency: {ps.latency:.3f} s, " - + f"Token generation latency: {ps.token_gen_latency:.3f} s/token, " - + f"First token received: {ps.first_token_latency:.3f} s") diff --git a/benchmarks/inference/mii/requirements.txt b/benchmarks/inference/mii/requirements.txt new file mode 100644 index 000000000..9f338ace5 --- /dev/null +++ b/benchmarks/inference/mii/requirements.txt @@ -0,0 +1,6 @@ +transformers +matplotlib +deepspeed-mii>=0.2.0 +vllm>=0.2.7 +numpy +tabulate diff --git a/benchmarks/inference/mii/run_all.sh b/benchmarks/inference/mii/run_all.sh index ca504a6c9..7c9311aea 100644 --- a/benchmarks/inference/mii/run_all.sh +++ b/benchmarks/inference/mii/run_all.sh @@ -1,25 +1,15 @@ -RAGGED_BATCH_SIZE=768 -PARAM_SIZES=(7b 13b 70b) +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 -declare -A TP_SIZES -TP_SIZES["7b"]="1" -TP_SIZES["13b"]="1:2:4" -TP_SIZES["70b"]="4:8" +# DeepSpeed Team -for PARAM_SIZE in ${PARAM_SIZES[@]}; do - - IFS=':' read -ra TP_VALUES <<< ${TP_SIZES[${PARAM_SIZE}]} - for TP in ${TP_VALUES[@]}; do - DEPLOYMENT_NAME=llama2-${PARAM_SIZE}-tp${TP}-b${RAGGED_BATCH_SIZE} - python server.py --model_name meta-llama/Llama-2-${PARAM_SIZE}-hf -d ${DEPLOYMENT_NAME} -m ${TP} -b ${RAGGED_BATCH_SIZE} start +MODELS=(meta-llama/Llama-2-7b-hf meta-llama/Llama-2-13b-hf meta-llama/Llama-2-70b-hf tiiuae/falcon-40B tiiuae/falcon-180B microsoft/phi-2 mistralai/Mixtral-8x7B-v0.1) - DEPLOYMENT_NAME=${DEPLOYMENT_NAME} PROMPT_LENGTH=2600 MAX_NEW_TOKENS=60 bash ./run_benchmark_client.sh - DEPLOYMENT_NAME=${DEPLOYMENT_NAME} PROMPT_LENGTH=2600 MAX_NEW_TOKENS=128 bash ./run_benchmark_client.sh - DEPLOYMENT_NAME=${DEPLOYMENT_NAME} PROMPT_LENGTH=1200 MAX_NEW_TOKENS=60 bash ./run_benchmark_client.sh - DEPLOYMENT_NAME=${DEPLOYMENT_NAME} PROMPT_LENGTH=1200 MAX_NEW_TOKENS=128 bash ./run_benchmark_client.sh - - echo "Stopping server" - python server.py -d ${DEPLOYMENT_NAME} stop - sleep 120 - done +for MODEL in ${MODELS[@]}; do + python ./run_benchmark.py --model ${MODEL} --stream --backend fastgen + python ./run_benchmark.py --model ${MODEL} --stream --backend vllm done + +# Extra runs for Mixtral with non-default settings +python ./run_benchmark.py --model mistralai/Mixtral-8x7B-v0.1 --stream --tp_size 4 --mean_prompt_length 500 --mean_max_new_tokens 150 500 1024 --backend fastgen +python ./run_benchmark.py --model mistralai/Mixtral-8x7B-v0.1 --stream --tp_size 4 --mean_prompt_length 500 --mean_max_new_tokens 150 500 1024 --backend vllm \ No newline at end of file diff --git a/benchmarks/inference/mii/run_all_replica.sh b/benchmarks/inference/mii/run_all_replica.sh deleted file mode 100644 index b3fba0408..000000000 --- a/benchmarks/inference/mii/run_all_replica.sh +++ /dev/null @@ -1,25 +0,0 @@ -RAGGED_BATCH_SIZE=768 -PARAM_SIZES=(7b) -REPLICA_NUMS=(1) - -declare -A TP_SIZES -TP_SIZES["7b"]="4" -TP_SIZES["13b"]="1" -TP_SIZES["70b"]="4" - -for PARAM_SIZE in ${PARAM_SIZES[@]}; do - IFS=':' read -ra TP_VALUES <<< ${TP_SIZES[${PARAM_SIZE}]} - for TP in ${TP_VALUES[@]}; do - for REPL in ${REPLICA_NUMS[@]}; do - DEPLOYMENT_NAME=llama2-${PARAM_SIZE}-tp${TP}-b${RAGGED_BATCH_SIZE}_repl${REPL} - python server.py --model_name meta-llama/Llama-2-${PARAM_SIZE}-hf -d ${DEPLOYMENT_NAME} -m ${TP} -r ${REPL} -b ${RAGGED_BATCH_SIZE} start - - REQUEST_NUM=$((256 * ${REPL})) - DEPLOYMENT_NAME=${DEPLOYMENT_NAME} PROMPT_LENGTH=2600 MAX_NEW_TOKENS=60 CLIENT_NUMS=$((16 * ${REPL})) REQUEST_NUM=$((256 * ${REPL})) bash ./run_bench_client_num.sh - - echo "Stopping server" - python server.py -d ${DEPLOYMENT_NAME} stop - sleep 120 - done - done -done diff --git a/benchmarks/inference/mii/run_all_vllm.sh b/benchmarks/inference/mii/run_all_vllm.sh deleted file mode 100644 index 572377f13..000000000 --- a/benchmarks/inference/mii/run_all_vllm.sh +++ /dev/null @@ -1,26 +0,0 @@ -RAGGED_BATCH_SIZE=768 -PARAM_SIZES=(7b 13b 70b) - -declare -A TP_SIZES -TP_SIZES["7b"]="1" -TP_SIZES["13b"]="1:2:4" -TP_SIZES["70b"]="4:8" - -for PARAM_SIZE in ${PARAM_SIZES[@]}; do - - IFS=':' read -ra TP_VALUES <<< ${TP_SIZES[${PARAM_SIZE}]} - for TP in ${TP_VALUES[@]}; do - DEPLOYMENT_NAME=vllm-llama2-${PARAM_SIZE}-tp${TP} - python -m vllm.entrypoints.api_server --host 127.0.0.1 --port 26500 --tensor-parallel-size ${TP} --model meta-llama/Llama-2-${PARAM_SIZE}-hf & - sleep 60 - - DEPLOYMENT_NAME=${DEPLOYMENT_NAME} PROMPT_LENGTH=2600 MAX_NEW_TOKENS=60 VLLM="--vllm" bash ./run_benchmark_client.sh - DEPLOYMENT_NAME=${DEPLOYMENT_NAME} PROMPT_LENGTH=2600 MAX_NEW_TOKENS=128 VLLM="--vllm" bash ./run_benchmark_client.sh - DEPLOYMENT_NAME=${DEPLOYMENT_NAME} PROMPT_LENGTH=1200 MAX_NEW_TOKENS=60 VLLM="--vllm" bash ./run_benchmark_client.sh - DEPLOYMENT_NAME=${DEPLOYMENT_NAME} PROMPT_LENGTH=1200 MAX_NEW_TOKENS=128 VLLM="--vllm" bash ./run_benchmark_client.sh - - echo "Stopping server" - pkill -u ${USER} -f vllm.entrypoints.api_server - sleep 30 - done -done diff --git a/benchmarks/inference/mii/run_aml.sh b/benchmarks/inference/mii/run_aml.sh new file mode 100644 index 000000000..90ad50e2c --- /dev/null +++ b/benchmarks/inference/mii/run_aml.sh @@ -0,0 +1,20 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +# Run benchmark against AML endpoint +python ./run_benchmark.py \ + --model \ + --deployment_name \ + --aml_api_url \ + --aml_api_key \ + --mean_prompt_length 2600 \ + --mean_max_new_tokens 60 \ + --num_requests 256 \ + --backend aml + +### Gernerate the plots +python ./src/plot_th_lat.py + +echo "Find figures in ./plots/ and log outputs in ./results/" \ No newline at end of file diff --git a/benchmarks/inference/mii/run_benchmark.py b/benchmarks/inference/mii/run_benchmark.py new file mode 100644 index 000000000..0a2e0e457 --- /dev/null +++ b/benchmarks/inference/mii/run_benchmark.py @@ -0,0 +1,44 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +from src.client import run_client +from src.server import start_server, stop_server +from src.utils import ( + get_args_product, + parse_args, + print_summary, + results_exist, + save_json_results, + CLIENT_PARAMS, + SERVER_PARAMS, +) + + +def run_benchmark() -> None: + args = parse_args(server_args=True, client_args=True) + + for server_args in get_args_product(args, which=SERVER_PARAMS): + if server_args.backend != "aml" and not server_args.client_only: + start_server(server_args) + + for client_args in get_args_product(server_args, which=CLIENT_PARAMS): + if results_exist(client_args) and not args.overwrite_results: + print( + f"Found existing results and skipping current setting. To ignore existing results, use --overwrite_results" + ) + continue + + if client_args.num_requests is None: + client_args.num_requests = client_args.num_clients * 4 + 32 + response_details = run_client(client_args) + print_summary(client_args, response_details) + save_json_results(client_args, response_details) + + if server_args.backend != "aml" and not server_args.client_only: + stop_server(server_args) + + +if __name__ == "__main__": + run_benchmark() diff --git a/benchmarks/inference/mii/run_benchmark_client.py b/benchmarks/inference/mii/run_benchmark_client.py deleted file mode 100644 index caf20351e..000000000 --- a/benchmarks/inference/mii/run_benchmark_client.py +++ /dev/null @@ -1,285 +0,0 @@ -import os -import time -import random -import argparse -import queue -import multiprocessing -import threading -from statistics import mean -from dataclasses import dataclass, asdict -from typing import List, Iterable -from pathlib import Path -from datetime import datetime -import numpy as np - -from transformers import AutoTokenizer -from random_query_generator import RandomQueryGenerator -from sample_input import all_text -import time -import json -import asyncio -import requests - -from postprocess_results import get_summary, ResponseDetails - -MAX_PROMPT_LENGTH = 4000 -PROMPT_LENGTH_VAR = 0.3 -MAX_NEW_TOKENS_VAR = 0.3 - -def parse_args(): - parser = argparse.ArgumentParser(description="Benchmark MII services") - parser.add_argument("-k", - "--max_new_tokens", - type=int, - default=60, - help="min and max num tokens argument for huggingface") - parser.add_argument("-d", - "--deployment_name", - type=str, - default="benchmark_deployment") - parser.add_argument("-n", - "--num_queries", - type=int, - help="number of queries to run", - default=10) - parser.add_argument("-w", - "--warmup", - type=int, - help="number of queries for warming up", - default=1) - parser.add_argument("-c", - "--client_num", - type=int, - help="number of parallel client processes", - default=2) - parser.add_argument("-l", - "--prompt_length", - type=int, - default=2600) - parser.add_argument('--use_thread', action='store_true', - help='use thread to run parallel clients, otherwise use multiprocessing', - default=False) - parser.add_argument('--stream', action='store_true', default=True) - parser.add_argument('--vllm', action='store_true', default=False) - parser.add_argument('-o', '--out_json_path', type=Path, default=None) - - args = parser.parse_args() - return args - - -def call_mii(client, input_tokens, max_new_tokens, stream): - output_tokens = [] - token_gen_time = [] - time_last_token = 0 - - def callback(response): - nonlocal time_last_token - # print(f"Received: {response[0].generated_text} time_last_token={time_last_token}") - output_tokens.append(response[0].generated_text) - time_now = time.time() - token_gen_time.append(time_now - time_last_token) - time_last_token = time_now - - time_last_token = start_time = time.time() - token_gen_time = [] - if stream: - output_tokens = [] - client.generate( - input_tokens, max_new_tokens=max_new_tokens, - streaming_fn=callback) - else: - result = client.generate( - input_tokens, max_new_tokens=max_new_tokens) - output_tokens = result[0].generated_text - - return ResponseDetails( - generated_tokens=output_tokens, - prompt=input_tokens, - start_time=start_time, - end_time=time.time(), - model_time=0, - token_gen_time=token_gen_time) - - -def call_vllm(input_tokens, max_new_tokens, stream=True): - api_url = "http://localhost:26500/generate" - headers = {"User-Agent": "Benchmark Client"} - pload = { - "prompt": input_tokens, - "n": 1, - "use_beam_search": False, - "temperature": 1.0, - "top_p": 0.9, - "max_tokens": max_new_tokens, - "ignore_eos": False, - "stream": stream, - } - def clear_line(n: int = 1) -> None: - LINE_UP = '\033[1A' - LINE_CLEAR = '\x1b[2K' - for _ in range(n): - print(LINE_UP, end=LINE_CLEAR, flush=True) - - def get_streaming_response(response: requests.Response, time_last_token) -> Iterable[List[str]]: - for chunk in response.iter_lines(chunk_size=8192, decode_unicode=False, - delimiter=b"\0"): - if chunk: - data = json.loads(chunk.decode("utf-8")) - output = data["text"][0] - time_now = time.time() - yield output, time_now - time_last_token - time_last_token = time_now - - def get_response(response: requests.Response) -> List[str]: - data = json.loads(response.content) - output = data["text"] - return output - - start_time = time.time() - response = requests.post(api_url, headers=headers, json=pload, stream=stream) - if stream: - token_gen_time = [] - for h, t in get_streaming_response(response, start_time): - output = h - token_gen_time.append(t) - - return ResponseDetails( - generated_tokens=output, - prompt=input_tokens, - start_time=start_time, - end_time=time.time(), - model_time=0, - token_gen_time=token_gen_time) - else: - output = get_response(response) - raise NotImplementedError("Not implemented for non-streaming") - - -def _run_parallel(deployment_name, warmup, barrier, query_queue, result_queue, client_num, stream, vllm): - pid = os.getpid() - session_id = f"test_session_p{pid}_t{threading.get_ident()}" - - event_loop = asyncio.new_event_loop() - asyncio.set_event_loop(event_loop) - if not vllm: - import mii - client = mii.client(deployment_name) - - barrier.wait() - - for _ in range(warmup): - print(f"warmup queue size: {query_queue.qsize()} ({pid})", flush=True) - input_tokens, req_max_new_tokens = query_queue.get(timeout=1.0) - - if vllm: - call_vllm(input_tokens, req_max_new_tokens, stream) - else: - call_mii(client, input_tokens, req_max_new_tokens, stream) - - barrier.wait() - - time.sleep(random.uniform(0, client_num) * 0.01) - try: - while not query_queue.empty(): - print(f"queue size: {query_queue.qsize()} ({pid})", flush=True) - input_tokens, req_max_new_tokens = query_queue.get(timeout=1.0) - - # Set max_new_tokens following normal distribution - if vllm: - r = call_vllm(input_tokens, req_max_new_tokens) - else: - r = call_mii(client, input_tokens, req_max_new_tokens, stream) - - result_queue.put(r) - except queue.Empty: - print(f"queue is empty ({pid})") - - print(f"Worker ({pid}) finished. session_id: {session_id}") - - -def run_client(client_num, deployment_name, prompt_length, max_new_tokens, num_queries, warmup, stream, vllm, use_thread=False): - """ - Run MII client for benchmarking. The scenario is a bit complicated: - 1. The main process puts `num_queries` queries into the input queue - 2. Each client runs `warmup` iterations () taking the queries from the input queue - 3. --- barrier --- - 4. The main process marks the start time - 5a. All clients send `num_queries' query in total and put the results into the result queue - 5b. The main process takes the results from the result queue (in parallel with 5a) - 6. The main process marks the end time after receiving `num_queries' results - """ - - if use_thread: - runnable_cls = threading.Thread - barrier_cls = threading.Barrier - queue_cls = queue.Queue - else: - runnable_cls = multiprocessing.Process - barrier_cls = multiprocessing.Barrier - queue_cls = multiprocessing.Queue - - barrier = barrier_cls(client_num + 1) - query_queue = queue_cls() - result_queue = queue_cls() - - processes = [runnable_cls(target=_run_parallel, - args=(deployment_name, warmup, barrier, query_queue, result_queue, client_num, stream, vllm)) - for i in range(client_num)] - for p in processes: - p.start() - - tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-hf") - query_generator = RandomQueryGenerator(all_text, tokenizer, seed=42) - MAX_PROMPT_LENGTH = 4000 - request_text = query_generator.get_random_request_text(prompt_length, prompt_length*PROMPT_LENGTH_VAR, MAX_PROMPT_LENGTH, num_queries + warmup*client_num) - - for t in request_text: - req_max_new_tokens = int(np.random.normal(max_new_tokens, MAX_NEW_TOKENS_VAR*max_new_tokens)) - query_queue.put((t, req_max_new_tokens)) - - # Tokenizers must be initialized after fork. - # So we need to fork before putting inputs to the queue. - # We need this barrier to stop child processse from taking inputs before the main process puts them - barrier.wait() - # This barrier is to make sure that all clients have finished warmup - barrier.wait() - - response_details = [] - while len(response_details) < num_queries: - res = result_queue.get() - # vLLM returns concatinated tokens - if vllm: - all_tokens = tokenizer.tokenize(res.generated_tokens) - res.generated_tokens = all_tokens[len(tokenizer.tokenize(res.prompt)):] - response_details.append(res) - - return response_details - -if __name__ == "__main__": - args = parse_args() - print(args) - - if args.out_json_path is not None and not args.out_json_path.parent.exists(): - raise ValueError(f"Parent directory of {args.out_json_path}") - - response_details = run_client(args.client_num, args.deployment_name, - args.prompt_length, - args.max_new_tokens, args.num_queries, args.warmup, - args.stream, args.vllm, args.use_thread) - - args_dict = vars(args) - ps = get_summary(args_dict, response_details) - print(f"Deployment: {args.deployment_name} Clients: {args.client_num}, " - + f"Prompt (mean): {args.prompt_length} tokens, " - + f"Generation (mean): {args.max_new_tokens} tokens, " - + f"Query throughput: {ps.throughput:.3f} queries/s, " - + f"Token throughput (total): {ps.tokens_per_sec:.3f} tokens/s, " - + f"Query latency: {ps.latency:.3f} s, " - + f"Token generation latency: {ps.token_gen_latency:.3f} s/token, " - + f"First token received: {ps.first_token_latency:.3f} s") - - if args.out_json_path is not None: - with open(args.out_json_path, "w") as f: - args_dict["out_json_path"] = str(args.out_json_path) # Path is not JSON serializable - data = {"args": args_dict, "time": str(datetime.now()), "response_details": [asdict(r) for r in response_details]} - json.dump(data, f, indent=2) diff --git a/benchmarks/inference/mii/run_benchmark_client.sh b/benchmarks/inference/mii/run_benchmark_client.sh deleted file mode 100644 index 318e9092e..000000000 --- a/benchmarks/inference/mii/run_benchmark_client.sh +++ /dev/null @@ -1,23 +0,0 @@ -#!/bin/bash - -DEPLOYMENT_NAME=${DEPLOYMENT_NAME:-llama2-7b} -VLLM=${VLLM:-""} - -CLIENT_NUMS=${CLIENT_NUMS:-1 2 4 6 8 12 16 20 24 28 32} -MAX_NEW_TOKENS=${MAX_NEW_TOKENS:-60} -PROMPT_LENGTH=${PROMPT_LENGTH:-3072} -REQUEST_NUM=${REQUEST_NUM:-512} - -LOG_DIR=logs.${DEPLOYMENT_NAME} -mkdir -p ${LOG_DIR} - -for client_num in ${CLIENT_NUMS[@]}; do - RESULT_FILE=${DEPLOYMENT_NAME}_c${client_num}_p${PROMPT_LENGTH}_g${MAX_NEW_TOKENS}.json - - python run_benchmark_client.py -w 1 \ - -d ${DEPLOYMENT_NAME} -n ${REQUEST_NUM} -c ${client_num} \ - -k ${MAX_NEW_TOKENS} -l ${PROMPT_LENGTH} \ - -o ${LOG_DIR}/${RESULT_FILE} \ - ${VLLM} --stream \ - 2>&1 | tee ${LOG_DIR}/bench_client_num_c${client_num}_p${PROMPT_LENGTH}_g${MAX_NEW_TOKENS}.log -done diff --git a/benchmarks/inference/mii/run_example.sh b/benchmarks/inference/mii/run_example.sh index ece8393ed..07af03260 100644 --- a/benchmarks/inference/mii/run_example.sh +++ b/benchmarks/inference/mii/run_example.sh @@ -1,19 +1,20 @@ -### Run the server -RAGGED_BATCH_SIZE=768 -PARAM_SIZES=(7b) -DEPLOYMENT_NAME=llama2-7b-tp1-b768 -python server.py --model_name meta-llama/Llama-2-7b-hf -d llama2-7b-tp1-b768 -m 1 -b 768 start +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 -### This command will run the client with 60 generation steps and input prompt length of 2600 -DEPLOYMENT_NAME=${DEPLOYMENT_NAME} PROMPT_LENGTH=2600 MAX_NEW_TOKENS=60 bash ./run_benchmark_client.sh +# DeepSpeed Team -### Stop the server -echo "Stopping server" -python server.py -d ${DEPLOYMENT_NAME} stop -sleep 120 +# Run benchmark +python ./run_benchmark.py \ + --model meta-llama/Llama-2-7b-hf \ + --tp_size 1 \ + --num_replicas 1 \ + --max_ragged_batch_size 768 \ + --mean_prompt_length 2600 \ + --mean_max_new_tokens 60 \ + --stream \ + --backend fastgen \ ### Gernerate the plots -python plot_th_lat.py --log_dir . --test --no_vllm -python plot_effective_throughput.py --log_dir . --test --no_vllm +python ./src/plot_th_lat.py -echo "Find the plots in the charts directory and the logs inside logs.llama2-7b-tp1-b768" +echo "Find figures in ./plots/ and log outputs in ./results/" \ No newline at end of file diff --git a/benchmarks/inference/mii/run_fp6.sh b/benchmarks/inference/mii/run_fp6.sh new file mode 100644 index 000000000..42c4fdbf8 --- /dev/null +++ b/benchmarks/inference/mii/run_fp6.sh @@ -0,0 +1,10 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +MODELS=(NousResearch/Llama-2-70b-hf) + +for MODEL in ${MODELS[@]}; do + python ./run_benchmark.py --model ${MODEL} --num_requests 128 --stream --backend fastgen --fp6 --tp_size 1 +done \ No newline at end of file diff --git a/benchmarks/inference/mii/server.py b/benchmarks/inference/mii/server.py deleted file mode 100644 index 2e6164187..000000000 --- a/benchmarks/inference/mii/server.py +++ /dev/null @@ -1,83 +0,0 @@ -# Copyright (c) Microsoft Corporation. -# SPDX-License-Identifier: Apache-2.0 - -# DeepSpeed Team -import mii -import argparse - -from mii.constants import DeploymentType - -from deepspeed.inference import RaggedInferenceEngineConfig, DeepSpeedTPConfig -from deepspeed.inference.v2.ragged import DSStateManagerConfig - -def start_server(model_name, - deployment_name, - task, - tensor_parallel, - replica_num, - max_ragged_batch_size): - tp_config = DeepSpeedTPConfig(tp_size=tensor_parallel) - mgr_config = DSStateManagerConfig(max_ragged_batch_size=max_ragged_batch_size, max_ragged_sequence_count=max_ragged_batch_size) - inference_config = RaggedInferenceEngineConfig(tensor_parallel=tp_config, - state_manager=mgr_config) - - mii.serve( - model_name, - deployment_name=deployment_name, - tensor_parallel=tensor_parallel, - task=task, - inference_engine_config=inference_config, - replica_num=replica_num - ) - -def stop_server(deployment_name): - mii.client(deployment_name).terminate_server() - - -def parse_args(): - parser = argparse.ArgumentParser() - parser.add_argument("--model_name", - type=str, - default="meta-llama/Llama-2-7b-hf", - help="Name of the model in the model_files to benchmark") - parser.add_argument("-d", - "--deployment_name", - type=str, - default="benchmark_deployment") - parser.add_argument("-t", "--task", type=str, - help="Task type. Currently only text-generation is supported", - default="text-generation") - parser.add_argument("-m", - "--tensor_parallel", - type=int, - help="Degree of tensor (model) parallelism", - default=1) - parser.add_argument("-b", - "--ragged_batch_size", - type=int, - help="Max batch size for ragged batching", - default=768) - parser.add_argument("-r", - "--replica_num", - type=int, - help="Number of replicas for load balancing", - default=1) - parser.add_argument("cmd", help="start, stop, or restart") - return parser.parse_args() - - -if __name__ == "__main__": - args = parse_args() - - if args.cmd == "start": - start_server(args.model_name, - args.deployment_name, - args.task, - args.tensor_parallel, - args.replica_num, - args.ragged_batch_size) - elif args.cmd == "stop": - print("running stop") - stop_server(args.deployment_name) - else: - raise ValueError(f"Unknown command: {args.cmd}") diff --git a/benchmarks/inference/mii/src/__init__.py b/benchmarks/inference/mii/src/__init__.py new file mode 100644 index 000000000..208299fb8 --- /dev/null +++ b/benchmarks/inference/mii/src/__init__.py @@ -0,0 +1,4 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team diff --git a/benchmarks/inference/mii/src/client.py b/benchmarks/inference/mii/src/client.py new file mode 100644 index 000000000..85f5207ea --- /dev/null +++ b/benchmarks/inference/mii/src/client.py @@ -0,0 +1,401 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +import argparse +import asyncio +import json +import multiprocessing +import os +import queue +import random +import requests +import threading +import time +from typing import List, Iterable, Union + +import numpy as np +from transformers import AutoTokenizer + +try: + from .postprocess_results import ResponseDetails + from .random_query_generator import RandomQueryGenerator + from .sample_input import all_text + from .utils import parse_args, print_summary, get_args_product, CLIENT_PARAMS +except ImportError: + from postprocess_results import ResponseDetails + from random_query_generator import RandomQueryGenerator + from sample_input import all_text + from utils import parse_args, print_summary, get_args_product, CLIENT_PARAMS + + +def call_fastgen( + input_tokens: str, max_new_tokens: int, args: argparse.Namespace +) -> ResponseDetails: + import mii + + client = mii.client(args.deployment_name) + + output_tokens = [] + token_gen_time = [] + time_last_token = 0 + + def callback(response): + nonlocal time_last_token + # print(f"Received: {response[0].generated_text} time_last_token={time_last_token}") + output_tokens.append(response[0].generated_text) + time_now = time.time() + token_gen_time.append(time_now - time_last_token) + time_last_token = time_now + + time_last_token = start_time = time.time() + token_gen_time = [] + if args.stream: + output_tokens = [] + client.generate( + input_tokens, max_new_tokens=max_new_tokens, streaming_fn=callback + ) + else: + result = client.generate(input_tokens, max_new_tokens=max_new_tokens) + output_tokens = result[0].generated_text + + return ResponseDetails( + generated_tokens=output_tokens, + prompt=input_tokens, + start_time=start_time, + end_time=time.time(), + model_time=0, + token_gen_time=token_gen_time, + ) + + +def call_vllm( + input_tokens: str, max_new_tokens: int, args: argparse.Namespace +) -> ResponseDetails: + if not args.stream: + raise NotImplementedError("Not implemented for non-streaming") + + api_url = "http://localhost:26500/generate" + headers = {"User-Agent": "Benchmark Client"} + pload = { + "prompt": input_tokens, + "n": 1, + "use_beam_search": False, + "temperature": 1.0, + "top_p": 0.9, + "max_tokens": max_new_tokens, + "ignore_eos": False, + "stream": args.stream, + } + + def clear_line(n: int = 1) -> None: + LINE_UP = "\033[1A" + LINE_CLEAR = "\x1b[2K" + for _ in range(n): + print(LINE_UP, end=LINE_CLEAR, flush=True) + + def get_streaming_response( + response: requests.Response, time_last_token + ) -> Iterable[List[str]]: + for chunk in response.iter_lines( + chunk_size=8192, decode_unicode=False, delimiter=b"\0" + ): + if chunk: + data = json.loads(chunk.decode("utf-8")) + output = data["text"][0] + time_now = time.time() + yield output, time_now - time_last_token + time_last_token = time_now + + # For non-streaming, but currently non-streaming is not fully implemented + def get_response(response: requests.Response) -> List[str]: + data = json.loads(response.content) + output = data["text"] + return output + + token_gen_time = [] + start_time = time.time() + response = requests.post(api_url, headers=headers, json=pload, stream=args.stream) + for h, t in get_streaming_response(response, start_time): + output = h + token_gen_time.append(t) + + return ResponseDetails( + generated_tokens=output, + prompt=input_tokens, + start_time=start_time, + end_time=time.time(), + model_time=0, + token_gen_time=token_gen_time, + ) + + +# client talks with openai api +def call_openai( + input_tokens: str, max_new_tokens: int, args: argparse.Namespace +) -> ResponseDetails: + + api_url = args.openai_api_url + headers = { + "User-Agent": "Benchmark Client", + "Content-Type": "application/json", + "Authorization": f"Bearer {args.openai_api_key}" + } + + pload = { + "prompt": input_tokens, + "model": args.model, + "n": 1, + "use_beam_search": False, + "temperature": 1.0, + "top_p": 0.9, + "max_tokens": max_new_tokens, + "ignore_eos": False, + "stream": args.stream, + } + + def clear_line(n: int = 1) -> None: + LINE_UP = "\033[1A" + LINE_CLEAR = "\x1b[2K" + for _ in range(n): + print(LINE_UP, end=LINE_CLEAR, flush=True) + + def get_streaming_response( + response: requests.Response, time_last_token + ) -> Iterable[List[str]]: + for chunk in response.iter_lines( + chunk_size=8192, decode_unicode=False, delimiter=b"data:" + ): + if chunk: + plain=chunk.decode("utf-8") + if plain.strip() == "[DONE]": + continue + data = json.loads(plain) + output = data["choices"][0]["text"] + time_now = time.time() + yield output, time_now - time_last_token + time_last_token = time_now + + # For non-streaming, but currently non-streaming is not fully implemented + def get_response(response: requests.Response) -> List[str]: + data = json.loads(response.content) + output = data["choices"][0]["text"] + return output + + token_gen_time = [] + start_time = time.time() + #response = requests.post(api_url, headers=headers, json=pload, stream=False) + response = requests.post(api_url, headers=headers, json=pload, stream=args.stream) + if args.stream: + output = "" + for h, t in get_streaming_response(response, start_time): + output += h + token_gen_time.append(t) + else: + output = get_response(response) + + return ResponseDetails( + generated_tokens=output, + prompt=input_tokens, + start_time=start_time, + end_time=time.time(), + model_time=0, + token_gen_time=token_gen_time, + ) + + +def call_aml( + input_tokens: str, + max_new_tokens: int, + args: argparse.Namespace, + start_time: Union[None, float] = None, +) -> ResponseDetails: + if args.stream: + raise NotImplementedError("Not implemented for streaming") + + headers = { + "Content-Type": "application/json", + "Authorization": ("Bearer " + args.aml_api_key), + "azureml-model-deployment": args.deployment_name, + } + pload = { + "input_data": { + "input_string": [ + input_tokens, + ], + "parameters": { + "max_tokens": max_new_tokens, + "return_full_text": False, + }, + } + } + + def get_response(response: requests.Response) -> List[str]: + data = json.loads(response.content) + try: + output = data[0]["0"] + except (KeyError, TypeError): + try: + output = data[0] + except (KeyError, TypeError): + output = data + return output + + token_gen_time = [] + response = None + if start_time is None: + start_time = time.time() + while True: + try: # Sometimes the AML endpoint will return an error, so we send the request again + response = requests.post(args.aml_api_url, headers=headers, json=pload, timeout=180) + output = get_response(response) + break + except Exception as e: + print(f"Connection failed with {e}. Retrying AML request") + # make sure response exist before we call it + if response: + print(f"{response.status_code}:{response.content}") + + return ResponseDetails( + generated_tokens=output, + prompt=input_tokens, + start_time=start_time, + end_time=time.time(), + model_time=0, + token_gen_time=token_gen_time, + ) + + +def _run_parallel( + barrier: Union[threading.Barrier, multiprocessing.Barrier], + query_queue: Union[queue.Queue, multiprocessing.Queue], + result_queue: Union[queue.Queue, multiprocessing.Queue], + args: argparse.Namespace, +): + pid = os.getpid() + session_id = f"test_session_p{pid}_t{threading.get_ident()}" + + event_loop = asyncio.new_event_loop() + asyncio.set_event_loop(event_loop) + + backend_call_fns = {"fastgen": call_fastgen, "vllm": call_vllm, "aml": call_aml, "openai": call_openai} + call_fn = backend_call_fns[args.backend] + + barrier.wait() + + for _ in range(args.warmup): + print(f"warmup queue size: {query_queue.qsize()} ({pid})", flush=True) + input_tokens, req_max_new_tokens = query_queue.get(timeout=1.0) + _ = call_fn(input_tokens, req_max_new_tokens, args) + + barrier.wait() + + time.sleep(random.uniform(0, args.num_clients) * 0.01) + try: + while True: + print(f"queue size: {query_queue.qsize()} ({pid})", flush=True) + input_tokens, req_max_new_tokens = query_queue.get(timeout=1.0) + + r = call_fn(input_tokens, req_max_new_tokens, args) + + result_queue.put(r) + except queue.Empty: + print(f"queue is empty ({pid})") + + print(f"Worker ({pid}) finished. session_id: {session_id}") + + +def run_client(args): + """ + Run MII client for benchmarking. The scenario is a bit complicated: + 1. The main process puts `num_requests` queries into the input queue + 2. Each client runs `warmup` iterations () taking the queries from the input queue + 3. --- barrier --- + 4. The main process marks the start time + 5a. All clients send `num_requests' query in total and put the results into the result queue + 5b. The main process takes the results from the result queue (in parallel with 5a) + 6. The main process marks the end time after receiving `num_requests' results + """ + + if args.use_thread: + runnable_cls = threading.Thread + barrier_cls = threading.Barrier + queue_cls = queue.Queue + else: + runnable_cls = multiprocessing.Process + barrier_cls = multiprocessing.Barrier + queue_cls = multiprocessing.Queue + + barrier = barrier_cls(args.num_clients + 1) + query_queue = queue_cls() + result_queue = queue_cls() + + processes = [ + runnable_cls( + target=_run_parallel, + args=( + barrier, + query_queue, + result_queue, + args, + ), + ) + for i in range(args.num_clients) + ] + for p in processes: + p.start() + + tokenizer = AutoTokenizer.from_pretrained(args.model) + + # make sure max_prompt_length is longer than the target prompt length + args.max_prompt_length = max(args.max_prompt_length, int(args.mean_prompt_length * 3)) + # check if the all_text is longer than the max prompt length, if not expand it + global all_text + while len(tokenizer.tokenize(all_text)) < args.max_prompt_length: + all_text += all_text + + query_generator = RandomQueryGenerator(all_text, tokenizer, seed=42) + request_text = query_generator.get_random_request_text( + args.mean_prompt_length, + args.mean_prompt_length * args.prompt_length_var, + args.max_prompt_length, + args.num_requests + args.warmup * args.num_clients, + ) + + for t in request_text: + # Set max_new_tokens following normal distribution + req_max_new_tokens = int( + np.random.normal( + args.mean_max_new_tokens, + args.max_new_tokens_var * args.mean_max_new_tokens, + ) + ) + query_queue.put((t, req_max_new_tokens)) + + # Tokenizers must be initialized after fork. + # So we need to fork before putting inputs to the queue. + # We need this barrier to stop child processse from taking inputs before the main process puts them + barrier.wait() + # This barrier is to make sure that all clients have finished warmup + barrier.wait() + + response_details = [] + while len(response_details) < args.num_requests: + res = result_queue.get() + # vLLM returns concatinated tokens + if args.backend == "vllm": + all_tokens = tokenizer.tokenize(res.generated_tokens) + res.generated_tokens = all_tokens[len(tokenizer.tokenize(res.prompt)) :] + response_details.append(res) + + return response_details + + +if __name__ == "__main__": + args = parse_args(client_args=True) + + for client_args in get_args_product(args, which=CLIENT_PARAMS): + response_details = run_client(client_args) + + print_summary(client_args, response_details) diff --git a/benchmarks/inference/mii/src/defaults.py b/benchmarks/inference/mii/src/defaults.py new file mode 100644 index 000000000..89255dfa6 --- /dev/null +++ b/benchmarks/inference/mii/src/defaults.py @@ -0,0 +1,60 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +ARG_DEFAULTS = { + "model": "meta-llama/Llama-2-7b-hf", + "deployment_name": "benchmark-deployment", + "tp_size": 1, + "max_ragged_batch_size": 768, + "num_replicas": 1, + "max_prompt_length": 4000, + "mean_prompt_length": 2600, + "mean_max_new_tokens": 60, +} + +MODEL_DEFAULTS = { + "meta-llama/Llama-2-7b-hf": { + "max_prompt_length": 4000, + "mean_prompt_length": (1200, 2600), + "mean_max_new_tokens": (60, 128), + "tp_size": 1, + }, + "meta-llama/Llama-13b-hf": { + "max_prompt_length": 4000, + "mean_prompt_length": (1200, 2600), + "mean_max_new_tokens": (60, 128), + "tp_size": (1, 2, 4), + }, + "meta-llama/Llama-2-70b-hf": { + "max_prompt_length": 4000, + "mean_prompt_length": (1200, 2600), + "mean_max_new_tokens": (60, 128), + "tp_size": (4, 8), + }, + "tiiuae/falcon-40B": { + "max_prompt_length": 2000, + "mean_prompt_length": (1200, 1900), + "mean_max_new_tokens": (60, 128), + "tp_size": (2, 4), + }, + "tiiuae/falcon-180B": { + "max_prompt_length": 2000, + "mean_prompt_length": (1200, 1900), + "mean_max_new_tokens": (60, 128), + "tp_size": 8, + }, + "microsoft/phi-2": { + "max_prompt_length": 2000, + "mean_prompt_length": (1200, 1900), + "mean_max_new_tokens": (60, 128), + "tp_size": 1, + }, + "mistralai/Mixtral-8x7B-v0.1": { + "max_prompt_length": 4000, + "mean_prompt_length": (1200, 2600), + "mean_max_new_tokens": (60, 128), + "tp_size": 4, + }, +} diff --git a/benchmarks/inference/mii/src/plot_effective_throughput.py b/benchmarks/inference/mii/src/plot_effective_throughput.py new file mode 100644 index 000000000..2370a2e1e --- /dev/null +++ b/benchmarks/inference/mii/src/plot_effective_throughput.py @@ -0,0 +1,215 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +import argparse +from pathlib import Path +import glob +import matplotlib.pyplot as plt +import numpy as np +import pandas as pd + +from postprocess_results import read_json, get_tokenizer, get_result_sets + + +def get_args(): + parser = argparse.ArgumentParser() + parser.add_argument("--backend", type=str, choices=["fastgen", "vllm", "openai"], default=["fastgen", "vllm"], \ + nargs="+", help="Specify the backends to generate plots for") + parser.add_argument("--log_dir", type=Path, default="./results") + parser.add_argument("--model", type=str) + parser.add_argument("--out_dir", type=Path, default="./plots/goodtput") + parser.add_argument("--sla_prompt_tokens_per_sec", type=int, default=512, help="SLA prompt tokens per second") + parser.add_argument("--sla_gen_tokens_per_sec", type=int, default=[1, 2, 3, 4, 6, 8], nargs="+", help="SLA generation tokens/s targets") + parser.add_argument("--ema_span", type=int, default=16, help="EMA span") + args = parser.parse_args() + return args + + +def check_token_latency_step(response_details, token_index): + P50_token_latency = np.percentile( + [ + r.token_gen_time[token_index] + for r in response_details + if len(r.token_gen_time) > token_index + ], + 50, + ) + P90_token_latency = np.percentile( + [ + r.token_gen_time[token_index] + for r in response_details + if len(r.token_gen_time) > token_index + ], + 90, + ) + P99_token_latency = np.percentile( + [ + r.token_gen_time[token_index] + for r in response_details + if len(r.token_gen_time) > token_index + ], + 99, + ) + + return P50_token_latency, P90_token_latency, P99_token_latency + + +def validate_token_cum_latency_SLA(response_detail, sla_token_gen): + cumsum_latencies = np.cumsum(np.array(response_detail.token_gen_time[1:])) + return all( + [ + cumsum_latencies[i] <= (1 / sla_token_gen) * (i + 1) + for i in range(len(cumsum_latencies)) + ] + ) + + +def validate_token_ema_latency_SLA(response_detail, sla_token_gen, ema_span): + ema_latency = ( + pd.Series(response_detail.token_gen_time[1:]) + .ewm(span=ema_span) + .mean() + .values.tolist() + ) + return all([t < 1.0 / sla_token_gen for t in ema_latency]) + + +def validate_prompt_latency_SLA(response_detail, sla_token_gen, f, sla_prompt_tokens_per_sec ): + tokenizer = get_tokenizer(args.model) + prompt_length = len(tokenizer.tokenize(response_detail.prompt)) + prompt_latency_SLA = prompt_length / sla_prompt_tokens_per_sec + if prompt_latency_SLA < response_detail.token_gen_time[0]: + return False + + if len(response_detail.token_gen_time) == 1: + return True + + return f[0](response_detail, sla_token_gen, *f[1]) + + +def calc_throughput(response_details): + start_time = min([r.start_time for r in response_details]) + end_time = max([r.end_time for r in response_details]) + return len(response_details) / (end_time - start_time) + + +def extract_values(file_pattern, sla_token_gen, validate_func, sla_prompt_tokens_per_sec): + files = glob.glob(file_pattern) + print(f"Found {len(files)} files") + goodputs = {} + good_ratios = {} + for f in files: + prof_args, response_details = read_json(f) + client_num = prof_args["num_clients"] + num_req_ok = len( + [ + r + for r in response_details + if validate_prompt_latency_SLA(r, sla_token_gen, validate_func, sla_prompt_tokens_per_sec) + ] + ) + goodputs[client_num] = calc_throughput(response_details) * ( + num_req_ok / len(response_details) + ) + good_ratios[client_num] = num_req_ok / len(response_details) + + return goodputs, good_ratios + + +def output_charts(args, model, tp_size, bs, replicas, sla_token_gen, prompt, gen, log_dir, out_dir): + if not log_dir.exists(): + print(f"Log directory {log_dir} does not exist") + return + + if not out_dir.exists(): + out_dir.mkdir(parents=True, exist_ok=True) + + print( + f"Model: {model} Prompt: {prompt}, Generation: {gen}, TP: {tp_size} sla_token_gen: {sla_token_gen}" + ) + + result_file_pattern = f"{model}-tp{tp_size}-bs{bs}-replicas{replicas}-prompt{prompt}-gen{gen}-clients*.json" + + validate_funcs = [ + (validate_token_cum_latency_SLA, (), "cum"), + (validate_token_ema_latency_SLA, (args.ema_span,), f"ema{args.ema_span}"), + ] + + plt_cfg = {'vllm': {'label': 'vLLM', 'marker': 'x', 'color': 'orange'},\ + 'fastgen': {'label': 'DeepSpeed-FastGen', 'marker': 'o', 'color': 'blue'}, \ + 'openai': {'label': 'openai-API', 'marker': '+', 'color': 'red'} + } + + for f in validate_funcs: + plt.figure() + + for backend in args.backend: + file_pattern = f"{log_dir}/{backend}/{result_file_pattern}" + goodputs, good_ratios = extract_values( + file_pattern, sla_token_gen, f, args.sla_prompt_tokens_per_sec + ) + client_num_list = sorted(list(goodputs.keys())) + goodputs_list = [goodputs[client_num] for client_num in client_num_list] + + # Plotting the scatter plot + plt.scatter( + client_num_list, + goodputs_list, + label=plt_cfg[backend]['label'], + marker=plt_cfg[backend]['marker'], + color=plt_cfg[backend]['color'], + ) + + fit_x_list = np.arange(min(client_num_list), max(client_num_list), 0.1) + fit_model = np.polyfit(client_num_list, goodputs_list, 4) + model_fn = np.poly1d(fit_model) + plt.plot( + fit_x_list, + model_fn(fit_x_list), + alpha=0.5, + linestyle="--", + color=plt_cfg[backend]['color'], + ) + + title = ( + f"Effective throughput (SLA prompt: {args.sla_prompt_tokens_per_sec} tokens/s, generation: {sla_token_gen} tokens/s)\n" + + f"Model: {model} Prompt: {prompt}, Generation: {gen}, TP: {tp_size}" + ) + plt.title(title, fontsize=10) + plt.xlabel("Number of clients", fontsize=10) + plt.ylabel("Effective throughput (queries/s)", fontsize=10) + plt.ylim(bottom=-0.05) + plt.legend() + plt.grid(True) + out_file = ( + out_dir + / f"{model}_SLAp{args.sla_prompt_tokens_per_sec}g{sla_token_gen}_tp{tp_size}_b{bs}_p{prompt}g{gen}_{f[2]}.png" + ) + plt.savefig(out_file) + plt.clf() + print(f"Saved {out_file}") + + +if __name__ == "__main__": + args = get_args() + + assert "aml" not in args.backend, "Effective throughput analysis is not supported for AML." + + result_params = get_result_sets(args) + + for model, tp_size, bs, replicas, prompt, gen in result_params: + for sla_token_gen in args.sla_gen_tokens_per_sec: + output_charts( + args=args, + model=model, + tp_size=tp_size, + bs=bs, + replicas=replicas, + sla_token_gen=sla_token_gen, + prompt=prompt, + gen=gen, + log_dir=args.log_dir, + out_dir=args.out_dir, + ) diff --git a/benchmarks/inference/mii/src/plot_latency_percentile.py b/benchmarks/inference/mii/src/plot_latency_percentile.py new file mode 100644 index 000000000..daeb8cc5a --- /dev/null +++ b/benchmarks/inference/mii/src/plot_latency_percentile.py @@ -0,0 +1,127 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +import argparse +import glob +import re +import os +from pathlib import Path +import matplotlib.pyplot as plt +import numpy as np +import itertools + +from postprocess_results import read_json, get_token_latency, get_result_sets + +def get_args(): + parser = argparse.ArgumentParser() + parser.add_argument("--backend", type=str, choices=["fastgen", "vllm"], default=["fastgen", "vllm"], \ + nargs="+", help="Specify the backends to generate plots for") + parser.add_argument("--log_dir", type=Path, default="./results") + parser.add_argument( + "--out_dir", type=Path, default="./plots/percentile_token_latency" + ) + parser.add_argument("--skip_head_token_num", type=int, default=1, help="Specify number of head tokens to skip") + parser.add_argument("--skip_request_num", type=int, default=1, help="Specify number of requests to skip") + args = parser.parse_args() + return args + + +def extract_values(args, file_pattern): + files = glob.glob(file_pattern) + + print(f"Found {len(files)}") + print("\n".join(files)) + + latencies = {} + for f in files: + prof_args, response_details = read_json(f) + client_num = prof_args["num_clients"] + + response_details.sort(key=lambda r: r.start_time) + + response_details = response_details[args.skip_request_num:-args.skip_request_num] + token_latencies = [ + r.token_gen_time[args.skip_head_token_num:-1] for r in response_details + ] + flat_latency_list = list(itertools.chain(*token_latencies)) + latencies[client_num] = flat_latency_list + return latencies + + +def output_charts(args, model, tp_size, bs, replicas, prompt, gen, log_dir, out_dir): + if not log_dir.exists(): + print(f"Log directory {log_dir} does not exist") + return + + if not out_dir.exists(): + out_dir.mkdir(parents=True, exist_ok=True) + + result_file_pattern = f"{model}-tp{tp_size}-bs{bs}-replicas{replicas}-prompt{prompt}-gen{gen}-clients*.json" + + plt_cfg = {'vllm': {'bar_x': [1, 2.5, 4], 'label': 'vLLM', 'color': 'orange'},\ + 'fastgen': {'bar_x': [1.3, 2.8, 4.3], 'label': 'DeepSpeed-FastGen', 'color': 'blue'}} + + latencies = {} + client_num_dict = {} + for backend in args.backend: + file_pattern = f"{log_dir}/{backend}/{result_file_pattern}" + latencies[backend] = extract_values(args, file_pattern) + client_num_dict[backend] = set(sorted(list(latencies[backend].keys()))) + + # Intersection of clients across all backends + client_num_set = set() + for backend in args.backend: + if not client_num_set: + client_num_set = client_num_dict[backend] + else: + client_num_set = client_num_set.intersection(client_num_dict[backend]) + + for client_num in client_num_set: + plt.figure() + percentile = 95 + + for backend in args.backend: + print(f"Generating data for plot, {backend=}") + P50_val = np.percentile(latencies[backend][client_num], 50) + P90_val = np.percentile(latencies[backend][client_num], 90) + P95_val = np.percentile(latencies[backend][client_num], 95) + y = [P50_val, P90_val, P95_val] + plt.bar(plt_cfg[backend]['bar_x'], y, width=0.3, label=plt_cfg[backend]['label'], align="center", color=plt_cfg[backend]['color']) + + out_file = ( + out_dir + / f"p{percentile}_token_latency_{model}_c{client_num}_tp{tp_size}_p{prompt}g{gen}.png" + ) + + plt.ylabel("Latency (s)", fontsize=14) + plt.legend(loc=2) + + label_x = ["P50", "P90", "P95"] + plt.xticks([1, 2.5, 4], label_x) + + plt.title(f"Model: {model}, Clients: {client_num}, Prompt: {prompt}, Gen: {gen}, TP: {tp_size}") + plt.savefig(out_file) + print(f"Saved {out_file}") + + +if __name__ == "__main__": + args = get_args() + + assert "aml" not in args.backend, "Percentile latency analysis is not supported for AML." + + result_params = get_result_sets(args) + + for model, tp_size, bs, replicas, prompt, gen in result_params: + output_charts( + args=args, + model=model, + tp_size=tp_size, + bs=bs, + replicas=replicas, + prompt=prompt, + gen=gen, + log_dir=args.log_dir, + out_dir=args.out_dir, + ) diff --git a/benchmarks/inference/mii/src/plot_repl_scale.py b/benchmarks/inference/mii/src/plot_repl_scale.py new file mode 100644 index 000000000..074bfb81a --- /dev/null +++ b/benchmarks/inference/mii/src/plot_repl_scale.py @@ -0,0 +1,117 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +import glob +import matplotlib.pyplot as plt +import argparse +from pathlib import Path +import numpy as np +from collections import defaultdict + +from postprocess_results import read_json, get_summary, get_result_sets + +def get_args(): + parser = argparse.ArgumentParser() + parser.add_argument("--backend", type=str, choices=["fastgen"], default=["fastgen"], \ + nargs=1, help="Specify the single backend to generate plots for") + parser.add_argument("--clients_per_replica", type=int, required=False, default=None, help="Optional \ + argument to specify explicit clients/replica to generate plot for") + parser.add_argument("--log_dir", type=Path, default="./results") + parser.add_argument("--out_dir", type=Path, default="./plots/repl_scale") + args = parser.parse_args() + return args + + +def extract_values(file_pattern): + files = glob.glob(file_pattern) + + clients = [] + throughputs = [] + latencies = [] + for f in files: + prof_args, response_details = read_json(f) + summary = get_summary(prof_args, response_details) + clients.append(prof_args["num_clients"]) + throughputs.append(summary.throughput) + latencies.append(summary.latency) + + return clients, throughputs, latencies + + +def output_charts(args, model, tp_size, bs, replica_nums, prompt, gen, log_dir, out_dir): + if not log_dir.exists(): + print(f"Log directory {log_dir} does not exist") + return + + if not out_dir.exists(): + out_dir.mkdir(parents=True, exist_ok=True) + + throughputs = {} + for repl in replica_nums: + result_file_pattern = f"{model}-tp{tp_size}-bs{bs}-replicas{repl}-prompt{prompt}-gen{gen}-clients*.json" + mii_file_pattern = f"{log_dir}/fastgen/{result_file_pattern}" + print(f"Looking for {mii_file_pattern}") + clients, mii_throughputs, mii_latencies = extract_values(mii_file_pattern) + + for c, th in zip(clients, mii_throughputs): + client_per_repl = c // repl + if client_per_repl not in throughputs: + throughputs[client_per_repl] = [] + print(f"Throughput for {client_per_repl} clients: {th}") + throughputs[client_per_repl].append(th) + + for c in throughputs: + if args.clients_per_replica != None and args.clients_per_replica != c: + continue + if len(throughputs[c]) == len(replica_nums): + print(f"Generating figure for {c} clients/replica.") + # Plotting the scatter plot + plt.figure() + + plt.bar(replica_nums, throughputs[c], color="blue", alpha=0.9) + + fit_x_list = np.arange(min(replica_nums), max(replica_nums), 0.1) + mii_fit_model = np.polyfit(replica_nums, throughputs[c], 1) + mii_model_fn = np.poly1d(mii_fit_model) + plt.plot(fit_x_list, mii_model_fn(fit_x_list), color="blue", linestyle="--") + + plt.title( + f"Model: {model}, Prompt: {prompt}, Generation: {gen}\n\ + TP: {tp_size}, Clients/Replica: {c}" + ) + plt.xlabel("Number of replicas", fontsize=14) + plt.ylabel("Throughput (queries/s)", fontsize=14) + plt.grid(True) + plt.tight_layout() + out_file = out_dir / f"repl_scale_{model}_tp{tp_size}_p{prompt}g{gen}_c_per_r{c}.png" + plt.savefig(out_file) + + +if __name__ == "__main__": + args = get_args() + + replica_sets = defaultdict(lambda: defaultdict(set)) + result_params = get_result_sets(args) + + # Find all replicas across same sets + for model, tp_size, bs, replicas, prompt, gen in result_params: + key = f'{model}_{tp_size}_{bs}_{prompt}_{gen}' + replica_sets[key]['config'].add((model, tp_size, bs, prompt, gen)) + replica_sets[key]['replicas'].add(int(replicas)) + + for replica_set in replica_sets.values(): + for model, tp_size, bs, prompt, gen in replica_set['config']: + replica_nums = sorted(replica_set['replicas']) + output_charts( + args=args, + model=model, + tp_size=tp_size, + bs=bs, + replica_nums=replica_nums, + prompt=prompt, + gen=gen, + log_dir=args.log_dir, + out_dir=args.out_dir, + ) diff --git a/benchmarks/inference/mii/src/plot_th_lat.py b/benchmarks/inference/mii/src/plot_th_lat.py new file mode 100644 index 000000000..18f115206 --- /dev/null +++ b/benchmarks/inference/mii/src/plot_th_lat.py @@ -0,0 +1,182 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +import argparse +import glob +import os +import re +import yaml +from pathlib import Path + +import matplotlib.pyplot as plt +import numpy as np + +from postprocess_results import read_json, get_summary, get_result_sets + + +def get_args(): + parser = argparse.ArgumentParser() + parser.add_argument("--data_dirs", type=str, nargs="+", \ + help="Specify the data directories to generate plots for") + parser.add_argument("--out_dir", type=Path, default="./plots/throughput_latency") + parser.add_argument("--model_name", type=str, default="", help="Optional model name override") + args = parser.parse_args() + return args + + +def extract_values(file_pattern): + files = glob.glob(file_pattern) + + print(f"Found {len(files)}") + print("\n".join(files)) + + clients = [] + throughputs = [] + latencies = [] + extra_args = {} + for f in files: + prof_args, response_details = read_json(f) + summary = get_summary(prof_args, response_details) + clients.append(prof_args["num_clients"]) + throughputs.append(summary.throughput) + latencies.append(summary.latency) + + return clients, throughputs, latencies, prof_args + + +def output_charts(model, tp_size, bs, replicas, prompt, gen, out_dir): + out_dir.mkdir(parents=True, exist_ok=True) + + result_file_pattern = f"{model}-tp{tp_size}-bs{bs}-replicas{replicas}-prompt{prompt}-gen{gen}-clients*.json" + + plt.figure() + + for data_dir in args.data_dirs: + file_pattern = f"{data_dir}/{result_file_pattern}" + _, throughputs, latencies, prof_args = extract_values(file_pattern) + + kwargs = {} + kwargs["label"] = str(data_dir) + kwargs["marker"] = "o" + kwargs["linestyle"] = "--" + + fit_kwargs = {} + fit_kwargs["linestyle"] = "--" + plot_fit_line = True + + polyfit_degree = 3 + plot_fn = plt.scatter + + plot_config = glob.glob(f"{data_dir}/plot_config.yaml") + + latencies = sorted(latencies) + throughputs = sorted(throughputs) + + if plot_config: + plot_config = plot_config[0] + plot_config = yaml.safe_load(Path(plot_config).read_text()) + plot_keys = plot_config.keys() + + # If x_max specified, clip data + if "x_max" in plot_keys: + for i, throughput in enumerate(throughputs): + if throughput > plot_config["x_max"]: + latencies = latencies[:i] + throughputs = throughputs[:i] + break + + # If y_max specified, clip data + if "y_max" in plot_keys: + for i, latency in enumerate(latencies): + if latency > plot_config["y_max"]: + latencies = latencies[:i] + throughputs = throughputs[:i] + break + + # Set polyfit degree + polyfit_degree = plot_config.get("polyfit_degree", polyfit_degree) + + # Select plot type + if polyfit_degree == 0: + plot_fit_line = False + + # Main plot kwargs + if "label" in plot_keys: + kwargs["label"] = plot_config["label"] + if "marker" in plot_keys: + kwargs["marker"] = plot_config["marker"] + if "color" in plot_keys: + kwargs["color"] = plot_config["color"] + if "linestyle" in plot_keys: + kwargs["linestyle"] = plot_config["linestyle"] + + # Fit line kwargs + if "color" in plot_keys: + fit_kwargs["color"] = plot_config["color"] + if "linestyle" in plot_keys: + fit_kwargs["linestyle"] = plot_config["linestyle"] + + if len(throughputs) > 0: + plot = plot_fn( + throughputs, + latencies, + **kwargs, + ) + + if plot_fn == plt.plot: + plot_color = plot[0].get_color() + else: + plot_color = plot.get_facecolor()[0] + + if not "color" in fit_kwargs.keys(): + fit_kwargs["color"] = plot_color + + fit_x_list = np.arange(min(throughputs), max(throughputs), 0.01) + data_model = np.polyfit(throughputs, latencies, polyfit_degree) + model_fn = np.poly1d(data_model) + x = fit_x_list if plot_fit_line else throughputs + y = model_fn(fit_x_list) if plot_fit_line else latencies + plt.plot( + x, + y, + alpha=0.5, + **fit_kwargs, + ) + + # Generic plot formatting + if args.model_name: + model_label = args.model_name + else: + model_label = model + + plt.title(f"Model: {model_label}, Prompt: {prompt}, Generation: {gen}, TP: {tp_size}") + plt.xlabel("Throughput (queries/s)", fontsize=14) + plt.ylabel("Latency (s)", fontsize=14) + plt.legend() + plt.grid(True) + plt.tight_layout() + out_file = ( + out_dir + / f"{model}-tp{tp_size}-bs{bs}-replicas{replicas}-prompt{prompt}-gen{gen}.png" + ) + print(f"Saving {out_file}") + plt.savefig(out_file) + + +if __name__ == "__main__": + args = get_args() + + result_params = get_result_sets(args) + + for model, tp_size, bs, replicas, prompt, gen in result_params: + output_charts( + model=model, + tp_size=tp_size, + bs=bs, + replicas=replicas, + prompt=prompt, + gen=gen, + out_dir=args.out_dir, + ) diff --git a/benchmarks/inference/mii/src/plot_tp_sizes.py b/benchmarks/inference/mii/src/plot_tp_sizes.py new file mode 100644 index 000000000..596a40de2 --- /dev/null +++ b/benchmarks/inference/mii/src/plot_tp_sizes.py @@ -0,0 +1,125 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +import glob +import matplotlib.pyplot as plt +import argparse +from pathlib import Path +import numpy as np +import re +from collections import defaultdict + +from postprocess_results import read_json, get_summary, get_result_sets + +def get_args(): + parser = argparse.ArgumentParser() + parser.add_argument("--backend", type=str, choices=["aml", "fastgen", "vllm"], default=["aml", "fastgen", "vllm"], \ + nargs=1, help="Specify the single backend to generate plots for") + parser.add_argument("--log_dir", type=Path, default="logs.release") + parser.add_argument("--out_dir", type=Path, default="./plots/tp_sizes") + args = parser.parse_args() + return args + + +def extract_values(file_pattern): + files = glob.glob(file_pattern) + + print(f"Found {len(files)}") + print("\n".join(files)) + + clients = [] + throughputs = [] + latencies = [] + for f in files: + prof_args, response_details = read_json(f) + summary = get_summary(prof_args, response_details) + clients.append(prof_args["num_clients"]) + throughputs.append(summary.throughput) + latencies.append(summary.latency) + + return clients, throughputs, latencies + + +def output_charts(args, model, tp_list, bs, replicas, prompt, gen, log_dir, out_dir): + if not log_dir.exists(): + print(f"Log directory {log_dir} does not exist") + return + + if not out_dir.exists(): + out_dir.mkdir(parents=True, exist_ok=True) + + # Plotting the scatter plot + plt.figure() + + for tp in tp_list: + result_file_pattern = f"{model}-tp{tp}-bs{bs}-replicas{replicas}-prompt{prompt}-gen{gen}-clients*.json" + file_pattern = f"{log_dir}/{args.backend[0]}/{result_file_pattern}" + _, throughputs, latencies = extract_values(file_pattern) + + if len(throughputs) == 0: + continue + + model_size = re.match('.*?(\d+[b|B|m|M])', model).groups()[0] + n_params = int(model_size[:-1]) + if model_size[-1].lower() == 'm': + # Scale n_params approriately for millions + n_params = n_params / 1000 + tflops_per_query = n_params * (int(prompt) + int(gen)) * 2 * 1e-3 + tflops = [th * tflops_per_query / tp for th in throughputs] + + plt.scatter( + tflops, latencies, label=f"TP={tp}", marker="o" + ) + fit_x_list = np.arange(min(tflops), max(tflops), 0.01) + fit_model = np.polyfit(tflops, latencies, 3) + model_fn = np.poly1d(fit_model) + plt.plot( + fit_x_list, + model_fn(fit_x_list), + alpha=0.5, + linestyle="--", + ) + + plt.title( + f"Model: {model}, Prompt: {prompt}, Generation: {gen}, TP: {tp_list}\n\ + Replicas: {replicas}, Backend: {args.backend[0]}" + ) + plt.xlabel("TFLOPs (per GPU)", fontsize=14) + plt.ylabel("Latency (s)", fontsize=14) + plt.legend() + plt.grid(True) + out_file = ( + out_dir + / f"tp_sizes_{model}_tp{'_'.join([str(tp) for tp in tp_list])}_p{prompt}g{gen}r{replicas}.png" + ) + plt.savefig(out_file) + + +if __name__ == "__main__": + args = get_args() + + tp_sets = defaultdict(lambda: defaultdict(set)) + result_params = get_result_sets(args) + + # Find all tp_sizes across same sets + for model, tp_size, bs, replicas, prompt, gen in result_params: + key = f'{model}_{bs}_{replicas}_{prompt}_{gen}' + tp_sets[key]['config'].add((model, bs, replicas, prompt, gen)) + tp_sets[key]['tp_list'].add(int(tp_size)) + + for tp_set in tp_sets.values(): + for model, bs, replicas, prompt, gen in tp_set['config']: + tp_list = sorted(tp_set['tp_list']) + output_charts( + args=args, + model=model, + tp_list=tp_list, + bs=bs, + replicas=replicas, + prompt=prompt, + gen=gen, + log_dir=args.log_dir, + out_dir=args.out_dir, + ) diff --git a/benchmarks/inference/mii/src/postprocess_results.py b/benchmarks/inference/mii/src/postprocess_results.py new file mode 100644 index 000000000..378925027 --- /dev/null +++ b/benchmarks/inference/mii/src/postprocess_results.py @@ -0,0 +1,200 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +import argparse +import json +import re +import os +from tabulate import tabulate +from dataclasses import dataclass +from functools import reduce +from pathlib import Path +from statistics import mean +from typing import List +from collections import defaultdict + +import numpy as np +from transformers import AutoTokenizer + + +tokenizer = None + + +@dataclass +class ResponseDetails: + generated_tokens: List[str] + prompt: str + start_time: float + end_time: float + model_time: float + token_gen_time: List[float] + + +@dataclass +class ProfilingSummary: + throughput: float + latency: float + token_gen_latency: float + first_token_latency: float + tokens_per_sec: float + + +def parse_args(): + parser = argparse.ArgumentParser(description="Postprocess results") + parser.add_argument("-i", "--input_path", type=Path, default="results.json") + + args = parser.parse_args() + return args + + +def get_tokenizer(model=None): + global tokenizer + if tokenizer is None: + if model==None: + tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-hf") + else: + tokenizer = AutoTokenizer.from_pretrained(model) + return tokenizer + + +def read_json(file_path): + with open(file_path, "r") as f: + data = json.load(f) + + args = data["args"] + + response_details = [] + for response in data["response_details"]: + response_details.append(ResponseDetails(**response)) + + return args, response_details + + +def get_summary(args, response_details): + num_clients = args["num_clients"] + + # Calculate latency and throughput using P95 latency + latency = mean([r.end_time - r.start_time for r in response_details]) + throughput = num_clients / latency + + tokens_per_sec = mean( + [ + (len(get_tokenizer(args["model"]).tokenize(r.prompt)) + + len(get_tokenizer(args["model"]).tokenize(r.generated_tokens)) if type(r.generated_tokens) == str + else len(r.generated_tokens)) + / (r.end_time - r.start_time) + for r in response_details + ] + ) + + # For non-streaming results, we don't have any token_gen_time information + first_token_latency = 0.0 + token_gen_latency = 0.0 + if response_details[0].token_gen_time: + first_token_latency = mean([r.token_gen_time[0] for r in response_details]) + token_gen_latency_flat = reduce( + list.__add__, + [ + r.token_gen_time[1:-1] + for r in response_details + if len(r.token_gen_time) > 2 + ], + ) + token_gen_latency = mean([t for t in token_gen_latency_flat]) + + return ProfilingSummary( + throughput, latency, token_gen_latency, first_token_latency, tokens_per_sec + ) + + +def get_token_latency( + response_details, percentile=None, variance=False, cumulative=False +): + req_latencies = [r.token_gen_time for r in response_details] + if cumulative: + req_latencies = [ + np.cumsum(np.array(r.token_gen_time)).tolist() for r in response_details + ] + max_gen_length = max([len(r.generated_tokens) for r in response_details]) + latency = [] + for i in range(max_gen_length): + if variance: + token_latency_step = np.var( + [latency[i] for latency in req_latencies if len(latency) > i] + ) + if percentile is None: + token_latency_step = [ + latency[i] for latency in req_latencies if len(latency) > i + ] + else: + token_latency_step = np.percentile( + [latency[i] for latency in req_latencies if len(latency) > i], + percentile, + ) + + latency.append(token_latency_step) + + return latency + + +def get_token_acc_latency(response_details, percentile=99): + return get_token_latency(response_details, percentile, cumulative=True) + + +if __name__ == "__main__": + args = parse_args() + prof_args, response_details = read_json(args.input_path) + + ps = get_summary(prof_args, response_details) + print( + f"Deployment: {prof_args['deployment_name']} Clients: {prof_args['num_clients']}, " + + f"Query throughput: {ps.throughput:.3f} queries/s, " + + f"Token throughput (total): {ps.tokens_per_sec:.3f} tokens/s, " + + f"Query latency: {ps.latency:.3f} s, " + + f"Token generation latency: {ps.token_gen_latency:.3f} s/token, " + + f"First token received: {ps.first_token_latency:.3f} s" + ) + +def get_result_sets(args: argparse.Namespace) -> set(): + result_params = None + result_re = re.compile( + r"(.+)-tp(\d+)-bs(\d+)-replicas(\d+)-prompt(\d+)-gen(\d+)-clients.*.json" + ) + + data_sets = defaultdict(set) + + if hasattr(args, "data_dirs"): + data_set_dirs = args.data_dirs + elif hasattr(args, "backend"): + data_set_dirs = args.backend + + # Generate data sets + for data in data_set_dirs: + if hasattr(args, "log_dir"): + os_path = os.path.join(args.log_dir, data) + else: + os_path = os.path.join(data) + + for f in os.listdir(os_path): + match = result_re.match(f) + if match: + data_sets[data].add(match.groups()) + + # Intersection between all sets + for data_set in data_sets.values(): + if result_params == None: + result_params = data_set + else: + result_params = result_params.intersection(data_set) + + # Warning messages about skipped sets + for key, data_set in data_sets.items(): + difference = data_set.difference(result_params) + if difference: + print(f"WARNING: data {key} has result combinations that are not present in all data sets:") + print(tabulate(difference, headers=["model", "tp_size", "bs", "replicas", "prompt", "gen"])) + print("") + + return result_params diff --git a/benchmarks/inference/mii/random_query_generator.py b/benchmarks/inference/mii/src/random_query_generator.py similarity index 72% rename from benchmarks/inference/mii/random_query_generator.py rename to benchmarks/inference/mii/src/random_query_generator.py index b8442af4f..eca16d8ff 100644 --- a/benchmarks/inference/mii/random_query_generator.py +++ b/benchmarks/inference/mii/src/random_query_generator.py @@ -1,7 +1,12 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +import numpy as np import torch import random -import numpy as np -import time + class RandomQueryGenerator: def __init__(self, input_text, tokenizer, seed): @@ -14,9 +19,9 @@ def __init__(self, input_text, tokenizer, seed): def get_random_request_text(self, length, variance, max_length, batch): request_text = [] - tokenized_input = self.tokenizer.batch_encode_plus([self.input_text], - return_tensors="pt", - padding=False) + tokenized_input = self.tokenizer.batch_encode_plus( + [self.input_text], return_tensors="pt", padding=False + ) offset = list(range(512)) random.shuffle(offset) @@ -25,6 +30,6 @@ def get_random_request_text(self, length, variance, max_length, batch): # Set max_new_tokens following normal distribution with mean=max_new_tokens and std=0.3*max_new_tokens req_prompt_length = min(int(np.random.normal(length, variance)), max_length) - text = self.tokenizer.decode(text_ids[i:req_prompt_length+i]) + text = self.tokenizer.decode(text_ids[i : req_prompt_length + i]) request_text.append(text) return request_text diff --git a/benchmarks/inference/mii/sample_input.py b/benchmarks/inference/mii/src/sample_input.py similarity index 99% rename from benchmarks/inference/mii/sample_input.py rename to benchmarks/inference/mii/src/sample_input.py index 77d02af5f..bae18ce62 100644 --- a/benchmarks/inference/mii/sample_input.py +++ b/benchmarks/inference/mii/src/sample_input.py @@ -1,8 +1,12 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team # This is a sample input consisting of: # Code & Text -all_text = '''Deep learning involves the use of neural networks, which are computational models inspired by the structure and functioning of the human brain. These networks consist of interconnected nodes called neurons. Each neuron takes input, performs a computation, and produces an output. +all_text = """Deep learning involves the use of neural networks, which are computational models inspired by the structure and functioning of the human brain. These networks consist of interconnected nodes called neurons. Each neuron takes input, performs a computation, and produces an output. During training, the neural network learns to make accurate predictions by adjusting its internal parameters. This adjustment is done using an optimization algorithm called gradient descent. Gradient descent calculates the gradients of a loss function, which measures the discrepancy between the predicted output of the network and the desired output. These gradients indicate the direction and magnitude of parameter updates that will minimize the loss. The learning rate is an important hyperparameter in gradient descent. It determines the step size taken during parameter updates. A higher learning rate can lead to faster convergence, but it risks overshooting the optimal solution. On the other hand, a lower learning rate may converge more slowly, but it can result in more precise updates. Activation functions are applied to the output of each neuron in a neural network. They introduce non-linearities, enabling the network to learn complex patterns and relationships in the data. Popular activation functions include the rectified linear unit (ReLU), sigmoid, and hyperbolic tangent (tanh). @@ -218,4 +222,4 @@ def top_p_sampling(self, logits, p=0.9): print("Top-k Sampling:", top_k_text) print("Top-p Sampling:", top_p_text) Make sure to adjust the server_url with the appropriate URL of your HTTP server, and ensure that the server is running and accessible before making requests through the API. - ''' \ No newline at end of file + """ diff --git a/benchmarks/inference/mii/src/server.py b/benchmarks/inference/mii/src/server.py new file mode 100644 index 000000000..6d3c1cd69 --- /dev/null +++ b/benchmarks/inference/mii/src/server.py @@ -0,0 +1,141 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +import argparse +import subprocess +import time + + +try: + from .utils import parse_args, SERVER_PARAMS +except ImportError: + from utils import parse_args, SERVER_PARAMS + + +def start_server(args: argparse.Namespace) -> None: + start_server_fns = { + "fastgen": start_fastgen_server, + "vllm": start_vllm_server, + "aml": start_aml_server, + "openai": start_openai_server, + } + start_fn = start_server_fns[args.backend] + start_fn(args) + + +def start_vllm_server(args: argparse.Namespace) -> None: + vllm_cmd = ( + "python", + "-m", + "vllm.entrypoints.api_server", + "--host", + "127.0.0.1", + "--port", + "26500", + "--tensor-parallel-size", + str(args.tp_size), + "--model", + args.model, + ) + p = subprocess.Popen( + vllm_cmd, stdout=subprocess.DEVNULL, stderr=subprocess.PIPE, close_fds=True + ) + start_time = time.time() + timeout_after = 60 * 5 # 5 minutes + while True: + line = p.stderr.readline().decode("utf-8") + if "Application startup complete" in line: + break + if "error" in line.lower(): + p.terminate() + stop_vllm_server(args) + raise RuntimeError(f"Error starting VLLM server: {line}") + if time.time() - start_time > timeout_after: + p.terminate() + stop_vllm_server(args) + raise TimeoutError("Timed out waiting for VLLM server to start") + time.sleep(0.01) + + +def start_fastgen_server(args: argparse.Namespace) -> None: + import mii + from deepspeed.inference import RaggedInferenceEngineConfig, DeepSpeedTPConfig + from deepspeed.inference.v2.ragged import DSStateManagerConfig + + tp_config = DeepSpeedTPConfig(tp_size=args.tp_size) + mgr_config = DSStateManagerConfig( + max_ragged_batch_size=args.max_ragged_batch_size, + max_ragged_sequence_count=args.max_ragged_batch_size, + ) + inference_config = RaggedInferenceEngineConfig( + tensor_parallel=tp_config, state_manager=mgr_config + ) + if args.fp6: + quantization_mode = 'wf6af16' + else: + quantization_mode = None + mii.serve( + args.model, + deployment_name=args.deployment_name, + tensor_parallel=args.tp_size, + inference_engine_config=inference_config, + replica_num=args.num_replicas, + quantization_mode=quantization_mode + ) + + +def start_aml_server(args: argparse.Namespace) -> None: + raise NotImplementedError( + "AML server start not implemented. Please use Azure Portal to start the server." + ) + +def start_openai_server(args: argparse.Namespace) -> None: + # openai api has no command to stop server + pass + +def stop_server(args: argparse.Namespace) -> None: + stop_server_fns = { + "fastgen": stop_fastgen_server, + "vllm": stop_vllm_server, + "aml": stop_aml_server, + "openai": stop_openai_server, + } + stop_fn = stop_server_fns[args.backend] + stop_fn(args) + + +def stop_vllm_server(args: argparse.Namespace) -> None: + vllm_cmd = ("pkill", "-f", "vllm.entrypoints.api_server") + p = subprocess.Popen(vllm_cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE) + p.wait() + + +def stop_fastgen_server(args: argparse.Namespace) -> None: + import mii + + mii.client(args.deployment_name).terminate_server() + + +def stop_aml_server(args: argparse.Namespace) -> None: + raise NotImplementedError( + "AML server stop not implemented. Please use Azure Portal to stop the server." + ) + +def stop_openai_server(args: argparse.Namespace) -> None: + # openai api has no command to stop server + pass + +if __name__ == "__main__": + args = parse_args(server_args=True) + + if args.cmd == "start": + start_server(args) + elif args.cmd == "stop": + stop_server(args) + elif args.cmd == "restart": + stop_server(args) + start_server(args) + else: + raise ValueError(f"Invalid command {args.cmd}") diff --git a/benchmarks/inference/mii/src/utils.py b/benchmarks/inference/mii/src/utils.py new file mode 100644 index 000000000..ac2065065 --- /dev/null +++ b/benchmarks/inference/mii/src/utils.py @@ -0,0 +1,281 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +import argparse +import copy +import itertools +import json +import os + +from dataclasses import asdict +from datetime import datetime +from pathlib import Path +from typing import Iterator, List + +try: + from .defaults import ARG_DEFAULTS, MODEL_DEFAULTS + from .postprocess_results import get_summary, ResponseDetails +except ImportError: + from defaults import ARG_DEFAULTS, MODEL_DEFAULTS + from postprocess_results import get_summary, ResponseDetails + +# For these arguments, users can provide multiple values when running the +# benchmark. The benchmark will iterate over all possible combinations. +SERVER_PARAMS = ["tp_size", "max_ragged_batch_size", "num_replicas"] +CLIENT_PARAMS = ["mean_prompt_length", "mean_max_new_tokens", "num_clients"] + +AML_REQUIRED_PARAMS = ["aml_api_url", "aml_api_key", "deployment_name", "model"] + + +def parse_args( + server_args: bool = False, client_args: bool = False +) -> argparse.Namespace: + if not (server_args or client_args): + raise ValueError("Must specify server_args or client_args or both") + + # Server args + server_parser = argparse.ArgumentParser(add_help=False) + server_parser.add_argument( + "--tp_size", type=int, nargs="+", default=None, help="Tensor parallelism size" + ) + server_parser.add_argument( + "--max_ragged_batch_size", + type=int, + nargs="+", + default=None, + help="Max batch size for ragged batching", + ) + server_parser.add_argument( + "--num_replicas", + type=int, + nargs="+", + default=None, + help="Number of FastGen model replicas", + ) + server_parser.add_argument( + "cmd", + type=str, + nargs="?", + choices=["start", "stop", "restart"], + help="Command for running server.py to manually start/stop/restart a server", + ) + server_parser.add_argument( + "--client_only", action="store_true", help="Run client only with server started" + ) + + + # Client args + client_parser = argparse.ArgumentParser(add_help=False) + client_parser.add_argument( + "--max_prompt_length", type=int, default=None, help="Max length a prompt can be" + ) + client_parser.add_argument( + "--mean_prompt_length", + type=int, + nargs="+", + default=None, + help="Mean prompt length in tokens", + ) + client_parser.add_argument( + "--mean_max_new_tokens", + type=int, + nargs="+", + default=None, + help="Mean number of new tokens to generate per prompt", + ) + client_parser.add_argument( + "--num_clients", + type=int, + nargs="+", + default=[1, 2, 4, 6, 8, 12, 16, 20, 24, 28, 32], + help="Number of concurrent clients", + ) + client_parser.add_argument( + "--num_requests", + type=int, + default=None, + help="Number of requests to process by clients", + ) + client_parser.add_argument( + "--prompt_length_var", type=float, default=0.3, help="Variance of prompt length" + ) + client_parser.add_argument( + "--max_new_tokens_var", + type=float, + default=0.3, + help="Variance of max new tokens", + ) + client_parser.add_argument( + "--warmup", type=int, default=1, help="Number of warmup requests to process" + ) + client_parser.add_argument( + "--use_thread", action="store_true", help="Use threads instead of processes" + ) + client_parser.add_argument( + "--stream", action="store_true", help="Stream generated tokens" + ) + client_parser.add_argument( + "--out_json_dir", + type=Path, + default="./results/", + help="Directory to save result JSON files", + ) + client_parser.add_argument( + "--openai_api_url", + type=str, + default=None, + help="When using the openai API backend, this is the API URL that points to an openai api server", + ) + client_parser.add_argument( + "--openai_api_key", + type=str, + default=None, + help="When using the openai API backend, this is the API key for a given openai_api_url", + ) + client_parser.add_argument( + "--aml_api_url", + type=str, + default=None, + help="When using the AML backend, this is the API URL that points to an AML endpoint", + ) + client_parser.add_argument( + "--aml_api_key", + type=str, + default=None, + help="When using the AML backend, this is the API key for a given aml_api_url", + ) + + # Create the parser, inheriting from the server and/or client parsers + parents = [] + if server_args: + parents.append(server_parser) + if client_args: + parents.append(client_parser) + + # Common args + parser = argparse.ArgumentParser(parents=parents) + parser.add_argument( + "--model", type=str, default=None, help="HuggingFace.co model name" + ) + parser.add_argument( + "--deployment_name", + type=str, + default=None, + help="When using FastGen backend, specifies which model deployment to use. When using AML backend, specifies the name of the deployment", + ) + parser.add_argument( + "--backend", + type=str, + choices=["aml", "fastgen", "vllm", "openai"], + default="fastgen", + help="Which backend to benchmark", + ) + parser.add_argument( + "--overwrite_results", action="store_true", help="Overwrite existing results" + ) + parser.add_argument("--fp6", action="store_true", help="Enable FP6") + + # Parse arguments + args = parser.parse_args() + + # Verify that AML required parameters are defined before filling in defaults + if args.backend == "aml": + for k in AML_REQUIRED_PARAMS: + if getattr(args, k) is None: + raise ValueError(f"AML backend requires {k} to be specified") + + # Set default values for model-specific parameters + if args.model in MODEL_DEFAULTS: + for k, v in MODEL_DEFAULTS[args.model].items(): + if hasattr(args, k) and getattr(args, k) is None: + setattr(args, k, v) + + # Grab any remaining default values not specified for a model + for k, v in ARG_DEFAULTS.items(): + if hasattr(args, k) and getattr(args, k) is None: + setattr(args, k, v) + + # If we are not running the benchmark, we need to make sure to only have one + # value for the server args + if server_args and not client_args: + for k in SERVER_PARAMS: + if not isinstance(getattr(args, k), int): + setattr(args, k, getattr(args, k)[0]) + + return args + + +def get_args_product( + args: argparse.Namespace, which: List[str] = None +) -> Iterator[argparse.Namespace]: + if which is None: + return copy.deepcopy(args) + for k in which: + if isinstance(getattr(args, k), int): + setattr(args, k, [getattr(args, k)]) + arg_values_product = itertools.product(*[getattr(args, k) for k in which]) + for arg_values in arg_values_product: + args_copy = copy.deepcopy(args) + for k, v in zip(which, arg_values): + setattr(args_copy, k, v) + yield args_copy + + +def get_results_path(args: argparse.Namespace) -> Path: + return Path( + f"{args.out_json_dir}_{args.backend}/", + "-".join( + ( + args.model.replace("/", "_"), + f"tp{args.tp_size}", + f"bs{args.max_ragged_batch_size}", + f"replicas{args.num_replicas}", + f"prompt{args.mean_prompt_length}", + f"gen{args.mean_max_new_tokens}", + f"clients{args.num_clients}", + ) + ) + + ".json", + ) + + +def print_summary( + args: argparse.Namespace, response_details: List[ResponseDetails] +) -> None: + ps = get_summary(vars(args), response_details) + print( + f"Deployment: {args.deployment_name} Clients: {args.num_clients}, " + + f"Prompt (mean): {args.mean_prompt_length} tokens, " + + f"Generation (mean): {args.mean_max_new_tokens} tokens, " + + f"Query throughput: {ps.throughput:.3f} queries/s, " + + f"Token throughput (total): {ps.tokens_per_sec:.3f} tokens/s, " + + f"Query latency: {ps.latency:.3f} s, " + + f"Token generation latency: {ps.token_gen_latency:.3f} s/token, " + + f"First token received: {ps.first_token_latency:.3f} s" + ) + + +def save_json_results( + args: argparse.Namespace, response_details: List[ResponseDetails] +) -> None: + args_dict = vars(args) + # Remove AML key from args dictionary + if "aml_api_key" in args_dict: + args_dict["aml_api_key"] = None + out_json_path = get_results_path(args) + os.makedirs(out_json_path.parent, exist_ok=True) + + with open(out_json_path, "w") as f: + args_dict["out_json_dir"] = str(out_json_path) # Path is not JSON serializable + data = { + "args": args_dict, + "time": str(datetime.now()), + "response_details": [asdict(r) for r in response_details], + } + json.dump(data, f, indent=2) + + +def results_exist(args: argparse.Namespace) -> bool: + return get_results_path(args).exists() diff --git a/deepnvme/file_access/README.md b/deepnvme/file_access/README.md new file mode 100644 index 000000000..a50f6f438 --- /dev/null +++ b/deepnvme/file_access/README.md @@ -0,0 +1,116 @@ +# Using DeepNVMe for simple file reads and writes involving CPU/GPU tensors + +The purpose of this folder is to provide example codes that illustrate how to use [DeepNVMe](https://github.com/microsoft/DeepSpeed/blob/master/blogs/deepspeed-gds/README.md) for simple file operations of moving raw data bytes between persistent storage and CPU/GPU tensors. For each file operation, we provide an implementation using Python I/O functionality, and a DeepNVMe implementation using CPU bounce buffer (aio) and NVIDIA Magnum IOTM GPUDirect® Storage (GDS) as appropriate. + +The following table is a mapping of file operations to the corresponding Python and DeepNVMe implementations. + + +File Operation | Python | DeepNVMe (aio) | DeepNVMe (GDS) +|---|---|---|---| +Load CPU tensor from file | py_load_cpu_tensor.py | aio_load_cpu_tensor.py | - | +Load GPU tensor from file | py_load_gpu_tensor.py | aio_load_gpu_tensor.py | gds_load_gpu_tensor.py | +Store CPU tensor to file | py_store_cpu_tensor.py | aio_store_cpu_tensor.py | - | +Store GPU tensor to file | py_store_gpu_tensor.py | aio_store_gpu_tensor.py | gds_store_gpu_tensor.py | + +The Python implementations are the scripts with `py_` prefix. while the DeepNVMe implementations are those with`aio_` and `gds_`prefixes. + +## Requirements +Ensure your environment is properly configured to run these examples. First, you need to install DeepSpeed version >= 0.15.0. Next, ensure that the DeepNVMe operators are available in the DeepSpeed installation. The `async_io` operator is required for any DeepNVMe functionality, while the `gds` operator is required only for GDS functionality. You can confirm availability of each operator by inspecting the output of `ds_report` to check that compatible status is [OKAY]. Below is a snippet of `ds_report` output showing availability of both `async_io` and `gds` operators. + +
+ +
+
+ ds_report output showing availability of DeepNVMe operators (async_io and gds) in a DeepSpeed installation. +
+ + +If `async_io` opertator is unavailable, you will need to install the appropriate `libaio` library binaries for your Linux flavor. For example, Ubuntu users will need to run `apt install libaio-dev`. In general, you should carefully inspect `ds_report` output for helpful tips such as the following: + +```bash +[WARNING] async_io requires the dev libaio .so object and headers but these were not found. +[WARNING] async_io: please install the libaio-dev package with apt +[WARNING] If libaio is already installed (perhaps from source), try setting the CFLAGS and LDFLAGS environment variables to where it can be found. +``` + +To enable `gds` operator, you will need to install NVIDIA GDS by consulting the appropriate guide for [bare-metal systems](https://docs.nvidia.com/gpudirect-storage/troubleshooting-guide/index.html) or Azure VMs (coming soon). + +## Tensor Load Examples +The tensor load example scripts share a common command-line interface, which is illustrated below using `py_read_load_cpu_tensor.py`. +```bash +$ python py_load_cpu_tensor.py --help +usage: py_load_cpu_tensor.py [-h] --input_file INPUT_FILE [--loop LOOP] [--validate] + +options: + -h, --help show this help message and exit + --input_file INPUT_FILE + File on NVMe device that will read as input. + --loop LOOP The number of times to repeat the operation (default 3). + --validate Run validation step that compares tensor value against Python file read +``` +Before running these example scripts ensure that the input file exists on an NVMe device. The `--validate` option is relevant only to the DeepNVme implementations. This option provides minimal correctness checking by comparing against a tensor loaded using Python. We also provide a bash script `run_load_tensor.sh`, which runs all the example tensor load scripts. + + +## Tensor Store Examples +The tensor store examples share a command-line interface, which is illustrated below using `py_store_cpu_tensor.py` +```bash +$ python py_store_cpu_tensor.py --help +usage: py_store_cpu_tensor.py [-h] --nvme_folder NVME_FOLDER [--mb_size MB_SIZE] [--loop LOOP] [--validate] + +options: + -h, --help show this help message and exit + --nvme_folder NVME_FOLDER + NVMe folder for file write. + --mb_size MB_SIZE Size of tensor to save in MB (default 1024). + --loop LOOP The number of times to repeat the operation (default 3). + --validate Run validation step that compares tensor value against Python file read + +``` +Before running these examples ensure that the output folder exists on an NVMe device and that you have write permission. The `--validate` option is relevant only to the DeepNVMe implementations. This option provides minimal correctness checking by comparing the output file against that created using Python. We also provide a bash script `run_store_tensor.sh`, which runs all the example tensor store scripts. + + +## Performance Advisory +Although this folder is primarily meant to help with integrating DeepNVMe into your Deep Learning applications, the example scripts also print out performance numbers of read and write throughput. So, we expect you will observe some performance advantage of DeepNVMe compared to Python. However, do note that it is likely that better performance can be realized by tuning DeepNVMe for your environment. Such tuning efforts will ideally generate more optimal values for configuring DeepNVMe. + +For reference, DeepNVMe configuration using hard-coded constants for `aio_` implementations is as follows: + +```python + aio_handle = AsyncIOBuilder().load().aio_handle(1024**2, 128, True, True, 1) +``` + +The corresponding DeepNVMe configuration for `gds_` implementations is as follows: + +```python + gds_handle = GDSBuilder().load().gds_handle(1024**2, 128, True, True, 1) +``` + +Despite the above caveat, it seems that some performance numbers would be useful here to help set the right expectations. The experiments were conducted on an Azure [NC80adis_H100_v5](https://learn.microsoft.com/en-us/azure/virtual-machines/ncads-h100-v5) series virtual machine (VM). This VM includes two 3.5TB local NVMe devices (labelled Microsoft NVMe Direct Disk v2) that we combined into a single RAID-0 volume. The software environment included Ubuntu 22.04.4 LTS, Linux kernel 6.5.0-26-generic, Pytorch 2.4, and CUDA 12.4. We ran experiments of 1GB data transfers using the unmodified scripts, i.e., without DeepNVMe tuning, and present the throughput results in the tables below. In summary, we observed that DeepNVMe significantly accelerates I/O operations compared to Python. DeepNVMe is 8-16X faster for loading tensor data, and 11X-119X faster for writing tensor data. + +Load 1GB CPU tensor (1GB file read) | GB/sec | Speedup over Python | +|---|---|---| +py_load_cpu_tensor.py | 1.5 | - | +aio_load_cpu_tensor.py | 12.3 | 8X | + +Load 1GB GPU tensor (1GB file read) | GB/sec | Speedup over Python | +|---|---|---| +py_load_gpu_tensor.py | 0.7| - | +aio_load_gpu_tensor.py | 9.9 | 14X | +gds_load_gpu_tensor.py | 11.1 | 16X | + + +Store 1GB CPU tensor (1GB file write) | GB/sec | Speedup over Python | +|---|---|---| +py_store_cpu_tensor.py | 0.7 | - | +aio_store_cpu_tensor.py | 8.1 | 11X | + + +Store 1GB GPU tensor (1GB file write) | GB/sec | Speedup over Python | +|---|---|---| +py_store_gpu_tensor.py | 0.5 | - | +aio_store_gpu_tensor.py | 8.3 | 18X | +gds_store_gpu_tensor.py | 8.6 | 19X | + + + +# Conclusion +We hope you find this document and example scripts useful for integrating DeepNVMe into your applications. diff --git a/deepnvme/file_access/aio_load_cpu_tensor.py b/deepnvme/file_access/aio_load_cpu_tensor.py new file mode 100644 index 000000000..27a1e61c5 --- /dev/null +++ b/deepnvme/file_access/aio_load_cpu_tensor.py @@ -0,0 +1,31 @@ +import torch +import os, timeit, functools +from deepspeed.ops.op_builder import AsyncIOBuilder +from utils import parse_read_arguments, GIGA_UNIT + +def file_read(inp_f, handle, bounce_buffer): + handle.sync_pread(bounce_buffer, inp_f) + return bounce_buffer.cpu() + +def main(): + args = parse_read_arguments() + input_file = args.input_file + file_sz = os.path.getsize(input_file) + cnt = args.loop + + aio_handle = AsyncIOBuilder().load().aio_handle() + bounce_buffer = torch.empty(os.path.getsize(input_file), dtype=torch.uint8).pin_memory() + + t = timeit.Timer(functools.partial(file_read, input_file, aio_handle, bounce_buffer)) + aio_t = t.timeit(cnt) + aio_gbs = (cnt*file_sz)/GIGA_UNIT/aio_t + print(f'aio load_cpu: {file_sz/GIGA_UNIT} GB, {aio_t/cnt} secs, {aio_gbs:5.2f} GB/sec') + + if args.validate: + from py_load_cpu_tensor import file_read as py_file_read + aio_tensor = file_read(input_file, aio_handle, bounce_buffer) + py_tensor = py_file_read(input_file) + print(f'Validation success = {aio_tensor.equal(py_tensor)}') + +if __name__ == "__main__": + main() diff --git a/deepnvme/file_access/aio_load_gpu_tensor.py b/deepnvme/file_access/aio_load_gpu_tensor.py new file mode 100644 index 000000000..aeecc6e5d --- /dev/null +++ b/deepnvme/file_access/aio_load_gpu_tensor.py @@ -0,0 +1,32 @@ +import torch +import os, timeit, functools +from deepspeed.ops.op_builder import AsyncIOBuilder +from utils import parse_read_arguments, GIGA_UNIT + +def file_read(inp_f, handle, bounce_buffer): + handle.sync_pread(bounce_buffer, inp_f) + return bounce_buffer.cuda() + + +def main(): + args = parse_read_arguments() + input_file = args.input_file + file_sz = os.path.getsize(input_file) + cnt = args.loop + + aio_handle = AsyncIOBuilder().load().aio_handle() + bounce_buffer = torch.empty(os.path.getsize(input_file), dtype=torch.uint8).pin_memory() + + t = timeit.Timer(functools.partial(file_read, input_file, aio_handle, bounce_buffer)) + aio_t = t.timeit(cnt) + aio_gbs = (cnt*file_sz)/GIGA_UNIT/aio_t + print(f'aio load_gpu: {file_sz/GIGA_UNIT} GB, {aio_t/cnt} secs, {aio_gbs:5.2f} GB/sec') + + if args.validate: + from py_load_cpu_tensor import file_read as py_file_read + aio_tensor = file_read(input_file, aio_handle, bounce_buffer).cpu() + py_tensor = py_file_read(input_file) + print(f'Validation success = {aio_tensor.equal(py_tensor)}') + +if __name__ == "__main__": + main() diff --git a/deepnvme/file_access/aio_store_cpu_tensor.py b/deepnvme/file_access/aio_store_cpu_tensor.py new file mode 100644 index 000000000..20c03792b --- /dev/null +++ b/deepnvme/file_access/aio_store_cpu_tensor.py @@ -0,0 +1,40 @@ +import torch +import os, timeit, functools, pathlib +from deepspeed.ops.op_builder import AsyncIOBuilder +from utils import parse_write_arguments, GIGA_UNIT + +def file_write(out_f, tensor, handle, bounce_buffer): + bounce_buffer.copy_(tensor) + handle.sync_pwrite(bounce_buffer, out_f) + +def main(): + args = parse_write_arguments() + cnt = args.loop + output_file = os.path.join(args.nvme_folder, f'test_ouput_{args.mb_size}MB.pt') + pathlib.Path(output_file).unlink(missing_ok=True) + file_sz = args.mb_size*(1024**2) + app_tensor = torch.empty(file_sz, dtype=torch.uint8, device='cpu', requires_grad=False) + + aio_handle = AsyncIOBuilder().load().aio_handle() + bounce_buffer = torch.empty(file_sz, dtype=torch.uint8, requires_grad=False).pin_memory() + + + t = timeit.Timer(functools.partial(file_write, output_file, app_tensor, aio_handle, bounce_buffer)) + + aio_t = t.timeit(cnt) + aio_gbs = (cnt*file_sz)/GIGA_UNIT/aio_t + print(f'aio store_cpu: {file_sz/GIGA_UNIT} GB, {aio_t/cnt} secs, {aio_gbs:5.2f} GB/sec') + + if args.validate: + import tempfile, filecmp + from py_store_cpu_tensor import file_write as py_file_write + py_ref_file = os.path.join(tempfile.gettempdir(), os.path.basename(output_file)) + py_file_write(py_ref_file, app_tensor) + filecmp.clear_cache() + print(f'Validation success = {filecmp.cmp(py_ref_file, output_file, shallow=False) }') + + pathlib.Path(output_file).unlink(missing_ok=True) + + +if __name__ == "__main__": + main() diff --git a/deepnvme/file_access/aio_store_gpu_tensor.py b/deepnvme/file_access/aio_store_gpu_tensor.py new file mode 100644 index 000000000..71a4aa7bb --- /dev/null +++ b/deepnvme/file_access/aio_store_gpu_tensor.py @@ -0,0 +1,40 @@ +import torch +import os, timeit, functools, pathlib +from deepspeed.ops.op_builder import AsyncIOBuilder +from utils import parse_write_arguments, GIGA_UNIT + +def file_write(out_f, tensor, handle, bounce_buffer): + bounce_buffer.copy_(tensor) + handle.sync_pwrite(bounce_buffer, out_f) + +def main(): + args = parse_write_arguments() + cnt = args.loop + output_file = os.path.join(args.nvme_folder, f'test_ouput_{args.mb_size}MB.pt') + pathlib.Path(output_file).unlink(missing_ok=True) + file_sz = args.mb_size*(1024**2) + app_tensor = torch.empty(file_sz, dtype=torch.uint8, device='cuda', requires_grad=False) + + aio_handle = AsyncIOBuilder().load().aio_handle() + bounce_buffer = torch.empty(file_sz, dtype=torch.uint8, requires_grad=False).pin_memory() + + + t = timeit.Timer(functools.partial(file_write, output_file, app_tensor, aio_handle, bounce_buffer)) + + aio_t = t.timeit(cnt) + aio_gbs = (cnt*file_sz)/GIGA_UNIT/aio_t + print(f'aio store_gpu: {file_sz/GIGA_UNIT} GB, {aio_t/cnt} secs, {aio_gbs:5.2f} GB/sec') + + if args.validate: + import tempfile, filecmp + from py_store_cpu_tensor import file_write as py_file_write + py_ref_file = os.path.join(tempfile.gettempdir(), os.path.basename(output_file)) + py_file_write(py_ref_file, app_tensor) + filecmp.clear_cache() + print(f'Validation success = {filecmp.cmp(py_ref_file, output_file, shallow=False) }') + + pathlib.Path(output_file).unlink(missing_ok=True) + + +if __name__ == "__main__": + main() diff --git a/deepnvme/file_access/gds_load_gpu_tensor.py b/deepnvme/file_access/gds_load_gpu_tensor.py new file mode 100644 index 000000000..dd6273707 --- /dev/null +++ b/deepnvme/file_access/gds_load_gpu_tensor.py @@ -0,0 +1,33 @@ +import torch +import os, timeit, functools +from utils import parse_read_arguments, GIGA_UNIT +from deepspeed.ops.op_builder import GDSBuilder + +def file_read(inp_f, handle, gpu_buffer): + handle.sync_pread(gpu_buffer, inp_f) + return gpu_buffer.cuda() + +def main(): + args = parse_read_arguments() + input_file = args.input_file + file_sz = os.path.getsize(input_file) + cnt = args.loop + + gds_handle = GDSBuilder().load().gds_handle() + gds_buffer = gds_handle.new_pinned_device_tensor(file_sz, torch.empty(0, dtype=torch.uint8, device='cuda', requires_grad=False)) + + t = timeit.Timer(functools.partial(file_read, input_file, gds_handle, gds_buffer)) + gds_t = t.timeit(cnt) + gds_gbs = (cnt*file_sz)/GIGA_UNIT/gds_t + print(f'gds load_gpu: {file_sz/GIGA_UNIT} GB, {gds_t/cnt} secs, {gds_gbs:5.2f} GB/sec') + + if args.validate: + from py_load_cpu_tensor import file_read as py_file_read + aio_tensor = file_read(input_file, gds_handle, gds_buffer).cpu() + py_tensor = py_file_read(input_file) + print(f'Validation success = {aio_tensor.equal(py_tensor)}') + + gds_handle.free_pinned_device_tensor(gds_buffer) + +if __name__ == "__main__": + main() diff --git a/deepnvme/file_access/gds_store_gpu_tensor.py b/deepnvme/file_access/gds_store_gpu_tensor.py new file mode 100644 index 000000000..06ba508ba --- /dev/null +++ b/deepnvme/file_access/gds_store_gpu_tensor.py @@ -0,0 +1,39 @@ +import torch +import os, timeit, functools, pathlib +from deepspeed.ops.op_builder import GDSBuilder +from utils import parse_write_arguments, GIGA_UNIT + +def file_write(out_f, tensor, handle, gpu_buffer): + gpu_buffer.copy_(tensor) + handle.sync_pwrite(gpu_buffer, out_f) + +def main(): + args = parse_write_arguments() + cnt = args.loop + output_file = os.path.join(args.nvme_folder, f'test_ouput_{args.mb_size}MB.pt') + pathlib.Path(output_file).unlink(missing_ok=True) + file_sz = args.mb_size*(1024**2) + app_tensor = torch.empty(file_sz, dtype=torch.uint8, device='cuda', requires_grad=False) + + gds_handle = GDSBuilder().load().gds_handle() + gds_buffer = gds_handle.new_pinned_device_tensor(file_sz, torch.empty(0, dtype=torch.uint8, device='cuda', requires_grad=False)) + + t = timeit.Timer(functools.partial(file_write, output_file, app_tensor, gds_handle, gds_buffer)) + + gds_t = t.timeit(cnt) + gds_gbs = (cnt*file_sz)/GIGA_UNIT/gds_t + print(f'gds store_gpu: {file_sz/GIGA_UNIT} GB, {gds_t/cnt} secs, {gds_gbs:5.2f} GB/sec') + + if args.validate: + import tempfile, filecmp + from py_store_cpu_tensor import file_write as py_file_write + py_ref_file = os.path.join(tempfile.gettempdir(), os.path.basename(output_file)) + py_file_write(py_ref_file, app_tensor) + filecmp.clear_cache() + print(f'Validation success = {filecmp.cmp(py_ref_file, output_file, shallow=False) }') + + gds_handle.free_pinned_device_tensor(gds_buffer) + pathlib.Path(output_file).unlink(missing_ok=True) + +if __name__ == "__main__": + main() diff --git a/deepnvme/file_access/media/deepnvme_ops_report.png b/deepnvme/file_access/media/deepnvme_ops_report.png new file mode 100644 index 000000000..c05e9b863 Binary files /dev/null and b/deepnvme/file_access/media/deepnvme_ops_report.png differ diff --git a/deepnvme/file_access/py_load_cpu_tensor.py b/deepnvme/file_access/py_load_cpu_tensor.py new file mode 100644 index 000000000..0650848f0 --- /dev/null +++ b/deepnvme/file_access/py_load_cpu_tensor.py @@ -0,0 +1,22 @@ +import torch +import os, timeit, functools +from utils import parse_read_arguments, GIGA_UNIT + +def file_read(inp_f): + with open(inp_f, 'rb') as f: + tensor = torch.frombuffer(f.read(), dtype=torch.uint8) + return tensor + +def main(): + args = parse_read_arguments() + input_file = args.input_file + file_sz = os.path.getsize(input_file) + cnt = args.loop + + t = timeit.Timer(functools.partial(file_read, input_file)) + py_t = t.timeit(cnt) + py_gbs = (cnt*file_sz)/GIGA_UNIT/py_t + print(f'py load_cpu: {file_sz/GIGA_UNIT} GB, {py_t/cnt} secs, {py_gbs:5.2f} GB/sec') + +if __name__ == "__main__": + main() diff --git a/deepnvme/file_access/py_load_gpu_tensor.py b/deepnvme/file_access/py_load_gpu_tensor.py new file mode 100644 index 000000000..976967dca --- /dev/null +++ b/deepnvme/file_access/py_load_gpu_tensor.py @@ -0,0 +1,22 @@ +import torch +import os, timeit, functools +from utils import parse_read_arguments, GIGA_UNIT + +def file_read(inp_f): + with open(inp_f, 'rb') as f: + tensor = torch.frombuffer(f.read(), dtype=torch.uint8) + return tensor.cuda() + +def main(): + args = parse_read_arguments() + input_file = args.input_file + file_sz = os.path.getsize(input_file) + cnt = args.loop + + t = timeit.Timer(functools.partial(file_read, input_file)) + py_t = t.timeit(cnt) + py_gbs = (cnt*file_sz)/GIGA_UNIT/py_t + print(f'py load_gpu: {file_sz/GIGA_UNIT} GB, {py_t/cnt} secs, {py_gbs:5.2f} GB/sec') + +if __name__ == "__main__": + main() diff --git a/deepnvme/file_access/py_store_cpu_tensor.py b/deepnvme/file_access/py_store_cpu_tensor.py new file mode 100644 index 000000000..50e477186 --- /dev/null +++ b/deepnvme/file_access/py_store_cpu_tensor.py @@ -0,0 +1,26 @@ +import torch +import os, timeit, functools +import pathlib +from utils import parse_write_arguments, GIGA_UNIT + +def file_write(out_f, tensor): + with open(out_f, 'wb') as f: + f.write(tensor.numpy(force=True)) + +def main(): + args = parse_write_arguments() + cnt = args.loop + output_file = os.path.join(args.nvme_folder, f'test_ouput_{args.mb_size}MB.pt') + pathlib.Path(output_file).unlink(missing_ok=True) + file_sz = args.mb_size*(1024**2) + cpu_tensor = torch.empty(file_sz, dtype=torch.uint8, device='cpu', requires_grad=False) + + t = timeit.Timer(functools.partial(file_write, output_file, cpu_tensor)) + + py_t = t.timeit(cnt) + py_gbs = (cnt*file_sz)/GIGA_UNIT/py_t + print(f'py store_cpu: {file_sz/GIGA_UNIT} GB, {py_t/cnt} secs, {py_gbs:5.2f} GB/sec') + pathlib.Path(output_file).unlink(missing_ok=True) + +if __name__ == "__main__": + main() diff --git a/deepnvme/file_access/py_store_gpu_tensor.py b/deepnvme/file_access/py_store_gpu_tensor.py new file mode 100644 index 000000000..a64209a12 --- /dev/null +++ b/deepnvme/file_access/py_store_gpu_tensor.py @@ -0,0 +1,27 @@ +import torch +import os, timeit, functools +import pathlib +from utils import parse_write_arguments, GIGA_UNIT + +def file_write(out_f, tensor): + with open(out_f, 'wb') as f: + f.write(tensor.numpy(force=True)) + +def main(): + args = parse_write_arguments() + cnt = args.loop + output_file = os.path.join(args.nvme_folder, f'test_ouput_{args.mb_size}MB.pt') + pathlib.Path(output_file).unlink(missing_ok=True) + file_sz = args.mb_size*(1024**2) + gpu_tensor = torch.empty(file_sz, dtype=torch.uint8, device='cuda', requires_grad=False) + + t = timeit.Timer(functools.partial(file_write, output_file, gpu_tensor)) + + py_t = t.timeit(cnt) + py_gbs = (cnt*file_sz)/GIGA_UNIT/py_t + print(f'py store_gpu: {file_sz/GIGA_UNIT} GB, {py_t/cnt} secs, {py_gbs:5.2f} GB/sec') + pathlib.Path(output_file).unlink(missing_ok=True) + + +if __name__ == "__main__": + main() diff --git a/deepnvme/file_access/run_load_tensor.sh b/deepnvme/file_access/run_load_tensor.sh new file mode 100644 index 000000000..e410c98b9 --- /dev/null +++ b/deepnvme/file_access/run_load_tensor.sh @@ -0,0 +1,26 @@ +#!/bin/bash + +if [[ $# -ne 1 ]]; then + echo "Usage: $0 " + exit 1 +fi + +input_file=$1 +if ! [[ -f "$input_file" ]]; then + echo "Error: $input_file does not exist" + exit 1 +fi + + +echo "Running load tensor examples using $input_file" +for f in aio_load_cpu_tensor.py aio_load_gpu_tensor.py \ + gds_load_gpu_tensor.py \ + py_load_cpu_tensor.py py_load_gpu_tensor.py; do + cmd="python $f --input_file $input_file" + sync + echo $cmd + eval $cmd + sleep 2 +done + + diff --git a/deepnvme/file_access/run_store_tensor.sh b/deepnvme/file_access/run_store_tensor.sh new file mode 100644 index 000000000..a10b3c219 --- /dev/null +++ b/deepnvme/file_access/run_store_tensor.sh @@ -0,0 +1,26 @@ +#!/bin/bash + +if [[ $# -ne 1 ]]; then + echo "Usage: $0 " + exit 1 +fi + +output_folder=$1 +if ! [[ -d "$output_folder" ]]; then + echo "Error: $output_folder does not exist" + exit 1 +fi + + +echo "Running store tensor examples using $output_folder" +for f in aio_store_cpu_tensor.py aio_store_gpu_tensor.py \ + gds_store_gpu_tensor.py \ + py_store_cpu_tensor.py py_store_gpu_tensor.py; do + cmd="python $f --nvme_folder $output_folder" + sync + echo $cmd + eval $cmd + sleep 2 +done + + diff --git a/deepnvme/file_access/utils.py b/deepnvme/file_access/utils.py new file mode 100644 index 000000000..e83168349 --- /dev/null +++ b/deepnvme/file_access/utils.py @@ -0,0 +1,57 @@ +import os +import argparse + +GIGA_UNIT = 1024**3 + +def parse_read_arguments(): + parser = argparse.ArgumentParser() + parser.add_argument('--input_file', + default=None, + type=str, + required=True, + help='File on NVMe device that will read as input.') + parser.add_argument('--loop', + type=int, + default=3, + help='The number of times to repeat the operation (default 3).') + parser.add_argument('--validate', + action="store_true", + help="Run validation step that compares tensor value against Python file read") + + args = parser.parse_args() + print(f'args = {args}') + if not os.path.isfile(args.input_file): + print(f'Invalid input file path: {args.input_file}') + quit() + + return args + + + +def parse_write_arguments(): + parser = argparse.ArgumentParser() + parser.add_argument('--nvme_folder', + default=None, + type=str, + required=True, + help='NVMe folder that will used for file write.') + parser.add_argument('--mb_size', + type=int, + default=1024, + help='Size of tensor to save in MB (default 1024).') + parser.add_argument('--loop', + type=int, + default=3, + help='The number of times to repeat the operation (default 3).') + parser.add_argument('--validate', + action="store_true", + help="Run validation step that compares tensor value against Python file read") + + args = parser.parse_args() + print(f'args = {args}') + if not os.path.isdir(args.nvme_folder): + print(f'Invalid output folder path: {args.output_folder}') + quit() + + return args + diff --git a/deepnvme/zero_inference/README.md b/deepnvme/zero_inference/README.md new file mode 100644 index 000000000..3214ad5ee --- /dev/null +++ b/deepnvme/zero_inference/README.md @@ -0,0 +1,28 @@ +# Using DeepNVMe for ZeRO-Inference +ZeRO-inference is an ideal use case for the DeepNVMe technology. When you have a model that exceeds the size of availabe GPU memory the [DeepNVMe](https://github.com/microsoft/DeepSpeed/blob/master/blogs/deepspeed-gds/README.md) library along with ZeRO-inference can be leveraged for high-throughput offline inference. + +Maximizing inference throughput (measured in tokens/sec) in this scenario has two parts. First offloading the model parameters to fast Non-Volatile Memory, either a single device or several devices RAIDed together to further increase the effective bandiwidth of the system. These parameters are then swapped into the GPU memory layer by layer to compute the forward pass for inference. This allows for the second part of the process, maximizing the batch size. By swapping in parameters layer by layer the remaining GPU memory can be used by the computational batch which leads to a maximizing of total inference throughput. + +## Testing Environment +The environment for these tests was a VM with NVIDIA Magnum IOTM GPUDirect® Storage (GDS) installed along with a single NVIDIA H100 GPU containing 96 GB of memory. The VM also had two NVMes each with a read bandwidth of ~6.5 GB/sec. The two NVMes were put into a RAID0 configuration, bringing the effective read bandwidth up to ~13 GB/sec. +
+ +
+ +## Initial Results +The following models were run from the folder DeepSpeedExamples/inference/huggingface/zero_inference using disk-offload of parameters via the following command: + +```bash +deepspeed --num_gpus 1 run_model.py --model $model_name --batch_size $bsz --prompt-len 512 --gen-len 32 --disk-offload $path_to_foler --use_gds +``` + +Where `--use_gds` is set to enable NVIDIA GDS and move parameters directly between the NVMe and GPU, otherwise an intermediate CPU bounce buffer will be used to move the parameters between the NVMe and GPU. + +All models tested were chosen so they could not fit into 96 GB of GPU memory. + +GDS | Mixtral-8x22B | Llama3-70B | Bloom-176B +|---|---|---|---| +False | 9.152(bsz=200) | 8.606(bsz=96) | 0.291(bsz=8) | +True | 9.233(bsz=200) | 8.876(bsz=96) | 0.293(bsz=8) | + +Throughput measured in tokens/sec. diff --git a/deepnvme/zero_inference/media/nvme_config.png b/deepnvme/zero_inference/media/nvme_config.png new file mode 100755 index 000000000..3c61cbb4c Binary files /dev/null and b/deepnvme/zero_inference/media/nvme_config.png differ diff --git a/deepnvme/zero_inference/media/zero_inf_mem_use_cpu.png b/deepnvme/zero_inference/media/zero_inf_mem_use_cpu.png new file mode 100755 index 000000000..7857265af Binary files /dev/null and b/deepnvme/zero_inference/media/zero_inf_mem_use_cpu.png differ diff --git a/deepnvme/zero_inference/media/zero_inf_mem_use_gds.png b/deepnvme/zero_inference/media/zero_inf_mem_use_gds.png new file mode 100755 index 000000000..fd0087ed6 Binary files /dev/null and b/deepnvme/zero_inference/media/zero_inf_mem_use_gds.png differ diff --git a/evaluation/inference/human_eval/README.md b/evaluation/inference/human_eval/README.md new file mode 100644 index 000000000..d3b254ea2 --- /dev/null +++ b/evaluation/inference/human_eval/README.md @@ -0,0 +1,45 @@ +# HumanEval Evaluation Script for DeepSpeed-FastGen + +## DISCLAIMER + +This human-eval evaluation will execute untrusted model-generated code. As per the OpenAI warning, we +strongly recommend you sandbox your environment as described in the [human-eval paper](https://arxiv.org/pdf/2107.03374.pdf). + +## Setup + +Running the human-eval evaluation requires installation of `human_eval` with the execution code enabled, +which requires local changes to `execution.py`. The following steps will setup `human-eval` for execution: + +```bash +git clone https://github.com/openai/human-eval.git +sed -i '/exec(check_program, exec_globals)/ s/^# //' human-eval/human_eval/execution.py +cd human-eval +python -m pip install -e . +``` + +This evaluation also requires the installation of DeepSpeed-MII: + +```bash +python -m pip install deepspeed-mii +``` + +Additional DeepSpeed-MII installation details can be found [here](https://github.com/microsoft/DeepSpeed-MII#installation). + +## Run the Evaluation + +The following command shows how to run a benchmark using the `codellama/CodeLlama-7b-Python-hf` model: + +```bash +python run_human_eval.py --model codellama/CodeLlama-7b-Python-hf --max-tokens 512 --num-samples-per-task 20 +``` + +## Run Evaluation on Samples + +Once samples have been generated, they can be evaluated independently using the `evaluate_functional_correctness` command. +For example, the following command will evaluate `mii_samples.jsonl`: + +```bash +evaluate_functional_correctness mii_samples.jsonl +``` + +The evaluation results will be saved to `mii_samples.jsonl_results.jsonl`. diff --git a/evaluation/inference/human_eval/run_human_eval.py b/evaluation/inference/human_eval/run_human_eval.py new file mode 100644 index 000000000..3acad8ece --- /dev/null +++ b/evaluation/inference/human_eval/run_human_eval.py @@ -0,0 +1,69 @@ +import os +import torch +import mii +import numpy +import argparse +from deepspeed.accelerator import get_accelerator +from transformers import pipeline +from human_eval.data import write_jsonl, read_problems +from human_eval.evaluation import evaluate_functional_correctness + +parser = argparse.ArgumentParser() +parser.add_argument("--model", "-m", type=str, default="codellama/CodeLlama-7b-Python-hf", help="evaluation model name") +parser.add_argument("--max-tokens", type=int, default=512, help="max new tokens") +parser.add_argument("--num-samples-per-task", type=int, default=20, help="number of samples to gen/eval per task") +parser.add_argument("--local_rank", type=int, default=int(os.getenv("LOCAL_RANK", "0")), help="local rank") +args = parser.parse_args() + +def generate_base_completion(pipe, problem_prompt: str) -> str: + return pipe(problem_prompt, do_sample=True)[0]["generated_text"] + +def generate_mii_completion(pipe, problem_prompt: str) -> str: + return pipe(problem_prompt, max_new_tokens=args.max_tokens)[0].generated_text + +def generate_samples(pipe, generation_function): + samples = [ + dict(task_id=task_id, completion=generation_function(pipe, problems[task_id]["prompt"])) for task_id in problems + for _ in range(args.num_samples_per_task) + ] + return samples + +print("Loading Problems") +problems = read_problems("human-eval/data/HumanEval.jsonl.gz") + +print("Initializing HuggingFace Pipeline") +device = torch.device(get_accelerator().device_name(args.local_rank)) +base_pipe = pipeline(model=args.model, + device=torch.device(get_accelerator().device_name(args.local_rank)), + max_length=args.max_tokens, + return_full_text=False) + +print("Generating Base Samples") +base_samples = generate_samples(base_pipe, generate_base_completion) + +print("Base Pipeline Teardown") +del base_pipe +torch.cuda.empty_cache() + +print("Initializing DeepSpeed-MII Pipeline") +mii_pipe = mii.pipeline(args.model) + +print("Generating MII Samples") +mii_samples = generate_samples(mii_pipe, generate_mii_completion) + +print("MII Pipeline Teardown") +mii_pipe.destroy() + +print("Writing Samples") +write_jsonl("base_samples.jsonl", base_samples) +write_jsonl("mii_samples.jsonl", mii_samples) + +print("Evaluating Samples") +base_results = evaluate_functional_correctness("base_samples.jsonl") +mii_results = evaluate_functional_correctness("mii_samples.jsonl") + +print(f"Base Results = {base_results}") +print(f"MII Results = {mii_results}") + +for key in base_results.keys(): + print(f"{key} - Base Result: {base_results[key]}, MII result: {mii_results[key]}") diff --git a/inference/huggingface/automatic-speech-recognition/test-wav2vec2.py b/inference/huggingface/automatic-speech-recognition/test-wav2vec2.py index f319928f2..18b5406bc 100644 --- a/inference/huggingface/automatic-speech-recognition/test-wav2vec2.py +++ b/inference/huggingface/automatic-speech-recognition/test-wav2vec2.py @@ -7,12 +7,14 @@ import deepspeed from deepspeed import module_inject from transformers.models.wav2vec2.modeling_wav2vec2 import Wav2Vec2EncoderLayer +from deepspeed.accelerator import get_accelerator librispeech_eval = load_dataset("librispeech_asr", "clean", split="test") # Get local gpu rank from torch.distributed/deepspeed launcher local_rank = int(os.getenv('LOCAL_RANK', '0')) world_size = int(os.getenv('WORLD_SIZE', '1')) +device = torch.device(get_accelerator().device_name(local_rank)) print( "***************** Creating model in RANK ({0}) with WORLD_SIZE = {1} *****************" @@ -27,7 +29,7 @@ dtype=torch.float, injection_policy={Wav2Vec2EncoderLayer: ('attention.out_proj','feed_forward.output_dense')}, replace_with_kernel_inject=False) -model.to(f'cuda:{local_rank}') +model.to(device) def map_to_array(batch): speech, _ = sf.read(batch["file"]) batch["speech"] = speech @@ -38,7 +40,7 @@ def map_to_array(batch): def map_to_pred(batch): input_values = processor(batch["speech"], return_tensors="pt", padding="longest").input_values with torch.no_grad(): - logits = model(input_values.to(f'cuda:{local_rank}')).logits + logits = model(input_values.to(device)).logits predicted_ids = torch.argmax(logits, dim=-1) transcription = processor.batch_decode(predicted_ids) diff --git a/inference/huggingface/fill-mask/test-bert.py b/inference/huggingface/fill-mask/test-bert.py index d317710a2..fb2af691a 100644 --- a/inference/huggingface/fill-mask/test-bert.py +++ b/inference/huggingface/fill-mask/test-bert.py @@ -4,13 +4,14 @@ import torch import os import argparse +from deepspeed.accelerator import get_accelerator parser = argparse.ArgumentParser() parser.add_argument("--model", "-m", type=str, help="hf model name") -parser.add_argument("--dtype", type=str, default="fp16", help="fp16 or fp32") +parser.add_argument("--dtype", type=str, default="fp16", help="fp16 or fp32 or bf16") parser.add_argument("--local_rank", type=int, default=0, help="local rank") parser.add_argument("--trials", type=int, default=8, help="number of trials") -parser.add_argument("--kernel-inject", action="store_true", help="inject kernels on") +parser.add_argument("--kernel_inject", action="store_true", help="inject kernels on") parser.add_argument("--graphs", action="store_true", help="CUDA Graphs on") parser.add_argument("--triton", action="store_true", help="triton kernels on") parser.add_argument("--deepspeed", action="store_true", help="use deepspeed inference") @@ -26,11 +27,11 @@ pipe.model, mp_size=world_size, dtype=torch.float16 if args.triton else torch.float, - replace_with_kernel_inject=True, + replace_with_kernel_inject=args.kernel_inject, use_triton=args.triton, ) -pipe.device = torch.device(f'cuda:{local_rank}') +pipe.device = torch.device(get_accelerator().device_name(local_rank)) output = pipe("In Autumn the [MASK] fall from the trees.") if not torch.distributed.is_initialized() or torch.distributed.get_rank() == 0: diff --git a/inference/huggingface/fill-mask/test-electra.py b/inference/huggingface/fill-mask/test-electra.py index 5c5448ace..28760f9f6 100644 --- a/inference/huggingface/fill-mask/test-electra.py +++ b/inference/huggingface/fill-mask/test-electra.py @@ -4,6 +4,7 @@ import torch import os from transformers.models.electra.modeling_electra import ElectraLayer +from deepspeed.accelerator import get_accelerator local_rank = int(os.getenv('LOCAL_RANK', '0')) world_size = int(os.getenv('WORLD_SIZE', '4')) @@ -21,7 +22,7 @@ dtype=torch.float, injection_policy={ElectraLayer: ('output.dense')} ) -pipe.device = torch.device(f'cuda:{local_rank}') +pipe.device = torch.device(get_accelerator().device_name(local_rank)) output = pipe(f"HuggingFace is creating a {pipe.tokenizer.mask_token} that the community uses to solve NLP tasks.") if not torch.distributed.is_initialized() or torch.distributed.get_rank() == 0: diff --git a/inference/huggingface/fill-mask/test-roberta.py b/inference/huggingface/fill-mask/test-roberta.py index 16bcec041..c625e6cf4 100644 --- a/inference/huggingface/fill-mask/test-roberta.py +++ b/inference/huggingface/fill-mask/test-roberta.py @@ -4,6 +4,7 @@ import torch import os from transformers.models.roberta.modeling_roberta import RobertaLayer +from deepspeed.accelerator import get_accelerator local_rank = int(os.getenv('LOCAL_RANK', '0')) world_size = int(os.getenv('WORLD_SIZE', '4')) @@ -22,7 +23,7 @@ injection_policy={RobertaLayer: ('output.dense')} ) -pipe.device = torch.device(f'cuda:{local_rank}') +pipe.device = torch.device(get_accelerator().device_name(local_rank)) output = pipe("The invention of the revolutionized the way we communicate with each other.") if not torch.distributed.is_initialized() or torch.distributed.get_rank() == 0: diff --git a/inference/huggingface/text-generation/README.md b/inference/huggingface/text-generation/README.md index 8019aa298..65e82bfe7 100644 --- a/inference/huggingface/text-generation/README.md +++ b/inference/huggingface/text-generation/README.md @@ -20,7 +20,7 @@ If you are using conda, the following works: conda create -c conda-forge -n deepspeed python=3.10 conda activate deepspeed pip install -r requirements.txt -deepspeed --num_gpus 1 inference-test.py --name bigscience/bloom-3b --batch_size 2 +deepspeed --num_gpus 1 inference-test.py --model bigscience/bloom-3b --batch_size 2 # Inference Test @@ -91,8 +91,9 @@ The DSPipeline class helps to load the model and run inference on it, given thes # DeepSpeed HuggingFace Compare The ds-hf-compare script can be used to compare the text generated outputs of DeepSpeed with kernel injection and HuggingFace inference of a model with the same parameters on a single GPU. +(p.s. kernel injection will not be used by default and is only enabled when the "--use_kernel" argument is provided.) ## Usage Examples can be run as follows: -
deepspeed --num_gpus 1 ds-hf-compare.py --model [model name/path] --dtype [data type] --num_inputs [number of test inputs] --print_outputs
+
deepspeed --num_gpus 1 ds-hf-compare.py --model [model name/path] --dtype [data type] --num_inputs [number of test inputs] --print_outputs --use_kernel[enable kernel injection]
 
\ No newline at end of file diff --git a/inference/huggingface/text-generation/arguments.py b/inference/huggingface/text-generation/arguments.py index b50198ff9..a6dade23f 100644 --- a/inference/huggingface/text-generation/arguments.py +++ b/inference/huggingface/text-generation/arguments.py @@ -7,7 +7,7 @@ parser.add_argument("--checkpoint_path", required=False, default=None, type=str, help="model checkpoint path") parser.add_argument("--save_mp_checkpoint_path", required=False, default=None, type=str, help="save-path to store the new model checkpoint") parser.add_argument("--batch_size", default=1, type=int, help="batch size") -parser.add_argument("--dtype", default="float16", type=str, choices=["float32", "float16", "int8"], help="data-type") +parser.add_argument("--dtype", default="float16", type=str, choices=["float32", "float16", "int8", "bfloat16"], help="data-type") parser.add_argument("--hf_baseline", action='store_true', help="disable DeepSpeed inference") parser.add_argument("--use_kernel", action='store_true', help="enable kernel-injection") parser.add_argument("--max_tokens", default=1024, type=int, help="maximum tokens used for the text-generation KV-cache") diff --git a/inference/huggingface/text-generation/ds-hf-compare.py b/inference/huggingface/text-generation/ds-hf-compare.py index 378a13940..bad82e9d8 100644 --- a/inference/huggingface/text-generation/ds-hf-compare.py +++ b/inference/huggingface/text-generation/ds-hf-compare.py @@ -3,18 +3,24 @@ from transformers import pipeline from difflib import SequenceMatcher from argparse import ArgumentParser +from deepspeed.accelerator import get_accelerator parser = ArgumentParser() parser.add_argument("--model", required=True, type=str, help="model_name") -parser.add_argument("--dtype", default="float16", type=str, choices=["float32", "float16", "int8"], help="data-type") +parser.add_argument("--dtype", default="float16", type=str, choices=["float32", "float16", "int8", "bfloat16"], help="data-type") parser.add_argument("--num_inputs", default=1, type=int, help="number of test inputs") parser.add_argument("--min_length", default=200, type=int, help="minimum tokens generated") parser.add_argument("--max_length", default=300, type=int, help="maximum tokens generated") parser.add_argument("--print_outputs", action='store_true', help="print generated text outputs") parser.add_argument("--local_rank", type=int, default=0, help="local rank") +parser.add_argument("--use_kernel", action='store_true', help="enable kernel-injection") args = parser.parse_args() +def print_0(output): + if args.local_rank == 0: + print(output) + def string_similarity(str1, str2): matcher = SequenceMatcher(None, str1, str2) similarity_ratio = matcher.ratio() @@ -69,36 +75,39 @@ def string_similarity(str1, str2): if args.num_inputs < len(test_inputs): inputs = test_inputs[:args.num_inputs] else: - print(f"Warning: num_inputs ({args.num_inputs}) is greater than the number of test inputs ({len(test_inputs)}). Using all test inputs.") + print_0(f"Warning: num_inputs ({args.num_inputs}) is greater than the number of test inputs ({len(test_inputs)}). Using all test inputs.") inputs = test_inputs data_type = getattr(torch, args.dtype) -pipe = pipeline('text-generation', args.model, torch_dtype=data_type, device=0) +pipe = pipeline('text-generation', args.model, torch_dtype=data_type, device=torch.device(get_accelerator().device_name(0))) base_out_list = [] match_count=0 mismatch_count=0 # Run the baseline model -for prompt in inputs: - base_out_list += pipe(prompt, do_sample=False, min_length=args.min_length, max_length=args.max_length) +if args.local_rank == 0: + for prompt in inputs: + base_out_list += pipe(prompt, do_sample=False, min_length=args.min_length, max_length=args.max_length) # Initialize the model with DeepSpeed -pipe.model = deepspeed.init_inference(pipe.model, dtype=data_type, replace_with_kernel_inject=True) +pipe.model = deepspeed.init_inference(pipe.model, dtype=data_type, replace_with_kernel_inject=args.use_kernel) # Run the DeepSpeed model and compare outputs for prompt, base_out in zip(inputs, base_out_list): ds_out = pipe(prompt, do_sample=False, min_length=args.min_length, max_length=args.max_length) - if args.print_outputs: - print(f"baseline output: {base_out}") - print(f"deepspeed output: {ds_out}") - print(f"{'-'*60}") - if base_out == ds_out[0]: - if args.print_outputs: print("outputs match") - match_count += 1 - else: - if args.print_outputs: print("outputs do not match") - mismatch_count += 1 - similarity = string_similarity(base_out['generated_text'], ds_out[0]['generated_text']) - if args.print_outputs: print(f"The similarity ratio is: {similarity*100}%") -print(f"Matches: {match_count}\nMismatches: {mismatch_count}") + if args.local_rank == 0: + if args.print_outputs: + print(f"baseline output: {base_out}") + print(f"deepspeed output: {ds_out}") + print(f"{'-'*60}") + if base_out == ds_out[0]: + if args.print_outputs: print("outputs match") + match_count += 1 + else: + if args.print_outputs: print("outputs do not match") + mismatch_count += 1 + similarity = string_similarity(base_out['generated_text'], ds_out[0]['generated_text']) + if args.print_outputs: print(f"The similarity ratio is: {similarity*100}%") + +print_0(f"Matches: {match_count}\nMismatches: {mismatch_count}") diff --git a/inference/huggingface/text-generation/inference-test.py b/inference/huggingface/text-generation/inference-test.py index 827d8db35..0ba3b20cd 100644 --- a/inference/huggingface/text-generation/inference-test.py +++ b/inference/huggingface/text-generation/inference-test.py @@ -6,6 +6,7 @@ import time from utils import DSPipeline, Performance from deepspeed.runtime.utils import see_memory_usage +from deepspeed.accelerator import get_accelerator from arguments import parser args = parser.parse_args() @@ -76,12 +77,12 @@ iters = 30 if args.test_performance else 2 #warmup times = [] for i in range(iters): - torch.cuda.synchronize() + get_accelerator().synchronize() start = time.time() outputs = pipe(inputs, num_tokens=args.max_new_tokens, do_sample=(not args.greedy)) - torch.cuda.synchronize() + get_accelerator().synchronize() end = time.time() times.append(end - start) print(f"generation time is {times[1]} sec") diff --git a/inference/huggingface/text-generation/utils.py b/inference/huggingface/text-generation/utils.py index 173eac039..bf727fefc 100644 --- a/inference/huggingface/text-generation/utils.py +++ b/inference/huggingface/text-generation/utils.py @@ -10,6 +10,7 @@ import torch from huggingface_hub import snapshot_download from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer, LlamaTokenizerFast +from deepspeed.accelerator import get_accelerator class DSPipeline(): ''' @@ -34,7 +35,7 @@ def __init__(self, elif device < 0: self.device = torch.device("cpu") else: - self.device = torch.device(f"cuda:{device}") + self.device = torch.device(get_accelerator().device_name(device)) # the Deepspeed team made these so it's super fast to load (~1 minute), rather than wait 10-20min loading time. self.tp_presharded_models = ["microsoft/bloom-deepspeed-inference-int8", "microsoft/bloom-deepspeed-inference-fp16"] @@ -110,7 +111,7 @@ def generate_outputs(self, if torch.is_tensor(input_tokens[t]): input_tokens[t] = input_tokens[t].to(self.device) - self.model.cuda().to(self.device) + self.model.to(self.device) if isinstance(self.tokenizer, LlamaTokenizerFast): # NOTE: Check if Llamma can work w/ **input_tokens diff --git a/inference/huggingface/zero_inference/README.md b/inference/huggingface/zero_inference/README.md index f6dd4850e..acca9404e 100644 --- a/inference/huggingface/zero_inference/README.md +++ b/inference/huggingface/zero_inference/README.md @@ -90,7 +90,7 @@ deepspeed --num_gpus 1 run_model.py --model bigscience/bloom-7b1 --batch-size 8 Here is an example of running `meta-llama/Llama-2-7b-hf` with Zero-Inference using 4-bit model weights and offloading kv cache to CPU: ```sh -deepspeed --num_gpus 1 run_model.py --model meta-llama/Llama-2-7b-hf` --batch-size 8 --prompt-len 512 --gen-len 32 --cpu-offload --quant-bits 4 --kv-offload +deepspeed --num_gpus 1 run_model.py --model meta-llama/Llama-2-7b-hf --batch-size 8 --prompt-len 512 --gen-len 32 --cpu-offload --quant-bits 4 --kv-offload ``` ## Performance Tuning Tips diff --git a/inference/huggingface/zero_inference/run_model.py b/inference/huggingface/zero_inference/run_model.py index fea8e0be1..d0e16eca3 100644 --- a/inference/huggingface/zero_inference/run_model.py +++ b/inference/huggingface/zero_inference/run_model.py @@ -19,7 +19,7 @@ from transformers import (AutoConfig, AutoTokenizer, AutoModelForCausalLM, BloomForCausalLM, OPTForCausalLM, LlamaForCausalLM, ) -from transformers.deepspeed import HfDeepSpeedConfig +from transformers.integrations.deepspeed import HfDeepSpeedConfig from utils import (GB, add_model_hooks, cache_bytes, get_filename, get_quant_config, hidden_bytes, meta_to_cpu, model_bytes, write_benchmark_log) @@ -87,7 +87,7 @@ def get_ds_model( }, "zero_optimization": { "stage": 3, - "stage3_prefetch_bucket_size": 2 * hidden_size * hidden_size, # 0, + "stage3_prefetch_bucket_size": 2 * hidden_size * hidden_size, "stage3_param_persistence_threshold": hidden_size, "stage3_max_live_parameters": 2 * hidden_size * hidden_size, }, @@ -105,17 +105,29 @@ def get_ds_model( ) if disk_offload: + if config.model_type == 'bloom': + buffer_count = 3 if args.use_gds else 5 + buffer_size = 8*GB if args.use_gds else 9*GB + + elif config.model_type == 'mixtral': + buffer_count = 10 + buffer_size = 1*GB + else: + buffer_count = 5 + buffer_size = 2*GB + ds_config["zero_optimization"]["offload_param"] = dict( device="nvme", pin_memory=pin_memory, nvme_path=offload_dir, - buffer_count=5, - buffer_size=9 * GB if config.model_type == 'bloom' else 2 * GB, + buffer_count=buffer_count, + buffer_size=buffer_size, ) ds_config["aio"] = { - "block_size": 1048576, - "queue_depth": 8, - "thread_count": 1, + "block_size": 1048576*16, + "queue_depth": 64, + "thread_count": 8, + "use_gds": args.use_gds, "single_submit": False, "overlap_events": True, } @@ -140,6 +152,10 @@ def get_ds_model( model = LlamaForCausalLM.from_pretrained( dummy_weights or model_name, torch_dtype=dtype, ) + elif config.model_type == "mixtral": + model = AutoModelForCausalLM.from_pretrained( + dummy_weights or model_name, torch_dtype=dtype, + ) else: raise ValueError(f"Unexpected model type: {config.model_type}") @@ -192,6 +208,8 @@ def run_generation( model = BloomForCausalLM(config) elif config.model_type == "llama": model = LlamaForCausalLM(config) + elif config.model_type == "mixtral": + model = AutoModelForCausalLM(config) else: raise ValueError(f"Unexpected model type: {config.model_type}") model.save_pretrained( @@ -354,6 +372,7 @@ def remove_model_hooks(module): parser.add_argument("--quant_group_size", type=int, default=64, help="model weight quantization group size") parser.add_argument("--pin_kv_cache", action="store_true", help="Allocate kv cache in pinned memory for offloading.") parser.add_argument("--async_kv_offload", action="store_true", help="Using non_blocking copy for kv cache offloading.") + parser.add_argument("--use_gds", action="store_true", help="Use NVIDIA GPU DirectStorage to transfer between NVMe and GPU.") args = parser.parse_args() deepspeed.init_distributed() diff --git a/inference/mii/README.md b/inference/mii/README.md index d701d5537..dfc9fda2b 100644 --- a/inference/mii/README.md +++ b/inference/mii/README.md @@ -2,4 +2,4 @@ Install the requirements by running `pip install -r requirements.txt`. -Once [DeepSpeed-MII](https://github.com/microsoft/deepspeed-mii) is installed you have two options for deployment: an interactive non-persistent pipeline or a persistent serving deployment. For details on these files please refer to the [Getting Started guide for MII](https://github.com/microsoft/deepspeed-mii#getting-started-with-mii). +Once [DeepSpeed-MII](https://github.com/microsoft/deepspeed-mii) is installed you have two options for deployment: an interactive non-persistent pipeline or a persistent serving deployment. See the scripts in [non-persistent](./non-persistent/) and [persistent](./persistent/) for examples. Details on the code implemented in these scripts can be found on our [Getting Started guide for MII](https://github.com/microsoft/deepspeed-mii#getting-started-with-mii). diff --git a/inference/mii/client.py b/inference/mii/client.py deleted file mode 100644 index 6d19fec3a..000000000 --- a/inference/mii/client.py +++ /dev/null @@ -1,6 +0,0 @@ -import mii - -client = mii.client("mistralai/Mistral-7B-v0.1") -output = client.generate("Deepspeed is", max_new_tokens=128) - -print(output) diff --git a/inference/mii/non-persistent/README.md b/inference/mii/non-persistent/README.md new file mode 100644 index 000000000..b9ca31acb --- /dev/null +++ b/inference/mii/non-persistent/README.md @@ -0,0 +1,28 @@ +# Non-Persistent Pipeline Examples + +The `pipeline.py` script can be used to run any of the [supported +models](https://github.com/microsoft/DeepSpeed-mii#supported-models). Provide +the HuggingFace model name, maximum generated tokens, and prompt(s). The +generated responses will be printed in the terminal: + +```shell +$ python pipeline.py --model "mistralai/Mistral-7B-v0.1" --max-new-tokens 128 --prompts "DeepSpeed is" "Seattle is" +``` + +Tensor-parallelism can be controlled using the `deepspeed` launcher and setting +`--num_gpus`: + +```shell +$ deepspeed --num_gpus 2 pipeline.py +``` + +## Model-Specific Examples + +For convenience, we also provide a set of scripts to quickly test the MII +Pipeline with some popular text-generation models: + +| Model | Launch command | +|-------|----------------| +| [meta-llama/Llama-2-7b-hf](https://huggingface.co/meta-llama/Llama-2-7b) | `$ python llama2.py` | +| [tiiuae/falcon-7b](https://huggingface.co/tiiuae/falcon-7b) | `$ python falcon.py` | +| [mistralai/Mixtral-8x7B-v0.1](https://huggingface.co/mistralai/Mixtral-8x7B-v0.1) | `$ deepspeed --num_gpus 2 mixtral.py` | \ No newline at end of file diff --git a/inference/mii/non-persistent/falcon.py b/inference/mii/non-persistent/falcon.py new file mode 100644 index 000000000..7dfc05ecb --- /dev/null +++ b/inference/mii/non-persistent/falcon.py @@ -0,0 +1,6 @@ +import mii + +pipe = mii.pipeline("tiiuae/falcon-7b") +responses = pipe("DeepSpeed is", max_new_tokens=128, return_full_text=True) +if pipe.is_rank_0: + print(responses[0]) diff --git a/inference/mii/non-persistent/llama2.py b/inference/mii/non-persistent/llama2.py new file mode 100644 index 000000000..1c519204e --- /dev/null +++ b/inference/mii/non-persistent/llama2.py @@ -0,0 +1,6 @@ +import mii + +pipe = mii.pipeline("meta-llama/Llama-2-7b-hf") +responses = pipe("DeepSpeed is", max_new_tokens=128, return_full_text=True) +if pipe.is_rank_0: + print(responses[0]) diff --git a/inference/mii/non-persistent/mixtral.py b/inference/mii/non-persistent/mixtral.py new file mode 100644 index 000000000..a429ea5e1 --- /dev/null +++ b/inference/mii/non-persistent/mixtral.py @@ -0,0 +1,6 @@ +import mii + +pipe = mii.pipeline("mistralai/Mixtral-8x7B-v0.1") +responses = pipe("DeepSpeed is", max_new_tokens=128, return_full_text=True) +if pipe.is_rank_0: + print(responses[0]) diff --git a/inference/mii/non-persistent/pipeline.py b/inference/mii/non-persistent/pipeline.py new file mode 100644 index 000000000..c7baa6716 --- /dev/null +++ b/inference/mii/non-persistent/pipeline.py @@ -0,0 +1,19 @@ +import argparse +import mii + +parser = argparse.ArgumentParser() +parser.add_argument("--model", type=str, default="mistralai/Mistral-7B-v0.1") +parser.add_argument( + "--prompts", type=str, nargs="+", default=["DeepSpeed is", "Seattle is"] +) +parser.add_argument("--max-new-tokens", type=int, default=128) +args = parser.parse_args() + +pipe = mii.pipeline(args.model) +responses = pipe( + args.prompts, max_new_tokens=args.max_new_tokens, return_full_text=True +) + +if pipe.is_rank_0: + for r in responses: + print(r, "\n", "-" * 80, "\n") diff --git a/inference/mii/persistent/README.md b/inference/mii/persistent/README.md new file mode 100644 index 000000000..e9cb2dc20 --- /dev/null +++ b/inference/mii/persistent/README.md @@ -0,0 +1,28 @@ +# Persistent Deployment Examples + +The `serve.py` script can be used to create an inference server for any of the +[supported models](https://github.com/microsoft/DeepSpeed-mii#supported-models). +Provide the HuggingFace model name and tensor-parallelism (use the default +values and run `$ python serve.py` for a single-GPU +[mistralai/Mistral-7B-v0.1](https://huggingface.co/mistralai/Mistral-7B-v0.1) +deployment): + +```shell +$ python serve.py --model "mistralai/Mistral-7B-v0.1" tensor-parallel 1 +``` + +Connect to the persistent deployment and generate text with `client.py`. Provide +the HuggingFace model name, maximum generated tokens, and prompt(s) (or if you +are using the default values, run `$ python client.py`): + +```shell +$ python client.py --model "mistralai/Mistral-7B-v0.1" --max-new-tokens 128 --prompts "DeepSpeed is" "Seattle is" +``` + +Shutdown the persistent deployment with `terminate.py`. Provide the HuggingFace +model name (or if you are using the default values, run `$ python +terminate.py`): + +```shell +$ python terminate.py --model "mistralai/Mistral-7B-v0.1 +``` \ No newline at end of file diff --git a/inference/mii/persistent/client.py b/inference/mii/persistent/client.py new file mode 100644 index 000000000..561744a8f --- /dev/null +++ b/inference/mii/persistent/client.py @@ -0,0 +1,18 @@ +import argparse +import mii + +parser = argparse.ArgumentParser() +parser.add_argument("--model", type=str, default="mistralai/Mistral-7B-v0.1") +parser.add_argument( + "--prompts", type=str, nargs="+", default=["DeepSpeed is", "Seattle is"] +) +parser.add_argument("--max-new-tokens", type=int, default=128) +args = parser.parse_args() + +client = mii.client(args.model) +responses = client( + args.prompts, max_new_tokens=args.max_new_tokens, return_full_text=True +) + +for r in responses: + print(r, "\n", "-" * 80, "\n") diff --git a/inference/mii/persistent/serve.py b/inference/mii/persistent/serve.py new file mode 100644 index 000000000..dd31f983a --- /dev/null +++ b/inference/mii/persistent/serve.py @@ -0,0 +1,13 @@ +import argparse +import mii + +parser = argparse.ArgumentParser() +parser.add_argument("--model", type=str, default="mistralai/Mistral-7B-v0.1") +parser.add_argument("--tensor-parallel", type=int, default=1) +args = parser.parse_args() + +mii.serve(args.model, tensor_parallel=args.tensor_parallel) + +print(f"Serving model {args.model} on {args.tensor_parallel} GPU(s).") +print(f"Run `python client.py --model {args.model}` to connect.") +print(f"Run `python terminate.py --model {args.model}` to terminate.") diff --git a/inference/mii/persistent/terminate.py b/inference/mii/persistent/terminate.py new file mode 100644 index 000000000..3c430d934 --- /dev/null +++ b/inference/mii/persistent/terminate.py @@ -0,0 +1,11 @@ +import argparse +import mii + +parser = argparse.ArgumentParser() +parser.add_argument("--model", type=str, default="mistralai/Mistral-7B-v0.1") +args = parser.parse_args() + +client = mii.client(args.model) +client.terminate_server() + +print(f"Terminated server for model {args.model}.") diff --git a/inference/mii/pipeline.py b/inference/mii/pipeline.py deleted file mode 100644 index dcf9e8b03..000000000 --- a/inference/mii/pipeline.py +++ /dev/null @@ -1,6 +0,0 @@ -from mii import pipeline - -pipe = pipeline("mistralai/Mistral-7B-v0.1") -output = pipe(["Hello, my name is", "DeepSpeed is"], max_new_tokens=128) - -print(output) diff --git a/inference/mii/requirements.txt b/inference/mii/requirements.txt index 07d9f7e16..48f92a784 100644 --- a/inference/mii/requirements.txt +++ b/inference/mii/requirements.txt @@ -1 +1 @@ -mii>=0.1.0 +deepspeed-mii>=0.1.3 diff --git a/inference/mii/serve.py b/inference/mii/serve.py deleted file mode 100644 index 09c0c306c..000000000 --- a/inference/mii/serve.py +++ /dev/null @@ -1,3 +0,0 @@ -import mii - -mii.serve("mistralai/Mistral-7B-v0.1") diff --git a/inference/mii/terminate.py b/inference/mii/terminate.py deleted file mode 100644 index 2a7ed3211..000000000 --- a/inference/mii/terminate.py +++ /dev/null @@ -1,4 +0,0 @@ -import mii - -client = mii.client("mistralai/Mistral-7B-v0.1") -client.terminate_server() diff --git a/training/HelloDeepSpeed/train_bert.py b/training/HelloDeepSpeed/train_bert.py index a55215dbe..05e360d9c 100644 --- a/training/HelloDeepSpeed/train_bert.py +++ b/training/HelloDeepSpeed/train_bert.py @@ -465,7 +465,7 @@ def create_experiment_dir(checkpoint_dir: pathlib.Path, try: gitlog = sh.git.log("-1", format="%H", _tty_out=False, _fg=False) with (exp_dir / "githash.log").open("w") as handle: - handle.write(gitlog.stdout.decode("utf-8")) + handle.write(gitlog) except sh.ErrorReturnCode_128: logger.info("Seems like the code is not running from" " within a git repo, so hash will" @@ -476,7 +476,7 @@ def create_experiment_dir(checkpoint_dir: pathlib.Path, try: gitdiff = sh.git.diff(_fg=False, _tty_out=False) with (exp_dir / "gitdiff.log").open("w") as handle: - handle.write(gitdiff.stdout.decode("utf-8")) + handle.write(gitdiff) except sh.ErrorReturnCode_129: logger.info("Seems like the code is not running from" " within a git repo, so diff will" diff --git a/training/MoQ/huggingface-transformers/examples/research_projects/lxmert/requirements.txt b/training/MoQ/huggingface-transformers/examples/research_projects/lxmert/requirements.txt index 9028e302b..69bc6ba07 100644 --- a/training/MoQ/huggingface-transformers/examples/research_projects/lxmert/requirements.txt +++ b/training/MoQ/huggingface-transformers/examples/research_projects/lxmert/requirements.txt @@ -48,7 +48,7 @@ nbformat==5.0.7 nest-asyncio==1.4.0 notebook==6.1.5 numpy==1.19.2 -opencv-python==4.4.0.42 +opencv-python==4.10.0.84 packaging==20.3 pandas==1.1.2 pandocfilters==1.4.2 diff --git a/training/bing_bert/nvidia/modelingpreln.py b/training/bing_bert/nvidia/modelingpreln.py index a7e398e26..9856f0607 100755 --- a/training/bing_bert/nvidia/modelingpreln.py +++ b/training/bing_bert/nvidia/modelingpreln.py @@ -1041,7 +1041,7 @@ def forward(self, position_ids=None, inputs_embeds=None, pad_token_id=self.pad_token_id, - model_mbeddings=self.embeddings) + model_embeddings=self.embeddings) embedding_output = self.embeddings(input_ids, token_type_ids) encoded_layers = self.encoder( diff --git a/training/cifar/README.md b/training/cifar/README.md index 7c58f3b98..878b28157 100644 --- a/training/cifar/README.md +++ b/training/cifar/README.md @@ -1,21 +1,22 @@ Thanks Gopi Kumar for contributing this example, demonstrating how to apply DeepSpeed to CIFAR-10 model. -cifar10_tutorial.py +`cifar10_tutorial.py` Baseline CIFAR-10 model. -cifar10_deepspeed.py +`cifar10_deepspeed.py` DeepSpeed applied CIFAR-10 model. -ds_config.json - DeepSpeed configuration file. - -run_ds.sh +`run_ds.sh` Script for running DeepSpeed applied model. -run_ds_moe.sh +`run_ds_moe.sh` Script for running DeepSpeed model with Mixture of Experts (MoE) integration. -* To run baseline CIFAR-10 model - "python cifar10_tutorial.py" -* To run DeepSpeed CIFAR-10 model - "bash run_ds.sh" -* To run DeepSpeed CIFAR-10 model with Mixture of Experts (MoE) - "bash run_ds_moe.sh" -* To run with different data type (default='fp16') and zero stages (default=0) - "bash run_ds.sh --dtype={fp16|bf16} --stage={0|1|2|3}" +`run_ds_prmoe.sh` + Script for running DeepSpeed model with Pyramid Residual MoE (PR-MoE) integration. + +* To run baseline CIFAR-10 model - `python cifar10_tutorial.py` +* To run DeepSpeed CIFAR-10 model - `bash run_ds.sh` +* To run DeepSpeed CIFAR-10 model with Mixture of Experts (MoE) - `bash run_ds_moe.sh` +* To run DeepSpeed CIFAR-10 model with Pyramid Residual MoE (PR-MoE) - `bash run_ds_prmoe.sh` +* To run with different data type (default=`fp16`) and zero stages (default=`0`) - `bash run_ds.sh --dtype={fp16|bf16} --stage={0|1|2|3}` diff --git a/training/cifar/cifar10_deepspeed.py b/training/cifar/cifar10_deepspeed.py index da82e60db..9888544d5 100755 --- a/training/cifar/cifar10_deepspeed.py +++ b/training/cifar/cifar10_deepspeed.py @@ -1,112 +1,106 @@ +import argparse +import os + +import deepspeed import torch +import torch.nn as nn +import torch.nn.functional as F import torchvision import torchvision.transforms as transforms -import argparse -import deepspeed from deepspeed.accelerator import get_accelerator +from deepspeed.moe.utils import split_params_into_different_moe_groups_for_optimizer def add_argument(): + parser = argparse.ArgumentParser(description="CIFAR") - parser = argparse.ArgumentParser(description='CIFAR') - - #data - # cuda - parser.add_argument('--with_cuda', - default=False, - action='store_true', - help='use CPU in case there\'s no GPU support') - parser.add_argument('--use_ema', - default=False, - action='store_true', - help='whether use exponential moving average') - - # train - parser.add_argument('-b', - '--batch_size', - default=32, - type=int, - help='mini-batch size (default: 32)') - parser.add_argument('-e', - '--epochs', - default=30, - type=int, - help='number of total epochs (default: 30)') - parser.add_argument('--local_rank', - type=int, - default=-1, - help='local rank passed from distributed launcher') - - parser.add_argument('--log-interval', - type=int, - default=2000, - help="output logging information at a given interval") - - parser.add_argument('--moe', - default=False, - action='store_true', - help='use deepspeed mixture of experts (moe)') - - parser.add_argument('--ep-world-size', - default=1, - type=int, - help='(moe) expert parallel world size') - parser.add_argument('--num-experts', - type=int, - nargs='+', - default=[ - 1, - ], - help='number of experts list, MoE related.') + # For train. parser.add_argument( - '--mlp-type', - type=str, - default='standard', - help= - 'Only applicable when num-experts > 1, accepts [standard, residual]') - parser.add_argument('--top-k', - default=1, - type=int, - help='(moe) gating top 1 and 2 supported') + "-e", + "--epochs", + default=30, + type=int, + help="number of total epochs (default: 30)", + ) parser.add_argument( - '--min-capacity', - default=0, + "--local_rank", type=int, - help= - '(moe) minimum capacity of an expert regardless of the capacity_factor' + default=-1, + help="local rank passed from distributed launcher", ) parser.add_argument( - '--noisy-gate-policy', - default=None, + "--log-interval", + type=int, + default=2000, + help="output logging information at a given interval", + ) + + # For mixed precision training. + parser.add_argument( + "--dtype", + default="fp16", type=str, - help= - '(moe) noisy gating (only supported with top-1). Valid values are None, RSample, and Jitter' + choices=["bf16", "fp16", "fp32"], + help="Datatype used for training", + ) + + # For ZeRO Optimization. + parser.add_argument( + "--stage", + default=0, + type=int, + choices=[0, 1, 2, 3], + help="Datatype used for training", ) + + # For MoE (Mixture of Experts). parser.add_argument( - '--moe-param-group', + "--moe", default=False, - action='store_true', - help= - '(moe) create separate moe param groups, required when using ZeRO w. MoE' + action="store_true", + help="use deepspeed mixture of experts (moe)", + ) + parser.add_argument( + "--ep-world-size", default=1, type=int, help="(moe) expert parallel world size" + ) + parser.add_argument( + "--num-experts", + type=int, + nargs="+", + default=[ + 1, + ], + help="number of experts list, MoE related.", ) parser.add_argument( - '--dtype', - default='fp16', + "--mlp-type", type=str, - choices=['bf16', 'fp16', 'fp32'], - help= - 'Datatype used for training' + default="standard", + help="Only applicable when num-experts > 1, accepts [standard, residual]", + ) + parser.add_argument( + "--top-k", default=1, type=int, help="(moe) gating top 1 and 2 supported" ) parser.add_argument( - '--stage', + "--min-capacity", default=0, type=int, - choices=[0, 1, 2, 3], - help= - 'Datatype used for training' + help="(moe) minimum capacity of an expert regardless of the capacity_factor", + ) + parser.add_argument( + "--noisy-gate-policy", + default=None, + type=str, + help="(moe) noisy gating (only supported with top-1). Valid values are None, RSample, and Jitter", + ) + parser.add_argument( + "--moe-param-group", + default=False, + action="store_true", + help="(moe) create separate moe param groups, required when using ZeRO w. MoE", ) - # Include DeepSpeed configuration arguments + # Include DeepSpeed configuration arguments. parser = deepspeed.add_config_arguments(parser) args = parser.parse_args() @@ -114,110 +108,87 @@ def add_argument(): return args -deepspeed.init_distributed() - -######################################################################## -# The output of torchvision datasets are PILImage images of range [0, 1]. -# We transform them to Tensors of normalized range [-1, 1]. -# .. note:: -# If running on Windows and you get a BrokenPipeError, try setting -# the num_worker of torch.utils.data.DataLoader() to 0. - -transform = transforms.Compose([ - transforms.ToTensor(), - transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) -]) - -if torch.distributed.get_rank() != 0: - # might be downloading cifar data, let rank 0 download first - torch.distributed.barrier() - -trainset = torchvision.datasets.CIFAR10(root='./data', - train=True, - download=True, - transform=transform) - -if torch.distributed.get_rank() == 0: - # cifar data is downloaded, indicate other ranks can proceed - torch.distributed.barrier() - -trainloader = torch.utils.data.DataLoader(trainset, - batch_size=16, - shuffle=True, - num_workers=2) - -testset = torchvision.datasets.CIFAR10(root='./data', - train=False, - download=True, - transform=transform) -testloader = torch.utils.data.DataLoader(testset, - batch_size=4, - shuffle=False, - num_workers=2) - -classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', - 'ship', 'truck') - -######################################################################## -# Let us show some of the training images, for fun. - -import matplotlib.pyplot as plt -import numpy as np - -# functions to show an image - - -def imshow(img): - img = img / 2 + 0.5 # unnormalize - npimg = img.numpy() - plt.imshow(np.transpose(npimg, (1, 2, 0))) - plt.show() - - -# get some random training images -dataiter = iter(trainloader) -images, labels = next(dataiter) - -# show images -imshow(torchvision.utils.make_grid(images)) -# print labels -print(' '.join('%5s' % classes[labels[j]] for j in range(4))) - -######################################################################## -# 2. Define a Convolutional Neural Network -# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ -# Copy the neural network from the Neural Networks section before and modify it to -# take 3-channel images (instead of 1-channel images as it was defined). +def create_moe_param_groups(model): + """Create separate parameter groups for each expert.""" + parameters = {"params": [p for p in model.parameters()], "name": "parameters"} + return split_params_into_different_moe_groups_for_optimizer(parameters) -import torch.nn as nn -import torch.nn.functional as F -args = add_argument() +def get_ds_config(args): + """Get the DeepSpeed configuration dictionary.""" + ds_config = { + "train_batch_size": 16, + "steps_per_print": 2000, + "optimizer": { + "type": "Adam", + "params": { + "lr": 0.001, + "betas": [0.8, 0.999], + "eps": 1e-8, + "weight_decay": 3e-7, + }, + }, + "scheduler": { + "type": "WarmupLR", + "params": { + "warmup_min_lr": 0, + "warmup_max_lr": 0.001, + "warmup_num_steps": 1000, + }, + }, + "gradient_clipping": 1.0, + "prescale_gradients": False, + "bf16": {"enabled": args.dtype == "bf16"}, + "fp16": { + "enabled": args.dtype == "fp16", + "fp16_master_weights_and_grads": False, + "loss_scale": 0, + "loss_scale_window": 500, + "hysteresis": 2, + "min_loss_scale": 1, + "initial_scale_power": 15, + }, + "wall_clock_breakdown": False, + "zero_optimization": { + "stage": args.stage, + "allgather_partitions": True, + "reduce_scatter": True, + "allgather_bucket_size": 50000000, + "reduce_bucket_size": 50000000, + "overlap_comm": True, + "contiguous_gradients": True, + "cpu_offload": False, + }, + } + return ds_config class Net(nn.Module): - def __init__(self): + def __init__(self, args): super(Net, self).__init__() self.conv1 = nn.Conv2d(3, 6, 5) self.pool = nn.MaxPool2d(2, 2) self.conv2 = nn.Conv2d(6, 16, 5) self.fc1 = nn.Linear(16 * 5 * 5, 120) self.fc2 = nn.Linear(120, 84) - if args.moe: + self.moe = args.moe + if self.moe: fc3 = nn.Linear(84, 84) self.moe_layer_list = [] for n_e in args.num_experts: - # create moe layers based on the number of experts + # Create moe layers based on the number of experts. self.moe_layer_list.append( deepspeed.moe.layer.MoE( hidden_size=84, expert=fc3, num_experts=n_e, ep_size=args.ep_world_size, - use_residual=args.mlp_type == 'residual', + use_residual=args.mlp_type == "residual", k=args.top_k, min_capacity=args.min_capacity, - noisy_gate_policy=args.noisy_gate_policy)) + noisy_gate_policy=args.noisy_gate_policy, + ) + ) self.moe_layer_list = nn.ModuleList(self.moe_layer_list) self.fc4 = nn.Linear(84, 10) else: @@ -229,7 +200,7 @@ def forward(self, x): x = x.view(-1, 16 * 5 * 5) x = F.relu(self.fc1(x)) x = F.relu(self.fc2(x)) - if args.moe: + if self.moe: for layer in self.moe_layer_list: x, _, _ = layer(x) x = self.fc4(x) @@ -238,214 +209,194 @@ def forward(self, x): return x -net = Net() +def test(model_engine, testset, local_device, target_dtype, test_batch_size=4): + """Test the network on the test data. + + Args: + model_engine (deepspeed.runtime.engine.DeepSpeedEngine): the DeepSpeed engine. + testset (torch.utils.data.Dataset): the test dataset. + local_device (str): the local device name. + target_dtype (torch.dtype): the target datatype for the test data. + test_batch_size (int): the test batch size. + + """ + # The 10 classes for CIFAR10. + classes = ( + "plane", + "car", + "bird", + "cat", + "deer", + "dog", + "frog", + "horse", + "ship", + "truck", + ) + # Define the test dataloader. + testloader = torch.utils.data.DataLoader( + testset, batch_size=test_batch_size, shuffle=False, num_workers=0 + ) -def create_moe_param_groups(model): - from deepspeed.moe.utils import split_params_into_different_moe_groups_for_optimizer + # For total accuracy. + correct, total = 0, 0 + # For accuracy per class. + class_correct = list(0.0 for i in range(10)) + class_total = list(0.0 for i in range(10)) + + # Start testing. + model_engine.eval() + with torch.no_grad(): + for data in testloader: + images, labels = data + if target_dtype != None: + images = images.to(target_dtype) + outputs = model_engine(images.to(local_device)) + _, predicted = torch.max(outputs.data, 1) + # Count the total accuracy. + total += labels.size(0) + correct += (predicted == labels.to(local_device)).sum().item() + + # Count the accuracy per class. + batch_correct = (predicted == labels.to(local_device)).squeeze() + for i in range(test_batch_size): + label = labels[i] + class_correct[label] += batch_correct[i].item() + class_total[label] += 1 + + if model_engine.local_rank == 0: + print( + f"Accuracy of the network on the {total} test images: {100 * correct / total : .0f} %" + ) + + # For all classes, print the accuracy. + for i in range(10): + print( + f"Accuracy of {classes[i] : >5s} : {100 * class_correct[i] / class_total[i] : 2.0f} %" + ) + + +def main(args): + # Initialize DeepSpeed distributed backend. + deepspeed.init_distributed() + _local_rank = int(os.environ.get("LOCAL_RANK")) + get_accelerator().set_device(_local_rank) + + ######################################################################## + # Step1. Data Preparation. + # + # The output of torchvision datasets are PILImage images of range [0, 1]. + # We transform them to Tensors of normalized range [-1, 1]. + # + # Note: + # If running on Windows and you get a BrokenPipeError, try setting + # the num_worker of torch.utils.data.DataLoader() to 0. + ######################################################################## + transform = transforms.Compose( + [transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))] + ) - parameters = { - 'params': [p for p in model.parameters()], - 'name': 'parameters' - } + if torch.distributed.get_rank() != 0: + # Might be downloading cifar data, let rank 0 download first. + torch.distributed.barrier() - return split_params_into_different_moe_groups_for_optimizer(parameters) + # Load or download cifar data. + trainset = torchvision.datasets.CIFAR10( + root="./data", train=True, download=True, transform=transform + ) + testset = torchvision.datasets.CIFAR10( + root="./data", train=False, download=True, transform=transform + ) + if torch.distributed.get_rank() == 0: + # Cifar data is downloaded, indicate other ranks can proceed. + torch.distributed.barrier() + + ######################################################################## + # Step 2. Define the network with DeepSpeed. + # + # First, we define a Convolution Neural Network. + # Then, we define the DeepSpeed configuration dictionary and use it to + # initialize the DeepSpeed engine. + ######################################################################## + net = Net(args) + + # Get list of parameters that require gradients. + parameters = filter(lambda p: p.requires_grad, net.parameters()) + + # If using MoE, create separate param groups for each expert. + if args.moe_param_group: + parameters = create_moe_param_groups(net) + + # Initialize DeepSpeed to use the following features. + # 1) Distributed model. + # 2) Distributed data loader. + # 3) DeepSpeed optimizer. + ds_config = get_ds_config(args) + model_engine, optimizer, trainloader, __ = deepspeed.initialize( + args=args, + model=net, + model_parameters=parameters, + training_data=trainset, + config=ds_config, + ) -parameters = filter(lambda p: p.requires_grad, net.parameters()) -if args.moe_param_group: - parameters = create_moe_param_groups(net) - -# Initialize DeepSpeed to use the following features -# 1) Distributed model -# 2) Distributed data loader -# 3) DeepSpeed optimizer -ds_config = { - "train_batch_size": 16, - "steps_per_print": 2000, - "optimizer": { - "type": "Adam", - "params": { - "lr": 0.001, - "betas": [ - 0.8, - 0.999 - ], - "eps": 1e-8, - "weight_decay": 3e-7 - } - }, - "scheduler": { - "type": "WarmupLR", - "params": { - "warmup_min_lr": 0, - "warmup_max_lr": 0.001, - "warmup_num_steps": 1000 - } - }, - "gradient_clipping": 1.0, - "prescale_gradients": False, - "bf16": { - "enabled": args.dtype == "bf16" - }, - "fp16": { - "enabled": args.dtype == "fp16", - "fp16_master_weights_and_grads": False, - "loss_scale": 0, - "loss_scale_window": 500, - "hysteresis": 2, - "min_loss_scale": 1, - "initial_scale_power": 15 - }, - "wall_clock_breakdown": False, - "zero_optimization": { - "stage": args.stage, - "allgather_partitions": True, - "reduce_scatter": True, - "allgather_bucket_size": 50000000, - "reduce_bucket_size": 50000000, - "overlap_comm": True, - "contiguous_gradients": True, - "cpu_offload": False - } -} - -model_engine, optimizer, trainloader, __ = deepspeed.initialize( - args=args, model=net, model_parameters=parameters, training_data=trainset, config=ds_config) - -local_device = get_accelerator().device_name(model_engine.local_rank) -local_rank = model_engine.local_rank - -# For float32, target_dtype will be None so no datatype conversion needed -target_dtype = None -if model_engine.bfloat16_enabled(): - target_dtype=torch.bfloat16 -elif model_engine.fp16_enabled(): - target_dtype=torch.half - -#device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") -#net.to(device) -######################################################################## -# 3. Define a Loss function and optimizer -# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ -# Let's use a Classification Cross-Entropy loss and SGD with momentum. - -import torch.optim as optim - -criterion = nn.CrossEntropyLoss() -#optimizer = optim.SGD(net.parameters(), lr=0.001, momentum=0.9) - -######################################################################## -# 4. Train the network -# ^^^^^^^^^^^^^^^^^^^^ -# -# This is when things start to get interesting. -# We simply have to loop over our data iterator, and feed the inputs to the -# network and optimize. - -for epoch in range(args.epochs): # loop over the dataset multiple times - - running_loss = 0.0 - for i, data in enumerate(trainloader): - # get the inputs; data is a list of [inputs, labels] - inputs, labels = data[0].to(local_device), data[1].to(local_device) - if target_dtype != None: - inputs = inputs.to(target_dtype) - outputs = model_engine(inputs) - loss = criterion(outputs, labels) - - model_engine.backward(loss) - model_engine.step() - - # print statistics - running_loss += loss.item() - if local_rank == 0 and i % args.log_interval == ( - args.log_interval - - 1): # print every log_interval mini-batches - print('[%d, %5d] loss: %.3f' % - (epoch + 1, i + 1, running_loss / args.log_interval)) - running_loss = 0.0 - -print('Finished Training') - -######################################################################## -# 5. Test the network on the test data -# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ -# -# We have trained the network for 2 passes over the training dataset. -# But we need to check if the network has learnt anything at all. -# -# We will check this by predicting the class label that the neural network -# outputs, and checking it against the ground-truth. If the prediction is -# correct, we add the sample to the list of correct predictions. -# -# Okay, first step. Let us display an image from the test set to get familiar. - -dataiter = iter(testloader) -images, labels = next(dataiter) - -# print images -imshow(torchvision.utils.make_grid(images)) -print('GroundTruth: ', ' '.join('%5s' % classes[labels[j]] for j in range(4))) - -######################################################################## -# Okay, now let us see what the neural network thinks these examples above are: -if target_dtype != None: - images = images.to(target_dtype) -outputs = net(images.to(local_device)) - -######################################################################## -# The outputs are energies for the 10 classes. -# The higher the energy for a class, the more the network -# thinks that the image is of the particular class. -# So, let's get the index of the highest energy: -_, predicted = torch.max(outputs, 1) - -print('Predicted: ', ' '.join('%5s' % classes[predicted[j]] for j in range(4))) - -######################################################################## -# The results seem pretty good. -# -# Let us look at how the network performs on the whole dataset. - -correct = 0 -total = 0 -with torch.no_grad(): - for data in testloader: - images, labels = data - if target_dtype != None: - images = images.to(target_dtype) - outputs = net(images.to(local_device)) - _, predicted = torch.max(outputs.data, 1) - total += labels.size(0) - correct += (predicted == labels.to(local_device)).sum().item() - -print('Accuracy of the network on the 10000 test images: %d %%' % - (100 * correct / total)) - -######################################################################## -# That looks way better than chance, which is 10% accuracy (randomly picking -# a class out of 10 classes). -# Seems like the network learnt something. -# -# Hmmm, what are the classes that performed well, and the classes that did -# not perform well: - -class_correct = list(0. for i in range(10)) -class_total = list(0. for i in range(10)) -with torch.no_grad(): - for data in testloader: - images, labels = data - if target_dtype != None: - images = images.to(target_dtype) - outputs = net(images.to(local_device)) - _, predicted = torch.max(outputs, 1) - c = (predicted == labels.to(local_device)).squeeze() - for i in range(4): - label = labels[i] - class_correct[label] += c[i].item() - class_total[label] += 1 - -for i in range(10): - print('Accuracy of %5s : %2d %%' % - (classes[i], 100 * class_correct[i] / class_total[i])) + # Get the local device name (str) and local rank (int). + local_device = get_accelerator().device_name(model_engine.local_rank) + local_rank = model_engine.local_rank + + # For float32, target_dtype will be None so no datatype conversion needed. + target_dtype = None + if model_engine.bfloat16_enabled(): + target_dtype = torch.bfloat16 + elif model_engine.fp16_enabled(): + target_dtype = torch.half + + # Define the Classification Cross-Entropy loss function. + criterion = nn.CrossEntropyLoss() + + ######################################################################## + # Step 3. Train the network. + # + # This is when things start to get interesting. + # We simply have to loop over our data iterator, and feed the inputs to the + # network and optimize. (DeepSpeed handles the distributed details for us!) + ######################################################################## + + for epoch in range(args.epochs): # loop over the dataset multiple times + running_loss = 0.0 + for i, data in enumerate(trainloader): + # Get the inputs. ``data`` is a list of [inputs, labels]. + inputs, labels = data[0].to(local_device), data[1].to(local_device) + + # Try to convert to target_dtype if needed. + if target_dtype != None: + inputs = inputs.to(target_dtype) + + outputs = model_engine(inputs) + loss = criterion(outputs, labels) + + model_engine.backward(loss) + model_engine.step() + + # Print statistics + running_loss += loss.item() + if local_rank == 0 and i % args.log_interval == ( + args.log_interval - 1 + ): # Print every log_interval mini-batches. + print( + f"[{epoch + 1 : d}, {i + 1 : 5d}] loss: {running_loss / args.log_interval : .3f}" + ) + running_loss = 0.0 + print("Finished Training") + + ######################################################################## + # Step 4. Test the network on the test data. + ######################################################################## + test(model_engine, testset, local_device, target_dtype) + + +if __name__ == "__main__": + args = add_argument() + main(args) diff --git a/training/cifar/run_ds_moe.sh b/training/cifar/run_ds_moe.sh index b7dcb7fa7..f87a29628 100755 --- a/training/cifar/run_ds_moe.sh +++ b/training/cifar/run_ds_moe.sh @@ -15,7 +15,6 @@ deepspeed --num_nodes=${NUM_NODES}\ cifar10_deepspeed.py \ --log-interval 100 \ --deepspeed \ - --deepspeed_config ds_config.json \ --moe \ --ep-world-size ${EP_SIZE} \ --num-experts ${EXPERTS} \ diff --git a/training/cifar/run_ds_prmoe.sh b/training/cifar/run_ds_prmoe.sh index 72731b0d5..d9755a331 100644 --- a/training/cifar/run_ds_prmoe.sh +++ b/training/cifar/run_ds_prmoe.sh @@ -12,7 +12,6 @@ EXPERTS='2 4' deepspeed --num_nodes=${NUM_NODES} --num_gpus=${NUM_GPUS} cifar10_deepspeed.py \ --log-interval 100 \ --deepspeed \ - --deepspeed_config ds_config.json \ --moe \ --ep-world-size ${EP_SIZE} \ --num-experts ${EXPERTS} \