Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Not for merge] Madam optimizer with OOM handling #8

Open
wants to merge 12 commits into
base: master
Choose a base branch
from

Conversation

csukuangfj
Copy link
Collaborator

@csukuangfj csukuangfj commented Aug 14, 2021

Put it here for discussion. Not ready for merge.

It contains

  • madam optimizer from Dan
  • Handle OOM when using a larger max-duration (e.g., from 200 to 350)

tensorboard log: https://tensorboard.dev/experiment/WCQbgwK2T0OI9kCWjHPOSw/#scalars
tensorboard log: https://tensorboard.dev/experiment/BedA6yRKRyGpFB2wY709fQ/

@csukuangfj csukuangfj changed the title [Not for merge] Madam oom [Not for merge] Madam optimizer with OOM handling Aug 14, 2021
graph_compiler=graph_compiler,
is_training=is_training,
)
except RuntimeError as ex:
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is where OOM is handled. I am testing it. Not sure whether it works or not.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure if it's helpful -- there is a repo that has some utilties for finding the optimal batch size in PyTorch and it has some code to deal with OOM. Maybe there is something useful that can be borrowed https://github.com/BlackHC/toma

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm, after looking into the DDP implementation, it looks to me like the backward gradient reduction is not done in one pass after the backward is completed on individual machines, but is done during the .backward() of the model. So I think catching errors that happen during .backward() is not going to be possible, and likewise for errors that happen during the top-level model forward() function, because ddp seems to have sync points there. Errors in the forward for CTC or transformer decoder may be catchable though.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

... but we can't do a new top-level model forward on new data. Might be possible to re-do a CTC or transformer forward with a subset of the minibatch, since its backward pass would be structurally similar to the one on the entire minibatch.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The current code is able to catch CUDA OOM exceptions during model.forward() when trained with a single GPU,
but it hangs when trained with two GPUs using DDP. I am looking into where it hangs.

pytorch/pytorch#18853 (comment)
says it is possible to catch CUDA OOM exceptions also for DDP training with a single GPU.


We are currently catching exceptions only for the forward pass, i..e, model.forward, get_tot_scores, attetion decoder.


So I think catching errors that happen during .backward() is not going to be possible,

Agreed.

from icefall.bpe_graph_compiler import BpeCtcTrainingGraphCompiler
from icefall.checkpoint import load_checkpoint
from icefall.checkpoint import save_checkpoint as save_checkpoint_impl
from icefall.dataset.librispeech import LibriSpeechAsrDataModule
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I strongly suggest making this class local too (e.g. in a local data.py file) -- in the current repo layout, it will make it much easier to experiment with different types of data setups and augmentations.

x = x.view(-1, self.size)
target = target.view(-1)
with torch.no_grad():
true_dist = x.clone()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

instead of x.clone(), torch.empty_like(x) would be more approprate..

We have to disable batch norm layers. Otherwise,
the process will hang indefinitely.
# NOTE(fangjun): The process hangs when using DDP
# if we try to recover from CUDA OOM, so we disable
# batchnorm layer here.
# self.norm = nn.BatchNorm1d(channels)
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

After disabling batch norm, training with DDP can now recover from OOM in model.forward() as expected.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Mm. Hopefully with the madam optimizer, the training will still be stable without the batchnorm. We'll have to see. Obviously would have to compare the performance after this change.

@csukuangfj
Copy link
Collaborator Author

I will port OOM handling to LF-MMI training as well.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants