Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add timer for steps #2416

Merged
merged 4 commits into from
Mar 18, 2024
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions wenet/bin/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,8 +114,8 @@ def main():

# Get executor
tag = configs["init_infos"].get("tag", "init")
executor = Executor()
executor.step = configs["init_infos"].get('step', -1) + int("step_" in tag)
executor = Executor(global_step=configs["init_infos"].get('step', -1) +
int("step_" in tag))

# Init scaler, used for pytorch amp mixed precision training
scaler = None
Expand Down
20 changes: 20 additions & 0 deletions wenet/utils/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
"""Unility functions for Transformer."""

import math
import time
from typing import List, Tuple

import torch
Expand Down Expand Up @@ -336,3 +337,22 @@ def mask_to_bias(mask: torch.Tensor, dtype: torch.dtype) -> torch.Tensor:
# chunk_masks = (1.0 - chunk_masks) * torch.finfo(dtype).min
mask = (1.0 - mask) * get_dtype_min(dtype)
return mask


class StepTimer:
"""Utility class for measuring steps/second."""

def __init__(self, step=0.0):
self.last_iteration = step
self.start()

def start(self):
self.last_time = time.time()

def steps_per_second(self, cur_step, restart=True):
value = ((float(cur_step) - self.last_iteration) /
(time.time() - self.last_time))
if restart:
self.start()
self.last_iteration = float(cur_step)
return value
21 changes: 17 additions & 4 deletions wenet/utils/executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
# if your python version < 3.7 use the below one
# from contextlib import suppress as nullcontext
import torch
from wenet.utils.common import StepTimer

from wenet.utils.train_utils import (wenet_join, batch_forward, batch_backward,
update_parameter_and_lr, log_per_step,
Expand All @@ -29,13 +30,17 @@

class Executor:

def __init__(self):
self.step = 0
def __init__(self, global_step: int = 0):
self.step = global_step
self.train_step_timer = None
self.cv_step_timer = None

def train(self, model, optimizer, scheduler, train_data_loader,
cv_data_loader, writer, configs, scaler, group_join):
''' Train one epoch
'''
if self.train_step_timer is None:
self.train_step_timer = StepTimer(self.step)
model.train()
info_dict = copy.deepcopy(configs)
logging.info('using accumulate grad, new batch size is {} times'
Expand Down Expand Up @@ -95,13 +100,18 @@ def train(self, model, optimizer, scheduler, train_data_loader,
optimizer.param_groups[0]['lr']
})
save_model(model, info_dict)
log_per_step(writer, info_dict)
log_per_step(writer, info_dict, timer=self.train_step_timer)
self.step += 1 if (batch_idx +
1) % info_dict["accum_grad"] == 0 else 0
self.train_step_timer.step = self.step
Mddct marked this conversation as resolved.
Show resolved Hide resolved

def cv(self, model, cv_data_loader, configs):
''' Cross validation on
'''
if self.cv_step_timer is None:
self.cv_step_timer = StepTimer(0.0)
else:
self.cv_step_timer.last_iteration = 0.0
model.eval()
info_dict = copy.deepcopy(configs)
num_seen_utts, loss_dict, total_acc = 1, {}, [] # avoid division by 0
Expand All @@ -110,6 +120,7 @@ def cv(self, model, cv_data_loader, configs):
info_dict["tag"] = "CV"
info_dict["step"] = self.step
info_dict["batch_idx"] = batch_idx
info_dict["cv_step"] = batch_idx

num_utts = batch_dict["target_lengths"].size(0)
if num_utts == 0:
Expand All @@ -128,7 +139,9 @@ def cv(self, model, cv_data_loader, configs):
loss_dict[loss_name] = loss_dict.get(loss_name, 0) + \
loss_value * num_utts

log_per_step(writer=None, info_dict=info_dict)
log_per_step(writer=None,
info_dict=info_dict,
timer=self.cv_step_timer)
for loss_name, loss_value in loss_dict.items():
loss_dict[loss_name] = loss_dict[loss_name] / num_seen_utts
loss_dict["acc"] = sum(total_acc) / len(total_acc)
Expand Down
15 changes: 12 additions & 3 deletions wenet/utils/train_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
# limitations under the License.

import copy
from typing import Optional
import deepspeed
import json
import logging
Expand All @@ -35,6 +36,7 @@
convert_zero_checkpoint_to_fp32_state_dict)
from wenet.dataset.dataset import Dataset
from wenet.utils.checkpoint import save_checkpoint
from wenet.utils.common import StepTimer
from wenet.utils.scheduler import WarmupLR, NoamHoldAnnealing
from wenet.utils.ctc_utils import get_blank_id

Expand Down Expand Up @@ -558,7 +560,7 @@ def update_parameter_and_lr(model, optimizer, scheduler, scaler, info_dict):
return info_dict


def log_per_step(writer, info_dict):
def log_per_step(writer, info_dict, timer: Optional[StepTimer] = None):
tag = info_dict["tag"]
step = info_dict["step"]
batch_idx = info_dict["batch_idx"]
Expand Down Expand Up @@ -589,8 +591,15 @@ def log_per_step(writer, info_dict):
writer.add_scalar('global_step/{}'.format(name), value, step + 1)

if (batch_idx + 1) % log_interval == 0:
log_str = '{} Batch {}/{} loss {:.6f} '.format(
tag, epoch, batch_idx + 1, loss_dict['loss'] * accum_grad)
log_str = '{} | '.format(tag)
if timer is not None:
timer_step = step
if info_dict.get("cv_step", None) is not None:
timer_step = info_dict['cv_step']
steps_per_second = timer.steps_per_second(timer_step)
log_str += 'steps/sec {:.1f}| '.format(steps_per_second)
log_str += 'Batch {}/{} loss {:.6f} '.format(
epoch, batch_idx + 1, loss_dict['loss'] * accum_grad)
for name, value in loss_dict.items():
if name != 'loss' and value is not None:
log_str += '{} {:.6f} '.format(name, value)
Expand Down
Loading