Skip to content

Commit

Permalink
Merge branch 'master' into master
Browse files Browse the repository at this point in the history
  • Loading branch information
loadams authored Nov 4, 2024
2 parents 50aaf5c + eefb0ef commit d3cd036
Show file tree
Hide file tree
Showing 75 changed files with 2,557 additions and 158 deletions.
17 changes: 4 additions & 13 deletions applications/DeepSpeed-Chat/dschat/rlhf/rlhf_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -268,31 +268,22 @@ def _init_reward(self, critic_model_name_or_path):
# If critic is ZeRO-3 then we use it for everything, otherwise assume we have enough memory
zero_stage = 0

ds_config = get_eval_ds_config(offload=self.args.offload,
ds_config = get_eval_ds_config(offload=self.args.offload_reward_model,
dtype=self.args.dtype,
stage=zero_stage)
ds_config[
'train_micro_batch_size_per_gpu'] = self.args.per_device_training_batch_size
ds_config[
'train_batch_size'] = self.args.per_device_training_batch_size * torch.distributed.get_world_size(
) * self.args.gradient_accumulation_steps

ds_eval_config = get_eval_ds_config(offload=False,
dtype=self.args.dtype,
stage=zero_stage)

# We need to set train batch size and micro batch size here to pass the sanity check of DeepSpeed engine.
ds_eval_config[
ds_config[
'train_micro_batch_size_per_gpu'] = self.args.per_device_training_batch_size
ds_eval_config[
ds_config[
'train_batch_size'] = self.args.per_device_training_batch_size * torch.distributed.get_world_size(
) * self.args.gradient_accumulation_steps

# Model
reward_model = create_critic_model(
model_name_or_path=critic_model_name_or_path,
tokenizer=self.tokenizer,
ds_config=ds_eval_config,
ds_config=ds_config,
num_padding_at_beginning=self.args.num_padding_at_beginning,
rlhf_training=True,
dropout=self.args.critic_dropout,
Expand Down
10 changes: 7 additions & 3 deletions applications/DeepSpeed-Chat/dschat/utils/data/data_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,9 +149,13 @@ def __len__(self):
def __getitem__(self, idx):
if self.train_phase == 1:
return {
"input_ids": self.chosen_dataset[idx]["input_ids"],
"attention_mask": self.chosen_dataset[idx]["attention_mask"],
"labels": self.chosen_dataset[idx]["input_ids"]
"input_ids":
self.chosen_dataset[idx]["input_ids"],
"attention_mask":
self.chosen_dataset[idx]["attention_mask"],
"labels":
torch.where(self.chosen_dataset[idx]["attention_mask"].bool(),
self.chosen_dataset[idx]["input_ids"], -100)
}
elif self.train_phase == 2:
return self.chosen_dataset[idx]["input_ids"], self.chosen_dataset[idx]["attention_mask"], \
Expand Down
1 change: 1 addition & 0 deletions applications/DeepSpeed-Chat/dschat/utils/ds_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ def get_train_ds_config(offload,
dtype_config = {"enabled": True}
zero_opt_dict = {
"stage": stage,
"overlap_comm": True,
"offload_param": {
"device": device
},
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
AutoModel,
)
from huggingface_hub import snapshot_download
from transformers.deepspeed import HfDeepSpeedConfig
from transformers.integrations.deepspeed import HfDeepSpeedConfig

from dschat.utils.model.reward_model import RewardModel
from dschat.utils.utils import load_state_dict_into_model, print_rank_0
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -191,7 +191,13 @@ def parse_args():
parser.add_argument(
"--add_eot_token",
action='store_true',
help="Add <|endoftext|> as additional special token to tokenizer")
help="Add `eot_token` as additional special token to tokenizer")
parser.add_argument(
"--eot_token",
type=str,
default="<|endoftext|>",
help="Specify the format of the `eot_token`",
)
## Print loss
parser.add_argument('--print_loss',
action='store_true',
Expand Down Expand Up @@ -234,8 +240,7 @@ def main():
torch.distributed.barrier()

# load_hf_tokenizer will get the correct tokenizer and set padding tokens based on the model family
args.end_of_conversation_token = "<|endoftext|>"
additional_special_tokens = args.end_of_conversation_token if args.add_eot_token else None
additional_special_tokens = args.eot_token if args.add_eot_token else None
tokenizer = load_hf_tokenizer(args.model_name_or_path,
fast_tokenizer=True,
add_special_tokens=additional_special_tokens)
Expand Down Expand Up @@ -270,6 +275,7 @@ def main():
args.seed,
tokenizer,
args.max_seq_len,
end_of_conversation_token=tokenizer.eos_token,
sft_only_data_path=args.sft_only_data_path)
# DataLoaders creation:
if args.local_rank == -1:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -246,6 +246,9 @@ def parse_args():
'--offload_reference_model',
action='store_true',
help='Enable ZeRO Offload techniques for reference model')
parser.add_argument('--offload_reward_model',
action='store_true',
help='Enable ZeRO Offload techniques for reward model')
parser.add_argument(
'--actor_zero_stage',
type=int,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
os.path.abspath(os.path.join(os.path.dirname(__file__), os.path.pardir)))
import data.DST as DST # default special tokens
from torch.utils.data import DataLoader
from transformers.deepspeed import HfDeepSpeedConfig
from transformers.integrations.deepspeed import HfDeepSpeedConfig
import numpy as np
from .vis_proj import VisProjection_vit, VisProjection_perceiver

Expand Down
10 changes: 9 additions & 1 deletion benchmarks/communication/README.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# The DeepSpeed Communication Benchmarking Suite

The intent of these benchmarks is to measure communication latency/bw of deepspeed and/or pytorch distributed communication operations at the Python layer. These benchmarks are complementary to C-level comms benchmarks like [OSU Micro-Benchmarks](https://mvapich.cse.ohio-state.edu/benchmarks/) and [NCCL Tests](https://github.com/NVIDIA/nccl-tests) in that users can:
The intent of these benchmarks is to measure communication latency/bw of deepspeed and/or pytorch distributed communication operations at the Python layer. These benchmarks are complementary to C-level comms benchmarks like [OSU Micro-Benchmarks](https://mvapich.cse.ohio-state.edu/benchmarks/) , [NCCL Tests](https://github.com/NVIDIA/nccl-tests) and [oneCCL Benchmark](https://oneapi-src.github.io/oneCCL/benchmark.html) in that users can:
- Easily debug which layer of the communication software stack hangs or performance degradations originate from.
- Measure the expected communication performance of either DeepSpeed comms or pure PyTorch distributed

Expand Down Expand Up @@ -77,6 +77,14 @@ Finally, users can choose specific communication operations to run in `run_all.p
deepspeed run_all.py --scan --all-reduce --all-to-all --broadcast
</pre>

## CPU Support
Those benchmarks could also support other devices like Intel CPU via oneCCL.
Users just need to append one more argument "--device cpu" for all python scripts to run on Intel CPU.
For example, run with a single large message size on Intel CPU:
<pre>
deepspeed all_reduce.py --device cpu
</pre>


# Adding Communication Benchmarks

Expand Down
14 changes: 12 additions & 2 deletions benchmarks/communication/all_gather.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,9 @@

# Run all_gather and print metrics
def timed_all_gather(input, output, start_event, end_event, args):
if args.device == "cpu":
print_rank_0(f"No Event support on CPU to measure time for now")
return
if args.dist == 'torch':
import torch.distributed as dist

Expand Down Expand Up @@ -64,8 +67,15 @@ def run_all_gather(local_rank, args):
global_rank = dist.get_rank()
world_size = dist.get_world_size()

start_event = torch.cuda.Event(enable_timing=True)
end_event = torch.cuda.Event(enable_timing=True)
if args.device == "xpu":
start_event = torch.xpu.Event(enable_timing=True)
end_event = torch.xpu.Event(enable_timing=True)
elif args.device == "cpu":
start_event = torch.cpu.Event()
end_event = torch.cpu.Event()
else:
start_event = torch.cuda.Event(enable_timing=True)
end_event = torch.cuda.Event(enable_timing=True)

if args.scan:
# Create list of message sizes
Expand Down
14 changes: 12 additions & 2 deletions benchmarks/communication/all_reduce.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,9 @@


def timed_all_reduce(input, start_event, end_event, args):
if args.device == "cpu":
print_rank_0(f"No Event support on CPU to measure time for now")
return
if args.dist == 'torch':
import torch.distributed as dist
elif args.dist == 'deepspeed':
Expand Down Expand Up @@ -60,8 +63,15 @@ def run_all_reduce(local_rank, args):
world_size = dist.get_world_size()
global_rank = dist.get_rank()

start_event = torch.cuda.Event(enable_timing=True)
end_event = torch.cuda.Event(enable_timing=True)
if args.device == "xpu":
start_event = torch.xpu.Event(enable_timing=True)
end_event = torch.xpu.Event(enable_timing=True)
elif args.device == "cpu":
start_event = torch.cpu.Event()
end_event = torch.cpu.Event()
else:
start_event = torch.cuda.Event(enable_timing=True)
end_event = torch.cuda.Event(enable_timing=True)

if args.scan:
M_LIST = []
Expand Down
14 changes: 12 additions & 2 deletions benchmarks/communication/all_to_all.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,9 @@


def timed_all_to_all(input, output, start_event, end_event, args):
if args.device == "cpu":
print_rank_0(f"No Event support on CPU to measure time for now")
return
if args.dist == 'torch':
import torch.distributed as dist
elif args.dist == 'deepspeed':
Expand Down Expand Up @@ -59,8 +62,15 @@ def run_all_to_all(local_rank, args):
# Prepare benchmark header
print_header(args, 'all_to_all')

start_event = torch.cuda.Event(enable_timing=True)
end_event = torch.cuda.Event(enable_timing=True)
if args.device == "xpu":
start_event = torch.xpu.Event(enable_timing=True)
end_event = torch.xpu.Event(enable_timing=True)
elif args.device == "cpu":
start_event = torch.cpu.Event()
end_event = torch.cpu.Event()
else:
start_event = torch.cuda.Event(enable_timing=True)
end_event = torch.cuda.Event(enable_timing=True)

if args.scan:
M_LIST = []
Expand Down
14 changes: 12 additions & 2 deletions benchmarks/communication/broadcast.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,9 @@


def timed_broadcast(input, start_event, end_event, args):
if args.device == "cpu":
print_rank_0(f"No Event support on CPU to measure time for now")
return
if args.dist == 'torch':
import torch.distributed as dist
elif args.dist == 'deepspeed':
Expand Down Expand Up @@ -60,8 +63,15 @@ def run_broadcast(local_rank, args):
world_size = dist.get_world_size()
global_rank = dist.get_rank()

start_event = torch.cuda.Event(enable_timing=True)
end_event = torch.cuda.Event(enable_timing=True)
if args.device == "xpu":
start_event = torch.xpu.Event(enable_timing=True)
end_event = torch.xpu.Event(enable_timing=True)
elif args.device == "cpu":
start_event = torch.cpu.Event()
end_event = torch.cpu.Event()
else:
start_event = torch.cuda.Event(enable_timing=True)
end_event = torch.cuda.Event(enable_timing=True)

if args.scan:
M_LIST = []
Expand Down
1 change: 1 addition & 0 deletions benchmarks/communication/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,4 +12,5 @@
DEFAULT_UNIT = 'Gbps'
DEFAULT_DIST = 'deepspeed'
DEFAULT_MAXSIZE = 24
DEFAULT_DEVICE = 'cuda'
TORCH_DISTRIBUTED_DEFAULT_PORT = 29500
14 changes: 12 additions & 2 deletions benchmarks/communication/pt2pt.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,9 @@


def timed_pt2pt(input, start_event, end_event, args):
if args.device == "cpu":
print_rank_0(f"No Event support on CPU to measure time for now")
return
if args.dist == 'torch':
import torch.distributed as dist
elif args.dist == 'deepspeed':
Expand Down Expand Up @@ -78,8 +81,15 @@ def run_pt2pt(local_rank, args):
global_rank = dist.get_rank()
world_size = dist.get_world_size()

start_event = torch.cuda.Event(enable_timing=True)
end_event = torch.cuda.Event(enable_timing=True)
if args.device == "xpu":
start_event = torch.xpu.Event(enable_timing=True)
end_event = torch.xpu.Event(enable_timing=True)
elif args.device == "cpu":
start_event = torch.cpu.Event()
end_event = torch.cpu.Event()
else:
start_event = torch.cuda.Event(enable_timing=True)
end_event = torch.cuda.Event(enable_timing=True)

if args.scan:
# Create list of message sizes
Expand Down
6 changes: 6 additions & 0 deletions benchmarks/communication/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,11 @@ def get_bw(comm_op, size, duration, args):
n = dist.get_world_size()
tput = 0
busbw = 0

if duration == 0:
print_rank_0("Error. Duration is 0.")
return tput, busbw

if comm_op == "all_to_all":
tput = (size / duration)
busbw = (size / duration) * ((n - 1) / n)
Expand Down Expand Up @@ -235,4 +240,5 @@ def benchmark_parser():
default=.3,
help='Proportion of max available GPU memory to use for single-size evals')
parser.add_argument("--debug", action="store_true", help='Enables all_to_all debug prints')
parser.add_argument("--device", type=str, default=DEFAULT_DEVICE, help='target device')
return parser
88 changes: 88 additions & 0 deletions benchmarks/inference/deepspeedometer/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
# DeepSpeedometer

NOTE: This is an experimental tool and is not currently being supported since it's not fully functional. Please use the MII benchmark which can be found here:
https://github.com/microsoft/DeepSpeedExamples/tree/master/benchmarks/inference/mii

This benchmark is designed to measure performance of LLM serving solutions. Using a different number of parallel clients sending requests to an inference server, we gather data to plot throughput-latency curves and find the saturation point of an inference server that demonstrates the maximum performance.

## Installation

To install the benchmark, clone this repository and install using `pip`:
```shell
git clone https://github.com/Microsoft/DeepSpeedExamples
cd ./DeepSpeedExamples/benchmarks/deepspeedometer
pip install .
```

## Usage

To quickly test the benchmark code without creating an inference server, run the following:
```
python3 -m deepspeedometer.benchmark_runner --model facebook/opt-125m --api dummy
```

### Supports APIs

The benchmark supports different APIs, each with their own client type. Depending on the client, you may need to run the benchmark against a locally hosted inference server or a remote inference server. Adding support for new serving solutions can be achieved by creating a new client class that defines a few basic methods. See the section below on adding new clients for more information.

The clients (i.e., APIs) curently supported (and configuration options for each) are listed below. You can see more information about the configuration options by looking at the `*ClientConfig` classes located in `clients/*.py`:

1. `fastgen`: Runs a local model inference server with DeepSpeed's FastGen. Config options include:
- `model`: Which model to use for serving (required)
- `deployment_name`: Name of the deployment server
- `tp_size`: Tensor parallel size for each model replicas
- `num_replicas`: Number of model replicas
- `max_ragged_batch_size`: Max number of requests running per model replicas
- `quantization_mode`: Type of quantization to use
2. `vllm`: Runs a local model inference server with vLLM.
- `model`: Which model to use for serving (required)
- `tp_size`: Tensor parallel size for model
- `port`: Which port to use for REST API
3. `azureml`: Interfaces with remote AzureML online endpoint/deployment.
- `api_url`: AzureML endpoint API URL (required)
- `api_key`: AzureML token key for connecting to endpoint (required)
- `deployment_name`: Name of deployment hosted in given endpoint (required)

### Benchmark Configuration

The Benchmark has many options for tailoring performance measurements to a specific use-cases. For additional information and default values, see the `BenchmarkConfig` class defined in `benchmark_runner.py`.

- `api`: Which API to use
- `warmup_requests`: Number of warm-up requests to run before measuring performance
- `result_dir`: Directory where results will be written out (as JSON files)
- `use_threading`: Whether to use threading for the benchmark clients. Default is to use multi-processing
- `config_file`: One or more config YAML files that contain values for any of the Prompt configuration options (see below section on prompt configuration)
- `num_clients`: One or more integer values for the number of parallel clients to run
- `num_requests_per_client`: Number of requests that will be run by each of the parallel clients
- `min_requests`: Minimum number of requests to be sent during duration of benchmark. Useful when there is a low number of clients to ensure good measurement.
- `prompt_text_source`: Text file or string that will be sampled to generate request prompts
- `early_stop_latency`: When running multiple values for `num_clients`, if the average latency per request exceeds this value (in seconds) the benchmark will not test a larger number of parallel clients
- `force`: Force the overwrite of result files. By default, if a result file exists, the benchmark is skipped

### Prompt Configuration

These options allow users to modify the prompt input and generation behavior of the served models. Note that you can run multiple prompt configurations in a single command by using the `config_file` input as described in the Benchmark Configuration section.

- `model`: Which model to use for tokenizing prompts (required)
- `prompt_generator_seed`: Seed value for random number generation
- `max_prompt_length`: The maximum prompt length allowed
- `prompt_length`: Target mean prompt length
- `prompt_lenght_var`: Variance of generated prompt lengths
- `max_new_tokens`: Target mean number of generated tokens
- `max_new_tokens_var`: Variance of generated tokens
- `streaming`: Whether to enabled streaming output for generated tokens

#### About Prompt Generation

To mimic real-world serving scenarios, this benchmark samples prompt length and generated token length values from a normal distribution. This distribution can be manipulated with the `prompt_length*` and `max_new_tokens*` values in the prompt configuration. To get all prompt lengths and generation lengths to match exactly, set the `*_var` values to 0.

## Adding New Client APIs

The DeepSpeedometer benchmark was designed to allow easily adding support for new inference server solutions. To do so:

1. Create a new `*_client.py` file in the `clients/` directory.
2. Define a `*Client` class that inherits from the `BaseClient` class in `clients/base.py`. This class should define 5 methods: `start_service`, `stop_service`, `prepare_request`, `send_request`, and `process_response`. Take a look at the type hints for these methods in the `BaseClient` class to understand the expected inputs and outputs for each method.
3. Define a `*ClientConfig` class that inherits from the `BaseConfigModel` class. Place any configuration options (i.e., user-passed command line arguments) necessary for your defined `*Client` class in here.
4. Import the newly added `*Client` and `*ClientConfig` into `clients/__init__.py` and add them to the `client_config_classes` and `client_classes` dictionaries.

For the simplest example of adding a new client, take a look at the `clients/dummy_client.py` file where we have defined a client that does not stand up a server and only returns a sample of the input prompt after a short sleep cycle. We use this as a light-weight class for unit testing.
Loading

0 comments on commit d3cd036

Please sign in to comment.