Skip to content

Commit

Permalink
[Benchmark] 给layer norm添加到bf16 o1白名单 (PaddlePaddle#110)
Browse files Browse the repository at this point in the history
PaddlePaddle#55713
修改了BF16默认黑白名单导致当前代码会报错。
因此需要手动将layer norm添加bf16 o1白名单

![bd3f6f1c3d576870635f19e8a49f04f6](https://github.com/PaddlePaddle/PaddleMIX/assets/50394665/33d20150-fa49-43c8-b421-e89618ed43ea)
  • Loading branch information
JunnYu authored Aug 31, 2023
1 parent 8069b5b commit 602795b
Showing 1 changed file with 6 additions and 2 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -1015,7 +1015,9 @@ def fn(layer):
enable=args.mixed_precision in ["bf16", "fp16"] and args.train_text_encoder,
level=args.fp16_opt_level,
custom_black_list=["reduce_sum", "c_softmax_with_cross_entropy"],
custom_white_list=["lookup_table", "lookup_table_v2"] if args.fp16_opt_level == "O2" else None,
custom_white_list=["lookup_table", "lookup_table_v2"]
if args.fp16_opt_level == "O2"
else ["layer_norm"],
dtype="bfloat16" if args.mixed_precision == "bf16" else "float16",
):
encoder_hidden_states = text_encoder(batch["input_ids"], attention_mask=attention_mask)[0]
Expand All @@ -1032,7 +1034,9 @@ def fn(layer):
enable=args.mixed_precision in ["bf16", "fp16"],
level=args.fp16_opt_level,
custom_black_list=["reduce_sum", "c_softmax_with_cross_entropy"],
custom_white_list=["lookup_table", "lookup_table_v2"] if args.fp16_opt_level == "O2" else None,
custom_white_list=["lookup_table", "lookup_table_v2"]
if args.fp16_opt_level == "O2"
else ["layer_norm"],
dtype="bfloat16" if args.mixed_precision == "bf16" else "float16",
):
# Predict the noise residual and compute loss
Expand Down

0 comments on commit 602795b

Please sign in to comment.