diff --git a/egs/librispeech/ASR/pruned_transducer_stateless/train.py b/egs/librispeech/ASR/pruned_transducer_stateless/train.py index b625ed3ff2..33b23038cf 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless/train.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless/train.py @@ -457,9 +457,6 @@ def load_checkpoint_if_available( if "cur_epoch" in saved_params: params["start_epoch"] = saved_params["cur_epoch"] - if "cur_batch_idx" in saved_params: - params["cur_batch_idx"] = saved_params["cur_batch_idx"] - return saved_params @@ -674,13 +671,7 @@ def maybe_log_weights(tag: str): global_step=params.batch_idx_train, ) - cur_batch_idx = params.get("cur_batch_idx", 0) - for batch_idx, batch in enumerate(train_dl): - if batch_idx < cur_batch_idx: - continue - cur_batch_idx = batch_idx - params.batch_idx_train += 1 batch_size = len(batch["supervisions"]["text"]) @@ -728,7 +719,6 @@ def maybe_log_weights(tag: str): params.batch_idx_train > 0 and params.batch_idx_train % params.save_every_n == 0 ): - params.cur_batch_idx = batch_idx save_checkpoint_with_global_batch_idx( out_dir=params.exp_dir, global_batch_idx=params.batch_idx_train, @@ -738,7 +728,6 @@ def maybe_log_weights(tag: str): sampler=train_dl.sampler, rank=rank, ) - del params.cur_batch_idx remove_checkpoints( out_dir=params.exp_dir, topk=params.keep_last_k, @@ -893,13 +882,14 @@ def remove_short_and_long_utt(c: Cut): valid_cuts += librispeech.dev_other_cuts() valid_dl = librispeech.valid_dataloaders(valid_cuts) - scan_pessimistic_batches_for_oom( - model=model, - train_dl=train_dl, - optimizer=optimizer, - sp=sp, - params=params, - ) + if params.start_batch <= 0: + scan_pessimistic_batches_for_oom( + model=model, + train_dl=train_dl, + optimizer=optimizer, + sp=sp, + params=params, + ) for epoch in range(params.start_epoch, params.num_epochs): fix_random_seed(params.seed + epoch) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/train.py b/egs/librispeech/ASR/pruned_transducer_stateless2/train.py index 46d2cb86d4..4d290e39fa 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/train.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/train.py @@ -503,9 +503,6 @@ def load_checkpoint_if_available( if "cur_epoch" in saved_params: params["start_epoch"] = saved_params["cur_epoch"] - if "cur_batch_idx" in saved_params: - params["cur_batch_idx"] = saved_params["cur_batch_idx"] - return saved_params @@ -724,13 +721,7 @@ def train_one_epoch( tot_loss = MetricsTracker() - cur_batch_idx = params.get("cur_batch_idx", 0) - for batch_idx, batch in enumerate(train_dl): - if batch_idx < cur_batch_idx: - continue - cur_batch_idx = batch_idx - params.batch_idx_train += 1 batch_size = len(batch["supervisions"]["text"]) @@ -765,7 +756,6 @@ def train_one_epoch( params.batch_idx_train > 0 and params.batch_idx_train % params.save_every_n == 0 ): - params.cur_batch_idx = batch_idx save_checkpoint_with_global_batch_idx( out_dir=params.exp_dir, global_batch_idx=params.batch_idx_train, @@ -777,7 +767,6 @@ def train_one_epoch( scaler=scaler, rank=rank, ) - del params.cur_batch_idx remove_checkpoints( out_dir=params.exp_dir, topk=params.keep_last_k, @@ -944,7 +933,7 @@ def remove_short_and_long_utt(c: Cut): valid_cuts += librispeech.dev_other_cuts() valid_dl = librispeech.valid_dataloaders(valid_cuts) - if not params.print_diagnostics: + if params.start_batch <= 0 and not params.print_diagnostics: scan_pessimistic_batches_for_oom( model=model, train_dl=train_dl, diff --git a/egs/librispeech/ASR/pruned_transducer_stateless3/train.py b/egs/librispeech/ASR/pruned_transducer_stateless3/train.py index ff9b8d8087..914b9b5eba 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless3/train.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless3/train.py @@ -1058,14 +1058,15 @@ def run(rank, world_size, args): # It's time consuming to include `giga_train_dl` here # for dl in [train_dl, giga_train_dl]: for dl in [train_dl]: - scan_pessimistic_batches_for_oom( - model=model, - train_dl=dl, - optimizer=optimizer, - sp=sp, - params=params, - warmup=0.0 if params.start_epoch == 0 else 1.0, - ) + if params.start_batch <= 0: + scan_pessimistic_batches_for_oom( + model=model, + train_dl=dl, + optimizer=optimizer, + sp=sp, + params=params, + warmup=0.0 if params.start_epoch == 0 else 1.0, + ) scaler = GradScaler(enabled=params.use_fp16) if checkpoints and "grad_scaler" in checkpoints: diff --git a/egs/librispeech/ASR/pruned_transducer_stateless4/train.py b/egs/librispeech/ASR/pruned_transducer_stateless4/train.py index 893a6a749a..325b013230 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless4/train.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless4/train.py @@ -525,9 +525,6 @@ def load_checkpoint_if_available( if "cur_epoch" in saved_params: params["start_epoch"] = saved_params["cur_epoch"] - if "cur_batch_idx" in saved_params: - params["cur_batch_idx"] = saved_params["cur_batch_idx"] - return saved_params @@ -757,13 +754,7 @@ def train_one_epoch( tot_loss = MetricsTracker() - cur_batch_idx = params.get("cur_batch_idx", 0) - for batch_idx, batch in enumerate(train_dl): - if batch_idx < cur_batch_idx: - continue - cur_batch_idx = batch_idx - params.batch_idx_train += 1 batch_size = len(batch["supervisions"]["text"]) @@ -805,7 +796,6 @@ def train_one_epoch( params.batch_idx_train > 0 and params.batch_idx_train % params.save_every_n == 0 ): - params.cur_batch_idx = batch_idx save_checkpoint_with_global_batch_idx( out_dir=params.exp_dir, global_batch_idx=params.batch_idx_train, @@ -818,7 +808,6 @@ def train_one_epoch( scaler=scaler, rank=rank, ) - del params.cur_batch_idx remove_checkpoints( out_dir=params.exp_dir, topk=params.keep_last_k, @@ -993,7 +982,7 @@ def remove_short_and_long_utt(c: Cut): valid_cuts += librispeech.dev_other_cuts() valid_dl = librispeech.valid_dataloaders(valid_cuts) - if not params.print_diagnostics: + if params.start_batch <= 0 and not params.print_diagnostics: scan_pessimistic_batches_for_oom( model=model, train_dl=train_dl, diff --git a/egs/librispeech/ASR/pruned_transducer_stateless5/train.py b/egs/librispeech/ASR/pruned_transducer_stateless5/train.py index 44abdcd490..3bfe22155f 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless5/train.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless5/train.py @@ -550,9 +550,6 @@ def load_checkpoint_if_available( if "cur_epoch" in saved_params: params["start_epoch"] = saved_params["cur_epoch"] - if "cur_batch_idx" in saved_params: - params["cur_batch_idx"] = saved_params["cur_batch_idx"] - return saved_params @@ -782,13 +779,7 @@ def train_one_epoch( tot_loss = MetricsTracker() - cur_batch_idx = params.get("cur_batch_idx", 0) - for batch_idx, batch in enumerate(train_dl): - if batch_idx < cur_batch_idx: - continue - cur_batch_idx = batch_idx - params.batch_idx_train += 1 batch_size = len(batch["supervisions"]["text"]) @@ -834,7 +825,6 @@ def train_one_epoch( params.batch_idx_train > 0 and params.batch_idx_train % params.save_every_n == 0 ): - params.cur_batch_idx = batch_idx save_checkpoint_with_global_batch_idx( out_dir=params.exp_dir, global_batch_idx=params.batch_idx_train, @@ -847,7 +837,6 @@ def train_one_epoch( scaler=scaler, rank=rank, ) - del params.cur_batch_idx remove_checkpoints( out_dir=params.exp_dir, topk=params.keep_last_k, @@ -1025,7 +1014,7 @@ def remove_short_and_long_utt(c: Cut): valid_cuts += librispeech.dev_other_cuts() valid_dl = librispeech.valid_dataloaders(valid_cuts) - if not params.print_diagnostics: + if params.start_batch <= 0 and not params.print_diagnostics: scan_pessimistic_batches_for_oom( model=model, train_dl=train_dl, diff --git a/egs/librispeech/ASR/pruned_transducer_stateless6/train.py b/egs/librispeech/ASR/pruned_transducer_stateless6/train.py index 596f8f7d93..a4595211c3 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless6/train.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless6/train.py @@ -507,9 +507,6 @@ def load_checkpoint_if_available( if "cur_epoch" in saved_params: params["start_epoch"] = saved_params["cur_epoch"] - if "cur_batch_idx" in saved_params: - params["cur_batch_idx"] = saved_params["cur_batch_idx"] - return saved_params @@ -763,13 +760,7 @@ def train_one_epoch( tot_loss = MetricsTracker() - cur_batch_idx = params.get("cur_batch_idx", 0) - for batch_idx, batch in enumerate(train_dl): - if batch_idx < cur_batch_idx: - continue - cur_batch_idx = batch_idx - params.batch_idx_train += 1 batch_size = len(batch["supervisions"]["text"]) @@ -811,7 +802,6 @@ def train_one_epoch( params.batch_idx_train > 0 and params.batch_idx_train % params.save_every_n == 0 ): - params.cur_batch_idx = batch_idx save_checkpoint_with_global_batch_idx( out_dir=params.exp_dir, global_batch_idx=params.batch_idx_train, @@ -824,7 +814,6 @@ def train_one_epoch( scaler=scaler, rank=rank, ) - del params.cur_batch_idx remove_checkpoints( out_dir=params.exp_dir, topk=params.keep_last_k, @@ -999,7 +988,7 @@ def remove_short_and_long_utt(c: Cut): valid_cuts += librispeech.dev_other_cuts() valid_dl = librispeech.valid_dataloaders(valid_cuts) - if not params.print_diagnostics: + if params.start_batch <= 0 and not params.print_diagnostics: scan_pessimistic_batches_for_oom( model=model, train_dl=train_dl,