Skip to content

Commit

Permalink
Merge branch 'master' into fix-eot_token_knob
Browse files Browse the repository at this point in the history
  • Loading branch information
loadams authored Oct 29, 2024
2 parents 1a9875d + aa4459f commit 06c42f8
Show file tree
Hide file tree
Showing 132 changed files with 4,953 additions and 1,693 deletions.
4 changes: 2 additions & 2 deletions applications/DeepSpeed-Chat/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -251,9 +251,9 @@ bash training_scripts/opt/single_gpu/run_1.3b.sh


### 🐼 Adding and using your own datasets in DeepSpeed-Chat
In addition to the datasets used in our example scripts, you can also add and use your own datasets. To do so, first you need to add a new Class in [training/utils/data/raw_datasets.py](https://github.com/microsoft/DeepSpeedExamples/blob/master/applications/DeepSpeed-Chat/training/utils/data/raw_datasets.py) to define the format when using your data. You need to make sure to follow the APIs and format defined in the PromptRawDataset class to ensure a consistent data format that DeepSpeed-Chat relies on. You can look at the existing classes to learn how to do so.
In addition to the datasets used in our example scripts, you can also add and use your own datasets. To do so, first you need to add a new Class in [dschat/utils/data/raw_datasets.py](https://github.com/microsoft/DeepSpeedExamples/blob/master/applications/DeepSpeed-Chat/dschat/utils/data/raw_datasets.py) to define the format when using your data. You need to make sure to follow the APIs and format defined in the PromptRawDataset class to ensure a consistent data format that DeepSpeed-Chat relies on. You can look at the existing classes to learn how to do so.

Second, you need to add an if condition in function get_raw_dataset in [training/utils/data/data_utils.py](https://github.com/microsoft/DeepSpeedExamples/blob/master/applications/DeepSpeed-Chat/training/utils/data/data_utils.py) corresponding to your new dataset. The dataset_name string in the if condition should be the dataset name you will provide as a arg for the training scripts. Last, you need to add your new dataset's dataset_name into your "--data_path" arg in your training scripts.
Second, you need to add an if condition in function get_raw_dataset in [dschat/utils/data/data_utils.py](https://github.com/microsoft/DeepSpeedExamples/blob/master/applications/DeepSpeed-Chat/dschat/utils/data/data_utils.py) corresponding to your new dataset. The dataset_name string in the if condition should be the dataset name you will provide as a arg for the training scripts. Last, you need to add your new dataset's dataset_name into your "--data_path" arg in your training scripts.
If you have downloaded huggingface datasets manually, you can add your local path into "--data_path", such as "--data_path ./relative/Dahoas/rm-static" and "--data_path /absolute/Dahoas/rm-static". Remember you should not make `data/` in your local path, it may cause an exception to `load_dataset`.

One thing to note is that some datasets may only have one response instead of two responses. For those datasets, you can only use them in step 1. And in such case, you should add the dataset_name as part of the "--sft_only_data_path" arg instead of the "--data_path" arg. One thing to note is that: If you plan to only do step 1 SFT, adding more single-response datasets is definitely beneficial. However, if you do plan to do steps 2 and 3, then adding too many single-response datasets during SFT could backfire: these data could be different from the data used for steps 2/3, generating different distributions which could cause training instability/worse model quality during step 2/3. That is part of the reason why we focused on trying the datasets with two responses and the preference, and always split a dataset into all 3 steps.
Expand Down
17 changes: 4 additions & 13 deletions applications/DeepSpeed-Chat/dschat/rlhf/rlhf_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -268,31 +268,22 @@ def _init_reward(self, critic_model_name_or_path):
# If critic is ZeRO-3 then we use it for everything, otherwise assume we have enough memory
zero_stage = 0

ds_config = get_eval_ds_config(offload=self.args.offload,
ds_config = get_eval_ds_config(offload=self.args.offload_reward_model,
dtype=self.args.dtype,
stage=zero_stage)
ds_config[
'train_micro_batch_size_per_gpu'] = self.args.per_device_training_batch_size
ds_config[
'train_batch_size'] = self.args.per_device_training_batch_size * torch.distributed.get_world_size(
) * self.args.gradient_accumulation_steps

ds_eval_config = get_eval_ds_config(offload=False,
dtype=self.args.dtype,
stage=zero_stage)

# We need to set train batch size and micro batch size here to pass the sanity check of DeepSpeed engine.
ds_eval_config[
ds_config[
'train_micro_batch_size_per_gpu'] = self.args.per_device_training_batch_size
ds_eval_config[
ds_config[
'train_batch_size'] = self.args.per_device_training_batch_size * torch.distributed.get_world_size(
) * self.args.gradient_accumulation_steps

# Model
reward_model = create_critic_model(
model_name_or_path=critic_model_name_or_path,
tokenizer=self.tokenizer,
ds_config=ds_eval_config,
ds_config=ds_config,
num_padding_at_beginning=self.args.num_padding_at_beginning,
rlhf_training=True,
dropout=self.args.critic_dropout,
Expand Down
10 changes: 7 additions & 3 deletions applications/DeepSpeed-Chat/dschat/utils/data/data_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,9 +149,13 @@ def __len__(self):
def __getitem__(self, idx):
if self.train_phase == 1:
return {
"input_ids": self.chosen_dataset[idx]["input_ids"],
"attention_mask": self.chosen_dataset[idx]["attention_mask"],
"labels": self.chosen_dataset[idx]["input_ids"]
"input_ids":
self.chosen_dataset[idx]["input_ids"],
"attention_mask":
self.chosen_dataset[idx]["attention_mask"],
"labels":
torch.where(self.chosen_dataset[idx]["attention_mask"].bool(),
self.chosen_dataset[idx]["input_ids"], -100)
}
elif self.train_phase == 2:
return self.chosen_dataset[idx]["input_ids"], self.chosen_dataset[idx]["attention_mask"], \
Expand Down
1 change: 1 addition & 0 deletions applications/DeepSpeed-Chat/dschat/utils/ds_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ def get_train_ds_config(offload,
dtype_config = {"enabled": True}
zero_opt_dict = {
"stage": stage,
"overlap_comm": True,
"offload_param": {
"device": device
},
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
AutoModel,
)
from huggingface_hub import snapshot_download
from transformers.deepspeed import HfDeepSpeedConfig
from transformers.integrations.deepspeed import HfDeepSpeedConfig

from dschat.utils.model.reward_model import RewardModel
from dschat.utils.utils import load_state_dict_into_model, print_rank_0
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -270,6 +270,7 @@ def main():
args.seed,
tokenizer,
args.max_seq_len,
end_of_conversation_token=tokenizer.eos_token,
sft_only_data_path=args.sft_only_data_path)
# DataLoaders creation:
if args.local_rank == -1:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ Finetuning the Reward Model (RM) is more or less similar to Step-1 Supervised F

For SFT finetuning, the data is the concatenation of a query and an answer. However, for RM finetuning, each batch of data consists of two query-answer pairs, i.e., the same query with a high-score answer and a low-score answer. This also leads to the second difference as describe below.

👉**The training objective difference**
👉 **The training objective difference**

For RW, the training objective is the pairwise ranking score, i.e., for the two query-answer pairs, RM is supposed to give a higher score to the better answer. There are multiple ways to achieve this. In our implementation, we use either the end token of the sequence or the first padding token as the aggregated score and compare them. Others may also use the average score for the entire answer as an alternative.

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -258,8 +258,17 @@ def main():
# the LN that precedes it.
force_optimize_params = []
if "bigscience/bloom-" in args.model_name_or_path:
torch.nn.init.ones_(rm_model.rwtransformer.ln_f.weight)
torch.nn.init.zeros_(rm_model.rwtransformer.ln_f.bias)
zero_init_enabled = (args.zero_stage == 3)
params = [
rm_model.rwtranrsformer.ln_f.weight,
rm_model.rwtranrsformer.ln_f.bias
]
with deepspeed.zero.GatheredParameters(params,
modifier_rank=0,
enabled=zero_init_enabled):
if deepspeed.comm.get_rank() == 0 or not zero_init_enabled:
torch.nn.init.ones_(rm_model.rwtransformer.ln_f.weight)
torch.nn.init.zeros_(rm_model.rwtransformer.ln_f.bias)
force_optimize_params.extend(
['rwtransformer.ln_f.weight', 'rwtransformer.ln_f.bias'])

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -246,6 +246,9 @@ def parse_args():
'--offload_reference_model',
action='store_true',
help='Enable ZeRO Offload techniques for reference model')
parser.add_argument('--offload_reward_model',
action='store_true',
help='Enable ZeRO Offload techniques for reward model')
parser.add_argument(
'--actor_zero_stage',
type=int,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
os.path.abspath(os.path.join(os.path.dirname(__file__), os.path.pardir)))
import data.DST as DST # default special tokens
from torch.utils.data import DataLoader
from transformers.deepspeed import HfDeepSpeedConfig
from transformers.integrations.deepspeed import HfDeepSpeedConfig
import numpy as np
from .vis_proj import VisProjection_vit, VisProjection_perceiver

Expand Down
10 changes: 9 additions & 1 deletion benchmarks/communication/README.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# The DeepSpeed Communication Benchmarking Suite

The intent of these benchmarks is to measure communication latency/bw of deepspeed and/or pytorch distributed communication operations at the Python layer. These benchmarks are complementary to C-level comms benchmarks like [OSU Micro-Benchmarks](https://mvapich.cse.ohio-state.edu/benchmarks/) and [NCCL Tests](https://github.com/NVIDIA/nccl-tests) in that users can:
The intent of these benchmarks is to measure communication latency/bw of deepspeed and/or pytorch distributed communication operations at the Python layer. These benchmarks are complementary to C-level comms benchmarks like [OSU Micro-Benchmarks](https://mvapich.cse.ohio-state.edu/benchmarks/) , [NCCL Tests](https://github.com/NVIDIA/nccl-tests) and [oneCCL Benchmark](https://oneapi-src.github.io/oneCCL/benchmark.html) in that users can:
- Easily debug which layer of the communication software stack hangs or performance degradations originate from.
- Measure the expected communication performance of either DeepSpeed comms or pure PyTorch distributed

Expand Down Expand Up @@ -77,6 +77,14 @@ Finally, users can choose specific communication operations to run in `run_all.p
deepspeed run_all.py --scan --all-reduce --all-to-all --broadcast
</pre>

## CPU Support
Those benchmarks could also support other devices like Intel CPU via oneCCL.
Users just need to append one more argument "--device cpu" for all python scripts to run on Intel CPU.
For example, run with a single large message size on Intel CPU:
<pre>
deepspeed all_reduce.py --device cpu
</pre>


# Adding Communication Benchmarks

Expand Down
14 changes: 12 additions & 2 deletions benchmarks/communication/all_gather.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,9 @@

# Run all_gather and print metrics
def timed_all_gather(input, output, start_event, end_event, args):
if args.device == "cpu":
print_rank_0(f"No Event support on CPU to measure time for now")
return
if args.dist == 'torch':
import torch.distributed as dist

Expand Down Expand Up @@ -64,8 +67,15 @@ def run_all_gather(local_rank, args):
global_rank = dist.get_rank()
world_size = dist.get_world_size()

start_event = torch.cuda.Event(enable_timing=True)
end_event = torch.cuda.Event(enable_timing=True)
if args.device == "xpu":
start_event = torch.xpu.Event(enable_timing=True)
end_event = torch.xpu.Event(enable_timing=True)
elif args.device == "cpu":
start_event = torch.cpu.Event()
end_event = torch.cpu.Event()
else:
start_event = torch.cuda.Event(enable_timing=True)
end_event = torch.cuda.Event(enable_timing=True)

if args.scan:
# Create list of message sizes
Expand Down
14 changes: 12 additions & 2 deletions benchmarks/communication/all_reduce.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,9 @@


def timed_all_reduce(input, start_event, end_event, args):
if args.device == "cpu":
print_rank_0(f"No Event support on CPU to measure time for now")
return
if args.dist == 'torch':
import torch.distributed as dist
elif args.dist == 'deepspeed':
Expand Down Expand Up @@ -60,8 +63,15 @@ def run_all_reduce(local_rank, args):
world_size = dist.get_world_size()
global_rank = dist.get_rank()

start_event = torch.cuda.Event(enable_timing=True)
end_event = torch.cuda.Event(enable_timing=True)
if args.device == "xpu":
start_event = torch.xpu.Event(enable_timing=True)
end_event = torch.xpu.Event(enable_timing=True)
elif args.device == "cpu":
start_event = torch.cpu.Event()
end_event = torch.cpu.Event()
else:
start_event = torch.cuda.Event(enable_timing=True)
end_event = torch.cuda.Event(enable_timing=True)

if args.scan:
M_LIST = []
Expand Down
14 changes: 12 additions & 2 deletions benchmarks/communication/all_to_all.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,9 @@


def timed_all_to_all(input, output, start_event, end_event, args):
if args.device == "cpu":
print_rank_0(f"No Event support on CPU to measure time for now")
return
if args.dist == 'torch':
import torch.distributed as dist
elif args.dist == 'deepspeed':
Expand Down Expand Up @@ -59,8 +62,15 @@ def run_all_to_all(local_rank, args):
# Prepare benchmark header
print_header(args, 'all_to_all')

start_event = torch.cuda.Event(enable_timing=True)
end_event = torch.cuda.Event(enable_timing=True)
if args.device == "xpu":
start_event = torch.xpu.Event(enable_timing=True)
end_event = torch.xpu.Event(enable_timing=True)
elif args.device == "cpu":
start_event = torch.cpu.Event()
end_event = torch.cpu.Event()
else:
start_event = torch.cuda.Event(enable_timing=True)
end_event = torch.cuda.Event(enable_timing=True)

if args.scan:
M_LIST = []
Expand Down
14 changes: 12 additions & 2 deletions benchmarks/communication/broadcast.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,9 @@


def timed_broadcast(input, start_event, end_event, args):
if args.device == "cpu":
print_rank_0(f"No Event support on CPU to measure time for now")
return
if args.dist == 'torch':
import torch.distributed as dist
elif args.dist == 'deepspeed':
Expand Down Expand Up @@ -60,8 +63,15 @@ def run_broadcast(local_rank, args):
world_size = dist.get_world_size()
global_rank = dist.get_rank()

start_event = torch.cuda.Event(enable_timing=True)
end_event = torch.cuda.Event(enable_timing=True)
if args.device == "xpu":
start_event = torch.xpu.Event(enable_timing=True)
end_event = torch.xpu.Event(enable_timing=True)
elif args.device == "cpu":
start_event = torch.cpu.Event()
end_event = torch.cpu.Event()
else:
start_event = torch.cuda.Event(enable_timing=True)
end_event = torch.cuda.Event(enable_timing=True)

if args.scan:
M_LIST = []
Expand Down
1 change: 1 addition & 0 deletions benchmarks/communication/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,4 +12,5 @@
DEFAULT_UNIT = 'Gbps'
DEFAULT_DIST = 'deepspeed'
DEFAULT_MAXSIZE = 24
DEFAULT_DEVICE = 'cuda'
TORCH_DISTRIBUTED_DEFAULT_PORT = 29500
14 changes: 12 additions & 2 deletions benchmarks/communication/pt2pt.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,9 @@


def timed_pt2pt(input, start_event, end_event, args):
if args.device == "cpu":
print_rank_0(f"No Event support on CPU to measure time for now")
return
if args.dist == 'torch':
import torch.distributed as dist
elif args.dist == 'deepspeed':
Expand Down Expand Up @@ -78,8 +81,15 @@ def run_pt2pt(local_rank, args):
global_rank = dist.get_rank()
world_size = dist.get_world_size()

start_event = torch.cuda.Event(enable_timing=True)
end_event = torch.cuda.Event(enable_timing=True)
if args.device == "xpu":
start_event = torch.xpu.Event(enable_timing=True)
end_event = torch.xpu.Event(enable_timing=True)
elif args.device == "cpu":
start_event = torch.cpu.Event()
end_event = torch.cpu.Event()
else:
start_event = torch.cuda.Event(enable_timing=True)
end_event = torch.cuda.Event(enable_timing=True)

if args.scan:
# Create list of message sizes
Expand Down
6 changes: 6 additions & 0 deletions benchmarks/communication/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,11 @@ def get_bw(comm_op, size, duration, args):
n = dist.get_world_size()
tput = 0
busbw = 0

if duration == 0:
print_rank_0("Error. Duration is 0.")
return tput, busbw

if comm_op == "all_to_all":
tput = (size / duration)
busbw = (size / duration) * ((n - 1) / n)
Expand Down Expand Up @@ -235,4 +240,5 @@ def benchmark_parser():
default=.3,
help='Proportion of max available GPU memory to use for single-size evals')
parser.add_argument("--debug", action="store_true", help='Enables all_to_all debug prints')
parser.add_argument("--device", type=str, default=DEFAULT_DEVICE, help='target device')
return parser
Loading

0 comments on commit 06c42f8

Please sign in to comment.