Skip to content

Commit

Permalink
deepspeed-chat: fix weight decay configuration
Browse files Browse the repository at this point in the history
Current default name used to detect LN layers is "LayerNorm.weight".
This does not work for the following models:
- opt: uses "layer_norm"
- llama: uses "norm" and "layernorm"
- bloom: uses "layernorm" and "ln_f"

Therefore, modify the default names to accomodate for the above.
Also, compare names in lower-caps to capture models with different caps.

Change-Id: I5b805df2663c62daf3d9c8a31a973742e344e76b
Signed-off-by: Moshe Island <[email protected]>
  • Loading branch information
mosheisland committed Oct 5, 2023
1 parent bfad08f commit 5bba361
Showing 1 changed file with 9 additions and 6 deletions.
15 changes: 9 additions & 6 deletions applications/DeepSpeed-Chat/training/utils/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,15 +174,18 @@ def get_optimizer_grouped_parameters(
model,
weight_decay,
lora_lr=5e-4,
no_decay_name_list=["bias", "LayerNorm.weight"],
no_decay_name_list=[
"bias", "layer_norm.weight", "layernorm.weight", "norm.weight",
"ln_f.weight"
],
lora_name_list=["lora_right_weight", "lora_left_weight"],
):
optimizer_grouped_parameters = [
{
"params": [
p for n, p in model.named_parameters()
if (not any(nd in n for nd in no_decay_name_list)
and p.requires_grad and not any(nd in n
if (not any(nd in n.lower() for nd in no_decay_name_list)
and p.requires_grad and not any(nd in n.lower()
for nd in lora_name_list))
],
"weight_decay":
Expand All @@ -191,8 +194,8 @@ def get_optimizer_grouped_parameters(
{
"params": [
p for n, p in model.named_parameters()
if (not any(nd in n for nd in no_decay_name_list)
and p.requires_grad and any(nd in n
if (not any(nd in n.lower() for nd in no_decay_name_list)
and p.requires_grad and any(nd in n.lower()
for nd in lora_name_list))
],
"weight_decay":
Expand All @@ -203,7 +206,7 @@ def get_optimizer_grouped_parameters(
{
"params": [
p for n, p in model.named_parameters()
if (any(nd in n
if (any(nd in n.lower()
for nd in no_decay_name_list) and p.requires_grad)
],
"weight_decay":
Expand Down

0 comments on commit 5bba361

Please sign in to comment.