-
Notifications
You must be signed in to change notification settings - Fork 297
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Not for merge] Madam optimizer with OOM handling #8
base: master
Are you sure you want to change the base?
Conversation
graph_compiler=graph_compiler, | ||
is_training=is_training, | ||
) | ||
except RuntimeError as ex: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is where OOM is handled. I am testing it. Not sure whether it works or not.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not sure if it's helpful -- there is a repo that has some utilties for finding the optimal batch size in PyTorch and it has some code to deal with OOM. Maybe there is something useful that can be borrowed https://github.com/BlackHC/toma
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hmm, after looking into the DDP implementation, it looks to me like the backward gradient reduction is not done in one pass after the backward is completed on individual machines, but is done during the .backward() of the model. So I think catching errors that happen during .backward() is not going to be possible, and likewise for errors that happen during the top-level model forward() function, because ddp seems to have sync points there. Errors in the forward for CTC or transformer decoder may be catchable though.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
... but we can't do a new top-level model forward on new data. Might be possible to re-do a CTC or transformer forward with a subset of the minibatch, since its backward pass would be structurally similar to the one on the entire minibatch.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The current code is able to catch CUDA OOM exceptions during model.forward()
when trained with a single GPU,
but it hangs when trained with two GPUs using DDP. I am looking into where it hangs.
pytorch/pytorch#18853 (comment)
says it is possible to catch CUDA OOM exceptions also for DDP training with a single GPU.
We are currently catching exceptions only for the forward pass, i..e, model.forward
, get_tot_scores
, attetion decoder
.
So I think catching errors that happen during .backward() is not going to be possible,
Agreed.
from icefall.bpe_graph_compiler import BpeCtcTrainingGraphCompiler | ||
from icefall.checkpoint import load_checkpoint | ||
from icefall.checkpoint import save_checkpoint as save_checkpoint_impl | ||
from icefall.dataset.librispeech import LibriSpeechAsrDataModule |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I strongly suggest making this class local too (e.g. in a local data.py
file) -- in the current repo layout, it will make it much easier to experiment with different types of data setups and augmentations.
x = x.view(-1, self.size) | ||
target = target.view(-1) | ||
with torch.no_grad(): | ||
true_dist = x.clone() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
instead of x.clone(), torch.empty_like(x) would be more approprate..
We have to disable batch norm layers. Otherwise, the process will hang indefinitely.
# NOTE(fangjun): The process hangs when using DDP | ||
# if we try to recover from CUDA OOM, so we disable | ||
# batchnorm layer here. | ||
# self.norm = nn.BatchNorm1d(channels) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
After disabling batch norm, training with DDP can now recover from OOM in model.forward()
as expected.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Mm. Hopefully with the madam optimizer, the training will still be stable without the batchnorm. We'll have to see. Obviously would have to compare the performance after this change.
I will port OOM handling to LF-MMI training as well. |
Put it here for discussion. Not ready for merge.
It contains
max-duration
(e.g., from 200 to 350)tensorboard log: https://tensorboard.dev/experiment/WCQbgwK2T0OI9kCWjHPOSw/#scalarstensorboard log: https://tensorboard.dev/experiment/BedA6yRKRyGpFB2wY709fQ/