Skip to content

Commit

Permalink
Merge branch 'master' into mrwyattii/generalize-mii-benchmark
Browse files Browse the repository at this point in the history
  • Loading branch information
mrwyattii authored Jan 19, 2024
2 parents 682e904 + 57dd8fb commit f961a39
Showing 1 changed file with 11 additions and 2 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -258,8 +258,17 @@ def main():
# the LN that precedes it.
force_optimize_params = []
if "bigscience/bloom-" in args.model_name_or_path:
torch.nn.init.ones_(rm_model.rwtransformer.ln_f.weight)
torch.nn.init.zeros_(rm_model.rwtransformer.ln_f.bias)
zero_init_enabled = (args.zero_stage == 3)
params = [
rm_model.rwtranrsformer.ln_f.weight,
rm_model.rwtranrsformer.ln_f.bias
]
with deepspeed.zero.GatheredParameters(params,
modifier_rank=0,
enabled=zero_init_enabled):
if deepspeed.comm.get_rank() == 0 or not zero_init_enabled:
torch.nn.init.ones_(rm_model.rwtransformer.ln_f.weight)
torch.nn.init.zeros_(rm_model.rwtransformer.ln_f.bias)
force_optimize_params.extend(
['rwtransformer.ln_f.weight', 'rwtransformer.ln_f.bias'])

Expand Down

0 comments on commit f961a39

Please sign in to comment.