Skip to content

Commit

Permalink
Follow upstream fixes.
Browse files Browse the repository at this point in the history
  • Loading branch information
stceum committed Jan 27, 2024
1 parent b563ff9 commit b5f0068
Showing 1 changed file with 24 additions and 4 deletions.
28 changes: 24 additions & 4 deletions applications/DeepSpeed-Chat/training/step2_dpo_finetuning/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -306,6 +306,26 @@ def main():
args.global_rank)
causal_lm_model_to_fp32_loss(model)

# Copied from ../step2_reward_model_finetuning/main.py.
# Model bigscience/bloom-560m has large variance at ln_f.weight parameter
# This makes bf16 finetuning hard.
# In general, since we are replacing the model head, it makes sense to reset
# the LN that precedes it.
force_optimize_params = []
if "bigscience/bloom-" in args.model_name_or_path:
zero_init_enabled = (args.zero_stage == 3)
params = [
model.rwtranrsformer.ln_f.weight, model.rwtranrsformer.ln_f.bias
]
with deepspeed.zero.GatheredParameters(params,
modifier_rank=0,
enabled=zero_init_enabled):
if deepspeed.comm.get_rank() == 0 or not zero_init_enabled:
torch.nn.init.ones_(model.rwtransformer.ln_f.weight)
torch.nn.init.zeros_(model.rwtransformer.ln_f.bias)
force_optimize_params.extend(
['rwtransformer.ln_f.weight', 'rwtransformer.ln_f.bias'])

if args.lora_dim > 0:
model = convert_linear_layer_to_lora(model, args.lora_module_name,
args.lora_dim)
Expand Down Expand Up @@ -372,7 +392,7 @@ def evaluation(model, ref_model, tokenizer, eval_dataloader):
logits = args.beta * ((chosen_logps - ref_chosen_logps) -
(rejected_logps - ref_rejected_logps))
loss = (- torch.nn.functional.logsigmoid(logits) * (1 - args.label_smoothing) - \
torch.nn.functional.logsigmoid(-logits) * args.label_smoothing).mean(0)
torch.nn.functional.logsigmoid(-logits) * args.label_smoothing).mean(0)
losses += loss.float()
losses = losses / (step + 1)
try:
Expand Down Expand Up @@ -419,7 +439,7 @@ def evaluation(model, ref_model, tokenizer, eval_dataloader):
# Train!
print_rank_0("***** Running training *****", args.global_rank)
print_rank_0(
f"***** Evaluating rewards, Epoch {0}/{args.num_train_epochs} *****",
f"***** Evaluating rewards, Epoch {1}/{args.num_train_epochs} *****",
args.global_rank)
chosen_rewards, rejected_rewards, eval_loss = evaluation(
model, ref_model, tokenizer, eval_dataloader)
Expand Down Expand Up @@ -466,7 +486,7 @@ def evaluation(model, ref_model, tokenizer, eval_dataloader):
logits = args.beta * ((chosen_logps - ref_chosen_logps) -
(rejected_logps - ref_rejected_logps))
loss = (- torch.nn.functional.logsigmoid(logits) * (1 - args.label_smoothing) - \
torch.nn.functional.logsigmoid(-logits) * args.label_smoothing).mean(0)
torch.nn.functional.logsigmoid(-logits) * args.label_smoothing).mean(0)
if args.print_loss:
print(
f"Epoch: {epoch}, Step: {step}, Rank: {torch.distributed.get_rank()}, loss = {loss}"
Expand All @@ -478,7 +498,7 @@ def evaluation(model, ref_model, tokenizer, eval_dataloader):
print_throughput(model.model, args, end - start,
args.global_rank)

# Evaluate perplexity on the validation set.
# Evaluate rewards on the validation set.
print_rank_0(
f"***** Evaluating rewards, Epoch {epoch+1}/{args.num_train_epochs} *****",
args.global_rank)
Expand Down

0 comments on commit b5f0068

Please sign in to comment.