Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

deepspeed-chat: print mean stage1/2 loss periodically #780

Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
29 commits
Select commit Hold shift + click to select a range
9f72c16
deepspeed-chat: support any model in chatbot
mosheisland Sep 11, 2023
e19d179
deepspeed-chat: handle overflow for bf16_optimizer
mosheisland Sep 26, 2023
dcba0a7
deepspeed-chat: support explicit configuration of dropout
mosheisland Sep 12, 2023
7a655bc
deepspeed-chat: support periodic eval in stage2
mosheisland Sep 12, 2023
d28dbbd
deepspeed-chat: calculate loss in fp32 when using bf16
mosheisland Sep 12, 2023
0801740
deepspeed-chat: fix weight decay configuration
mosheisland Sep 15, 2023
73f3ace
deepspeed-chat: fix incorrect lr when using lora only
mosheisland Sep 15, 2023
759bf63
deepspeed-chat: train v_head when only optimizing lora
mosheisland Sep 12, 2023
1975eee
deepspeed-chat: fix bf16 stage2 accuracy for bloom-560m
mosheisland Sep 12, 2023
698f648
deepspeed-chat: fix training stage1 ppl calculation
mosheisland Sep 12, 2023
d2bca11
deepspeed-chat: fix rw_eval
mosheisland Sep 12, 2023
43ae1c5
deepspeed-chat: add end-of-text special token
mosheisland Sep 13, 2023
10b889e
deepspeed-chat: print average stage1/2 loss periodically
mosheisland Sep 13, 2023
2225fa1
deepspeed-chat: display reward ema in stage3
mosheisland Sep 13, 2023
7099b7f
deepspeed-chat: handle stage3 generate too short
mosheisland Sep 13, 2023
ec2c6c8
deepspeed-chat: support print answers interval
mosheisland Sep 13, 2023
09c48ee
deepspeed-chat: filter prompts too long
mosheisland Sep 13, 2023
07e4742
deepspeed-chat [internal]: support using torch adamw
mosheisland Sep 13, 2023
36894a0
deepspeed-chat [internal]: add bloom training scripts
mosheisland Sep 14, 2023
8aab515
deepspeed-chat [internal]: fix bad access to DeepSpeedEngine.model
mosheisland Sep 19, 2023
f71b14c
deepspeed-chat [internal]: add support for hpu
mosheisland Sep 18, 2023
e94eb22
deepspeed-chat [internal]: optimize stage2 for hpu
mosheisland Sep 18, 2023
26d2ef8
deepspeed-chat: print mean stage1/2 loss periodically
mosheisland Sep 13, 2023
891c93f
Merge branch 'master' into 12_print_step_1_2_loss_periodically
tjruwase Oct 31, 2023
05a56d4
Merge branch 'master' into 12_print_step_1_2_loss_periodically
tjruwase Nov 1, 2023
0d6a7b5
Merge branch 'master' into 12_print_step_1_2_loss_periodically
mosheisland Nov 5, 2023
e4c5460
Fix merge conflicts
loadams Nov 4, 2024
d378c25
Merge master from upstream
loadams Nov 4, 2024
fa87108
Un-add files
loadams Nov 4, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
65 changes: 52 additions & 13 deletions applications/DeepSpeed-Chat/dschat/rlhf/ppo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,15 +7,15 @@
import time
import deepspeed
from deepspeed.runtime.zero.partition_parameters import ZeroParamStatus
from deepspeed.accelerator import get_accelerator

from dschat.utils.utils import print_rank_0


