Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Integrate MS-AMP Support for FP8 Precision #2224

Closed
wants to merge 16 commits into from
Closed

Integrate MS-AMP Support for FP8 Precision #2224

wants to merge 16 commits into from

Conversation

muellerzr
Copy link
Collaborator

@muellerzr muellerzr commented Dec 6, 2023

Integrate MS-AMP to the Accelerator

What does this add?

This PR introduces an additional backend for FP8 support through MS-AMP which has shown to decrease memory and increase throughput when using FP8 precision

Who is it for?

Individuals training with FP8 (H100/4090's, etc)

Issues linked to

Azure/MS-AMP#128

What parts of the API does this impact?

User-facing:

Two new arguments were added to the FP8RecipeKwargs:

  • enable_ms_amp (bool): Whether a user should use MS-AMP. True by default if it's available in the environment
  • optimization_level (str), should be one of "O1" or "O2". "O3" is for DeepSpeed and we need to wait for them to update to v0.9.3 of deepspeed to match what Accelerate supports

General guideline to optimization levels:

  • O1: Weight gradients and all_reduce communications are done in fp8, reducing GPU
    memory usage and communication bandwidth
  • O2: First-order optimizer states are in 8-bit, and second order states are in FP16.
    Only available when using Adam or AdamW. This maintains accuracy and can potentially save the highest
    memory.
  • 03: Specifically for DeepSpeed, implements capabilities so weights and master weights of models
    are stored in FP8. If fp8 is selected and deepspeed is enabled, will be used by default.
    (Not available currently).

As a result, "O2" is the default. Here is an overview of each optimization level and what it does, taken from their docs:

Optimization Level Computation(GEMM) Comm Weight Master Weight Weight Gradient Optimizer States
FP16 AMP FP16 FP32 FP32 N/A FP32 FP32+FP32
Nvidia TE FP8 FP32 FP32 N/A FP32 FP32+FP32
MS-AMP O1 FP8 FP8 FP16 N/A FP8 FP32+FP32
MS-AMP O2 FP8 FP8 FP16 N/A FP8 FP8+FP16
MS-AMP O3 FP8 FP8 FP8 FP16 FP8 FP8+FP16

Internal structure:

With how fp8 optimization works with MS-AMP, we get the best bang for our buck if we combine both MS-AMP and transformers engine. As a result when preparing the model and optimizer we run through the same fix that exists for TPU optimizers so that we can replace the new te.Linear layers with the equivalent ms-amp ones that increase throughput without decreasing performance.

Basic Usage Example(s):

A user can either do:

accelerator = Accelerator(mixed_precision="fp8")

Or use the FP8RecipeKwargs:

# To disable MS-AMP if available in the environment
kwarg_handlers = [FP8RecipeKwargs(enable_ms_amp=False)]

# To change the optimization level
kwarg_handlers = [FP8RecipeKwargs(optimization_level="O1")]

accelerator = Accelerator(
    mixed_precision="fp8",
    kwargs_handlers=kwarg_handlers,
)

Benchmarks

When running on bloomz-560m I saw the following speedups on the first 100 batches:

Batch size: 8
Max seq length: 256
GPU used: single 4090

Model Configuration Time per Batch Peak Memory
FP16 0.183s 20.62 GB
BF16 0.139s 15.41 GB
Raw TransformersEngine 0.129s 12.02 GB
TE + MS-AMP 0.108s 10.66 GB

I also verified the loss curves for the experiment were identical

@muellerzr muellerzr linked an issue Dec 6, 2023 that may be closed by this pull request
@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

"""

margin: int = 0
interval: int = 1
fp8_format: str = "E4M3"
amax_history_len: int = 1
amax_history_len: int = 1024
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I noticed this default was different than the one NVIDIA has, so set it to 1024. Asking what is critical about it

@casper-hansen
Copy link

  • 03: Specifically for DeepSpeed, implements capabilities so weights and master weights of models
    are stored in FP8. If fp8 is selected and deepspeed is enabled, will be used by default.
    (Not available currently).

Can we not use DeepSpeed with O3 FP8? Storing gradients in FP8 is already a great step, but storing weights in FP8 would be an even better step, especially when we also enable zero3 to split the model.

bloomz-560m

Small model which already sees improvements is really nice! Do you by chance have the time+compute to do a Llama/Mistral 7B run to check the difference there?

@muellerzr
Copy link
Collaborator Author

@casper-hansen as mentioned in the PR, we cannot until they update the deepspeed version they require

Copy link
Contributor

@pacman100 pacman100 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hello Zach, thank you for working on adding MS AMP FP8 support 🔥🚀✨! The experiments you performed with bloomz-560M are already showing nice improvements. Looking forward to experiments on larger scales.

Left a couple suggestions/comment.


def __post_init__(self):
self.fp8_format = self.fp8_format.upper()
if self.fp8_format not in ["E4M3", "HYBRID"]:
raise ValueError("`fp8_format` must be 'E4M3' or 'HYBRID'.")
if self.amax_compute_algo not in ["max", "most_recent"]:
raise ValueError("`amax_compute_algo` must be 'max' or 'most_recent'")
if self.enable_ms_amp and not is_msamp_available():
self.enable_ms_amp = False
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

raise warning mentioning that msamp is not available but the flag is True

@@ -1286,13 +1292,16 @@ def prepare(self, *args, device_placement=None):
result = tuple(
self._prepare_one(obj, first_pass=True, device_placement=d) for obj, d in zip(args, device_placement)
)
if self.mixed_precision == "fp8" and self.fp8_recipe_handler.enable_ms_amp:
# MS-AMP needs both model and optimizer
args = self._prepare_ms_amp(*result)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
args = self._prepare_ms_amp(*result)
result = self._prepare_ms_amp(*result)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

given that the second pass takes the result

@pacman100
Copy link
Contributor

Also, what happens if we run this with latest DeepSpeed?

Copy link
Member

@SunMarc SunMarc left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for adding ms_amp support ! This looks very good overall. I left a couple of comments. One thing that can be added is some docs about mixed-precision training in the How-to Guide and Concept guide. Correct me if i'm wrong but I wasn't able to find much doc about that. This can be done in a follow-up PR. LMK what you think. The table you linked in this PR is a very good summary.

model: torch.nn.Module,
device_placement: bool = None,
evaluation_mode: bool = False,
first_pass: bool = True,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add it in the docstring. Moreover, maybe change the default value to False to have the same behavior as _prepare_one

elif isinstance(obj, torch.optim.Optimizer):
optimizer = self.prepare_optimizer(obj, device_placement=device_placement)
return optimizer
# Second pass of preparation: LR scheduler (which need the full list of optimizers)
elif isinstance(obj, LRScheduler):
scheduler = self.prepare_scheduler(obj)
return scheduler
# Second pass of preparation: FP8 with MS-AMP
elif isinstance(obj, torch.nn.Module):
return self.prepare_model(obj, device_placement=device_placement)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should first_pass be set to False since the default value is True ?

@muellerzr muellerzr closed this Dec 7, 2023
@muellerzr
Copy link
Collaborator Author

Closing as there's a very odd performance bug with ending accuracy with bert using TransformerEngine. Revisiting this implementation in a seperate PR in the guise as an alternative framework instead of directly wrapping with TE

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Feature Request: Support MS-AMP
5 participants