Skip to content

Commit

Permalink
deepspeed-chat: Support zero3 params initialization in the last LN (#839
Browse files Browse the repository at this point in the history
)

Zero3 requires that gathering partitioned parameters before
they can be accessed.

We enable that mechanism for initialization of the last LN weight
and bias.

Co-authored-by: Olatunji Ruwase <[email protected]>
  • Loading branch information
deepcharm and tjruwase authored Jan 17, 2024
1 parent 6c31d8d commit 57dd8fb
Showing 1 changed file with 11 additions and 2 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -258,8 +258,17 @@ def main():
# the LN that precedes it.
force_optimize_params = []
if "bigscience/bloom-" in args.model_name_or_path:
torch.nn.init.ones_(rm_model.rwtransformer.ln_f.weight)
torch.nn.init.zeros_(rm_model.rwtransformer.ln_f.bias)
zero_init_enabled = (args.zero_stage == 3)
params = [
rm_model.rwtranrsformer.ln_f.weight,
rm_model.rwtranrsformer.ln_f.bias
]
with deepspeed.zero.GatheredParameters(params,
modifier_rank=0,
enabled=zero_init_enabled):
if deepspeed.comm.get_rank() == 0 or not zero_init_enabled:
torch.nn.init.ones_(rm_model.rwtransformer.ln_f.weight)
torch.nn.init.zeros_(rm_model.rwtransformer.ln_f.bias)
force_optimize_params.extend(
['rwtransformer.ln_f.weight', 'rwtransformer.ln_f.bias'])

Expand Down

0 comments on commit 57dd8fb

Please sign in to comment.