Skip to content

Commit

Permalink
Merge branch 'master' into 6_fix_wd_config
Browse files Browse the repository at this point in the history
  • Loading branch information
lekurile authored Oct 4, 2023
2 parents 281fc2a + e6f400a commit c63d79a
Show file tree
Hide file tree
Showing 90 changed files with 7,102 additions and 18 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
from utils.utils import print_rank_0, to_device, save_hf_format, set_random_seed, get_all_reduce_mean, get_optimizer_grouped_parameters, save_zero_three_model, load_hf_tokenizer
from utils.ds_utils import get_train_ds_config
from utils.module.lora import convert_linear_layer_to_lora, convert_lora_to_linear_layer, only_optimize_lora_parameters, make_model_gradient_checkpointing_compatible
from utils.model.model_utils import create_hf_model
from utils.model.model_utils import create_hf_model, causal_lm_model_to_fp32_loss
from utils.perf import print_throughput


Expand Down Expand Up @@ -178,6 +178,12 @@ def parse_args():
help=
"Initial LoRA learning rate (after the potential warmup period) to use."
)
## low precision
parser.add_argument(
'--compute_fp32_loss',
action='store_true',
help='Relevant for low precision dtypes (fp16, bf16, etc.). '
'If specified, loss is calculated in fp32.')
## Tensorboard logging
parser.add_argument('--enable_tensorboard',
action='store_true',
Expand Down Expand Up @@ -234,6 +240,12 @@ def main():
ds_config,
dropout=args.dropout)

if args.compute_fp32_loss:
print_rank_0(
f"Using model {model.__class__.__name__} with loss in fp32",
args.global_rank)
causal_lm_model_to_fp32_loss(model)

if args.lora_dim > 0:
model = convert_linear_layer_to_lora(model, args.lora_module_name,
args.lora_dim)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -177,6 +177,12 @@ def parse_args():
help=
"Initial LoRA learning rate (after the potential warmup period) to use."
)
## low precision
parser.add_argument(
'--compute_fp32_loss',
action='store_true',
help='Relevant for low precision dtypes (fp16, bf16, etc.). '
'If specified, loss is calculated in fp32.')
## Tensorboard logging
parser.add_argument('--enable_tensorboard',
action='store_true',
Expand Down Expand Up @@ -226,7 +232,9 @@ def main():
tokenizer,
ds_config,
args.num_padding_at_beginning,
dropout=args.dropout)
dropout=args.dropout,
zero_stage=args.zero_stage,
compute_fp32_loss=args.compute_fp32_loss)

