From 144e913779df2a907f30f64150ff557186369eb8 Mon Sep 17 00:00:00 2001 From: Konstantinos Fertakis Date: Tue, 1 Oct 2024 17:26:37 +0100 Subject: [PATCH 1/4] enable reward model offloading option --- .../DeepSpeed-Chat/dschat/rlhf/rlhf_engine.py | 17 ++++------------- .../training/step3_rlhf_finetuning/main.py | 4 ++++ 2 files changed, 8 insertions(+), 13 deletions(-) diff --git a/applications/DeepSpeed-Chat/dschat/rlhf/rlhf_engine.py b/applications/DeepSpeed-Chat/dschat/rlhf/rlhf_engine.py index 5b6778cc2..f518076c7 100755 --- a/applications/DeepSpeed-Chat/dschat/rlhf/rlhf_engine.py +++ b/applications/DeepSpeed-Chat/dschat/rlhf/rlhf_engine.py @@ -268,23 +268,14 @@ def _init_reward(self, critic_model_name_or_path): # If critic is ZeRO-3 then we use it for everything, otherwise assume we have enough memory zero_stage = 0 - ds_config = get_eval_ds_config(offload=self.args.offload, - dtype=self.args.dtype, - stage=zero_stage) - ds_config[ - 'train_micro_batch_size_per_gpu'] = self.args.per_device_training_batch_size - ds_config[ - 'train_batch_size'] = self.args.per_device_training_batch_size * torch.distributed.get_world_size( - ) * self.args.gradient_accumulation_steps - - ds_eval_config = get_eval_ds_config(offload=False, + ds_config = get_eval_ds_config(offload=self.args.offload_reward_model, dtype=self.args.dtype, stage=zero_stage) # We need to set train batch size and micro batch size here to pass the sanity check of DeepSpeed engine. - ds_eval_config[ + ds_config[ 'train_micro_batch_size_per_gpu'] = self.args.per_device_training_batch_size - ds_eval_config[ + ds_config[ 'train_batch_size'] = self.args.per_device_training_batch_size * torch.distributed.get_world_size( ) * self.args.gradient_accumulation_steps @@ -292,7 +283,7 @@ def _init_reward(self, critic_model_name_or_path): reward_model = create_critic_model( model_name_or_path=critic_model_name_or_path, tokenizer=self.tokenizer, - ds_config=ds_eval_config, + ds_config=ds_config, num_padding_at_beginning=self.args.num_padding_at_beginning, rlhf_training=True, dropout=self.args.critic_dropout, diff --git a/applications/DeepSpeed-Chat/training/step3_rlhf_finetuning/main.py b/applications/DeepSpeed-Chat/training/step3_rlhf_finetuning/main.py index a5be5671b..f3db70e05 100644 --- a/applications/DeepSpeed-Chat/training/step3_rlhf_finetuning/main.py +++ b/applications/DeepSpeed-Chat/training/step3_rlhf_finetuning/main.py @@ -246,6 +246,10 @@ def parse_args(): '--offload_reference_model', action='store_true', help='Enable ZeRO Offload techniques for reference model') + parser.add_argument( + '--offload_reward_model', + action='store_true', + help='Enable ZeRO Offload techniques for reward model') parser.add_argument( '--actor_zero_stage', type=int, From 82c59abfc44c4681ffd86d3b3c18a895d605c6e8 Mon Sep 17 00:00:00 2001 From: Konstantinos Fertakis Date: Mon, 14 Oct 2024 16:28:20 +0100 Subject: [PATCH 2/4] fixed code formatting --- applications/DeepSpeed-Chat/dschat/rlhf/rlhf_engine.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/applications/DeepSpeed-Chat/dschat/rlhf/rlhf_engine.py b/applications/DeepSpeed-Chat/dschat/rlhf/rlhf_engine.py index f518076c7..0e67efcf9 100755 --- a/applications/DeepSpeed-Chat/dschat/rlhf/rlhf_engine.py +++ b/applications/DeepSpeed-Chat/dschat/rlhf/rlhf_engine.py @@ -269,8 +269,8 @@ def _init_reward(self, critic_model_name_or_path): zero_stage = 0 ds_config = get_eval_ds_config(offload=self.args.offload_reward_model, - dtype=self.args.dtype, - stage=zero_stage) + dtype=self.args.dtype, + stage=zero_stage) # We need to set train batch size and micro batch size here to pass the sanity check of DeepSpeed engine. ds_config[ From 2babdd3c260d47362dc09873f6a1e2cc16af8feb Mon Sep 17 00:00:00 2001 From: Konstantinos Fertakis Date: Tue, 15 Oct 2024 10:13:04 +0100 Subject: [PATCH 3/4] more formatting fixes --- .../DeepSpeed-Chat/dschat/utils/model/reward_model.py | 4 ++-- .../DeepSpeed-Chat/training/step3_rlhf_finetuning/main.py | 7 +++---- 2 files changed, 5 insertions(+), 6 deletions(-) diff --git a/applications/DeepSpeed-Chat/dschat/utils/model/reward_model.py b/applications/DeepSpeed-Chat/dschat/utils/model/reward_model.py index 60d063b18..4f29d0dd8 100644 --- a/applications/DeepSpeed-Chat/dschat/utils/model/reward_model.py +++ b/applications/DeepSpeed-Chat/dschat/utils/model/reward_model.py @@ -98,8 +98,8 @@ def forward(self, else: # Check if there is any padding otherwise take length of sequence r_inds = (rejected_id == self.PAD_ID).nonzero() - r_ind = r_inds[self.num_padding_at_beginning].item( - ) if len(r_inds) > self.num_padding_at_beginning else seq_len + r_ind = r_inds[self.num_padding_at_beginning].item() if len( + r_inds) > self.num_padding_at_beginning else seq_len end_ind = max(c_ind, r_ind) divergence_ind = check_divergence[0] assert divergence_ind > 0 diff --git a/applications/DeepSpeed-Chat/training/step3_rlhf_finetuning/main.py b/applications/DeepSpeed-Chat/training/step3_rlhf_finetuning/main.py index f3db70e05..1378dc4e6 100644 --- a/applications/DeepSpeed-Chat/training/step3_rlhf_finetuning/main.py +++ b/applications/DeepSpeed-Chat/training/step3_rlhf_finetuning/main.py @@ -246,10 +246,9 @@ def parse_args(): '--offload_reference_model', action='store_true', help='Enable ZeRO Offload techniques for reference model') - parser.add_argument( - '--offload_reward_model', - action='store_true', - help='Enable ZeRO Offload techniques for reward model') + parser.add_argument('--offload_reward_model', + action='store_true', + help='Enable ZeRO Offload techniques for reward model') parser.add_argument( '--actor_zero_stage', type=int, From 9a44a43b10b0234a1462c6b44cdc5d554d4b531b Mon Sep 17 00:00:00 2001 From: Logan Adams Date: Tue, 29 Oct 2024 15:56:20 -0700 Subject: [PATCH 4/4] Pre-commit formatting fix --- .../DeepSpeed-Chat/dschat/utils/model/reward_model.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/applications/DeepSpeed-Chat/dschat/utils/model/reward_model.py b/applications/DeepSpeed-Chat/dschat/utils/model/reward_model.py index 4f29d0dd8..60d063b18 100644 --- a/applications/DeepSpeed-Chat/dschat/utils/model/reward_model.py +++ b/applications/DeepSpeed-Chat/dschat/utils/model/reward_model.py @@ -98,8 +98,8 @@ def forward(self, else: # Check if there is any padding otherwise take length of sequence r_inds = (rejected_id == self.PAD_ID).nonzero() - r_ind = r_inds[self.num_padding_at_beginning].item() if len( - r_inds) > self.num_padding_at_beginning else seq_len + r_ind = r_inds[self.num_padding_at_beginning].item( + ) if len(r_inds) > self.num_padding_at_beginning else seq_len end_ind = max(c_ind, r_ind) divergence_ind = check_divergence[0] assert divergence_ind > 0