-
Notifications
You must be signed in to change notification settings - Fork 977
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Integrate MS-AMP Support for FP8 Precision #2224
Conversation
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
""" | ||
|
||
margin: int = 0 | ||
interval: int = 1 | ||
fp8_format: str = "E4M3" | ||
amax_history_len: int = 1 | ||
amax_history_len: int = 1024 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I noticed this default was different than the one NVIDIA has, so set it to 1024. Asking what is critical about it
Can we not use DeepSpeed with O3 FP8? Storing gradients in FP8 is already a great step, but storing weights in FP8 would be an even better step, especially when we also enable zero3 to split the model.
Small model which already sees improvements is really nice! Do you by chance have the time+compute to do a Llama/Mistral 7B run to check the difference there? |
@casper-hansen as mentioned in the PR, we cannot until they update the deepspeed version they require |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hello Zach, thank you for working on adding MS AMP FP8 support 🔥🚀✨! The experiments you performed with bloomz-560M are already showing nice improvements. Looking forward to experiments on larger scales.
Left a couple suggestions/comment.
|
||
def __post_init__(self): | ||
self.fp8_format = self.fp8_format.upper() | ||
if self.fp8_format not in ["E4M3", "HYBRID"]: | ||
raise ValueError("`fp8_format` must be 'E4M3' or 'HYBRID'.") | ||
if self.amax_compute_algo not in ["max", "most_recent"]: | ||
raise ValueError("`amax_compute_algo` must be 'max' or 'most_recent'") | ||
if self.enable_ms_amp and not is_msamp_available(): | ||
self.enable_ms_amp = False |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
raise warning mentioning that msamp
is not available but the flag is True
@@ -1286,13 +1292,16 @@ def prepare(self, *args, device_placement=None): | |||
result = tuple( | |||
self._prepare_one(obj, first_pass=True, device_placement=d) for obj, d in zip(args, device_placement) | |||
) | |||
if self.mixed_precision == "fp8" and self.fp8_recipe_handler.enable_ms_amp: | |||
# MS-AMP needs both model and optimizer | |||
args = self._prepare_ms_amp(*result) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
args = self._prepare_ms_amp(*result) | |
result = self._prepare_ms_amp(*result) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
given that the second pass takes the result
Also, what happens if we run this with latest DeepSpeed? |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for adding ms_amp support ! This looks very good overall. I left a couple of comments. One thing that can be added is some docs about mixed-precision training in the How-to Guide and Concept guide. Correct me if i'm wrong but I wasn't able to find much doc about that. This can be done in a follow-up PR. LMK what you think. The table you linked in this PR is a very good summary.
model: torch.nn.Module, | ||
device_placement: bool = None, | ||
evaluation_mode: bool = False, | ||
first_pass: bool = True, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Add it in the docstring. Moreover, maybe change the default value to False
to have the same behavior as _prepare_one
elif isinstance(obj, torch.optim.Optimizer): | ||
optimizer = self.prepare_optimizer(obj, device_placement=device_placement) | ||
return optimizer | ||
# Second pass of preparation: LR scheduler (which need the full list of optimizers) | ||
elif isinstance(obj, LRScheduler): | ||
scheduler = self.prepare_scheduler(obj) | ||
return scheduler | ||
# Second pass of preparation: FP8 with MS-AMP | ||
elif isinstance(obj, torch.nn.Module): | ||
return self.prepare_model(obj, device_placement=device_placement) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
should first_pass
be set to False
since the default value is True ?
Closing as there's a very odd performance bug with ending accuracy with |
Integrate MS-AMP to the
Accelerator
What does this add?
This PR introduces an additional backend for FP8 support through MS-AMP which has shown to decrease memory and increase throughput when using FP8 precision
Who is it for?
Individuals training with FP8 (H100/4090's, etc)
Issues linked to
Azure/MS-AMP#128
What parts of the API does this impact?
User-facing:
Two new arguments were added to the
FP8RecipeKwargs
:enable_ms_amp
(bool
): Whether a user should use MS-AMP. True by default if it's available in the environmentoptimization_level
(str
), should be one of"O1"
or"O2"
."O3"
is for DeepSpeed and we need to wait for them to update to v0.9.3 of deepspeed to match what Accelerate supportsGeneral guideline to optimization levels:
all_reduce
communications are done in fp8, reducing GPUmemory usage and communication bandwidth
Only available when using Adam or AdamW. This maintains accuracy and can potentially save the highest
memory.
are stored in FP8. If
fp8
is selected and deepspeed is enabled, will be used by default.(Not available currently).
As a result,
"O2"
is the default. Here is an overview of each optimization level and what it does, taken from their docs:Internal structure:
With how
fp8
optimization works with MS-AMP, we get the best bang for our buck if we combine both MS-AMP and transformers engine. As a result when preparing the model and optimizer we run through the same fix that exists for TPU optimizers so that we can replace the newte.Linear
layers with the equivalentms-amp
ones that increase throughput without decreasing performance.Basic Usage Example(s):
A user can either do:
Or use the
FP8RecipeKwargs
:Benchmarks
When running on
bloomz-560m
I saw the following speedups on the first 100 batches:Batch size: 8
Max seq length: 256
GPU used: single 4090
I also verified the loss curves for the experiment were identical