Skip to content

Commit

Permalink
fix bf16 amp
Browse files Browse the repository at this point in the history
  • Loading branch information
Mddct committed Apr 5, 2024
1 parent 6008bab commit 2f1f4df
Showing 1 changed file with 3 additions and 1 deletion.
4 changes: 3 additions & 1 deletion wenet/utils/train_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@

from contextlib import nullcontext
import copy
from functools import partial
from typing import Optional

import deepspeed
Expand Down Expand Up @@ -564,7 +565,8 @@ def batch_forward(model, batch, scaler, info_dict):
# https://pytorch.org/docs/stable/notes/amp_examples.html
if dtype is not None and info_dict.get("tag", "train") == "train":
assert scaler is not None
autocast = torch.cuda.amp.autocast if dtype is not None else nullcontext
autocast = partial(torch.cuda.amp.autocast,
dtype=dtype) if dtype is not None else nullcontext
with autocast():
loss_dict = model(batch, device)
info_dict['loss_dict'] = loss_dict
Expand Down

0 comments on commit 2f1f4df

Please sign in to comment.