Skip to content

Commit

Permalink
deepspeed-chat: support periodic eval in stage2
Browse files Browse the repository at this point in the history
Add support for periodic evaluation during rm reward model training.
Configurable via added arguments: --eval_interval and --eval_iters.
The default configuration is backward compatible.

In addition, display also the score of the rejected predictions.

Change-Id: Ib377fd731fe676c01114c087581a30777a3f3f49
Signed-off-by: Moshe Island <[email protected]>
  • Loading branch information
mosheisland committed Oct 3, 2023
1 parent ca03bd7 commit 9cead57
Showing 1 changed file with 54 additions and 20 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -177,6 +177,15 @@ def parse_args():
help=
"Initial LoRA learning rate (after the potential warmup period) to use."
)
# Evaluation
parser.add_argument("--eval_interval",
type=int,
default=0,
help="If > 0, perform evaluation at this interval")
parser.add_argument("--eval_iters",
type=int,
default=100,
help="Maximum evaluation iterations")
## Tensorboard logging
parser.add_argument('--enable_tensorboard',
action='store_true',
Expand Down Expand Up @@ -260,31 +269,35 @@ def main():
sampler=eval_sampler,
batch_size=args.per_device_eval_batch_size)

def evaluation_reward(model, eval_dataloader):
def evaluation_reward(model, dataloader, eval_iters):
model.eval()
correct_predictions = 0
total_predictions = 0
scores = 0
for step, batch in enumerate(eval_dataloader):
batch = to_device(batch, device)
chosen_scores = 0.
rejected_scores = 0.
for _step, _batch in enumerate(dataloader):
_batch = to_device(_batch, device)
with torch.no_grad():
outputs = model(**batch)
_outputs = model(**_batch)

chosen = outputs["chosen_mean_scores"]
rejected = outputs["rejected_mean_scores"]
chosen = _outputs["chosen_mean_scores"]
rejected = _outputs["rejected_mean_scores"]
correct_predictions += (chosen > rejected).sum()
total_predictions += chosen.shape[0]
scores += outputs["chosen_mean_scores"].mean().float()
if step == 99: # For faster evaluation and debugging
chosen_scores += _outputs["chosen_mean_scores"].mean().float()
rejected_scores += _outputs["rejected_mean_scores"].mean().float()
if (_step + 1) == eval_iters:
break
acc = correct_predictions / total_predictions
scores = scores / (step + 1)
_acc = correct_predictions / total_predictions
chosen_scores = chosen_scores / (_step + 1)
rejected_scores = rejected_scores / (_step + 1)
try:
acc = get_all_reduce_mean(acc).item()
scores = get_all_reduce_mean(scores).item()
_acc = get_all_reduce_mean(_acc).item()
chosen_scores = get_all_reduce_mean(chosen_scores).item()
rejected_scores = get_all_reduce_mean(rejected_scores).item()
except:
pass
return scores, acc
return chosen_scores, rejected_scores, _acc

# Split weights in two groups, one with weight decay and the other not.
optimizer_grouped_parameters = get_optimizer_grouped_parameters(
Expand Down Expand Up @@ -322,11 +335,14 @@ def evaluation_reward(model, eval_dataloader):
print_rank_0(
f"***** Evaluating reward, Epoch {0}/{args.num_train_epochs} *****",
args.global_rank)
reward_score, acc = evaluation_reward(rm_model, eval_dataloader)
reward_score, reject_score, acc = evaluation_reward(
rm_model, eval_dataloader, args.eval_iters)
print_rank_0(
f"chosen_last_scores (higher is better) : {reward_score}, acc (higher is better) : {acc}",
args.global_rank)
f"chosen_last_scores (higher is better) : {reward_score}, "
f"rejected_last_scores (lower is better) : {reject_score}, "
f"acc (higher is better) : {acc}", args.global_rank)

total_micro_steps = 0
for epoch in range(args.num_train_epochs):
print_rank_0(
f"Beginning of Epoch {epoch+1}/{args.num_train_epochs}, Total Micro Batches {len(train_dataloader)}",
Expand All @@ -340,17 +356,35 @@ def evaluation_reward(model, eval_dataloader):
rm_model.backward(loss)
rm_model.step()
mean_loss += loss.item()
total_micro_steps += 1
gas_boundary = (total_micro_steps %
args.gradient_accumulation_steps == 0)
total_steps = total_micro_steps // args.gradient_accumulation_steps
if args.eval_interval and gas_boundary and (
total_steps % args.eval_interval == 0):
print_rank_0(f"Iter {total_steps}: Evaluating reward",
args.global_rank)
reward_score, reject_score, acc = evaluation_reward(
rm_model, eval_dataloader, args.eval_iters)
print_rank_0(
f"Iter {total_steps}: c_scores: {reward_score}, r_scores: {reject_score}, "
f"diff: {reward_score - reject_score}, acc: {acc}",
args.global_rank)
rm_model.train()

print_rank_0(
f"Epoch {epoch+1}/{args.num_train_epochs} with loss {mean_loss/(step+1)}",
args.global_rank)
# Evaluate reward_loss on the validation set.
print_rank_0(
f"***** Evaluating reward, Epoch {epoch+1}/{args.num_train_epochs} *****",
args.global_rank)
reward_score, acc = evaluation_reward(rm_model, eval_dataloader)
reward_score, reject_score, acc = evaluation_reward(
rm_model, eval_dataloader, args.eval_iters)
print_rank_0(
f"chosen_last_scores (higher is better) : {reward_score}, acc (higher is better) : {acc}",
args.global_rank)
f"chosen_last_scores (higher is better) : {reward_score}, "
f"rejected_last_scores (lower is better) : {reject_score}, "
f"acc (higher is better) : {acc}", args.global_rank)
rm_model.tput_timer.update_epoch_count()

if args.output_dir is not None:
Expand Down

0 comments on commit 9cead57

Please sign in to comment.