def print_all_ranks(tag, value, rank):
world_size = torch.distributed.get_world_size()
all_tensor = torch.zeros(world_size, dtype=torch.float32).to(
get_accelerator().current_device_name())
all_tensor = torch.zeros(world_size,
dtype=torch.float32,
device=value.device)
all_tensor[rank] = value
torch.distributed.all_reduce(all_tensor, op=torch.distributed.ReduceOp.SUM)
print_rank_0(f'{tag} {all_tensor}', rank)
Expand Down Expand Up @@ -55,7 +55,8 @@ def __init__(self, rlhf_engine, args):
self.end_of_conversation_token_id = self.tokenizer(
args.end_of_conversation_token)['input_ids'][-1]
self.z3_enabled = args.actor_zero_stage == 3
self.compute_fp32_loss = self.args.compute_fp32_loss
self.calculate_fp32_loss = (self.args.dtype
== "bf16") and self.args.bf16_to_fp32_loss

# In case the generated experience is not valid (too short), we use the last valid
# generated experience. Alternatively, we can skip the step (on all workers).
Expand All @@ -70,6 +71,8 @@ def __init__(self, rlhf_engine, args):
self.gamma = 1.0
self.lam = 0.95
self.generate_time = 0.0
self.first_generate = True
self.actor_model_hpu_graph_wrapped_fwd_fn = None

def _generate_sequence(self, prompts, mask, step):

Expand All @@ -84,13 +87,32 @@ def _generate_sequence(self, prompts, mask, step):
kwargs = dict()

with torch.no_grad():
seq = self.actor_model.module.generate(
prompts,
attention_mask=mask,
max_length=max_min_length,
pad_token_id=self.tokenizer.pad_token_id,
synced_gpus=self.z3_enabled,
**kwargs)
if is_hpu():
import habana_frameworks.torch.hpu as thpu
if is_hpu() and self.args.enable_hpu_graphs:
orig_actor_model_fwd_fn = self.actor_model.module.forward
if self.first_generate:
self.actor_model.module.forward = thpu.wrap_in_hpu_graph_func(
self.actor_model.module.forward)
self.first_generate = False
else:
self.actor_model.module.forward = self.actor_model_hpu_graph_wrapped_fwd_fn
seq = self.actor_model.module.generate(
prompts,
attention_mask=mask,
max_length=max_min_length,
min_length=max_min_length,
lazy_mode=True)
self.actor_model_hpu_graph_wrapped_fwd_fn = self.actor_model.module.forward
self.actor_model.module.forward = orig_actor_model_fwd_fn
else:
seq = self.actor_model.module.generate(
prompts,
attention_mask=mask,
max_length=max_min_length,
pad_token_id=self.tokenizer.pad_token_id,
synced_gpus=self.z3_enabled,
**kwargs)

# Filter out seq with no answers (or very short). This happens when users directly use the pre-training ckpt without supervised finetuning
# NOTE: this will causes each GPU has different number of examples
Expand Down Expand Up @@ -149,19 +171,25 @@ def generate_experience(self, prompts, mask, step):

pad_token_id = self.tokenizer.pad_token_id
attention_mask = seq.not_equal(pad_token_id).long()

hpu_mark_step()
with torch.no_grad():
output = self.actor_model(seq, attention_mask=attention_mask)
hpu_mark_step()
output_ref = self.ref_model(seq, attention_mask=attention_mask)
hpu_mark_step()
reward_score = self.reward_model.forward_value(
seq, attention_mask,
prompt_length=self.prompt_length)['chosen_end_scores'].detach(
)
hpu_mark_step()
values = self.critic_model.forward_value(
seq, attention_mask, return_value_only=True).detach()[:, :-1]
hpu_mark_step()

logits = output.logits
logits_ref = output_ref.logits
if self.compute_fp32_loss:
if self.calculate_fp32_loss:
logits = logits.to(torch.float)
logits_ref = logits_ref.to(torch.float)

Expand Down Expand Up @@ -221,25 +249,34 @@ def train_rlhf(self, inputs):
advantages, returns = self.get_advantages_and_returns(
old_values, old_rewards, start)