if args.lora_dim > 0:
rm_model = convert_linear_layer_to_lora(rm_model,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -325,6 +325,13 @@ def parse_args():
'--enable_mixed_precision_lora',
action='store_true',
help='Enable Mixed Precision ZeRO++ for training and generation.')
## low precision
parser.add_argument(
'--compute_fp32_loss',
action='store_true',
help='Relevant for low precision dtypes (fp16, bf16, etc.). '
'If specified, loss is calculated in fp32.'
'This applies for both actor and critic models.')
## Tensorboard logging
parser.add_argument('--enable_tensorboard',
action='store_true',
Expand Down Expand Up @@ -572,13 +579,13 @@ def main():
average_reward / inner_iter,
global_step=step)
writer.add_scalar('actor_loss',
actor_loss,
actor_loss.item(),
global_step=step)
writer.add_scalar('actor_loss_sum',
actor_loss_sum,
global_step=step)
writer.add_scalar('critic_loss',
critic_loss,
critic_loss.item(),
global_step=step)
writer.add_scalar('critic_loss_sum',
critic_loss_sum,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ def __init__(self, rlhf_engine, args):
self.end_of_conversation_token_id = self.tokenizer(
args.end_of_conversation_token)['input_ids'][-1]
self.z3_enabled = args.actor_zero_stage == 3
self.compute_fp32_loss = self.args.compute_fp32_loss

# Those value can be changed
self.kl_ctl = 0.1
Expand Down Expand Up @@ -139,6 +140,9 @@ def generate_experience(self, prompts, mask, step):

logits = output.logits
logits_ref = output_ref.logits
if self.compute_fp32_loss:
logits = logits.to(torch.float)
logits_ref = logits_ref.to(torch.float)

self.generate_time = generate_end - generate_start

Expand Down Expand Up @@ -271,6 +275,9 @@ def critic_loss_fn(self, values, old_values, returns, mask):
old_values - self.cliprange_value,
old_values + self.cliprange_value,
)
if self.compute_fp32_loss:
values = values.float()
values_clipped = values_clipped.float()
vf_loss1 = (values - returns)**2
vf_loss2 = (values_clipped - returns)**2
vf_loss = 0.5 * torch.sum(
Expand Down
61 changes: 59 additions & 2 deletions applications/DeepSpeed-Chat/training/utils/model/model_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,61 @@ def configure_dropout(model_config, dropout):
setattr(model_config, key, dropout)


def causal_lm_model_to_fp32_loss(model):
""" Convert CausalLM model to calculate loss in fp32 """

def causal_lm_forward(
input_ids=None,
past_key_values=None,
attention_mask=None,
head_mask=None,
inputs_embeds=None,
labels=None,
use_cache=None,
output_attentions=None,
output_hidden_states=None,
return_dict=None,
**deprecated_arguments,
):
output = model.__original_forward__(
input_ids=input_ids,
past_key_values=past_key_values,
attention_mask=attention_mask,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
labels=None,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict)

return_dict = isinstance(output, dict)
lm_logits = output.logits if return_dict else output[0]
loss = None
if labels is not None:
# move labels to correct device to enable model parallelism
labels = labels.to(lm_logits.device)
# Shift so that tokens < n predict n
shift_logits = lm_logits[..., :-1, :].float().contiguous()
shift_labels = labels[..., 1:].contiguous()
batch_size, seq_length, vocab_size = shift_logits.shape
# Flatten the tokens
loss_fct = torch.nn.CrossEntropyLoss()
loss = loss_fct(
shift_logits.view(batch_size * seq_length, vocab_size),
shift_labels.view(batch_size * seq_length))

if not return_dict:
# re-pack output with fp32 loss
return ((loss, ) + output) if loss is not None else output

output.loss = loss
return output

model.__original_forward__ = model.forward
model.forward = causal_lm_forward


def create_hf_model(model_class,
model_name_or_path,
tokenizer,
Expand Down Expand Up @@ -64,7 +119,8 @@ def create_critic_model(model_name_or_path,
num_padding_at_beginning=0,
rlhf_training=False,
dropout=None,
zero_stage=0):
zero_stage=0,
compute_fp32_loss=False):
# OPT model family always put a padding token at the beginning of the sequence,
# we did not see this in other models but not sure if it is a general rule

Expand All @@ -80,7 +136,8 @@ def create_critic_model(model_name_or_path,
critic_model = RewardModel(
critic_model,
tokenizer,
num_padding_at_beginning=num_padding_at_beginning)
num_padding_at_beginning=num_padding_at_beginning,
compute_fp32_loss=compute_fp32_loss)

if rlhf_training:
# load critic model from checkpoint
Expand Down
12 changes: 10 additions & 2 deletions applications/DeepSpeed-Chat/training/utils/model/reward_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,11 @@
## https://github.com/CarperAI/trlx/blob/main/examples/summarize_rlhf/reward_model/reward_model.py
class RewardModel(nn.Module):

def __init__(self, base_model, tokenizer, num_padding_at_beginning=0):
def __init__(self,
base_model,
tokenizer,
num_padding_at_beginning=0,
compute_fp32_loss=False):
super().__init__()
self.config = base_model.config
self.num_padding_at_beginning = num_padding_at_beginning
Expand All @@ -27,6 +31,7 @@ def __init__(self, base_model, tokenizer, num_padding_at_beginning=0):
self.v_head = nn.Linear(self.config.n_embd, 1, bias=False)
self.rwtranrsformer = base_model
self.PAD_ID = tokenizer.pad_token_id
self.compute_fp32_loss = compute_fp32_loss

def gradient_checkpointing_enable(self):
self.rwtranrsformer.gradient_checkpointing_enable()
Expand Down Expand Up @@ -73,7 +78,7 @@ def forward(self,
rejected_rewards = rewards[bs:]

# Compute pairwise loss. Only backprop on the different tokens before padding
loss = 0
loss = 0.
for i in range(bs):
chosen_id = chosen_ids[i]
rejected_id = rejected_ids[i]
Expand Down Expand Up @@ -104,6 +109,9 @@ def forward(self,
chosen_reward[c_ind - 1]) #use the end score for reference
rejected_mean_scores.append(rejected_reward[r_ind - 1])

if self.compute_fp32_loss:
c_truncated_reward = c_truncated_reward.float()
r_truncated_reward = r_truncated_reward.float()
loss += -torch.nn.functional.logsigmoid(c_truncated_reward -
r_truncated_reward).mean()

Expand Down
3 changes: 2 additions & 1 deletion applications/DeepSpeed-Chat/training/utils/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,8 @@ def load_hf_tokenizer(model_name_or_path, fast_tokenizer=True):
model_json = os.path.join(model_name_or_path, "config.json")
if os.path.exists(model_json):
model_json_file = json.load(open(model_json))
model_name = model_json_file["_name_or_path"]
model_name = model_json_file.get("_name_or_path",
model_name_or_path)
tokenizer = get_tokenizer(model_name,
fast_tokenizer=fast_tokenizer)
else:
Expand Down
Loading

0 comments on commit c63d79a

Please sign in to comment.