Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add function display_and_save_batch in wenetspeech/pruned_transducer_stateless2/train.py #528

Merged
merged 4 commits into from
Aug 13, 2022
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
72 changes: 53 additions & 19 deletions egs/wenetspeech/ASR/pruned_transducer_stateless2/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -701,25 +701,29 @@ def train_one_epoch(
params.batch_idx_train += 1
batch_size = len(batch["supervisions"]["text"])

with torch.cuda.amp.autocast(enabled=params.use_fp16):
loss, loss_info = compute_loss(
params=params,
model=model,
graph_compiler=graph_compiler,
batch=batch,
is_training=True,
warmup=(params.batch_idx_train / params.model_warm_step),
)
# summary stats
tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info

# NOTE: We use reduction==sum and loss is computed over utterances
# in the batch and there is no normalization to it so far.
scaler.scale(loss).backward()
scheduler.step_batch(params.batch_idx_train)
scaler.step(optimizer)
scaler.update()
optimizer.zero_grad()
try:
with torch.cuda.amp.autocast(enabled=params.use_fp16):
loss, loss_info = compute_loss(
params=params,
model=model,
graph_compiler=graph_compiler,
batch=batch,
is_training=True,
warmup=(params.batch_idx_train / params.model_warm_step),
)
# summary stats
tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info

# NOTE: We use reduction==sum and loss is computed over utterances
# in the batch and there is no normalization to it so far.
scaler.scale(loss).backward()
scheduler.step_batch(params.batch_idx_train)
scaler.step(optimizer)
scaler.update()
optimizer.zero_grad()
except: # noqa
display_and_save_batch(batch, params=params)
raise

if params.print_diagnostics and batch_idx == 5:
return
Expand Down Expand Up @@ -958,6 +962,35 @@ def remove_short_and_long_utt(c: Cut):
cleanup_dist()


def display_and_save_batch(
batch: dict,
params: AttributeDict,
) -> None:
"""Display the batch statistics and save the batch into disk.

Args:
batch:
A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()`
for the content in it.
params:
Parameters for training. See :func:`get_params`.
"""
from lhotse.utils import uuid4

filename = f"{params.exp_dir}/batch-{uuid4()}.pt"
logging.info(f"Saving batch to {filename}")
torch.save(batch, filename)

features = batch["inputs"]

logging.info(f"features shape: {features.shape}")

texts = batch["supervisions"]["text"]
num_tokens = sum(len(i) for i in texts)

logging.info(f"num tokens: {num_tokens}")


def scan_pessimistic_batches_for_oom(
model: nn.Module,
train_dl: torch.utils.data.DataLoader,
Expand Down Expand Up @@ -998,6 +1031,7 @@ def scan_pessimistic_batches_for_oom(
f"Failing criterion: {criterion} "
f"(={crit_values[criterion]}) ..."
)
display_and_save_batch(batch, params=params)
raise


Expand Down