hpu_mark_step()
### process the new outputs
batch = {'input_ids': seq, "attention_mask": attention_mask}
actor_prob = self.actor_model(**batch, use_cache=False).logits
hpu_mark_step()
actor_log_prob = gather_log_probs(actor_prob[:, :-1, :], seq[:, 1:])
hpu_mark_step()
actor_loss = self.actor_loss_fn(actor_log_prob[:, start:],
log_probs[:, start:], advantages,
action_mask[:, start:])
hpu_mark_step()
self.actor_model.backward(actor_loss)
hpu_mark_step()

if not self.args.align_overflow:
self.actor_model.step()
hpu_mark_step()

value = self.critic_model.forward_value(**batch,
return_value_only=True,
use_cache=False)[:, :-1]
hpu_mark_step()
critic_loss = self.critic_loss_fn(value[:, start:], old_values[:,
start:],
returns, action_mask[:, start:])
hpu_mark_step()
self.critic_model.backward(critic_loss)
hpu_mark_step()

if self.args.align_overflow:
actor_overflow = self.actor_model.optimizer.check_overflow(
Expand All @@ -263,8 +300,10 @@ def train_rlhf(self, inputs):
"OVERFLOW: actor and critic overflow, skipping both actor and critic steps",
rank)
self.actor_model.step()
hpu_mark_step()

self.critic_model.step()
hpu_mark_step()

return actor_loss, critic_loss

Expand Down Expand Up @@ -296,7 +335,7 @@ def critic_loss_fn(self, values, old_values, returns, mask):
old_values - self.cliprange_value,
old_values + self.cliprange_value,
)
if self.compute_fp32_loss:
if self.calculate_fp32_loss:
values = values.float()
values_clipped = values_clipped.float()
vf_loss1 = (values - returns)**2
Expand Down
35 changes: 33 additions & 2 deletions applications/DeepSpeed-Chat/dschat/rlhf/rlhf_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import deepspeed
from deepspeed.ops.adam import FusedAdam
from deepspeed.ops.adam import DeepSpeedCPUAdam
from deepspeed.accelerator import get_accelerator
from transformers import AutoModelForCausalLM, get_scheduler

from dschat.utils.ds_utils import get_train_ds_config, get_eval_ds_config
Expand Down Expand Up @@ -104,8 +105,23 @@ def _init_actor(self, actor_model_name_or_path):
actor_model = make_model_gradient_checkpointing_compatible(
actor_model)

# TODO SW-146776: remove this WA once SW-141762 is resolved
if is_hpu():
import habana_frameworks.torch.core as htcore
actor_model.to(dtype=torch.bfloat16,
device=get_accelerator().device())

# Optimizer
AdamOptimizer = DeepSpeedCPUAdam if self.args.offload else FusedAdam
if self.args.offload:
AdamOptimizer = DeepSpeedCPUAdam
elif self.args.no_fused_kernels or is_hpu():
AdamOptimizer = torch.optim.AdamW
else:
AdamOptimizer = FusedAdam
print_rank_0(
f'Using {AdamOptimizer.__name__} optimizer for actor model',
self.args.global_rank)

optim_params = get_optimizer_grouped_parameters(
actor_model, self.args.actor_weight_decay,
self.args.actor_lora_learning_rate)
Expand Down Expand Up @@ -234,8 +250,23 @@ def _init_critic(self, critic_model_name_or_path):
critic_model = make_model_gradient_checkpointing_compatible(
critic_model)

# TODO SW-146776: remove this WA once SW-141762 is resolved
if is_hpu():
critic_model.to(dtype=torch.bfloat16,
device=get_accelerator().device())

# Optimizer
AdamOptimizer = DeepSpeedCPUAdam if self.args.offload else FusedAdam
# TODO SW-147425: change the file to use HPEX optimizer instead of AdamW on hpu
if self.args.offload:
AdamOptimizer = DeepSpeedCPUAdam
elif self.args.no_fused_kernels or is_hpu():
AdamOptimizer = torch.optim.AdamW
else:
AdamOptimizer = FusedAdam
print_rank_0(
f'Using {AdamOptimizer.__name__} optimizer for critic model',
self.args.global_rank)

