diff --git a/applications/DeepSpeed-Chat/README.md b/applications/DeepSpeed-Chat/README.md index 91e4854fb..fa8fa9d38 100644 --- a/applications/DeepSpeed-Chat/README.md +++ b/applications/DeepSpeed-Chat/README.md @@ -251,9 +251,9 @@ bash training_scripts/opt/single_gpu/run_1.3b.sh ### 🐼 Adding and using your own datasets in DeepSpeed-Chat -In addition to the datasets used in our example scripts, you can also add and use your own datasets. To do so, first you need to add a new Class in [training/utils/data/raw_datasets.py](https://github.com/microsoft/DeepSpeedExamples/blob/master/applications/DeepSpeed-Chat/training/utils/data/raw_datasets.py) to define the format when using your data. You need to make sure to follow the APIs and format defined in the PromptRawDataset class to ensure a consistent data format that DeepSpeed-Chat relies on. You can look at the existing classes to learn how to do so. +In addition to the datasets used in our example scripts, you can also add and use your own datasets. To do so, first you need to add a new Class in [dschat/utils/data/raw_datasets.py](https://github.com/microsoft/DeepSpeedExamples/blob/master/applications/DeepSpeed-Chat/dschat/utils/data/raw_datasets.py) to define the format when using your data. You need to make sure to follow the APIs and format defined in the PromptRawDataset class to ensure a consistent data format that DeepSpeed-Chat relies on. You can look at the existing classes to learn how to do so. -Second, you need to add an if condition in function get_raw_dataset in [training/utils/data/data_utils.py](https://github.com/microsoft/DeepSpeedExamples/blob/master/applications/DeepSpeed-Chat/training/utils/data/data_utils.py) corresponding to your new dataset. The dataset_name string in the if condition should be the dataset name you will provide as a arg for the training scripts. Last, you need to add your new dataset's dataset_name into your "--data_path" arg in your training scripts. +Second, you need to add an if condition in function get_raw_dataset in [dschat/utils/data/data_utils.py](https://github.com/microsoft/DeepSpeedExamples/blob/master/applications/DeepSpeed-Chat/dschat/utils/data/data_utils.py) corresponding to your new dataset. The dataset_name string in the if condition should be the dataset name you will provide as a arg for the training scripts. Last, you need to add your new dataset's dataset_name into your "--data_path" arg in your training scripts. If you have downloaded huggingface datasets manually, you can add your local path into "--data_path", such as "--data_path ./relative/Dahoas/rm-static" and "--data_path /absolute/Dahoas/rm-static". Remember you should not make `data/` in your local path, it may cause an exception to `load_dataset`. One thing to note is that some datasets may only have one response instead of two responses. For those datasets, you can only use them in step 1. And in such case, you should add the dataset_name as part of the "--sft_only_data_path" arg instead of the "--data_path" arg. One thing to note is that: If you plan to only do step 1 SFT, adding more single-response datasets is definitely beneficial. However, if you do plan to do steps 2 and 3, then adding too many single-response datasets during SFT could backfire: these data could be different from the data used for steps 2/3, generating different distributions which could cause training instability/worse model quality during step 2/3. That is part of the reason why we focused on trying the datasets with two responses and the preference, and always split a dataset into all 3 steps. diff --git a/applications/DeepSpeed-Chat/dschat/rlhf/rlhf_engine.py b/applications/DeepSpeed-Chat/dschat/rlhf/rlhf_engine.py index 5b6778cc2..0e67efcf9 100755 --- a/applications/DeepSpeed-Chat/dschat/rlhf/rlhf_engine.py +++ b/applications/DeepSpeed-Chat/dschat/rlhf/rlhf_engine.py @@ -268,23 +268,14 @@ def _init_reward(self, critic_model_name_or_path): # If critic is ZeRO-3 then we use it for everything, otherwise assume we have enough memory zero_stage = 0 - ds_config = get_eval_ds_config(offload=self.args.offload, + ds_config = get_eval_ds_config(offload=self.args.offload_reward_model, dtype=self.args.dtype, stage=zero_stage) - ds_config[ - 'train_micro_batch_size_per_gpu'] = self.args.per_device_training_batch_size - ds_config[ - 'train_batch_size'] = self.args.per_device_training_batch_size * torch.distributed.get_world_size( - ) * self.args.gradient_accumulation_steps - - ds_eval_config = get_eval_ds_config(offload=False, - dtype=self.args.dtype, - stage=zero_stage) # We need to set train batch size and micro batch size here to pass the sanity check of DeepSpeed engine. - ds_eval_config[ + ds_config[ 'train_micro_batch_size_per_gpu'] = self.args.per_device_training_batch_size - ds_eval_config[ + ds_config[ 'train_batch_size'] = self.args.per_device_training_batch_size * torch.distributed.get_world_size( ) * self.args.gradient_accumulation_steps @@ -292,7 +283,7 @@ def _init_reward(self, critic_model_name_or_path): reward_model = create_critic_model( model_name_or_path=critic_model_name_or_path, tokenizer=self.tokenizer, - ds_config=ds_eval_config, + ds_config=ds_config, num_padding_at_beginning=self.args.num_padding_at_beginning, rlhf_training=True, dropout=self.args.critic_dropout, diff --git a/applications/DeepSpeed-Chat/dschat/utils/data/data_utils.py b/applications/DeepSpeed-Chat/dschat/utils/data/data_utils.py index 62c153644..53479a1eb 100644 --- a/applications/DeepSpeed-Chat/dschat/utils/data/data_utils.py +++ b/applications/DeepSpeed-Chat/dschat/utils/data/data_utils.py @@ -149,9 +149,13 @@ def __len__(self): def __getitem__(self, idx): if self.train_phase == 1: return { - "input_ids": self.chosen_dataset[idx]["input_ids"], - "attention_mask": self.chosen_dataset[idx]["attention_mask"], - "labels": self.chosen_dataset[idx]["input_ids"] + "input_ids": + self.chosen_dataset[idx]["input_ids"], + "attention_mask": + self.chosen_dataset[idx]["attention_mask"], + "labels": + torch.where(self.chosen_dataset[idx]["attention_mask"].bool(), + self.chosen_dataset[idx]["input_ids"], -100) } elif self.train_phase == 2: return self.chosen_dataset[idx]["input_ids"], self.chosen_dataset[idx]["attention_mask"], \ diff --git a/applications/DeepSpeed-Chat/dschat/utils/ds_utils.py b/applications/DeepSpeed-Chat/dschat/utils/ds_utils.py index 9c15e5143..0cf1c28ab 100644 --- a/applications/DeepSpeed-Chat/dschat/utils/ds_utils.py +++ b/applications/DeepSpeed-Chat/dschat/utils/ds_utils.py @@ -33,6 +33,7 @@ def get_train_ds_config(offload, dtype_config = {"enabled": True} zero_opt_dict = { "stage": stage, + "overlap_comm": True, "offload_param": { "device": device }, diff --git a/applications/DeepSpeed-Chat/dschat/utils/model/model_utils.py b/applications/DeepSpeed-Chat/dschat/utils/model/model_utils.py index 97d3bff15..050819a22 100644 --- a/applications/DeepSpeed-Chat/dschat/utils/model/model_utils.py +++ b/applications/DeepSpeed-Chat/dschat/utils/model/model_utils.py @@ -10,7 +10,7 @@ AutoModel, ) from huggingface_hub import snapshot_download -from transformers.deepspeed import HfDeepSpeedConfig +from transformers.integrations.deepspeed import HfDeepSpeedConfig from dschat.utils.model.reward_model import RewardModel from dschat.utils.utils import load_state_dict_into_model, print_rank_0 diff --git a/applications/DeepSpeed-Chat/training/step1_supervised_finetuning/main.py b/applications/DeepSpeed-Chat/training/step1_supervised_finetuning/main.py index c37d1f4cd..d9527af54 100755 --- a/applications/DeepSpeed-Chat/training/step1_supervised_finetuning/main.py +++ b/applications/DeepSpeed-Chat/training/step1_supervised_finetuning/main.py @@ -270,6 +270,7 @@ def main(): args.seed, tokenizer, args.max_seq_len, + end_of_conversation_token=tokenizer.eos_token, sft_only_data_path=args.sft_only_data_path) # DataLoaders creation: if args.local_rank == -1: diff --git a/applications/DeepSpeed-Chat/training/step2_reward_model_finetuning/README.md b/applications/DeepSpeed-Chat/training/step2_reward_model_finetuning/README.md index ede072a79..3c62b9f82 100644 --- a/applications/DeepSpeed-Chat/training/step2_reward_model_finetuning/README.md +++ b/applications/DeepSpeed-Chat/training/step2_reward_model_finetuning/README.md @@ -6,7 +6,7 @@ Finetuning the Reward Model (RM) is more or less similar to Step-1 Supervised F For SFT finetuning, the data is the concatenation of a query and an answer. However, for RM finetuning, each batch of data consists of two query-answer pairs, i.e., the same query with a high-score answer and a low-score answer. This also leads to the second difference as describe below. -👉**The training objective difference** +👉 **The training objective difference** For RW, the training objective is the pairwise ranking score, i.e., for the two query-answer pairs, RM is supposed to give a higher score to the better answer. There are multiple ways to achieve this. In our implementation, we use either the end token of the sequence or the first padding token as the aggregated score and compare them. Others may also use the average score for the entire answer as an alternative. diff --git a/applications/DeepSpeed-Chat/training/step2_reward_model_finetuning/main.py b/applications/DeepSpeed-Chat/training/step2_reward_model_finetuning/main.py index 6a8211988..c247b53e8 100644 --- a/applications/DeepSpeed-Chat/training/step2_reward_model_finetuning/main.py +++ b/applications/DeepSpeed-Chat/training/step2_reward_model_finetuning/main.py @@ -258,8 +258,17 @@ def main(): # the LN that precedes it. force_optimize_params = [] if "bigscience/bloom-" in args.model_name_or_path: - torch.nn.init.ones_(rm_model.rwtransformer.ln_f.weight) - torch.nn.init.zeros_(rm_model.rwtransformer.ln_f.bias) + zero_init_enabled = (args.zero_stage == 3) + params = [ + rm_model.rwtranrsformer.ln_f.weight, + rm_model.rwtranrsformer.ln_f.bias + ] + with deepspeed.zero.GatheredParameters(params, + modifier_rank=0, + enabled=zero_init_enabled): + if deepspeed.comm.get_rank() == 0 or not zero_init_enabled: + torch.nn.init.ones_(rm_model.rwtransformer.ln_f.weight) + torch.nn.init.zeros_(rm_model.rwtransformer.ln_f.bias) force_optimize_params.extend( ['rwtransformer.ln_f.weight', 'rwtransformer.ln_f.bias']) diff --git a/applications/DeepSpeed-Chat/training/step3_rlhf_finetuning/main.py b/applications/DeepSpeed-Chat/training/step3_rlhf_finetuning/main.py index a5be5671b..1378dc4e6 100644 --- a/applications/DeepSpeed-Chat/training/step3_rlhf_finetuning/main.py +++ b/applications/DeepSpeed-Chat/training/step3_rlhf_finetuning/main.py @@ -246,6 +246,9 @@ def parse_args(): '--offload_reference_model', action='store_true', help='Enable ZeRO Offload techniques for reference model') + parser.add_argument('--offload_reward_model', + action='store_true', + help='Enable ZeRO Offload techniques for reward model') parser.add_argument( '--actor_zero_stage', type=int, diff --git a/applications/DeepSpeed-VisualChat/utils/model/modeling_dsvl.py b/applications/DeepSpeed-VisualChat/utils/model/modeling_dsvl.py index eb9db9428..1407c1dfc 100755 --- a/applications/DeepSpeed-VisualChat/utils/model/modeling_dsvl.py +++ b/applications/DeepSpeed-VisualChat/utils/model/modeling_dsvl.py @@ -15,7 +15,7 @@ os.path.abspath(os.path.join(os.path.dirname(__file__), os.path.pardir))) import data.DST as DST # default special tokens from torch.utils.data import DataLoader -from transformers.deepspeed import HfDeepSpeedConfig +from transformers.integrations.deepspeed import HfDeepSpeedConfig import numpy as np from .vis_proj import VisProjection_vit, VisProjection_perceiver diff --git a/benchmarks/communication/README.md b/benchmarks/communication/README.md index 535b5d308..15ce1995b 100644 --- a/benchmarks/communication/README.md +++ b/benchmarks/communication/README.md @@ -1,6 +1,6 @@ # The DeepSpeed Communication Benchmarking Suite -The intent of these benchmarks is to measure communication latency/bw of deepspeed and/or pytorch distributed communication operations at the Python layer. These benchmarks are complementary to C-level comms benchmarks like [OSU Micro-Benchmarks](https://mvapich.cse.ohio-state.edu/benchmarks/) and [NCCL Tests](https://github.com/NVIDIA/nccl-tests) in that users can: +The intent of these benchmarks is to measure communication latency/bw of deepspeed and/or pytorch distributed communication operations at the Python layer. These benchmarks are complementary to C-level comms benchmarks like [OSU Micro-Benchmarks](https://mvapich.cse.ohio-state.edu/benchmarks/) , [NCCL Tests](https://github.com/NVIDIA/nccl-tests) and [oneCCL Benchmark](https://oneapi-src.github.io/oneCCL/benchmark.html) in that users can: - Easily debug which layer of the communication software stack hangs or performance degradations originate from. - Measure the expected communication performance of either DeepSpeed comms or pure PyTorch distributed @@ -77,6 +77,14 @@ Finally, users can choose specific communication operations to run in `run_all.p deepspeed run_all.py --scan --all-reduce --all-to-all --broadcast +## CPU Support +Those benchmarks could also support other devices like Intel CPU via oneCCL. +Users just need to append one more argument "--device cpu" for all python scripts to run on Intel CPU. +For example, run with a single large message size on Intel CPU: +
+deepspeed all_reduce.py --device cpu ++ # Adding Communication Benchmarks diff --git a/benchmarks/communication/all_gather.py b/benchmarks/communication/all_gather.py index 8aa33581d..76c4f3b1e 100644 --- a/benchmarks/communication/all_gather.py +++ b/benchmarks/communication/all_gather.py @@ -17,6 +17,9 @@ # Run all_gather and print metrics def timed_all_gather(input, output, start_event, end_event, args): + if args.device == "cpu": + print_rank_0(f"No Event support on CPU to measure time for now") + return if args.dist == 'torch': import torch.distributed as dist @@ -64,8 +67,15 @@ def run_all_gather(local_rank, args): global_rank = dist.get_rank() world_size = dist.get_world_size() - start_event = torch.cuda.Event(enable_timing=True) - end_event = torch.cuda.Event(enable_timing=True) + if args.device == "xpu": + start_event = torch.xpu.Event(enable_timing=True) + end_event = torch.xpu.Event(enable_timing=True) + elif args.device == "cpu": + start_event = torch.cpu.Event() + end_event = torch.cpu.Event() + else: + start_event = torch.cuda.Event(enable_timing=True) + end_event = torch.cuda.Event(enable_timing=True) if args.scan: # Create list of message sizes diff --git a/benchmarks/communication/all_reduce.py b/benchmarks/communication/all_reduce.py index b9decfd98..41c3116ee 100644 --- a/benchmarks/communication/all_reduce.py +++ b/benchmarks/communication/all_reduce.py @@ -15,6 +15,9 @@ def timed_all_reduce(input, start_event, end_event, args): + if args.device == "cpu": + print_rank_0(f"No Event support on CPU to measure time for now") + return if args.dist == 'torch': import torch.distributed as dist elif args.dist == 'deepspeed': @@ -60,8 +63,15 @@ def run_all_reduce(local_rank, args): world_size = dist.get_world_size() global_rank = dist.get_rank() - start_event = torch.cuda.Event(enable_timing=True) - end_event = torch.cuda.Event(enable_timing=True) + if args.device == "xpu": + start_event = torch.xpu.Event(enable_timing=True) + end_event = torch.xpu.Event(enable_timing=True) + elif args.device == "cpu": + start_event = torch.cpu.Event() + end_event = torch.cpu.Event() + else: + start_event = torch.cuda.Event(enable_timing=True) + end_event = torch.cuda.Event(enable_timing=True) if args.scan: M_LIST = [] diff --git a/benchmarks/communication/all_to_all.py b/benchmarks/communication/all_to_all.py index 7eccfa824..dc10b9ec9 100644 --- a/benchmarks/communication/all_to_all.py +++ b/benchmarks/communication/all_to_all.py @@ -15,6 +15,9 @@ def timed_all_to_all(input, output, start_event, end_event, args): + if args.device == "cpu": + print_rank_0(f"No Event support on CPU to measure time for now") + return if args.dist == 'torch': import torch.distributed as dist elif args.dist == 'deepspeed': @@ -59,8 +62,15 @@ def run_all_to_all(local_rank, args): # Prepare benchmark header print_header(args, 'all_to_all') - start_event = torch.cuda.Event(enable_timing=True) - end_event = torch.cuda.Event(enable_timing=True) + if args.device == "xpu": + start_event = torch.xpu.Event(enable_timing=True) + end_event = torch.xpu.Event(enable_timing=True) + elif args.device == "cpu": + start_event = torch.cpu.Event() + end_event = torch.cpu.Event() + else: + start_event = torch.cuda.Event(enable_timing=True) + end_event = torch.cuda.Event(enable_timing=True) if args.scan: M_LIST = [] diff --git a/benchmarks/communication/broadcast.py b/benchmarks/communication/broadcast.py index 860c9555b..d05303be1 100644 --- a/benchmarks/communication/broadcast.py +++ b/benchmarks/communication/broadcast.py @@ -15,6 +15,9 @@ def timed_broadcast(input, start_event, end_event, args): + if args.device == "cpu": + print_rank_0(f"No Event support on CPU to measure time for now") + return if args.dist == 'torch': import torch.distributed as dist elif args.dist == 'deepspeed': @@ -60,8 +63,15 @@ def run_broadcast(local_rank, args): world_size = dist.get_world_size() global_rank = dist.get_rank() - start_event = torch.cuda.Event(enable_timing=True) - end_event = torch.cuda.Event(enable_timing=True) + if args.device == "xpu": + start_event = torch.xpu.Event(enable_timing=True) + end_event = torch.xpu.Event(enable_timing=True) + elif args.device == "cpu": + start_event = torch.cpu.Event() + end_event = torch.cpu.Event() + else: + start_event = torch.cuda.Event(enable_timing=True) + end_event = torch.cuda.Event(enable_timing=True) if args.scan: M_LIST = [] diff --git a/benchmarks/communication/constants.py b/benchmarks/communication/constants.py index ae9fa261b..60df98ed2 100644 --- a/benchmarks/communication/constants.py +++ b/benchmarks/communication/constants.py @@ -12,4 +12,5 @@ DEFAULT_UNIT = 'Gbps' DEFAULT_DIST = 'deepspeed' DEFAULT_MAXSIZE = 24 +DEFAULT_DEVICE = 'cuda' TORCH_DISTRIBUTED_DEFAULT_PORT = 29500 diff --git a/benchmarks/communication/pt2pt.py b/benchmarks/communication/pt2pt.py index 57eab9a66..ec3252eb8 100644 --- a/benchmarks/communication/pt2pt.py +++ b/benchmarks/communication/pt2pt.py @@ -15,6 +15,9 @@ def timed_pt2pt(input, start_event, end_event, args): + if args.device == "cpu": + print_rank_0(f"No Event support on CPU to measure time for now") + return if args.dist == 'torch': import torch.distributed as dist elif args.dist == 'deepspeed': @@ -78,8 +81,15 @@ def run_pt2pt(local_rank, args): global_rank = dist.get_rank() world_size = dist.get_world_size() - start_event = torch.cuda.Event(enable_timing=True) - end_event = torch.cuda.Event(enable_timing=True) + if args.device == "xpu": + start_event = torch.xpu.Event(enable_timing=True) + end_event = torch.xpu.Event(enable_timing=True) + elif args.device == "cpu": + start_event = torch.cpu.Event() + end_event = torch.cpu.Event() + else: + start_event = torch.cuda.Event(enable_timing=True) + end_event = torch.cuda.Event(enable_timing=True) if args.scan: # Create list of message sizes diff --git a/benchmarks/communication/utils.py b/benchmarks/communication/utils.py index a74d24e28..6f6dd83a1 100644 --- a/benchmarks/communication/utils.py +++ b/benchmarks/communication/utils.py @@ -108,6 +108,11 @@ def get_bw(comm_op, size, duration, args): n = dist.get_world_size() tput = 0 busbw = 0 + + if duration == 0: + print_rank_0("Error. Duration is 0.") + return tput, busbw + if comm_op == "all_to_all": tput = (size / duration) busbw = (size / duration) * ((n - 1) / n) @@ -235,4 +240,5 @@ def benchmark_parser(): default=.3, help='Proportion of max available GPU memory to use for single-size evals') parser.add_argument("--debug", action="store_true", help='Enables all_to_all debug prints') + parser.add_argument("--device", type=str, default=DEFAULT_DEVICE, help='target device') return parser diff --git a/benchmarks/inference/deepspeedometer/README.md b/benchmarks/inference/deepspeedometer/README.md new file mode 100644 index 000000000..b327916c5 --- /dev/null +++ b/benchmarks/inference/deepspeedometer/README.md @@ -0,0 +1,88 @@ +# DeepSpeedometer + +NOTE: This is an experimental tool and is not currently being supported since it's not fully functional. Please use the MII benchmark which can be found here: +https://github.com/microsoft/DeepSpeedExamples/tree/master/benchmarks/inference/mii + +This benchmark is designed to measure performance of LLM serving solutions. Using a different number of parallel clients sending requests to an inference server, we gather data to plot throughput-latency curves and find the saturation point of an inference server that demonstrates the maximum performance. + +## Installation + +To install the benchmark, clone this repository and install using `pip`: +```shell +git clone https://github.com/Microsoft/DeepSpeedExamples +cd ./DeepSpeedExamples/benchmarks/deepspeedometer +pip install . +``` + +## Usage + +To quickly test the benchmark code without creating an inference server, run the following: +``` +python3 -m deepspeedometer.benchmark_runner --model facebook/opt-125m --api dummy +``` + +### Supports APIs + +The benchmark supports different APIs, each with their own client type. Depending on the client, you may need to run the benchmark against a locally hosted inference server or a remote inference server. Adding support for new serving solutions can be achieved by creating a new client class that defines a few basic methods. See the section below on adding new clients for more information. + +The clients (i.e., APIs) curently supported (and configuration options for each) are listed below. You can see more information about the configuration options by looking at the `*ClientConfig` classes located in `clients/*.py`: + +1. `fastgen`: Runs a local model inference server with DeepSpeed's FastGen. Config options include: + - `model`: Which model to use for serving (required) + - `deployment_name`: Name of the deployment server + - `tp_size`: Tensor parallel size for each model replicas + - `num_replicas`: Number of model replicas + - `max_ragged_batch_size`: Max number of requests running per model replicas + - `quantization_mode`: Type of quantization to use +2. `vllm`: Runs a local model inference server with vLLM. + - `model`: Which model to use for serving (required) + - `tp_size`: Tensor parallel size for model + - `port`: Which port to use for REST API +3. `azureml`: Interfaces with remote AzureML online endpoint/deployment. + - `api_url`: AzureML endpoint API URL (required) + - `api_key`: AzureML token key for connecting to endpoint (required) + - `deployment_name`: Name of deployment hosted in given endpoint (required) + +### Benchmark Configuration + +The Benchmark has many options for tailoring performance measurements to a specific use-cases. For additional information and default values, see the `BenchmarkConfig` class defined in `benchmark_runner.py`. + +- `api`: Which API to use +- `warmup_requests`: Number of warm-up requests to run before measuring performance +- `result_dir`: Directory where results will be written out (as JSON files) +- `use_threading`: Whether to use threading for the benchmark clients. Default is to use multi-processing +- `config_file`: One or more config YAML files that contain values for any of the Prompt configuration options (see below section on prompt configuration) +- `num_clients`: One or more integer values for the number of parallel clients to run +- `num_requests_per_client`: Number of requests that will be run by each of the parallel clients +- `min_requests`: Minimum number of requests to be sent during duration of benchmark. Useful when there is a low number of clients to ensure good measurement. +- `prompt_text_source`: Text file or string that will be sampled to generate request prompts +- `early_stop_latency`: When running multiple values for `num_clients`, if the average latency per request exceeds this value (in seconds) the benchmark will not test a larger number of parallel clients +- `force`: Force the overwrite of result files. By default, if a result file exists, the benchmark is skipped + +### Prompt Configuration + +These options allow users to modify the prompt input and generation behavior of the served models. Note that you can run multiple prompt configurations in a single command by using the `config_file` input as described in the Benchmark Configuration section. + +- `model`: Which model to use for tokenizing prompts (required) +- `prompt_generator_seed`: Seed value for random number generation +- `max_prompt_length`: The maximum prompt length allowed +- `prompt_length`: Target mean prompt length +- `prompt_lenght_var`: Variance of generated prompt lengths +- `max_new_tokens`: Target mean number of generated tokens +- `max_new_tokens_var`: Variance of generated tokens +- `streaming`: Whether to enabled streaming output for generated tokens + +#### About Prompt Generation + +To mimic real-world serving scenarios, this benchmark samples prompt length and generated token length values from a normal distribution. This distribution can be manipulated with the `prompt_length*` and `max_new_tokens*` values in the prompt configuration. To get all prompt lengths and generation lengths to match exactly, set the `*_var` values to 0. + +## Adding New Client APIs + +The DeepSpeedometer benchmark was designed to allow easily adding support for new inference server solutions. To do so: + +1. Create a new `*_client.py` file in the `clients/` directory. +2. Define a `*Client` class that inherits from the `BaseClient` class in `clients/base.py`. This class should define 5 methods: `start_service`, `stop_service`, `prepare_request`, `send_request`, and `process_response`. Take a look at the type hints for these methods in the `BaseClient` class to understand the expected inputs and outputs for each method. +3. Define a `*ClientConfig` class that inherits from the `BaseConfigModel` class. Place any configuration options (i.e., user-passed command line arguments) necessary for your defined `*Client` class in here. +4. Import the newly added `*Client` and `*ClientConfig` into `clients/__init__.py` and add them to the `client_config_classes` and `client_classes` dictionaries. + +For the simplest example of adding a new client, take a look at the `clients/dummy_client.py` file where we have defined a client that does not stand up a server and only returns a sample of the input prompt after a short sleep cycle. We use this as a light-weight class for unit testing. diff --git a/benchmarks/inference/deepspeedometer/configs/128k-120.yaml b/benchmarks/inference/deepspeedometer/configs/128k-120.yaml new file mode 100644 index 000000000..574e8e05e --- /dev/null +++ b/benchmarks/inference/deepspeedometer/configs/128k-120.yaml @@ -0,0 +1,5 @@ +prompt_length: 128000 +prompt_length_var: 0.1 +max_prompt_length: 131072 +max_new_tokens: 120 +max_new_tokens_var: 0.3 diff --git a/benchmarks/inference/deepspeedometer/configs/1300-120.yaml b/benchmarks/inference/deepspeedometer/configs/1300-120.yaml new file mode 100644 index 000000000..874a24c27 --- /dev/null +++ b/benchmarks/inference/deepspeedometer/configs/1300-120.yaml @@ -0,0 +1,4 @@ +prompt_length: 1300 +prompt_lenght_var: 0.3 +max_new_tokens: 120 +max_new_tokens_var: 0.3 diff --git a/benchmarks/inference/deepspeedometer/configs/2600-60.yaml b/benchmarks/inference/deepspeedometer/configs/2600-60.yaml new file mode 100644 index 000000000..f7674f819 --- /dev/null +++ b/benchmarks/inference/deepspeedometer/configs/2600-60.yaml @@ -0,0 +1,4 @@ +prompt_length: 2600 +prompt_lenght_var: 0.3 +max_new_tokens: 60 +max_new_tokens_var: 0.3 diff --git a/benchmarks/inference/deepspeedometer/configs/500-500.yaml b/benchmarks/inference/deepspeedometer/configs/500-500.yaml new file mode 100644 index 000000000..72389b37d --- /dev/null +++ b/benchmarks/inference/deepspeedometer/configs/500-500.yaml @@ -0,0 +1,4 @@ +prompt_length: 500 +prompt_lenght_var: 0.3 +max_new_tokens: 500 +max_new_tokens_var: 0.3 diff --git a/benchmarks/inference/deepspeedometer/pyproject.toml b/benchmarks/inference/deepspeedometer/pyproject.toml new file mode 100644 index 000000000..c15a27035 --- /dev/null +++ b/benchmarks/inference/deepspeedometer/pyproject.toml @@ -0,0 +1,32 @@ +[build-system] +requires = ["setuptools>=61.0"] +build-backend = "setuptools.build_meta" +[project] +name = "deepspeedometer" +version = "0.0.1" +authors = [ + { name="Ammar Ahmad Awan", email="ammar.awan@microsoft.com" }, + { name="Arash Bakhitiari", email="abakhtiari@microsoft.com" }, + { name="Connor Holmes"}, + { name="Lev Kurilenko", email="lev.kurilenko@microsoft.com" }, + { name="Heyang Qin", email="heyangqin@microsoft.com" }, + { name="Masahiro Tanaka", email="mtanaka@microsoft.com" }, + { name="Michael Wyatt", email="michaelwyatt@microsoft.com" }, +] +description = "LLM benchmarking tool" +readme = "README.md" +requires-python = ">=3.8" +classifiers = [ + "Programming Language :: Python :: 3", +] +dependencies = [ + "loguru", + "pydantic>=2.0.0", + "torch", + "tqdm", + "transformers", +] + +[project.urls] +Homepage = "https://github.com/Microsoft/DeepSpeedExamples/tree/master/benchmarks/inference/deepspeedometer" +Issues = "https://github.com/Microsoft/DeepSpeedExamples/issues" diff --git a/benchmarks/inference/deepspeedometer/run_example.sh b/benchmarks/inference/deepspeedometer/run_example.sh new file mode 100644 index 000000000..42fef231d --- /dev/null +++ b/benchmarks/inference/deepspeedometer/run_example.sh @@ -0,0 +1 @@ +python -m src.deepspeedometer.benchmark_runner --model "facebook/opt-125m" --api dummy --config_file ./configs/1300-120.yaml diff --git a/benchmarks/inference/deepspeedometer/src/deepspeedometer/__init__.py b/benchmarks/inference/deepspeedometer/src/deepspeedometer/__init__.py new file mode 100644 index 000000000..32cb0a0f9 --- /dev/null +++ b/benchmarks/inference/deepspeedometer/src/deepspeedometer/__init__.py @@ -0,0 +1,2 @@ +from .arg_parsing import parse_args_to_configs +from .benchmark_runner import BenchmarkRunner diff --git a/benchmarks/inference/deepspeedometer/src/deepspeedometer/arg_parsing.py b/benchmarks/inference/deepspeedometer/src/deepspeedometer/arg_parsing.py new file mode 100644 index 000000000..8be6d0d42 --- /dev/null +++ b/benchmarks/inference/deepspeedometer/src/deepspeedometer/arg_parsing.py @@ -0,0 +1,51 @@ +import argparse +from typing import List, Tuple + +from .benchmark_runner import BenchmarkConfig +from .clients import client_config_classes +from .config import BaseConfigModel + + +def parse_args_to_configs(args: List[str]) -> Tuple[BenchmarkConfig, BaseConfigModel]: + def add_model(parser: argparse.ArgumentParser, model: BaseConfigModel): + """Adds fields from pydantic model to the parser.""" + for name, field in model.model_fields.items(): + field_type = field.annotation + + # Get information about number of arguments expected + nargs = None + if getattr(field.annotation, "_name", "") == "List": + nargs = "+" + field_type = field.annotation.__args__[0] + + # Add field to parser + parser.add_argument( + f"--{name}", + dest=name, + nargs=nargs, + type=field_type, + required=getattr(field, "required", False), + default=getattr(field, "default", None), + help=getattr(field, "description", ""), + ) + + # Parse benchmark config fields + parser = argparse.ArgumentParser(allow_abbrev=False) + add_model(parser, BenchmarkConfig) + benchmark_args, remaining_args = parser.parse_known_args(args) + benchmark_config = BenchmarkConfig(**vars(benchmark_args)) + unused_args = set(remaining_args) + + # Parse client config fields + client_config_class = client_config_classes[benchmark_config.api] + parser = argparse.ArgumentParser(allow_abbrev=False) + add_model(parser, client_config_class) + client_args, remaining_args = parser.parse_known_args(args) + client_config = client_config_class(**vars(client_args)) + + # Check for unused arguments + unused_args = unused_args.intersection(remaining_args) + if unused_args: + raise ValueError(f"Unused arguments: {unused_args}") + + return benchmark_config, client_config diff --git a/benchmarks/inference/deepspeedometer/src/deepspeedometer/benchmark_runner.py b/benchmarks/inference/deepspeedometer/src/deepspeedometer/benchmark_runner.py new file mode 100644 index 000000000..96dd3a0da --- /dev/null +++ b/benchmarks/inference/deepspeedometer/src/deepspeedometer/benchmark_runner.py @@ -0,0 +1,390 @@ +import itertools +import json +import multiprocessing +import os +import queue +import sys +import threading +import time +import yaml +from pathlib import Path +from typing import List, Iterable, Tuple + +from loguru import logger +from tqdm import tqdm + +from .clients import client_classes, BaseClient +from .config import BaseConfigModel +from .prompt import Prompt, PromptConfig, PromptGenerator +from .response import Response +from .sample_input import sample_input_text + + +class BenchmarkConfig(PromptConfig): + api: str = "azure_ml" + """ Which API to use for benchmarking. New APIs can be added by creating a new client class in the `clients` directory. """ + + warmup_requests: int = 1 + """ Number of requests to run (per client) as a warm-up before starting the benchmark. """ + + result_dir: Path = Path("./results") + """ Top directory where results will be saved. """ + + use_threading: bool = False + """ Whether to use threading or multiprocessing for parallel client requests. Default is multiprocessing. """ + + config_file: List[Path] = [] + """ Path to YAML file(s) containing benchmark configuration settings. """ + + num_clients: List[int] = [1, 2, 4, 6, 8, 12, 16, 20, 24, 28, 32] + """ Number of clients to run in parallel. """ + + num_requests_per_client: int = 16 + """ Number of requests to run per client. """ + + min_requests: int = 128 + """ Minimum number of request to create (regardless of num_requests_per_client). """ + + prompt_text_source: str = sample_input_text + """ Text file or string to use for generated prompts. """ + + early_stop_latency: float = 10.0 + """ Maximum mean latency (in seconds) to allow before stopping the benchmark early. """ + + force: bool = False + """ Whether to overwrite existing result files. """ + + +class ClientLauncher: + def __init__( + self, + client_class: BaseClient, + client_config: BaseConfigModel, + warmup_requests: int, + use_threading: bool, + prompt_generator: PromptGenerator, + ): + self.client_class = client_class + self.client_config = client_config + self.client_obj = client_class(client_config) + self.warmup_requests = warmup_requests + self.prompt_generator = prompt_generator + + if use_threading: + self.runnable_cls = threading.Thread + self.barrier_cls = threading.Barrier + self.queue_cls = queue.Queue + else: + self.runnable_cls = multiprocessing.Process + self.barrier_cls = multiprocessing.Barrier + self.queue_cls = multiprocessing.Queue + + def run_parallel_clients(self, num_clients: int) -> None: + logger.info(f"Launching {num_clients} client(s)") + + total_requests = self.request_queue.qsize() + + self.barrier = self.barrier_cls(num_clients + 1) + processes = [ + self.runnable_cls( + target=self._run_client, + args=( + i, + self.barrier, + self.request_queue, + self.response_queue, + self.client_class, + self.client_config, + self.warmup_requests, + ), + ) + for i in range(num_clients) + ] + for p in processes: + p.start() + + self.barrier.wait() # Barrier 1 for master process + + self._progress_bar(total_requests - self.warmup_requests * num_clients) + + self.barrier.wait() # Barrier 2 for master process + + def _progress_bar(self, total_requests: int) -> None: + pbar = tqdm(total=total_requests) + num_responses = 0 + while num_responses != total_requests: + num_responses = self.response_queue.qsize() + pbar.update(num_responses - pbar.n) + time.sleep(1) + pbar.close() + + @staticmethod + def _run_client( + client_id: int, + barrier: multiprocessing.Barrier, + request_queue: multiprocessing.Queue, + response_queue: multiprocessing.Queue, + client_class: BaseClient, + client_config: BaseConfigModel, + warmup_requests: int, + ): + client = client_class(client_config) + + for _ in range(warmup_requests): + prompt = request_queue.get(timeout=1.0) + _ = client.send_request(prompt.request_kwargs) + + barrier.wait() # Barrier 1 for client process + try: + while True: + prompt = request_queue.get(timeout=1.0) + start_time = time.time() + raw_response = client.send_request(prompt.request_kwargs) + end_time = time.time() + request_time = end_time - start_time + response = Response( + prompt_text=prompt.text, + prompt_tokens=prompt.num_tokens, + raw_response=raw_response, + request_time=request_time, + client_id=client_id, + ) + response_queue.put_nowait(response) + except queue.Empty: + pass + + barrier.wait() # Barrier 2 for client process + + def add_request(self, prompt: Prompt) -> None: + request_kwargs = self.client_obj.prepare_request(prompt) + prompt.request_kwargs = request_kwargs + self.request_queue.put(prompt) + + def get_response(self) -> Response: + response = self.response_queue.get(timeout=1.0) + processed_response = self.client_obj.process_response(response.raw_response) + response.generated_output = processed_response + response.generated_tokens = self.prompt_generator.count_tokens( + processed_response + ) + return response + + def clear_queues(self) -> None: + self.request_queue = self.queue_cls() + self.response_queue = self.queue_cls() + + def start_service(self) -> None: + self.client_obj.start_service() + + def stop_service(self) -> None: + self.client_obj.stop_service() + + +class BenchmarkRunner: + def __init__( + self, benchmark_config: BaseConfigModel, client_config: BaseConfigModel + ) -> None: + logger.info("Initializing Benchmark Runner") + self.config = benchmark_config + self.client_config = client_config + self.client_class = client_classes[self.config.api] + self.prompt_generator = PromptGenerator( + self.config.model, self.config.prompt_text_source + ) + self.client_launcher = ClientLauncher( + client_class=self.client_class, + client_config=self.client_config, + warmup_requests=self.config.warmup_requests, + use_threading=self.config.use_threading, + prompt_generator=self.prompt_generator, + ) + self.all_responses = [] + + def _benchmark_settings(self) -> Iterable[Tuple[List[int], PromptConfig]]: + prompt_config_keys = list(PromptConfig.model_fields.keys()) + + configs_list = [] + for f in self.config.config_file: + logger.info(f"Generating benchmark run settings from config file: {f}") + with open(f, "r") as fh: + file_config = yaml.safe_load(fh) + + # Get any prompt config values stored in config files + for key in prompt_config_keys + ["num_clients"]: + if key not in file_config: + file_config[key] = getattr(self.config, key) + configs_list.append(file_config) + + if not configs_list: + logger.info(f"Generating benchmark run settings from command line args") + configs_list.append( + { + key: getattr(self.config, key) + for key in prompt_config_keys + ["num_clients"] + } + ) + + all_config_product = [] + for config in configs_list: + # Ensure all config values are iterable types (i.e., list or tuple) + for k, v in config.items(): + if not isinstance(v, list) or isinstance(v, tuple): + config[k] = [v] + + # We treat num_clients differently to enable early stopping + num_clients = config.pop("num_clients") + + # Generate all possible combinations of prompt config values + for vals in itertools.product(*[config[k] for k in prompt_config_keys]): + config_product = {k: v for k, v in zip(prompt_config_keys, vals)} + config_product["num_clients"] = num_clients + all_config_product.append(config_product) + + logger.info(f"Generated {len(all_config_product)} benchmark run setting(s)") + + for config in all_config_product: + num_clients = config.pop("num_clients") + prompt_config = PromptConfig(**config) + yield num_clients, prompt_config + + def _generate_requests(self, prompt_config: PromptConfig, num_clients: int) -> None: + logger.info("Generating Prompts") + + warmup_prompts = self.config.warmup_requests * num_clients + workload_prompts = max( + self.config.min_requests, self.config.num_requests_per_client * num_clients + ) + for prompt in self.prompt_generator( + config=prompt_config, num_prompts=warmup_prompts + workload_prompts + ): + self.client_launcher.add_request(prompt) + + logger.info( + f"Generated {warmup_prompts} warmup and {workload_prompts} workload prompts." + ) + + def _get_output_dir(self) -> Path: + return self.config.result_dir / self.config.api / self.config.model + + def _get_output_path(self, prompt_config: PromptConfig, num_clients: int) -> Path: + output_dir = self._get_output_dir() + output_file = f"prompt{prompt_config.prompt_length}_gen{prompt_config.max_new_tokens}_clients{num_clients}.json" + return output_dir / output_file + + def _process_responses( + self, prompt_config: PromptConfig, num_clients: int + ) -> List[Response]: + output_path = self._get_output_path( + prompt_config=prompt_config, num_clients=num_clients + ) + + logger.info(f"Saving results to {output_path}") + + all_responses = [] + while True: + try: + all_responses.append(self.client_launcher.get_response()) + except queue.Empty: + break + + os.makedirs(output_path.parent, exist_ok=True) + with open(output_path, "w") as fh: + json.dump([r.to_dict() for r in all_responses], fh, indent=2) + + logger.info(f"Saved {len(all_responses)} responses to {output_path}") + + return all_responses + + def _print_result_summary( + self, all_responses: List[Response], num_clients: int + ) -> None: + num_responses = int(len(all_responses)) + mean_latency = sum([r.request_time for r in all_responses]) / num_responses + query_throughput = num_clients / mean_latency + mean_prompt_length = int( + sum([r.prompt_tokens for r in all_responses]) / num_responses + ) + mean_gen_length = int( + sum([r.generated_tokens for r in all_responses]) / num_responses + ) + logger.info( + f"Result summary - # Requests: {num_responses:d}, Mean Prompt Length: {mean_prompt_length:d} tokens, Mean Generation Length: {mean_gen_length:d} tokens, Mean Latency: {mean_latency:.2f} s, Throughput: {query_throughput:.2f} queries/s," + ) + + def _check_early_stop(self, all_responses: List[Response]) -> bool: + if not all_responses: + return False + mean_latency = sum([r.request_time for r in all_responses]) / len(all_responses) + if mean_latency >= self.config.early_stop_latency: + logger.info( + f"Mean latency of {mean_latency:.2f} exceeds early stopping threshold of {self.config.early_stop_latency}. Stopping early." + ) + return True + return False + + def _skip_existing_result( + self, prompt_config: PromptConfig, num_clients: int + ) -> bool: + output_path = self._get_output_path( + prompt_config=prompt_config, num_clients=num_clients + ) + if output_path.exists(): + if self.config.force: + logger.info( + f"Result already exists, but force flag is set. Overwriting benchmark with {num_clients} client(s) and prompt config: {prompt_config}" + ) + return False + else: + logger.info( + f"Result already exists, skipping benchmark with {num_clients} client(s) and prompt config: {prompt_config}" + ) + return True + return False + + def run(self) -> None: + # Start the client service + self.client_launcher.start_service() + + # Generate all benchmark settings from user config(s) + for num_clients_list, prompt_config in self._benchmark_settings(): + all_responses = [] + for num_clients in sorted(num_clients_list): + if self._skip_existing_result( + prompt_config=prompt_config, num_clients=num_clients + ): + continue + + if self._check_early_stop(all_responses): + break + + logger.info( + f"Running benchmark with {num_clients} client(s) and prompt config: {prompt_config}" + ) + # Clear out queues and generate request prompts + self.client_launcher.clear_queues() + self._generate_requests( + prompt_config=prompt_config, num_clients=num_clients + ) + + # Launch the clients and process requests + self.client_launcher.run_parallel_clients(num_clients=num_clients) + + # Process raw responses and save results to file + all_responses = self._process_responses( + prompt_config=prompt_config, num_clients=num_clients + ) + + self._print_result_summary( + all_responses=all_responses, num_clients=num_clients + ) + + # Stop the client service + self.client_launcher.stop_service() + + +if __name__ == "__main__": + from .arg_parsing import parse_args_to_configs + + benchmark_config, client_config = parse_args_to_configs(sys.argv[1:]) + benchmark_runner = BenchmarkRunner(benchmark_config, client_config) + benchmark_runner.run() diff --git a/benchmarks/inference/deepspeedometer/src/deepspeedometer/clients/__init__.py b/benchmarks/inference/deepspeedometer/src/deepspeedometer/clients/__init__.py new file mode 100644 index 000000000..ac1891112 --- /dev/null +++ b/benchmarks/inference/deepspeedometer/src/deepspeedometer/clients/__init__.py @@ -0,0 +1,22 @@ +from .base import BaseClient + +from .azure_ml_client import AzureMLClientConfig, AzureMLClient +from .dummy_client import DummyClientConfig, DummyClient +from .fastgen_client import FastGenClientConfig, FastGenClient +from .vllm_client import vLLMClientConfig, vLLMClient +from .openai_client import openaiClientConfig, openaiClient + +client_config_classes = { + "dummy": DummyClientConfig, + "azure_ml": AzureMLClientConfig, + "fastgen": FastGenClientConfig, + "vllm": vLLMClientConfig, + "openai": openaiClientConfig +} +client_classes = { + "dummy": DummyClient, + "azure_ml": AzureMLClient, + "fastgen": FastGenClient, + "vllm": vLLMClient, + "openai": openaiClient, +} diff --git a/benchmarks/inference/deepspeedometer/src/deepspeedometer/clients/azure_ml_client.py b/benchmarks/inference/deepspeedometer/src/deepspeedometer/clients/azure_ml_client.py new file mode 100644 index 000000000..5bedff692 --- /dev/null +++ b/benchmarks/inference/deepspeedometer/src/deepspeedometer/clients/azure_ml_client.py @@ -0,0 +1,79 @@ +import json +import requests +from typing import Any, Dict + +from loguru import logger + +from .base import BaseClient +from ..config import BaseConfigModel +from ..prompt import Prompt + + +class AzureMLClientConfig(BaseConfigModel): + api_url: str = "" + """ URL for the AzureML REST API. """ + + api_key: str = "" + """ REST API key for the AzureML deployment. """ + + deployment_name: str = "" + """ Name of the AzureML deployment. """ + + +class AzureMLClient(BaseClient): + def __init__(self, config: AzureMLClientConfig) -> None: + super().__init__(config) + self.api_url = config.api_url + self.api_key = config.api_key + self.deployment_name = config.deployment_name + + def start_service(self) -> None: + # Verify that the server exists, this could be extended to actually + # start an AML deployment. However currently we assume one exists. + test_prompt = Prompt("hello world", num_tokens=5, max_new_tokens=16) + _ = self.process_response(self.send_request(self.prepare_request(test_prompt))) + + def stop_service(self) -> None: + pass + + def prepare_request(self, prompt: Prompt) -> Dict[str, Any]: + # TODO: add support for OpenAI chat completion template + if prompt.streaming: + raise ValueError("AzureMLClient does not support streaming prompts.") + + headers = { + "Content-Type": "application/json", + "Authorization": ("Bearer " + self.api_key), + "azureml-model-deployment": self.deployment_name, + } + pload = { + "input_data": { + "input_string": [ + prompt.text, + ], + "parameters": { + "max_tokens": prompt.max_new_tokens, + "return_full_text": prompt.return_full_text, + }, + } + } + return {"url": self.api_url, "headers": headers, "json": pload, "timeout": 180} + + def send_request(self, request_kwargs: Dict[str, Any]) -> Any: + while True: + try: # Sometimes the AML endpoint will return an error, so we send the request again + response = requests.post(**request_kwargs) + output = json.loads(response.content) + assert ( + response.status_code == 200 + ), f"Status code: {response.status_code}" + assert output[0]["0"], f"Empty response" + break + except Exception as e: + logger.debug(f"Connection failed with {e}. Retrying AML request") + + return output + + def process_response(self, raw_response: Any) -> str: + response_text = raw_response[0]["0"] + return response_text diff --git a/benchmarks/inference/deepspeedometer/src/deepspeedometer/clients/base.py b/benchmarks/inference/deepspeedometer/src/deepspeedometer/clients/base.py new file mode 100644 index 000000000..40a38e057 --- /dev/null +++ b/benchmarks/inference/deepspeedometer/src/deepspeedometer/clients/base.py @@ -0,0 +1,30 @@ +from abc import ABC, abstractmethod +from typing import Any, Dict + +from ..config import BaseConfigModel +from ..prompt import Prompt + + +class BaseClient(ABC): + def __init__(self, config: BaseConfigModel) -> None: + self.config = config + + @abstractmethod + def start_service(self) -> None: + pass + + @abstractmethod + def stop_service(self) -> None: + pass + + @abstractmethod + def prepare_request(self, prompt: Prompt) -> Dict[str, Any]: + pass + + @abstractmethod + def send_request(self, request_kwargs: Dict[str, Any]) -> Any: + pass + + @abstractmethod + def process_response(self, raw_response: Any) -> str: + pass diff --git a/benchmarks/inference/deepspeedometer/src/deepspeedometer/clients/dummy_client.py b/benchmarks/inference/deepspeedometer/src/deepspeedometer/clients/dummy_client.py new file mode 100644 index 000000000..f10b1e94e --- /dev/null +++ b/benchmarks/inference/deepspeedometer/src/deepspeedometer/clients/dummy_client.py @@ -0,0 +1,45 @@ +import time +import random +from typing import Any, Dict + +from transformers import AutoTokenizer + +from .base import BaseClient +from ..config import BaseConfigModel +from ..prompt import Prompt + + +class DummyClientConfig(BaseConfigModel): + model: str + dummy_client_latency_time: float = 0.1 + + +class DummyClient(BaseClient): + def __init__(self, config: DummyClientConfig) -> None: + super().__init__(config) + self.tokenizer = AutoTokenizer.from_pretrained(self.config.model) + self.latency_time = config.dummy_client_latency_time + + def start_service(self) -> None: + pass + + def stop_service(self) -> None: + pass + + def prepare_request(self, prompt: Prompt) -> Dict[str, Any]: + return {"input_text": prompt.text, "max_new_tokens": prompt.max_new_tokens} + + def send_request(self, request_kwargs: Dict[str, Any]) -> Any: + time.sleep( + abs(random.uniform(self.latency_time - 0.1, self.latency_time + 0.2)) + ) + output_text = self.tokenizer.decode( + random.choices( + self.tokenizer.encode(request_kwargs["input_text"]), + k=request_kwargs["max_new_tokens"], + ) + ) + return output_text + + def process_response(self, raw_response: Any) -> str: + return raw_response diff --git a/benchmarks/inference/deepspeedometer/src/deepspeedometer/clients/fastgen_client.py b/benchmarks/inference/deepspeedometer/src/deepspeedometer/clients/fastgen_client.py new file mode 100644 index 000000000..c3f3a086f --- /dev/null +++ b/benchmarks/inference/deepspeedometer/src/deepspeedometer/clients/fastgen_client.py @@ -0,0 +1,91 @@ +import time +from typing import Any, Dict, Optional + +from loguru import logger +from pydantic import Field + +from .base import BaseClient +from ..config import BaseConfigModel +from ..prompt import Prompt + + +class FastGenClientConfig(BaseConfigModel): + model: str = Field(..., description="HuggingFace.co model name") + deployment_name: str = "fastgen-benchmark-deployment" + tp_size: int = 1 + num_replicas: int = 1 + max_ragged_batch_size: int = 768 + quantization_mode: Optional[str] = None + + +class FastGenClient(BaseClient): + def __init__(self, config: FastGenClientConfig): + super().__init__(config) + try: + import mii + except ImportError as e: + logger.error( + "Please install the `deepspeed-mii` package to use this client." + ) + raise e + + self.mii_client = mii.client(config.deployment_name) + self.streaming = config.streaming + + def start_service(self) -> None: + import mii + from deepspeed.inference import RaggedInferenceEngineConfig, DeepSpeedTPConfig + from deepspeed.inference.v2.ragged import DSStateManagerConfig + + tp_config = DeepSpeedTPConfig(tp_size=self.config.tp_size) + mgr_config = DSStateManagerConfig( + max_ragged_batch_size=self.config.max_ragged_batch_size, + max_ragged_sequence_count=self.config.max_ragged_batch_size, + ) + inference_config = RaggedInferenceEngineConfig( + tensor_parallel=tp_config, state_manager=mgr_config + ) + mii.serve( + self.config.model, + deployment_name=self.config.deployment_name, + tensor_parallel=self.config.tp_size, + inference_engine_config=inference_config, + replica_num=self.config.num_replicas, + quantization_mode=self.config.quantization_mode, + ) + + def stop_service(self) -> None: + import mii + + mii.client(self.config.deployment_name).terminate_server() + + def _streaming_callback(self, raw_response) -> None: + self.streaming_response_tokens.append(raw_response[0].generated_text) + time_now = time.time() + self.streaming_token_gen_time.append(time_now - time_last_token) + time_last_token = time_now + + def prepare_request(self, prompt: Prompt) -> Dict[str, Any]: + request_kwargs = { + "prompts": prompt.text, + "max_new_tokens": prompt.max_new_tokens, + } + if self.streaming: + self.streaming_response_tokens = [] + self.streaming_token_gen_time = [] + self.streaming_time_last_token = None + request_kwargs["streaming_fn"] = self._streaming_callback + return request_kwargs + + def send_request(self, request_kwargs: Dict[str, Any]) -> Any: + if self.streaming: + self.streaming_time_last_token = time.time() + response = self.mii_client(**request_kwargs) + if self.streaming: + response = self.streaming_response_tokens + + return response + + def process_response(self, raw_response: Any) -> str: + if not self.streaming: + return raw_response[0].generated_text diff --git a/benchmarks/inference/deepspeedometer/src/deepspeedometer/clients/openai_client.py b/benchmarks/inference/deepspeedometer/src/deepspeedometer/clients/openai_client.py new file mode 100644 index 000000000..76eadfc5c --- /dev/null +++ b/benchmarks/inference/deepspeedometer/src/deepspeedometer/clients/openai_client.py @@ -0,0 +1,57 @@ +import os +import json +import requests +import subprocess +import time +from typing import Any, Dict + +from loguru import logger +from pydantic import Field + +from .base import BaseClient +from ..config import BaseConfigModel +from ..prompt import Prompt + + +# client to test any openai API +class openaiClientConfig(BaseConfigModel): + model: str = Field(..., description="HuggingFace.co model name") + url: str = "http://127.0.0.1:26500/v1/completions" + + +class openaiClient(BaseClient): + def __init__(self, config: openaiClientConfig): + super().__init__(config) + + def start_service(self) -> None: + pass + + def stop_service(self) -> None: + pass + + def prepare_request(self, prompt: Prompt) -> Dict[str, Any]: + api_url = self.config.url + headers = { + "User-Agent": "Benchmark Client", + "Content-Type": "application/json", + "Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}" + } + pload = { + "prompt": prompt.text, + "model": self.config.model, + "n": 1, + "use_beam_search": False, + "temperature": 1.0, + "top_p": 0.9, + "max_tokens": prompt.max_new_tokens, + "ignore_eos": False, + } + return {"url": api_url, "headers": headers, "json": pload, "timeout": 180} + + def send_request(self, request_kwargs: Dict[str, Any]) -> Any: + response = requests.post(**request_kwargs) + output = json.loads(response.content) + return output + + def process_response(self, raw_response: Any) -> str: + return raw_response["choices"][0]["text"] diff --git a/benchmarks/inference/deepspeedometer/src/deepspeedometer/clients/vllm_client.py b/benchmarks/inference/deepspeedometer/src/deepspeedometer/clients/vllm_client.py new file mode 100644 index 000000000..563c66e9d --- /dev/null +++ b/benchmarks/inference/deepspeedometer/src/deepspeedometer/clients/vllm_client.py @@ -0,0 +1,88 @@ +import json +import requests +import subprocess +import time +from typing import Any, Dict + +from loguru import logger +from pydantic import Field + +from .base import BaseClient +from ..config import BaseConfigModel +from ..prompt import Prompt + + +class vLLMClientConfig(BaseConfigModel): + model: str = Field(..., description="HuggingFace.co model name") + tp_size: int = 1 + port: int = 26500 + + +class vLLMClient(BaseClient): + def __init__(self, config: vLLMClientConfig): + super().__init__(config) + try: + import vllm + except ImportError as e: + logger.error("Please install the `vllm` package to use this client.") + raise e + + def start_service(self) -> None: + vllm_cmd = ( + "python", + "-m", + "vllm.entrypoints.api_server", + "--host", + "127.0.0.1", + "--port", + str(self.config.port), + "--tensor-parallel-size", + str(self.config.tp_size), + "--model", + self.config.model, + ) + p = subprocess.Popen( + vllm_cmd, stdout=subprocess.DEVNULL, stderr=subprocess.PIPE, close_fds=True + ) + start_time = time.time() + timeout_after = 60 * 5 # 5 minutes + while True: + line = p.stderr.readline().decode("utf-8") + if "Application startup complete" in line: + break + if "error" in line.lower(): + p.terminate() + # self.stop_service(config) + raise RuntimeError(f"Error starting VLLM server: {line}") + if time.time() - start_time > timeout_after: + p.terminate() + # self.stop_service(config) + raise TimeoutError("Timed out waiting for VLLM server to start") + time.sleep(0.01) + + def stop_service(self) -> None: + vllm_cmd = ("pkill", "-f", "vllm.entrypoints.api_server") + p = subprocess.Popen(vllm_cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE) + p.wait() + + def prepare_request(self, prompt: Prompt) -> Dict[str, Any]: + api_url = "http://localhost:26500/generate" + headers = {"User-Agent": "Benchmark Client"} + pload = { + "prompt": prompt.text, + "n": 1, + "use_beam_search": False, + "temperature": 1.0, + "top_p": 0.9, + "max_tokens": prompt.max_new_tokens, + "ignore_eos": False, + } + return {"url": api_url, "headers": headers, "json": pload, "timeout": 180} + + def send_request(self, request_kwargs: Dict[str, Any]) -> Any: + response = requests.post(**request_kwargs) + output = json.loads(response.content) + return output + + def process_response(self, raw_response: Any) -> str: + return raw_response["text"] diff --git a/benchmarks/inference/deepspeedometer/src/deepspeedometer/config.py b/benchmarks/inference/deepspeedometer/src/deepspeedometer/config.py new file mode 100644 index 000000000..d524eb2cf --- /dev/null +++ b/benchmarks/inference/deepspeedometer/src/deepspeedometer/config.py @@ -0,0 +1,13 @@ +from pydantic import BaseModel, ConfigDict + + +class BaseConfigModel(BaseModel): + model_config = ConfigDict( + validate_default=True, + validate_assignment=False, + use_enum_values=True, + populate_by_name=True, + extra="forbid", + arbitrary_types_allowed=True, + protected_namespaces=(), + ) diff --git a/benchmarks/inference/deepspeedometer/src/deepspeedometer/prompt.py b/benchmarks/inference/deepspeedometer/src/deepspeedometer/prompt.py new file mode 100644 index 000000000..58bd82d0a --- /dev/null +++ b/benchmarks/inference/deepspeedometer/src/deepspeedometer/prompt.py @@ -0,0 +1,117 @@ +import os +from dataclasses import dataclass +from typing import Iterable, Optional +from typing_extensions import Self + +import numpy as np +import torch +from loguru import logger +from pydantic import model_validator +from transformers import AutoTokenizer + +from .config import BaseConfigModel + +# Avoids a warning from transformers +os.environ["TOKENIZERS_PARALLELISM"] = "false" + + +@dataclass +class Prompt: + text: str + num_tokens: int + max_new_tokens: int + streaming: bool = False + return_full_text: bool = False + request_kwargs: dict = None + + +class PromptConfig(BaseConfigModel): + model: str + """ Names of the model used to benchmark. Used to load the model/tokenizer from HuggingFace.co. """ + + prompt_generator_seed: Optional[int] = None + """ Seed value for prompt generator. """ + + max_prompt_length: int = 4000 + """ Maximum prompt length for any request. """ + + prompt_length: int = 2600 + """ Mean prompt length for requests. """ + + prompt_length_var: float = 0.3 + """ Variance of prompt length. """ + + max_new_tokens: int = 60 + """ Mean number of new tokens to generate in each request. """ + + max_new_tokens_var: float = 0.3 + """ Variance of new tokens to generate. """ + + streaming: bool = False + """ Whether to enable streaming mode for the client. """ + + @model_validator(mode="after") + def set_max_prompt_length(self) -> Self: + if self.prompt_length > self.max_prompt_length: + logger.warning( + f"Prompt length {self.prompt_length} is greater than max prompt length {self.max_prompt_length}. Setting max prompt length to {self.prompt_length}." + ) + self.max_prompt_length = max(self.max_prompt_length, self.prompt_length) + return self + + +class PromptGenerator: + def __init__(self, model: str, prompt_text_source: str) -> None: + self.tokenizer = AutoTokenizer.from_pretrained(model) + if os.path.isfile(prompt_text_source): + with open(prompt_text_source, "r") as f: + prompt_text_source = f.read() + self.input_text = prompt_text_source + self.tokenized_input = self.tokenizer.encode( + self.input_text, return_tensors="pt", padding=False + )[0] + + def count_tokens(self, text: str) -> int: + return len(self.tokenizer.encode(text)) + + def __call__(self, config: PromptConfig, num_prompts: int) -> Iterable[Prompt]: + tokenized_input = self.tokenized_input + if len(tokenized_input) < config.max_prompt_length: + tokenized_input = torch.cat( + [ + tokenized_input + for _ in range(config.max_prompt_length // len(tokenized_input) + 1) + ] + ).flatten() + + if config.prompt_generator_seed is not None: + np.random.seed(config.prompt_generator_seed) + + for _ in range(num_prompts): + # Take the absolute value here because sometimes the normal + # distribution will return a negative value. This is technically + # wrong, but works out OK for most scenarios. + prompt_length = min( + abs( + int( + np.random.normal( + config.prompt_length, + config.prompt_length * config.prompt_length_var, + ) + ) + ), + config.max_prompt_length, + ) + max_new_tokens = abs( + int( + np.random.normal( + config.max_new_tokens, + config.max_new_tokens * config.max_new_tokens_var, + ) + ) + ) + yield Prompt( + text=self.tokenizer.decode(tokenized_input[:prompt_length]), + num_tokens=prompt_length, + max_new_tokens=max_new_tokens, + ) diff --git a/benchmarks/inference/deepspeedometer/src/deepspeedometer/response.py b/benchmarks/inference/deepspeedometer/src/deepspeedometer/response.py new file mode 100644 index 000000000..3842ce5d7 --- /dev/null +++ b/benchmarks/inference/deepspeedometer/src/deepspeedometer/response.py @@ -0,0 +1,16 @@ +from dataclasses import asdict, dataclass +from typing import Any + + +@dataclass +class Response: + prompt_text: str = "" + prompt_tokens: int = 0 + generated_output: str = "" + generated_tokens: int = 0 + request_time: float = 0 + raw_response: Any = None + client_id: int = 0 + + def to_dict(self) -> dict: + return asdict(self) diff --git a/benchmarks/inference/deepspeedometer/src/deepspeedometer/sample_input.py b/benchmarks/inference/deepspeedometer/src/deepspeedometer/sample_input.py new file mode 100644 index 000000000..0754da724 --- /dev/null +++ b/benchmarks/inference/deepspeedometer/src/deepspeedometer/sample_input.py @@ -0,0 +1,225 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +# This is a sample input consisting of: +# Code & Text + +sample_input_text = """Deep learning involves the use of neural networks, which are computational models inspired by the structure and functioning of the human brain. These networks consist of interconnected nodes called neurons. Each neuron takes input, performs a computation, and produces an output. + During training, the neural network learns to make accurate predictions by adjusting its internal parameters. This adjustment is done using an optimization algorithm called gradient descent. Gradient descent calculates the gradients of a loss function, which measures the discrepancy between the predicted output of the network and the desired output. These gradients indicate the direction and magnitude of parameter updates that will minimize the loss. + The learning rate is an important hyperparameter in gradient descent. It determines the step size taken during parameter updates. A higher learning rate can lead to faster convergence, but it risks overshooting the optimal solution. On the other hand, a lower learning rate may converge more slowly, but it can result in more precise updates. + Activation functions are applied to the output of each neuron in a neural network. They introduce non-linearities, enabling the network to learn complex patterns and relationships in the data. Popular activation functions include the rectified linear unit (ReLU), sigmoid, and hyperbolic tangent (tanh). + By adjusting the parameters of the neural network during training, deep learning models learn to represent and generalize from complex data patterns. They have achieved remarkable success in various tasks, including image recognition, speech recognition, and natural language processing. + Here are the key fundamentals of deep learning for training large language models: + Neural Networks: At the heart of deep learning are artificial neural networks, which are inspired by the structure and functioning of biological neurons in the human brain. These networks consist of interconnected layers of artificial neurons called nodes or units. The nodes receive input, perform computations, and pass the results to the next layer. + Representation Learning: Deep learning models excel at learning meaningful representations of data. In the context of language, the models can automatically learn hierarchical representations of text, capturing complex relationships and semantic structures. + Feedforward and Backpropagation: Deep learning models typically use feedforward neural networks, where information flows from the input layer through intermediate hidden layers to the output layer. The network makes predictions based on the input data, and the prediction error is then backpropagated through the network. Backpropagation calculates gradients that indicate how each parameter in the network should be adjusted to minimize the error. + Activation Functions: Activation functions introduce non-linearities to neural networks, enabling them to learn complex patterns. Common activation functions include the rectified linear unit (ReLU), sigmoid, and hyperbolic tangent (tanh). These functions determine the output of each neuron based on its weighted inputs. + Loss Functions: During training, a loss function is used to measure the discrepancy between the predicted output of the neural network and the desired output. In language modeling tasks, common loss functions include cross-entropy loss, which quantifies the difference in probability distributions. + Optimization Algorithms: Optimization algorithms determine how the network's parameters are updated based on the calculated gradients during backpropagation. Stochastic Gradient Descent (SGD) is a widely used algorithm that iteratively updates the parameters in the direction that minimizes the loss. Variants of SGD, such as Adam or RMSprop, adaptively adjust the learning rate to accelerate convergence. + Regularization Techniques: Deep learning models are prone to overfitting, where they memorize the training data but fail to generalize well to unseen examples. Regularization techniques such as dropout and weight decay are commonly used to prevent overfitting and improve generalization by adding constraints to the model's parameters. + Training on Large-Scale Datasets: Deep learning models, including large language models, require substantial amounts of labeled training data to learn effectively. Large-scale datasets are crucial to expose the model to diverse language patterns and ensure it captures a broad understanding of language. + Parallel Computing: Training large language models is computationally demanding. To accelerate the training process, parallel computing techniques, such as using multiple GPUs or distributed computing systems, are employed. These techniques allow for efficient processing of large datasets and speeding up the training iterations. + Transfer Learning and Fine-tuning: Transfer learning is a technique where a pre-trained model, trained on a large-scale dataset, is used as a starting point for a new task or dataset. Fine-tuning involves adjusting the pre-trained model's parameters on the new dataset to adapt it to the specific task at hand. This approach significantly reduces the training time and data requirements for new models. + The training process of a large language model typically involves the following steps: + Data Collection: A diverse and comprehensive dataset is collected, which typically consists of a vast range of text from sources like books, websites, articles, and other textual resources. The quality and variety of the dataset are crucial to ensure the model learns a broad understanding of language. + Preprocessing: The collected text data is preprocessed to clean and normalize it. This step involves removing irrelevant characters or symbols, converting the text to a consistent format, and organizing it into smaller units such as sentences or paragraphs. + Tokenization: The preprocessed text is divided into individual tokens, which can be as small as words or even subword units. Tokenization helps in representing and processing the text efficiently during training. + Architecture Design: The model architecture, often based on the transformer architecture, is defined. Transformers are neural network models that excel in capturing long-range dependencies in sequential data, making them well-suited for language modeling tasks. + Model Initialization: The model parameters are randomly initialized to start the training process. These parameters will be adjusted iteratively during training to optimize the model's performance. + Training Loop: The model is trained using a large-scale computational infrastructure. The training loop typically involves several iterations over the dataset, known as epochs. During each epoch, the model processes the input data, generates predictions, and compares them with the expected output. The discrepancy between the predicted and expected output is used to compute a loss, which quantifies the model's performance. + Backpropagation and Optimization: Backpropagation is employed to calculate the gradients of the model's parameters with respect to the loss. These gradients indicate the direction and magnitude of the parameter updates needed to minimize the loss. Optimization algorithms, such as stochastic gradient descent (SGD) or its variants, are then used to update the model's parameters based on the computed gradients. + Iterative Refinement: Steps 6 and 7 are repeated for multiple epochs, gradually refining the model's performance. The model's ability to generate coherent and contextually relevant responses improves as it learns from the dataset. + Evaluation: The trained model is evaluated on a separate dataset to assess its performance and identify areas for improvement. Various metrics, such as perplexity or accuracy, can be used to evaluate the model's language generation capabilities. + Fine-tuning and Iteration: Based on the evaluation results, the model may undergo fine-tuning or further iterations of training to enhance its performance. This process helps in addressing specific limitations or biases and aligning the model's output more closely with desired expectations. + It's important to note that training a large language model from scratch is a computationally intensive process that requires substantial computational resources, including powerful hardware like GPUs or specialized hardware accelerators, and large-scale distributed systems to handle the massive amount of data and model parameters involved. + Here are ten highly recommended books that can help you learn deep learning: + "Deep Learning" by Ian Goodfellow, Yoshua Bengio, and Aaron Courville: + This comprehensive book covers the fundamental concepts of deep learning, including neural networks, optimization algorithms, and regularization techniques. It also explores advanced topics like generative models and deep reinforcement learning. + "Deep Learning with Python" by François Chollet: + Written by the creator of the Keras deep learning library, this book provides a practical introduction to deep learning with Python. It covers essential concepts, tools, and techniques, and includes hands-on examples and case studies. + "Hands-On Machine Learning with Scikit-Learn, Keras, and TensorFlow" by Aurélien Géron: + This book offers a hands-on approach to learning machine learning and deep learning using popular Python libraries such as Scikit-Learn, Keras, and TensorFlow. It covers various algorithms and provides practical examples and exercises. + "Deep Learning for Computer Vision" by Rajalingappaa Shanmugamani: + Focusing on deep learning techniques for computer vision tasks, this book explores topics such as convolutional neural networks (CNNs), image classification, object detection, and image generation. It includes code examples using Python and popular deep learning frameworks. + "Deep Learning: A Practitioner's Approach" by Josh Patterson and Adam Gibson: + This book offers a practical guide to implementing deep learning solutions using the Deeplearning4j library. It covers key concepts, architectures, and techniques, and includes code examples and case studies. + "Grokking Deep Learning" by Andrew Trask: + Geared towards beginners, this book provides an intuitive and accessible introduction to deep learning concepts. It covers neural networks, backpropagation, gradient descent, and other fundamental topics with clear explanations and visualizations. + "Deep Learning for Natural Language Processing" by Palash Goyal, Sumit Pandey, and Karan Jain: + Focusing on deep learning techniques for natural language processing (NLP), this book explores topics like word embeddings, recurrent neural networks (RNNs), and sequence-to-sequence models. It includes code examples using Python and popular NLP libraries. + "Deep Reinforcement Learning" by Pieter Abbeel and John Schulman: + This book provides an in-depth exploration of deep reinforcement learning, a subfield that combines deep learning with reinforcement learning. It covers topics like Q-learning, policy gradients, and deep Q-networks (DQNs) and provides practical examples. + "Deep Learning for Time Series Forecasting" by N.D. Lewis: + Focusing on deep learning techniques for time series data, this book covers topics such as recurrent neural networks (RNNs), long short-term memory (LSTM) networks, and attention models. It includes code examples using Python and popular deep learning frameworks. + "Interpretable Deep Learning" by Christoph Molnar: + This book delves into the challenges and techniques for interpreting and understanding deep learning models. It covers model visualization, feature importance, and other methods for explaining and interpreting deep learning predictions. + These books cover a range of deep learning topics and provide valuable insights and practical guidance for learning and applying deep learning techniques. Choose the ones that align with your interests and learning style to enhance your understanding of deep learning. + Here are 10 popular GitHub projects that can be useful for building large language models (LLMs) or working with natural language processing (NLP) tasks: + TensorFlow: An open-source deep learning framework that provides tools and resources for building and training LLMs. It offers extensive support for various neural network architectures and has a large community. + PyTorch: Another popular deep learning framework that provides a dynamic computational graph and a wide range of tools for building LLMs. It is known for its user-friendly interface and flexibility. + Hugging Face Transformers: A library that provides pre-trained models and a high-level API for natural language understanding (NLU) tasks, including LLMs. It supports popular models like GPT, BERT, and RoBERTa. + Fairseq: A library developed by Facebook AI Research that focuses on sequence modeling tasks, including LLMs. It offers pre-trained models and tools for training and evaluating models using sequence-to-sequence architectures. + AllenNLP: A powerful NLP research library that simplifies the process of building and evaluating deep learning models. It offers pre-built components for common NLP tasks and supports LLMs with various architectures. + OpenAI GPT-3: Although not available on GitHub, OpenAI's GPT-3 language model is widely recognized and can be accessed via the OpenAI API. It offers state-of-the-art language generation capabilities and can be used for various NLP tasks. + BERT: A pre-trained language model developed by Google Research that has achieved exceptional results on various NLP benchmarks. The official implementation is available on GitHub and can be fine-tuned for specific tasks. + spaCy: A popular Python library for NLP tasks that provides efficient and scalable tools for tokenization, named entity recognition, part-of-speech tagging, and more. It integrates well with deep learning frameworks. + FastText: A library developed by Facebook Research that provides efficient tools for text classification and word representation learning. It offers pre-trained word embeddings and supports training LLMs for classification tasks. + NLTK (Natural Language Toolkit): A comprehensive library for NLP tasks in Python. It provides various modules for tokenization, stemming, tagging, parsing, and more. Although it doesn't focus explicitly on LLMs, it is widely used for preprocessing text data in NLP pipelines. + These projects offer a range of resources, pre-trained models, and tools that can assist you in building and working with large language models. Make sure to review the documentation and examples provided by each project to understand their capabilities and how they can be integrated into your workflow. + Here are some popular backend libraries that are commonly used for deep learning: + TensorFlow: Developed by Google's Brain Team, TensorFlow is one of the most widely used deep learning frameworks. It provides a flexible and comprehensive ecosystem for building and deploying machine learning models. TensorFlow offers high-level APIs for easy model construction, as well as lower-level APIs for fine-grained control. It supports distributed computing and has extensive community support. + PyTorch: Developed by Facebook's AI Research lab, PyTorch is known for its simplicity and dynamic computational graph. It allows for intuitive model construction and debugging. PyTorch is widely used in both research and industry due to its flexibility, support for dynamic networks, and strong GPU acceleration capabilities. + Keras: Initially developed as a user-friendly deep learning library, Keras is now integrated as the official high-level API in TensorFlow. It provides a user-friendly and modular interface for building neural networks. Keras abstracts away many complexities and allows users to build models with just a few lines of code. It supports multiple backends, including TensorFlow and Theano. + Theano: Although its development has been discontinued, Theano was one of the first widely-used deep learning libraries. It allows for efficient mathematical operations on multi-dimensional arrays and supports GPU acceleration. Theano was influential in shaping the deep learning landscape and served as a precursor to subsequent frameworks. + Caffe: Developed by the Berkeley Vision and Learning Center (BVLC), Caffe is a popular deep learning framework known for its efficiency and simplicity. It is particularly suitable for convolutional neural networks (CNNs) and image-related tasks. Caffe has a clean and expressive architecture description language that makes it easy to define and train deep models. + MXNet: MXNet is an open-source deep learning framework developed by Apache. It offers a flexible and efficient interface for building and deploying neural networks. MXNet provides a hybrid frontend that allows users to seamlessly switch between symbolic and imperative programming. It is known for its scalability and supports multiple programming languages. + Chainer: Chainer is a flexible deep learning framework that focuses on dynamic neural networks. It allows for intuitive model construction using imperative programming, making it easy to define complex architectures and manipulate data within the network. Chainer is known for its "define-by-run" approach, which facilitates dynamic computations. + Microsoft Cognitive Toolkit (CNTK): CNTK is a deep learning framework developed by Microsoft. It provides a highly efficient and scalable implementation of deep neural networks. CNTK supports both declarative and imperative programming models, making it suitable for both research and production-level deployments. + Deeplearning4j: Deeplearning4j is an open-source deep learning library that focuses on scalability and performance. It is designed to integrate with the Java ecosystem and supports distributed computing. Deeplearning4j provides tools for building various types of neural networks and offers integration with other popular libraries like Hadoop and Spark. + PaddlePaddle: PaddlePaddle (PArallel Distributed Deep LEarning) is a deep learning framework developed by Baidu. It emphasizes scalability and supports large-scale distributed training. PaddlePaddle provides a rich set of built-in models and algorithms, making it accessible to both beginners and advanced users. + Each of these backend libraries offers unique features, performance characteristics, and levels of abstraction. The choice of a backend library depends on factors such as your programming language preferences, the complexity of your models, the availability of community support, and the specific requirements of your deep learning project. + Here's an example code snippet that demonstrates how to create a GPT-Neox20B model using the Hugging Face Transformers library and start fine-tuning it with sample data from the '/tmp/wikitext' directory: + + import torch + from transformers import GPTNeoForCausalLM, GPT2Tokenizer, TextDataset, DataCollatorForLanguageModeling, Trainer, TrainingArguments + + # Load the GPT-Neo model and tokenizer + model = GPTNeoForCausalLM.from_pretrained("EleutherAI/gpt-neo-2.7B") + tokenizer = GPT2Tokenizer.from_pretrained("EleutherAI/gpt-neo-2.7B") + + # Set the path to the training data + data_path = "/tmp/wikitext" + + # Define the dataset and data collator + dataset = TextDataset(tokenizer=tokenizer, file_path=data_path, block_size=128) + data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False) + + # Define the training arguments + training_args = TrainingArguments( + output_dir="./output_dir", # Directory to save the fine-tuned model and logs + overwrite_output_dir=True, + num_train_epochs=3, # Number of training epochs + per_device_train_batch_size=8, + save_steps=500, # Number of updates steps before saving the model + save_total_limit=2, # Limit the total number of saved models + prediction_loss_only=True, + ) + + # Initialize the Trainer + trainer = Trainer( + model=model, + args=training_args, + data_collator=data_collator, + train_dataset=dataset, + ) + + # Start fine-tuning + trainer.train() + + # Save the fine-tuned model + trainer.save_model("./fine-tuned-model") + + To deploy the GPT-Neo model on an HTTP server, you can use a framework like Flask or FastAPI. Here's a basic example using Flask: + + Install Flask: + + pip install flask + + Create a Python file, let's call it app.py, and import the necessary libraries: + + from flask import Flask, request, jsonify + from transformers import GPTNeoForCausalLM, GPT2Tokenizer + + # Initialize Flask app + app = Flask(__name__) + + # Load the pretrained GPT-Neo model and tokenizer + model = GPTNeoForCausalLM.from_pretrained("EleutherAI/gpt-neo-2.7B") + tokenizer = GPT2Tokenizer.from_pretrained("EleutherAI/gpt-neo-2.7B") + Define an endpoint for model inference: + + python + Copy code + @app.route("/generate", methods=["POST"]) + def generate_text(): + data = request.get_json() + input_text = data["input_text"] + + # Tokenize the input text + input_ids = tokenizer.encode(input_text, return_tensors="pt") + + # Generate text with the model + output = model.generate(input_ids, max_length=100, num_return_sequences=1) + generated_text = tokenizer.decode(output[0], skip_special_tokens=True) + + return jsonify({"generated_text": generated_text}) + Run the Flask app: + + python + Copy code + if __name__ == "__main__": + app.run(host="0.0.0.0", port=5000) + The code that includes the text generation sampling functions and provides a commented example: + + import requests + import numpy as np + + class TextGeneratorAPI: + def __init__(self, server_url): + self.server_url = server_url + + def generate_text(self, input_text, sampling_algorithm="greedy", temperature=0.7): + url = f"{self.server_url}/generate" + payload = { + "input_text": input_text, + "sampling_algorithm": sampling_algorithm, + "temperature": temperature + } + response = requests.post(url, json=payload) + generated_text = response.json()["generated_text"] + return generated_text + + def greedy_sampling(self, logits): + return np.argmax(logits) + + def random_sampling(self, logits): + probabilities = np.exp(logits / temperature) + probabilities = probabilities / np.sum(probabilities) + return np.random.choice(len(logits), p=probabilities) + + def top_k_sampling(self, logits, k=10): + indices = np.argsort(logits)[-k:] + probabilities = np.exp(logits[indices] / temperature) + probabilities = probabilities / np.sum(probabilities) + return np.random.choice(indices, p=probabilities) + + def top_p_sampling(self, logits, p=0.9): + sorted_logits = np.sort(logits)[::-1] + cumulative_probs = np.cumsum(np.exp(sorted_logits) / temperature) + indices = np.arange(len(sorted_logits)) + selected_indices = indices[cumulative_probs <= p] + probabilities = np.exp(logits[selected_indices] / temperature) + probabilities = probabilities / np.sum(probabilities) + return np.random.choice(selected_indices, p=probabilities) + In this updated code, the TextGeneratorAPI class includes the additional sampling functions: greedy_sampling, random_sampling, top_k_sampling, and top_p_sampling. These functions take logits (output of the model) as input and return the index of the selected token based on the respective sampling algorithm. + The greedy_sampling function selects the token with the highest probability (argmax) as the next token. The random_sampling function applies a temperature scaling to the logits and then samples from the resulting probability distribution. The top_k_sampling function selects from the top-k tokens with the highest probabilities. The top_p_sampling function selects from the tokens with cumulative probabilities below a certain threshold (top-p). + You can now use the updated TextGeneratorAPI class with the sampling functions. Here's an example: + + api = TextGeneratorAPI(server_url="http://localhost:5000") + + input_text = "Once upon a time" + + # Generate text using different sampling algorithms and temperatures + greedy_text = api.generate_text(input_text, sampling_algorithm="greedy") + random_text = api.generate_text(input_text, sampling_algorithm="random") + top_k_text = api.generate_text(input_text, sampling_algorithm="top_k", temperature=0.8) + top_p_text = api.generate_text(input_text, sampling_algorithm="top_p", temperature=0.9) + + print("Greedy Sampling:", greedy_text) + print("Random Sampling:", random_text) + print("Top-k Sampling:", top_k_text) + print("Top-p Sampling:", top_p_text) + Make sure to adjust the server_url with the appropriate URL of your HTTP server, and ensure that the server is running and accessible before making requests through the API. + """ diff --git a/benchmarks/inference/deepspeedometer/tests/README.md b/benchmarks/inference/deepspeedometer/tests/README.md new file mode 100644 index 000000000..15a5f49f9 --- /dev/null +++ b/benchmarks/inference/deepspeedometer/tests/README.md @@ -0,0 +1,3 @@ +To run the unit tests: + +`python3 -m pytest .` \ No newline at end of file diff --git a/benchmarks/inference/deepspeedometer/tests/__init__.py b/benchmarks/inference/deepspeedometer/tests/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/benchmarks/inference/deepspeedometer/tests/conftest.py b/benchmarks/inference/deepspeedometer/tests/conftest.py new file mode 100644 index 000000000..e2f779c44 --- /dev/null +++ b/benchmarks/inference/deepspeedometer/tests/conftest.py @@ -0,0 +1,95 @@ +import pytest + + +@pytest.fixture(scope="function", params=["facebook/opt-125m"]) +def model(request): + return request.param + + +@pytest.fixture(scope="function", params=["dummy"]) +def api(request): + return request.param + + +@pytest.fixture(scope="function", params=[""]) +def result_dir(request, tmpdir): + if request.param: + return str(request.param) + return str(tmpdir) + + +@pytest.fixture(scope="function", params=[5]) +def num_requests_per_client(request): + return str(request.param) + + +@pytest.fixture(scope="function", params=[16]) +def min_requests(request): + return str(request.param) + + +@pytest.fixture(scope="function", params=[(1, 2)]) +def num_clients(request): + if isinstance(request.param, tuple) or isinstance(request.param, list): + return [str(num) for num in request.param] + else: + return [str(request.param)] + + +@pytest.fixture(scope="function", params=[0]) +def num_config_files(request): + return request.param + + +@pytest.fixture(scope="function") +def config_files(num_config_files, tmp_path): + config_files = [] + for i in range(num_config_files): + config_file = tmp_path / f"config_{i}.yaml" + config_file.touch() + config_files.append(str(config_file)) + return config_files + + +@pytest.fixture(scope="function", params=[""]) +def prompt_length_var(request): + return str(request.param) + + +@pytest.fixture(scope="function", params=[""]) +def max_new_tokens_var(request): + return str(request.param) + + +@pytest.fixture(scope="function") +def benchmark_args( + model, + api, + result_dir, + num_requests_per_client, + min_requests, + num_clients, + config_files, + prompt_length_var, + max_new_tokens_var, +): + args = [] + if model: + args.extend(["--model", model]) + if api: + args.extend(["--api", api]) + if result_dir: + args.extend(["--result_dir", result_dir]) + if num_requests_per_client: + args.extend(["--num_requests_per_client", num_requests_per_client]) + if min_requests: + args.extend(["--min_requests", min_requests]) + if num_clients: + args.extend(["--num_clients"] + num_clients) + if config_files: + args.extend(["--config_file"] + config_files) + if prompt_length_var: + args.extend(["--prompt_length_var", prompt_length_var]) + if max_new_tokens_var: + args.extend(["--max_new_tokens_var", max_new_tokens_var]) + return args diff --git a/benchmarks/inference/deepspeedometer/tests/test_benchmark.py b/benchmarks/inference/deepspeedometer/tests/test_benchmark.py new file mode 100644 index 000000000..2b067d39e --- /dev/null +++ b/benchmarks/inference/deepspeedometer/tests/test_benchmark.py @@ -0,0 +1,17 @@ +import pytest + +from deepspeedometer import parse_args_to_configs, BenchmarkRunner + + +def test_benchmark_runner(benchmark_args, num_clients): + benchmark_config, client_config = parse_args_to_configs(benchmark_args) + benchmark_runner = BenchmarkRunner(benchmark_config, client_config) + benchmark_runner.run() + + expected_results = sum(1 for _ in benchmark_runner._benchmark_settings()) * len( + num_clients + ) + actual_results = len(list(benchmark_runner._get_output_dir().glob("*.json"))) + assert ( + expected_results == actual_results + ), f"Number of result files ({actual_results}) does not match expected number ({expected_results})." diff --git a/benchmarks/inference/deepspeedometer/tests/test_config.py b/benchmarks/inference/deepspeedometer/tests/test_config.py new file mode 100644 index 000000000..d20e0981a --- /dev/null +++ b/benchmarks/inference/deepspeedometer/tests/test_config.py @@ -0,0 +1,32 @@ +import pytest + +import yaml + +import pydantic + +from deepspeedometer import BenchmarkRunner, parse_args_to_configs + + +def test_config(benchmark_args): + benchmark_config, client_config = parse_args_to_configs(benchmark_args) + + +@pytest.mark.parametrize("model", [""]) +def test_config_required_fail(benchmark_args): + with pytest.raises(pydantic.ValidationError): + benchmark_config, client_config = parse_args_to_configs(benchmark_args) + + +@pytest.mark.parametrize("num_config_files", [1]) +def test_config_file(benchmark_args, config_files, num_clients): + # Create a config that would generate 6 benchmark settings + config = {"max_prompt_length": [500, 1300, 2600], "num_clients": [1, 2]} + with open(config_files[0], "w") as f: + yaml.dump(config, f) + + benchmark_config, client_config = parse_args_to_configs(benchmark_args) + benchmark_runner = BenchmarkRunner(benchmark_config, client_config) + benchmark_settings = sum(1 for _ in benchmark_runner._benchmark_settings()) * len( + num_clients + ) + assert benchmark_settings == 6 diff --git a/benchmarks/inference/deepspeedometer/tests/test_early_stop.py b/benchmarks/inference/deepspeedometer/tests/test_early_stop.py new file mode 100644 index 000000000..2a63ba206 --- /dev/null +++ b/benchmarks/inference/deepspeedometer/tests/test_early_stop.py @@ -0,0 +1,23 @@ +import pytest + +from deepspeedometer import parse_args_to_configs, BenchmarkRunner + + +@pytest.mark.parametrize("num_clients", [(1, 2, 4)], indirect=True) +def test_early_stop(benchmark_args): + benchmark_args += [ + "--early_stop_latency", + "1", + "--dummy_client_latency_time", + "2.0", + ] + print(benchmark_args) + benchmark_config, client_config = parse_args_to_configs(benchmark_args) + benchmark_runner = BenchmarkRunner(benchmark_config, client_config) + benchmark_runner.run() + + expected_results = 1 + actual_results = len(list(benchmark_runner._get_output_dir().glob("*.json"))) + assert ( + expected_results == actual_results + ), f"Number of result files ({actual_results}) does not match expected number ({expected_results})." diff --git a/benchmarks/inference/deepspeedometer/tests/test_prompt.py b/benchmarks/inference/deepspeedometer/tests/test_prompt.py new file mode 100644 index 000000000..997a82dd5 --- /dev/null +++ b/benchmarks/inference/deepspeedometer/tests/test_prompt.py @@ -0,0 +1,15 @@ +import pytest + +from deepspeedometer import BenchmarkRunner, parse_args_to_configs + + +@pytest.mark.parametrize("prompt_length_var, max_new_tokens_var", [(0, 0)]) +def test_prompt_length(benchmark_args): + benchmark_config, client_config = parse_args_to_configs(benchmark_args) + benchmark_runner = BenchmarkRunner(benchmark_config, client_config) + num_clients, prompt_config = next(benchmark_runner._benchmark_settings()) + + for prompt in benchmark_runner.prompt_generator(prompt_config, num_prompts=10): + prompt_length = benchmark_runner.prompt_generator.count_tokens(prompt.text) + # Using pytest.approx here because often we will have 1-off errors due to tokenization special tokens + assert prompt_length == pytest.approx(benchmark_runner.config.prompt_length, 1) diff --git a/benchmarks/inference/mii/README.md b/benchmarks/inference/mii/README.md index d9e475cdb..726cad462 100644 --- a/benchmarks/inference/mii/README.md +++ b/benchmarks/inference/mii/README.md @@ -1,39 +1,104 @@ -# Benchmarking Scripts for DeepSpeed-FastGen +# Inference Benchmarking Scripts for vLLM, DeepSpeed-FastGen, and Azure ML endpoints ## Run the Benchmark -The benchmarking scripts use DeepSpeed-FastGen in the persistent mode. -You can start the server with the command below: +The benchmarking scripts use DeepSpeed-FastGen in the persistent mode. You can +run the benchmark using `run_benchmark.py`. This script will run several +combinations of inference servers and clients with different tensor parallel +size, number of model replicas (MII only), number of clients, prompt length, and +max new tokens values. By default, the benchmark will run with the `meta-llama/Llama-2-7b-hf` model. ```bash -python server.py [options] start +python run_benchmark.py ``` -Use the -h option to view all available options. To stop the server, use this command: +Use the -h option to view all available options. Several models have pre-defined +default values, including `meta-llama/Llama-2-{7|13|70}b-hf`, +`tiiuae/falcon-{40|180}B`, `microsoft/phi-2`, and `mistralai/Mixtral-8x7B-v0.1`. +These defaults can be overridden if provided to the `run_benchmark.py` script. +For example, to run `meta-llama/Llama-13b-hf` with a tensor parallel size of `1` +and `2` (instead of the default `1`, `2`, and `4`): -```bash -python server.py stop +```bash +python run_benchmark.py --tp_size 1 2 +``` + +By default the benchmark runs with DeepSpeed-MII as the backend inference +server. The benchmark also supports vLLM and Azure endpoints. To change the +backend to vLLM, provide the `--backend vllm` arg: + +```bash +python run_benchmark.py --backend vllm ``` -Once the server is up and running, initiate the client using the command below. The -h option will display all the possible options. +To benchmark against an Azure endpoint, provide the `--backend aml` as well as +the following values: +- `--aml_api_url`: API URL that points to an AML endpoint +- `--aml_api_key`: API key for the given AML endpoint +- `--deployment_name`: The name of the AML endpoint deployment you want to test against +- `--model`: The name of the HuggingFace-hosted model deployed on the AML endpoint. This is used to load a tokenizer and correctly calculate the number of tokens in the prompts and responses. ```bash -python run_benchmark_client.py [options] +python run_benchmark.py --backend aml --model mistralai/Mixtral-8x7B-v0.1 --deployment_name mistralai-mixtral-8x7b-v01-4 --aml_api_url