Skip to content

Commit

Permalink
Fix loading sampler state dict. (#421)
Browse files Browse the repository at this point in the history
* Fix loading sampler state dict.

* skip scan_pessimistic_batches_for_oom if params.start_batch > 0
  • Loading branch information
csukuangfj authored Aug 6, 2022
1 parent 7157f62 commit 1f7832b
Show file tree
Hide file tree
Showing 6 changed files with 21 additions and 74 deletions.
26 changes: 8 additions & 18 deletions egs/librispeech/ASR/pruned_transducer_stateless/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -457,9 +457,6 @@ def load_checkpoint_if_available(
if "cur_epoch" in saved_params:
params["start_epoch"] = saved_params["cur_epoch"]

if "cur_batch_idx" in saved_params:
params["cur_batch_idx"] = saved_params["cur_batch_idx"]

return saved_params


Expand Down Expand Up @@ -674,13 +671,7 @@ def maybe_log_weights(tag: str):
global_step=params.batch_idx_train,
)

cur_batch_idx = params.get("cur_batch_idx", 0)

for batch_idx, batch in enumerate(train_dl):
if batch_idx < cur_batch_idx:
continue
cur_batch_idx = batch_idx

params.batch_idx_train += 1
batch_size = len(batch["supervisions"]["text"])

Expand Down Expand Up @@ -728,7 +719,6 @@ def maybe_log_weights(tag: str):
params.batch_idx_train > 0
and params.batch_idx_train % params.save_every_n == 0
):
params.cur_batch_idx = batch_idx
save_checkpoint_with_global_batch_idx(
out_dir=params.exp_dir,
global_batch_idx=params.batch_idx_train,
Expand All @@ -738,7 +728,6 @@ def maybe_log_weights(tag: str):
sampler=train_dl.sampler,
rank=rank,
)
del params.cur_batch_idx
remove_checkpoints(
out_dir=params.exp_dir,
topk=params.keep_last_k,
Expand Down Expand Up @@ -893,13 +882,14 @@ def remove_short_and_long_utt(c: Cut):
valid_cuts += librispeech.dev_other_cuts()
valid_dl = librispeech.valid_dataloaders(valid_cuts)

scan_pessimistic_batches_for_oom(
model=model,
train_dl=train_dl,
optimizer=optimizer,
sp=sp,
params=params,
)
if params.start_batch <= 0:
scan_pessimistic_batches_for_oom(
model=model,
train_dl=train_dl,
optimizer=optimizer,
sp=sp,
params=params,
)

for epoch in range(params.start_epoch, params.num_epochs):
fix_random_seed(params.seed + epoch)
Expand Down
13 changes: 1 addition & 12 deletions egs/librispeech/ASR/pruned_transducer_stateless2/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -503,9 +503,6 @@ def load_checkpoint_if_available(
if "cur_epoch" in saved_params:
params["start_epoch"] = saved_params["cur_epoch"]

if "cur_batch_idx" in saved_params:
params["cur_batch_idx"] = saved_params["cur_batch_idx"]

return saved_params


Expand Down Expand Up @@ -724,13 +721,7 @@ def train_one_epoch(

tot_loss = MetricsTracker()

cur_batch_idx = params.get("cur_batch_idx", 0)

for batch_idx, batch in enumerate(train_dl):
if batch_idx < cur_batch_idx:
continue
cur_batch_idx = batch_idx

params.batch_idx_train += 1
batch_size = len(batch["supervisions"]["text"])

Expand Down Expand Up @@ -765,7 +756,6 @@ def train_one_epoch(
params.batch_idx_train > 0
and params.batch_idx_train % params.save_every_n == 0
):
params.cur_batch_idx = batch_idx
save_checkpoint_with_global_batch_idx(
out_dir=params.exp_dir,
global_batch_idx=params.batch_idx_train,
Expand All @@ -777,7 +767,6 @@ def train_one_epoch(
scaler=scaler,
rank=rank,
)
del params.cur_batch_idx
remove_checkpoints(
out_dir=params.exp_dir,
topk=params.keep_last_k,
Expand Down Expand Up @@ -944,7 +933,7 @@ def remove_short_and_long_utt(c: Cut):
valid_cuts += librispeech.dev_other_cuts()
valid_dl = librispeech.valid_dataloaders(valid_cuts)

if not params.print_diagnostics:
if params.start_batch <= 0 and not params.print_diagnostics:
scan_pessimistic_batches_for_oom(
model=model,
train_dl=train_dl,
Expand Down
17 changes: 9 additions & 8 deletions egs/librispeech/ASR/pruned_transducer_stateless3/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -1058,14 +1058,15 @@ def run(rank, world_size, args):
# It's time consuming to include `giga_train_dl` here
# for dl in [train_dl, giga_train_dl]:
for dl in [train_dl]:
scan_pessimistic_batches_for_oom(
model=model,
train_dl=dl,
optimizer=optimizer,
sp=sp,
params=params,
warmup=0.0 if params.start_epoch == 0 else 1.0,
)
if params.start_batch <= 0:
scan_pessimistic_batches_for_oom(
model=model,
train_dl=dl,
optimizer=optimizer,
sp=sp,
params=params,
warmup=0.0 if params.start_epoch == 0 else 1.0,
)

scaler = GradScaler(enabled=params.use_fp16)
if checkpoints and "grad_scaler" in checkpoints:
Expand Down
13 changes: 1 addition & 12 deletions egs/librispeech/ASR/pruned_transducer_stateless4/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -525,9 +525,6 @@ def load_checkpoint_if_available(
if "cur_epoch" in saved_params:
params["start_epoch"] = saved_params["cur_epoch"]

if "cur_batch_idx" in saved_params:
params["cur_batch_idx"] = saved_params["cur_batch_idx"]

return saved_params


Expand Down Expand Up @@ -757,13 +754,7 @@ def train_one_epoch(

tot_loss = MetricsTracker()

cur_batch_idx = params.get("cur_batch_idx", 0)

for batch_idx, batch in enumerate(train_dl):
if batch_idx < cur_batch_idx:
continue
cur_batch_idx = batch_idx

params.batch_idx_train += 1
batch_size = len(batch["supervisions"]["text"])

Expand Down Expand Up @@ -805,7 +796,6 @@ def train_one_epoch(
params.batch_idx_train > 0
and params.batch_idx_train % params.save_every_n == 0
):
params.cur_batch_idx = batch_idx
save_checkpoint_with_global_batch_idx(
out_dir=params.exp_dir,
global_batch_idx=params.batch_idx_train,
Expand All @@ -818,7 +808,6 @@ def train_one_epoch(
scaler=scaler,
rank=rank,
)
del params.cur_batch_idx
remove_checkpoints(
out_dir=params.exp_dir,
topk=params.keep_last_k,
Expand Down Expand Up @@ -993,7 +982,7 @@ def remove_short_and_long_utt(c: Cut):
valid_cuts += librispeech.dev_other_cuts()
valid_dl = librispeech.valid_dataloaders(valid_cuts)

if not params.print_diagnostics:
if params.start_batch <= 0 and not params.print_diagnostics:
scan_pessimistic_batches_for_oom(
model=model,
train_dl=train_dl,
Expand Down
13 changes: 1 addition & 12 deletions egs/librispeech/ASR/pruned_transducer_stateless5/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -550,9 +550,6 @@ def load_checkpoint_if_available(
if "cur_epoch" in saved_params:
params["start_epoch"] = saved_params["cur_epoch"]

if "cur_batch_idx" in saved_params:
params["cur_batch_idx"] = saved_params["cur_batch_idx"]

return saved_params


Expand Down Expand Up @@ -782,13 +779,7 @@ def train_one_epoch(

tot_loss = MetricsTracker()

cur_batch_idx = params.get("cur_batch_idx", 0)

for batch_idx, batch in enumerate(train_dl):
if batch_idx < cur_batch_idx:
continue
cur_batch_idx = batch_idx

params.batch_idx_train += 1
batch_size = len(batch["supervisions"]["text"])

Expand Down Expand Up @@ -834,7 +825,6 @@ def train_one_epoch(
params.batch_idx_train > 0
and params.batch_idx_train % params.save_every_n == 0
):
params.cur_batch_idx = batch_idx
save_checkpoint_with_global_batch_idx(
out_dir=params.exp_dir,
global_batch_idx=params.batch_idx_train,
Expand All @@ -847,7 +837,6 @@ def train_one_epoch(
scaler=scaler,
rank=rank,
)
del params.cur_batch_idx
remove_checkpoints(
out_dir=params.exp_dir,
topk=params.keep_last_k,
Expand Down Expand Up @@ -1025,7 +1014,7 @@ def remove_short_and_long_utt(c: Cut):
valid_cuts += librispeech.dev_other_cuts()
valid_dl = librispeech.valid_dataloaders(valid_cuts)

if not params.print_diagnostics:
if params.start_batch <= 0 and not params.print_diagnostics:
scan_pessimistic_batches_for_oom(
model=model,
train_dl=train_dl,
Expand Down
13 changes: 1 addition & 12 deletions egs/librispeech/ASR/pruned_transducer_stateless6/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -507,9 +507,6 @@ def load_checkpoint_if_available(
if "cur_epoch" in saved_params:
params["start_epoch"] = saved_params["cur_epoch"]

if "cur_batch_idx" in saved_params:
params["cur_batch_idx"] = saved_params["cur_batch_idx"]

return saved_params


Expand Down Expand Up @@ -763,13 +760,7 @@ def train_one_epoch(

tot_loss = MetricsTracker()

cur_batch_idx = params.get("cur_batch_idx", 0)

for batch_idx, batch in enumerate(train_dl):
if batch_idx < cur_batch_idx:
continue
cur_batch_idx = batch_idx

params.batch_idx_train += 1
batch_size = len(batch["supervisions"]["text"])

Expand Down Expand Up @@ -811,7 +802,6 @@ def train_one_epoch(
params.batch_idx_train > 0
and params.batch_idx_train % params.save_every_n == 0
):
params.cur_batch_idx = batch_idx
save_checkpoint_with_global_batch_idx(
out_dir=params.exp_dir,
global_batch_idx=params.batch_idx_train,
Expand All @@ -824,7 +814,6 @@ def train_one_epoch(
scaler=scaler,
rank=rank,
)
del params.cur_batch_idx
remove_checkpoints(
out_dir=params.exp_dir,
topk=params.keep_last_k,
Expand Down Expand Up @@ -999,7 +988,7 @@ def remove_short_and_long_utt(c: Cut):
valid_cuts += librispeech.dev_other_cuts()
valid_dl = librispeech.valid_dataloaders(valid_cuts)

if not params.print_diagnostics:
if params.start_batch <= 0 and not params.print_diagnostics:
scan_pessimistic_batches_for_oom(
model=model,
train_dl=train_dl,
Expand Down

0 comments on commit 1f7832b

Please sign in to comment.