optim_params = get_optimizer_grouped_parameters(
critic_model, self.args.critic_weight_decay,
self.args.critic_lora_learning_rate)
Expand Down
6 changes: 4 additions & 2 deletions applications/DeepSpeed-Chat/dschat/utils/data/data_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -298,8 +298,10 @@ def create_prompt_dataset(local_rank,
eval_fname = f"{output_path}/evaldata_{fname}.pt"

cache_found = os.path.isfile(train_fname) and os.path.isfile(eval_fname)
buf_create_cache = torch.ByteTensor([not cache_found]).to(
get_accelerator().current_device_name())
device = torch.device(get_accelerator().device_name(
torch.distributed.get_rank()))
buf_create_cache = get_accelerator().ByteTensor([not cache_found],
device=device)
torch.distributed.all_reduce(buf_create_cache)

if local_rank <= 0 and (buf_create_cache.item() != 0 or reload):
Expand Down
70 changes: 68 additions & 2 deletions applications/DeepSpeed-Chat/dschat/utils/model/model_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,70 @@ def causal_lm_forward(
model.forward = causal_lm_forward


def configure_dropout(model_config, dropout):
if dropout is not None:
for key in ('dropout', 'attention_dropout', 'hidden_dropout',
'activation_dropout'):
if hasattr(model_config, key):
print(f"Setting model_config.{key} to {dropout}")
setattr(model_config, key, dropout)


def causal_lm_model_to_fp32_loss(model):
""" Convert CausalLM model to calculate loss in fp32 """

def causal_lm_forward(
input_ids=None,
past_key_values=None,
attention_mask=None,
head_mask=None,
inputs_embeds=None,
labels=None,
use_cache=None,
output_attentions=None,
output_hidden_states=None,
return_dict=None,
**deprecated_arguments,
):
output = model.__original_forward__(
input_ids=input_ids,
past_key_values=past_key_values,
attention_mask=attention_mask,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
labels=None,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict)

return_dict = isinstance(output, dict)
lm_logits = output.logits if return_dict else output[0]
loss = None
if labels is not None:
# move labels to correct device to enable model parallelism
labels = labels.to(lm_logits.device)
# Shift so that tokens < n predict n
shift_logits = lm_logits[..., :-1, :].float().contiguous()
shift_labels = labels[..., 1:].contiguous()
batch_size, seq_length, vocab_size = shift_logits.shape
# Flatten the tokens
loss_fct = torch.nn.CrossEntropyLoss()
loss = loss_fct(
shift_logits.view(batch_size * seq_length, vocab_size),
shift_labels.view(batch_size * seq_length))

if not return_dict:
# re-pack output with fp32 loss
return ((loss, ) + output) if loss is not None else output

output.loss = loss
return output

model.__original_forward__ = model.forward
model.forward = causal_lm_forward


def create_hf_model(model_class,
model_name_or_path,
tokenizer,
Expand Down Expand Up @@ -122,7 +186,8 @@ def create_critic_model(model_name_or_path,
rlhf_training=False,
dropout=None,
zero_stage=0,
compute_fp32_loss=False):
loss_to_fp32=False,
optimized_reward_loss_calc=False):
# OPT model family always put a padding token at the beginning of the sequence,
# we did not see this in other models but not sure if it is a general rule

Expand All @@ -139,7 +204,8 @@ def create_critic_model(model_name_or_path,
critic_model,
tokenizer,
num_padding_at_beginning=num_padding_at_beginning,
compute_fp32_loss=compute_fp32_loss)
loss_to_fp32=loss_to_fp32,
opt_loss_calc=optimized_reward_loss_calc)

if rlhf_training:
# load critic model from checkpoint
Expand Down
Loading
Loading