Skip to content

Commit

Permalink
Merge branch 'master' into master
Browse files Browse the repository at this point in the history
  • Loading branch information
mrwyattii authored Feb 1, 2024
2 parents f2ad8d5 + 6863634 commit 553c692
Show file tree
Hide file tree
Showing 56 changed files with 1,576 additions and 995 deletions.
4 changes: 2 additions & 2 deletions applications/DeepSpeed-Chat/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -251,9 +251,9 @@ bash training_scripts/opt/single_gpu/run_1.3b.sh


### 🐼 Adding and using your own datasets in DeepSpeed-Chat
In addition to the datasets used in our example scripts, you can also add and use your own datasets. To do so, first you need to add a new Class in [training/utils/data/raw_datasets.py](https://github.com/microsoft/DeepSpeedExamples/blob/master/applications/DeepSpeed-Chat/training/utils/data/raw_datasets.py) to define the format when using your data. You need to make sure to follow the APIs and format defined in the PromptRawDataset class to ensure a consistent data format that DeepSpeed-Chat relies on. You can look at the existing classes to learn how to do so.
In addition to the datasets used in our example scripts, you can also add and use your own datasets. To do so, first you need to add a new Class in [dschat/utils/data/raw_datasets.py](https://github.com/microsoft/DeepSpeedExamples/blob/master/applications/DeepSpeed-Chat/dschat/utils/data/raw_datasets.py) to define the format when using your data. You need to make sure to follow the APIs and format defined in the PromptRawDataset class to ensure a consistent data format that DeepSpeed-Chat relies on. You can look at the existing classes to learn how to do so.

Second, you need to add an if condition in function get_raw_dataset in [training/utils/data/data_utils.py](https://github.com/microsoft/DeepSpeedExamples/blob/master/applications/DeepSpeed-Chat/training/utils/data/data_utils.py) corresponding to your new dataset. The dataset_name string in the if condition should be the dataset name you will provide as a arg for the training scripts. Last, you need to add your new dataset's dataset_name into your "--data_path" arg in your training scripts.
Second, you need to add an if condition in function get_raw_dataset in [dschat/utils/data/data_utils.py](https://github.com/microsoft/DeepSpeedExamples/blob/master/applications/DeepSpeed-Chat/dschat/utils/data/data_utils.py) corresponding to your new dataset. The dataset_name string in the if condition should be the dataset name you will provide as a arg for the training scripts. Last, you need to add your new dataset's dataset_name into your "--data_path" arg in your training scripts.
If you have downloaded huggingface datasets manually, you can add your local path into "--data_path", such as "--data_path ./relative/Dahoas/rm-static" and "--data_path /absolute/Dahoas/rm-static". Remember you should not make `data/` in your local path, it may cause an exception to `load_dataset`.

One thing to note is that some datasets may only have one response instead of two responses. For those datasets, you can only use them in step 1. And in such case, you should add the dataset_name as part of the "--sft_only_data_path" arg instead of the "--data_path" arg. One thing to note is that: If you plan to only do step 1 SFT, adding more single-response datasets is definitely beneficial. However, if you do plan to do steps 2 and 3, then adding too many single-response datasets during SFT could backfire: these data could be different from the data used for steps 2/3, generating different distributions which could cause training instability/worse model quality during step 2/3. That is part of the reason why we focused on trying the datasets with two responses and the preference, and always split a dataset into all 3 steps.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ Finetuning the Reward Model (RM) is more or less similar to Step-1 Supervised F

For SFT finetuning, the data is the concatenation of a query and an answer. However, for RM finetuning, each batch of data consists of two query-answer pairs, i.e., the same query with a high-score answer and a low-score answer. This also leads to the second difference as describe below.

👉**The training objective difference**
👉 **The training objective difference**

For RW, the training objective is the pairwise ranking score, i.e., for the two query-answer pairs, RM is supposed to give a higher score to the better answer. There are multiple ways to achieve this. In our implementation, we use either the end token of the sequence or the first padding token as the aggregated score and compare them. Others may also use the average score for the entire answer as an alternative.

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -258,8 +258,17 @@ def main():
# the LN that precedes it.
force_optimize_params = []
if "bigscience/bloom-" in args.model_name_or_path:
torch.nn.init.ones_(rm_model.rwtransformer.ln_f.weight)
torch.nn.init.zeros_(rm_model.rwtransformer.ln_f.bias)
zero_init_enabled = (args.zero_stage == 3)
params = [
rm_model.rwtranrsformer.ln_f.weight,
rm_model.rwtranrsformer.ln_f.bias
]
with deepspeed.zero.GatheredParameters(params,
modifier_rank=0,
enabled=zero_init_enabled):
if deepspeed.comm.get_rank() == 0 or not zero_init_enabled:
torch.nn.init.ones_(rm_model.rwtransformer.ln_f.weight)
torch.nn.init.zeros_(rm_model.rwtransformer.ln_f.bias)
force_optimize_params.extend(
['rwtransformer.ln_f.weight', 'rwtransformer.ln_f.bias'])

Expand Down
14 changes: 9 additions & 5 deletions benchmarks/communication/all_gather.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@


# Run all_gather and print metrics
def timed_all_gather(input, output, args):
def timed_all_gather(input, output, start_event, end_event, args):
if args.dist == 'torch':
import torch.distributed as dist

Expand All @@ -33,11 +33,12 @@ def timed_all_gather(input, output, args):
sync_all()

# time the actual comm op trials times and average it
pre = time.perf_counter()
start_event.record()
for i in range(args.trials):
all_gather_func(output, input, group=None, async_op=args.async_op)
end_event.record()
sync_all()
duration = time.perf_counter() - pre
duration = start_event.elapsed_time(end_event) / 1000

# maintain and clean performance data
avg_duration = duration / args.trials
Expand All @@ -63,6 +64,9 @@ def run_all_gather(local_rank, args):
global_rank = dist.get_rank()
world_size = dist.get_world_size()

start_event = torch.cuda.Event(enable_timing=True)
end_event = torch.cuda.Event(enable_timing=True)

if args.scan:
# Create list of message sizes
M_LIST = []
Expand Down Expand Up @@ -92,7 +96,7 @@ def run_all_gather(local_rank, args):
else:
raise e
sync_all()
timed_all_gather(input, output, args)
timed_all_gather(input, output, start_event, end_event, args)
else:
# all_gather_into_tensor saves memory
if ((args.dist == 'torch' or args.dist == 'deepspeed') and dist.has_all_gather_into_tensor()):
Expand Down Expand Up @@ -126,7 +130,7 @@ def run_all_gather(local_rank, args):
raise e

sync_all()
timed_all_gather(input, output, args)
timed_all_gather(input, output, start_event, end_event, args)


if __name__ == "__main__":
Expand Down
14 changes: 9 additions & 5 deletions benchmarks/communication/all_reduce.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from deepspeed.accelerator import get_accelerator


def timed_all_reduce(input, args):
def timed_all_reduce(input, start_event, end_event, args):
if args.dist == 'torch':
import torch.distributed as dist
elif args.dist == 'deepspeed':
Expand All @@ -27,11 +27,12 @@ def timed_all_reduce(input, args):
sync_all()

# time the actual comm op trials times and average it
pre = time.perf_counter()
start_event.record()
for i in range(args.trials):
dist.all_reduce(input, async_op=args.async_op)
end_event.record()
sync_all()
duration = time.perf_counter() - pre
duration = start_event.elapsed_time(end_event) / 1000

# maintain and clean performance data
avg_duration = duration / args.trials
Expand Down Expand Up @@ -59,6 +60,9 @@ def run_all_reduce(local_rank, args):
world_size = dist.get_world_size()
global_rank = dist.get_rank()

start_event = torch.cuda.Event(enable_timing=True)
end_event = torch.cuda.Event(enable_timing=True)

if args.scan:
M_LIST = []
for x in (2**p for p in range(1, args.maxsize)):
Expand All @@ -82,7 +86,7 @@ def run_all_reduce(local_rank, args):
else:
raise e
sync_all()
timed_all_reduce(input, args)
timed_all_reduce(input, start_event, end_event, args)
else:
# Send the biggest message size our GPUs can fit. If you're facing OOM errors, reduce the mem_factor
# Don't need output tensor, so we double mem_factor
Expand All @@ -104,7 +108,7 @@ def run_all_reduce(local_rank, args):
else:
raise e
sync_all()
timed_all_reduce(input, args)
timed_all_reduce(input, start_event, end_event, args)


if __name__ == "__main__":
Expand Down
14 changes: 9 additions & 5 deletions benchmarks/communication/all_to_all.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from deepspeed.accelerator import get_accelerator


def timed_all_to_all(input, output, args):
def timed_all_to_all(input, output, start_event, end_event, args):
if args.dist == 'torch':
import torch.distributed as dist
elif args.dist == 'deepspeed':
Expand All @@ -27,11 +27,12 @@ def timed_all_to_all(input, output, args):
sync_all()

# time the actual comm op trials times and average it
pre = time.perf_counter()
start_event.record()
for i in range(args.trials):
dist.all_to_all_single(output, input, async_op=args.async_op)
end_event.record()
sync_all()
duration = time.perf_counter() - pre
duration = start_event.elapsed_time(end_event) / 1000

# maintain and clean performance data
avg_duration = duration / args.trials
Expand All @@ -58,6 +59,9 @@ def run_all_to_all(local_rank, args):
# Prepare benchmark header
print_header(args, 'all_to_all')

start_event = torch.cuda.Event(enable_timing=True)
end_event = torch.cuda.Event(enable_timing=True)

if args.scan:
M_LIST = []
for x in (2**p for p in range(1, args.maxsize)):
Expand All @@ -83,7 +87,7 @@ def run_all_to_all(local_rank, args):
else:
raise e
sync_all()
timed_all_to_all(input, output, args)
timed_all_to_all(input, output, start_event, end_event, args)
else:
# Send the biggest message size our GPUs can fit. If you're facing OOM errors, reduce the mem_factor
elements_per_gpu = max_numel(comm_op='all_to_all',
Expand Down Expand Up @@ -118,7 +122,7 @@ def run_all_to_all(local_rank, args):
print(f"Before AllToAll Input List at rank {global_rank}: {input}")
dist.barrier()

timed_all_to_all(input, output, args)
timed_all_to_all(input, output, start_event, end_event, args)

if args.debug:
for i in range(world_size):
Expand Down
14 changes: 9 additions & 5 deletions benchmarks/communication/broadcast.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from deepspeed.accelerator import get_accelerator


def timed_broadcast(input, args):
def timed_broadcast(input, start_event, end_event, args):
if args.dist == 'torch':
import torch.distributed as dist
elif args.dist == 'deepspeed':
Expand All @@ -27,11 +27,12 @@ def timed_broadcast(input, args):
sync_all()

# time the actual comm op trials times and average it
pre = time.perf_counter()
start_event.record()
for i in range(args.trials):
dist.broadcast(input, 0, async_op=args.async_op)
end_event.record()
sync_all()
duration = time.perf_counter() - pre
duration = start_event.elapsed_time(end_event) / 1000

# maintain and clean performance data
avg_duration = duration / args.trials
Expand Down Expand Up @@ -59,6 +60,9 @@ def run_broadcast(local_rank, args):
world_size = dist.get_world_size()
global_rank = dist.get_rank()

start_event = torch.cuda.Event(enable_timing=True)
end_event = torch.cuda.Event(enable_timing=True)

if args.scan:
M_LIST = []
for x in (2**p for p in range(1, args.maxsize)):
Expand All @@ -82,7 +86,7 @@ def run_broadcast(local_rank, args):
else:
raise e
sync_all()
timed_broadcast(input, args)
timed_broadcast(input, start_event, end_event, args)
else:
# Send the biggest message size our GPUs can fit. If you're facing OOM errors, reduce the mem_factor
# Don't need output tensor, so we double mem_factor
Expand All @@ -102,7 +106,7 @@ def run_broadcast(local_rank, args):
sync_all()
return
sync_all()
timed_broadcast(input, args)
timed_broadcast(input, start_event, end_event, args)


if __name__ == "__main__":
Expand Down
14 changes: 9 additions & 5 deletions benchmarks/communication/pt2pt.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from deepspeed.accelerator import get_accelerator


def timed_pt2pt(input, args):
def timed_pt2pt(input, start_event, end_event, args):
if args.dist == 'torch':
import torch.distributed as dist
elif args.dist == 'deepspeed':
Expand All @@ -36,7 +36,7 @@ def timed_pt2pt(input, args):
sync_all()

# time the actual comm op trials times and average it
pre = time.perf_counter()
start_event.record()
for i in range(args.trials):
if dist.get_rank() == 0:
if args.async_op:
Expand All @@ -49,8 +49,9 @@ def timed_pt2pt(input, args):
else:
dist.recv(input, src=0)

end_event.record()
sync_all()
duration = time.perf_counter() - pre
duration = start_event.elapsed_time(end_event) / 1000

# maintain and clean performance data
avg_duration = duration / args.trials
Expand All @@ -77,6 +78,9 @@ def run_pt2pt(local_rank, args):
global_rank = dist.get_rank()
world_size = dist.get_world_size()

start_event = torch.cuda.Event(enable_timing=True)
end_event = torch.cuda.Event(enable_timing=True)

if args.scan:
# Create list of message sizes
M_LIST = []
Expand All @@ -101,7 +105,7 @@ def run_pt2pt(local_rank, args):
else:
raise e
sync_all()
timed_pt2pt(input, args)
timed_pt2pt(input, start_event, end_event, args)
else:
# Send the biggest message size our GPUs can fit. If you're facing OOM errors, reduce the mem_factor
# Don't need output tensor, so double mem_factor
Expand All @@ -121,7 +125,7 @@ def run_pt2pt(local_rank, args):
sync_all()
return
sync_all()
timed_pt2pt(input, args)
timed_pt2pt(input, start_event, end_event, args)


if __name__ == "__main__":
Expand Down
49 changes: 35 additions & 14 deletions benchmarks/inference/mii/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,38 +2,59 @@

## Run the Benchmark

The benchmarking scripts use DeepSpeed-FastGen in the persistent mode.
You can start the server with the command below:
The benchmarking scripts use DeepSpeed-FastGen in the persistent mode. You can
run the benchmark using `run_benchmark.py`. This script will run several
combinations of inference servers and clients with different tensor parallel
size, number of model replicas (MII only), number of clients, prompt length, and
max new tokens values. By default, the benchmark will run with the `meta-llama/Llama-2-7b-hf` model.

```bash
python server.py [options] start
python run_benchmark.py
```

Use the -h option to view all available options. To stop the server, use this command:
Use the -h option to view all available options. Several models have pre-defined
default values, including `meta-llama/Llama-2-{7|13|70}b-hf`,
`tiiuae/falcon-{40|180}B`, `microsoft/phi-2`, and `mistralai/Mixtral-8x7B-v0.1`.
These defaults can be overridden if provided to the `run_benchmark.py` script.
For example, to run `meta-llama/Llama-13b-hf` with a tensor parallel size of `1`
and `2` (instead of the default `1`, `2`, and `4`):

```bash
python server.py stop
```bash
python run_benchmark.py --tp_size 1 2
```

Once the server is up and running, initiate the client using the command below. The -h option will display all the possible options.
By default the benchmark runs with DeepSpeed-MII as the backend inference
server. To change the backend to vLLM, provide the `--vllm` flag:

```bash
python run_benchmark_client.py [options]
python run_benchmark.py --vllm
```

The run_all.sh script performs benchmarks across various model sizes and client numbers. For VLLM benchmarks, use the run_all_vllm.sh script. Results are logged in a directory named logs.[BENCHMARK_PARAMETERS].
The run_all.sh script performs benchmarks across various models, client numbers,
tensor parallel sizes, etc. This script is intended to be run on a system with
8xA100 (80GB) GPUs available. It will run all the benchmarks (including vLLM)
and collect the data used in our [DeepSpeed-Fastgen
blogs](https://github.com/microsoft/DeepSpeed/tree/master/blogs/deepspeed-fastgen).
Results are collected in `./results/`.

## Analyze the Benchmark Results

The scripts mentioned below were used for generating the plots featured in our blog. Specify the root directory for log files using --log_dir.
The scripts mentioned below were used for generating the plots featured in our
blog. Specify the root directory for log files using `--log_dir`. The generated
figures will be saved to `./plots/`

- `plot_th_lat.py`: This script generates charts for throughput and latency across different model sizes and client counts.
- `plot_effective_throughput.py`: Use this to chart effective throughput.
- `plot_latency_percentile.py`: This script will plot the 50th, 90th, and 95th percentile latencies.
- `src/plot_th_lat.py`: This script generates charts for throughput and latency across different model sizes and client counts.
- `src/plot_effective_throughput.py`: Use this to chart effective throughput.
- `src/plot_latency_percentile.py`: This script will plot the 50th, 90th, and 95th percentile latencies.

## Running an End-to-End Example

To quickly experience the end-to-end process of running our benchmark and getting results, you can use the `run_example.sh`. This script is designed to execute the benchmark with a specific configuration. The plots below will be generated in the charts directory. These plots show the performance as depicted in figure 8 of our blog [post.](https://github.com/microsoft/DeepSpeed/tree/master/blogs/deepspeed-fastgen#f-other-hardware-platforms)
To quickly experience the end-to-end process of running our benchmark and
getting results, you can use the `run_example.sh`. This script is designed to
execute the benchmark with a specific configuration. The plots below will be
generated in the `./plots/` directory. These plots show the performance as
depicted in figure 8 of our blog
[post.](https://github.com/microsoft/DeepSpeed/tree/master/blogs/deepspeed-fastgen#f-other-hardware-platforms)

```bash
bash run_example.sh
Expand Down
Loading

0 comments on commit 553c692

Please sign in to comment.