Skip to content

Commit

Permalink
refactor scaler and autocast, fix fp32 fp16 bf16 for fsdp
Browse files Browse the repository at this point in the history
  • Loading branch information
Mddct committed Apr 6, 2024
1 parent 2f1f4df commit 885013f
Showing 1 changed file with 29 additions and 23 deletions.
52 changes: 29 additions & 23 deletions wenet/utils/train_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@

from contextlib import nullcontext
import copy
from functools import partial
from typing import Optional

import deepspeed
Expand Down Expand Up @@ -472,8 +471,8 @@ def init_scaler(args):
elif args.train_engine == 'torch_fsdp':
# why bf16 don't need scaler:
# https://discuss.pytorch.org/t/why-bf16-do-not-need-loss-scaling/176596
scaler = sharded_grad_scaler.ShardedGradScaler(
enabled=args.dtype in ['fp16'])
if args.dtype in ['fp16']:
scaler = sharded_grad_scaler.ShardedGradScaler(enabled=True)
return scaler


Expand Down Expand Up @@ -552,25 +551,24 @@ def batch_forward(model, batch, scaler, info_dict):
else: # fp32
dtype = None

if train_engine == "deepspeed":
# deepspeed
with torch.cuda.amp.autocast(enabled=dtype is not None,
dtype=dtype,
cache_enabled=False):
loss_dict = model(batch, device)
else:
# torch_ddp
# autocast context
# The more details about amp can be found in
# https://pytorch.org/docs/stable/notes/amp_examples.html
if dtype is not None and info_dict.get("tag", "train") == "train":
assert scaler is not None
autocast = partial(torch.cuda.amp.autocast,
dtype=dtype) if dtype is not None else nullcontext
with autocast():
loss_dict = model(batch, device)
info_dict['loss_dict'] = loss_dict
# autocast context
# The more details about amp can be found in
# https://pytorch.org/docs/stable/notes/amp_examples.html
autocast = {
"deepspeed":
torch.cuda.amp.autocast(enabled=dtype is not None,
dtype=dtype,
cache_enabled=False),
"torch_ddp":
torch.cuda.amp.autocast(enabled=scaler is not None),
"torch_fsdp":
torch.cuda.amp.autocast(enabled=True, dtype=dtype)
if dtype is not None else nullcontext
}[train_engine]
with autocast():
loss_dict = model(batch, device)

info_dict['loss_dict'] = loss_dict
return info_dict


Expand All @@ -590,8 +588,11 @@ def batch_backward(model, scaler, info_dict):
else:
scaled_loss = loss / accum_grad
if scaler is not None:
# fp16 (amp and fsdp)
scaler.scale(scaled_loss).backward()
else:
# float32 (ddp and fsdp)
# bf16 (fsdp)
scaled_loss.backward()
info_dict['loss_dict']['loss'] = scaled_loss
for loss_name, loss_value in info_dict['loss_dict'].items():
Expand Down Expand Up @@ -628,11 +629,13 @@ def update_parameter_and_lr(model, optimizer, scheduler, scaler, info_dict):
grad_norm = model.get_global_grad_norm()
elif (batch_idx + 1) % accum_grad == 0:
# Use mixed precision training
# fp16 (ddp fsdp)
if scaler is not None:
scaler.unscale_(optimizer)
if isinstance(scaler, torch.cuda.amp.GradScaler):
if train_engine == "torch_ddp":
grad_norm = clip_grad_norm_(model.parameters(), clip)
else:
# fsdp
grad_norm = model.clip_grad_norm_(clip)
# Must invoke scaler.update() if unscale_() is used in
# the iteration to avoid the following error:
Expand All @@ -644,7 +647,10 @@ def update_parameter_and_lr(model, optimizer, scheduler, scaler, info_dict):
scaler.step(optimizer)
scaler.update()
else:
grad_norm = clip_grad_norm_(model.parameters(), clip)
if train_engine == "torch_ddp":
grad_norm = clip_grad_norm_(model.parameters(), clip)
else:
grad_norm = model.clip_grad_norm_(clip)
if torch.isfinite(grad_norm):
optimizer.step()
optimizer.zero_grad()
Expand Down

0 comments on commit 885013f

Please sign in to comment.