forked from NVIDIA/NeMo
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Use model-cast-to-bfloat16 rather than AMP-to-bfloat16 for inference. (…
…NVIDIA#9198) * Fix the "cast ping pong" problem when we run AMP inference. This has been tested only for Parakeet-CTC-1.1B right now. This problem certainly exists elsewhere. Automatic mixed precision and inference do not play well together. First, automatic mixed precision was created back when neural networks were much simpler. In particular, they did not have softmax and layer norm as frequent operations. In the era of transformers, softmax and layer norm are very common. AMP will uncoditionally output fp32 outputs from these operations, even if their inputs are fp16. See here: https://pytorch.org/docs/stable/amp.html#cuda-ops-that-can-autocast-to-float32 This is no longer necessary, now that layer norm does accumulation in fp32 in pytorch, even if the input is fp16: pytorch/pytorch#66707 Do infernece by casting model to bfloat16, not by using AMP. Do feature preprocessing in float32 for accuracy. Warn if someone tries to input a non-float32 tensor. Always create the output in the type the rest of the model expects. Sort manifests by duration. Signed-off-by: Daniel Galvez <[email protected]> * Always cast softmax inputs to float32 when in training mode. While we don't need this for accurate results in b/float16, this is a safety precaution to make sure that training accuracy does not regress. Signed-off-by: Daniel Galvez <[email protected]> --------- Signed-off-by: Daniel Galvez <[email protected]>
- Loading branch information
Showing
10 changed files
with
120 additions
and
46 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.