Skip to content

Commit

Permalink
Support mix precision training on the reworked model (#305)
Browse files Browse the repository at this point in the history
* Add mix precision support

* Minor fixes

* Minor fixes

* Minor fixes
  • Loading branch information
pkufool authored Apr 11, 2022
1 parent 34aad74 commit 7012fd6
Show file tree
Hide file tree
Showing 2 changed files with 80 additions and 38 deletions.
43 changes: 24 additions & 19 deletions egs/librispeech/ASR/pruned_transducer_stateless2/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,17 +141,21 @@ def forward(
boundary[:, 2] = y_lens
boundary[:, 3] = x_lens

simple_loss, (px_grad, py_grad) = k2.rnnt_loss_smoothed(
lm=self.simple_lm_proj(decoder_out),
am=self.simple_am_proj(encoder_out),
symbols=y_padded,
termination_symbol=blank_id,
lm_only_scale=lm_scale,
am_only_scale=am_scale,
boundary=boundary,
reduction="sum",
return_grad=True,
)
lm=self.simple_lm_proj(decoder_out)
am=self.simple_am_proj(encoder_out)

with torch.cuda.amp.autocast(enabled=False):
simple_loss, (px_grad, py_grad) = k2.rnnt_loss_smoothed(
lm=lm.float(),
am=am.float(),
symbols=y_padded,
termination_symbol=blank_id,
lm_only_scale=lm_scale,
am_only_scale=am_scale,
boundary=boundary,
reduction="sum",
return_grad=True,
)

# ranges : [B, T, prune_range]
ranges = k2.get_rnnt_prune_ranges(
Expand All @@ -176,13 +180,14 @@ def forward(
logits = self.joiner(am_pruned, lm_pruned,
project_input=False)

pruned_loss = k2.rnnt_loss_pruned(
logits=logits,
symbols=y_padded,
ranges=ranges,
termination_symbol=blank_id,
boundary=boundary,
reduction="sum",
)
with torch.cuda.amp.autocast(enabled=False):
pruned_loss = k2.rnnt_loss_pruned(
logits=logits.float(),
symbols=y_padded,
ranges=ranges,
termination_symbol=blank_id,
boundary=boundary,
reduction="sum",
)

return (simple_loss, pruned_loss)
75 changes: 56 additions & 19 deletions egs/librispeech/ASR/pruned_transducer_stateless2/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,16 @@
--full-libri 1 \
--max-duration 300
# For mix precision training:
./pruned_transducer_stateless2/train.py \
--world-size 4 \
--num-epochs 30 \
--start-epoch 0 \
--use_fp16 1 \
--exp-dir pruned_transducer_stateless2/exp \
--full-libri 1 \
--max-duration 550
"""

Expand Down Expand Up @@ -58,6 +67,7 @@
from model import Transducer
from optim import Eve, Eden
from torch import Tensor
from torch.cuda.amp import GradScaler
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.tensorboard import SummaryWriter

Expand Down Expand Up @@ -249,6 +259,13 @@ def get_parser():
""",
)

parser.add_argument(
"--use-fp16",
type=str2bool,
default=False,
help="Whether to use half precision training.",
)

return parser


Expand Down Expand Up @@ -447,6 +464,7 @@ def save_checkpoint(
optimizer: Optional[torch.optim.Optimizer] = None,
scheduler: Optional[LRSchedulerType] = None,
sampler: Optional[CutSampler] = None,
scaler: Optional[GradScaler] = None,
rank: int = 0,
) -> None:
"""Save model, optimizer, scheduler and training stats to file.
Expand All @@ -460,6 +478,8 @@ def save_checkpoint(
The optimizer used in the training.
sampler:
The sampler for the training dataset.
scaler:
The scaler used for mix precision training.
"""
if rank != 0:
return
Expand All @@ -471,6 +491,7 @@ def save_checkpoint(
optimizer=optimizer,
scheduler=scheduler,
sampler=sampler,
scaler=scaler,
rank=rank,
)

Expand Down Expand Up @@ -599,6 +620,7 @@ def train_one_epoch(
sp: spm.SentencePieceProcessor,
train_dl: torch.utils.data.DataLoader,
valid_dl: torch.utils.data.DataLoader,
scaler: GradScaler,
tb_writer: Optional[SummaryWriter] = None,
world_size: int = 1,
rank: int = 0,
Expand All @@ -622,6 +644,8 @@ def train_one_epoch(
Dataloader for the training dataset.
valid_dl:
Dataloader for the validation dataset.
scaler:
The scaler used for mix precision training.
tb_writer:
Writer to write log messages to tensorboard.
world_size:
Expand All @@ -644,22 +668,24 @@ def train_one_epoch(
params.batch_idx_train += 1
batch_size = len(batch["supervisions"]["text"])

loss, loss_info = compute_loss(
params=params,
model=model,
sp=sp,
batch=batch,
is_training=True,
warmup=(params.batch_idx_train / params.model_warm_step)
)
with torch.cuda.amp.autocast(enabled=params.use_fp16):
loss, loss_info = compute_loss(
params=params,
model=model,
sp=sp,
batch=batch,
is_training=True,
warmup=(params.batch_idx_train / params.model_warm_step)
)
# summary stats
tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info

# NOTE: We use reduction==sum and loss is computed over utterances
# in the batch and there is no normalization to it so far.
loss.backward()
scaler.scale(loss).backward()
scheduler.step_batch(params.batch_idx_train)
optimizer.step()
scaler.step(optimizer)
scaler.update()
optimizer.zero_grad()

if params.print_diagnostics and batch_idx == 5:
Expand All @@ -676,6 +702,7 @@ def train_one_epoch(
optimizer=optimizer,
scheduler=scheduler,
sampler=train_dl.sampler,
scaler=scaler,
rank=rank,
)
del params.cur_batch_idx
Expand All @@ -695,7 +722,9 @@ def train_one_epoch(
)

if tb_writer is not None:
tb_writer.add_scalar("train/learning_rate", cur_params.batch_idx_train)
tb_writer.add_scalar(
"train/learning_rate", cur_lr, params.batch_idx_train
)

loss_info.write_summary(
tb_writer, "train/current_", params.batch_idx_train
Expand Down Expand Up @@ -850,6 +879,11 @@ def remove_short_and_long_utt(c: Cut):
params=params,
)

scaler = GradScaler(enabled=params.use_fp16)
if checkpoints and "grad_scaler" in checkpoints:
logging.info("Loading grad scaler state dict")
scaler.load_state_dict(checkpoints["grad_scaler"])

for epoch in range(params.start_epoch, params.num_epochs):
scheduler.step_epoch(epoch)
fix_random_seed(params.seed + epoch)
Expand All @@ -869,6 +903,7 @@ def remove_short_and_long_utt(c: Cut):
sp=sp,
train_dl=train_dl,
valid_dl=valid_dl,
scaler=scaler,
tb_writer=tb_writer,
world_size=world_size,
rank=rank,
Expand All @@ -884,6 +919,7 @@ def remove_short_and_long_utt(c: Cut):
optimizer=optimizer,
scheduler=scheduler,
sampler=train_dl.sampler,
scaler=scaler,
rank=rank,
)

Expand Down Expand Up @@ -913,14 +949,15 @@ def scan_pessimistic_batches_for_oom(
# warmup = 0.0 is so that the derivs for the pruned loss stay zero
# (i.e. are not remembered by the decaying-average in adam), because
# we want to avoid these params being subject to shrinkage in adam.
loss, _ = compute_loss(
params=params,
model=model,
sp=sp,
batch=batch,
is_training=True,
warmup = 0.0
)
with torch.cuda.amp.autocast(enabled=params.use_fp16):
loss, _ = compute_loss(
params=params,
model=model,
sp=sp,
batch=batch,
is_training=True,
warmup = 0.0
)
loss.backward()
optimizer.step()
optimizer.zero_grad()
Expand Down

0 comments on commit 7012fd6

Please sign in to comment.