Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support CUDA Graph for MoE models #1233

Open
wants to merge 11 commits into
base: main
Choose a base branch
from
Open

Conversation

buptzyb
Copy link

@buptzyb buptzyb commented Oct 9, 2024

Description

Different from non-MoE models like llama2, MoE models have dynamic shaped activations in FFN layers, so one cudagraph can only capture a part of one transformer layer, instead of covering the whole layer. We call this a "breaking-layer" cudagraph mode. This PR adds breaking-layer cudagraph supports for MoE models on the TE side, and fixes several related bugs in TE.

Fixes # (issue)

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refractor

Changes

Please list the changes introduced in this PR:

  • Add is_initialized() method in CudaRNGStatesTracker to align with what is already done in MCore.
  • Fix wrong per_callable_module_params order bug in _make_graphed_callables when _order is given.
  • Fix warmup argument mismatch bug in _make_graphed_callables when _order is given.
  • Fix fp8 accuracy issue by adding fp8_group argument to make_graphed_callables() and modifing is_first_microbatch, skip_fp8_weight_update and fp8_meta code.
  • Support MoE models cudagraph by filtering graphed TE modules and model weights during warmup.

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

buptzyb and others added 4 commits October 9, 2024 00:48
Signed-off-by: Robin Zhang <[email protected]>
Co-authored-by: Yifei Song <[email protected]>
Signed-off-by: Robin Zhang <[email protected]>
Co-authored-by: Yifei Song <[email protected]>
Copy link
Collaborator

@timmoon10 timmoon10 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Technically this seems mostly reasonable, although I have questions and stylistic suggestions. Have you tested that it works with Mcore?

@ptrendx @ksivaman @sbhavani What is our priority for this feature? The custom Mcore logic in make_graphed_callables is already messy and fragile, and this PR does exacerbate those problems.

transformer_engine/pytorch/module/layernorm_linear.py Outdated Show resolved Hide resolved
Comment on lines +176 to +183
for m_chunk in range(num_model_chunks):
for _ in range(num_microbatches):
for l_no in range(num_layers):
per_callable_module_params.append(
tuple(callables[m_chunk * num_layers + l_no].parameters())
if isinstance(callables[m_chunk * num_layers + l_no], torch.nn.Module)
else ()
)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This change seems correct to me, but it's odd if the Mcore integration was working before. @ksivaman Have we run this with Mcore, or did we run with num_microbatches=1?

This changes the interpretation of per_callable_module_params from (num_chunks, layers_per_chunk, num_microbatches) to (num_chunks, num_microbatches, layers_per_chunk). This matches the interpretation of per_callable_* lists when capturing graphs:

per_callable_fwd_idx = (m_chunk * num_microbatches * num_layers) + (
fwd_idx[m_chunk] * num_layers + l_no
)

transformer_engine/pytorch/graph.py Outdated Show resolved Hide resolved
Comment on lines +282 to +284
for module in func.modules():
if hasattr(module, "is_first_microbatch"):
module.is_first_microbatch = True
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

TE modules don't set or read the is_first_microbatch attr. It's a kwarg in the forward function. Also, this assumes callables contains torch.nn.Modules.

Suggested change
for module in func.modules():
if hasattr(module, "is_first_microbatch"):
module.is_first_microbatch = True

Copy link

@yifeis-nv yifeis-nv Oct 11, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for the reminder. I also believe that modifications are needed here. The change we originally intended to address is the issue where the fp8 weight caching behavior in MoE leads to different behaviors for the first microbatch compared to other microbatches. If we do not include this piece of code, the warmup process will update is_first_microbatch to False, causing all captured graphs to exhibit non-first microbatch behavior, which does not align with our requirements. Therefore, we chose to reset this parameter after warmup.

In summary, our requirement is either to prevent is_first_microbatch from being updated during warmup or to reset is_first_microbatch after warmup. Choosing the former may require adding a flag to inform MoE that it is currently in the warmup phase, while choosing the latter might necessitate making this code a bit more general. Do you have any input on the modification plan that could serve as a reference for us?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see, so this is Mcore-specific logic. It's uncomfortable that it's made its way into TE, but it's a tricky problem and I can't think of a better solution either, at least without significant changes in Mcore.

We should document what this is doing, especially for TE developers with no knowledge of Mcore.

transformer_engine/pytorch/graph.py Outdated Show resolved Hide resolved
Comment on lines +545 to +551
if (
not fp8_recipe.fp8_mha
and not fp8_recipe.fp8_dpa
and hasattr(m, "attention_dropout")
and m.deterministic
):
continue
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why are we skipping the FP8 scale update logic for this case?

Suggested change
if (
not fp8_recipe.fp8_mha
and not fp8_recipe.fp8_dpa
and hasattr(m, "attention_dropout")
and m.deterministic
):
continue

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is for a deterministic test with FP8. Even without CUDA graphs, we find that FP8_DPA leads to random output, even when set to deterministic mode. With CUDA graphs, we not only need to set DPA to BF16, but also skip the FP8_meta update for DPA. After that, we will finally obtain a stable output under the deterministic test.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It seems like we're covering up a correctness bug. This may also affect convergence since we are no longer doing amax reductions within MHA (e.g. for the FP8 cast before qkv and proj).

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It makes sense, and I think we should 'continue' without fp8_mha and fp8_dpa, regardless of whether the test is deterministic or not.

Copy link
Collaborator

@timmoon10 timmoon10 Oct 16, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually, this change will cause correctness issues. We do a max all-reduce on the amaxes so that FP8 scaling factors are synchronized over the TP group. If we skip the amax reduction, then the MHA's TP communication will be wrong. In particular, we'll all-gather FP8 inputs to the qkv GEMM, but the scaling factors will be different for each TP rank. In pseudocode:

def mha_qkv(x_local, w_local, x_fp8_scale):
    x_local_fp8, x_amax = cast_to_fp8(x, x_fp8_scale)
    x_fp8 = all_gather(x_local_fp8)
    y_local = gemm(x_fp8, w_local, x_fp8_scale)
    max_all_reduce(x_amax)  # Without this, x_fp8_scale is different between ranks
    update_fp8_scale(x_fp8_scale, x_amax)
    return y_local

Turns out this isn't relevant, since MultiheadAttention is not a TransformerEngineBaseModule.

Copy link
Collaborator

@timmoon10 timmoon10 Oct 16, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As far as I can tell, this logic is specific to DotProductAttention. We could make intent much more obvious with:

Suggested change
if (
not fp8_recipe.fp8_mha
and not fp8_recipe.fp8_dpa
and hasattr(m, "attention_dropout")
and m.deterministic
):
continue
if (
isinstance(m, DotProductAttention)
and not fp8_recipe.fp8_mha
and not fp8_recipe.fp8_dpa
):
# Don't need to update FP8 meta for non-FP8 DPA
continue

@timmoon10
Copy link
Collaborator

/te-ci pytorch

buptzyb and others added 5 commits October 10, 2024 20:07
Signed-off-by: Robin Zhang <[email protected]>
Signed-off-by: Robin Zhang <[email protected]>
Co-authored-by: Yifei Song <[email protected]>
Signed-off-by: Robin Zhang <[email protected]>
@yaox12
Copy link
Collaborator

yaox12 commented Oct 11, 2024

/te-ci pytorch

@buptzyb
Copy link
Author

buptzyb commented Oct 11, 2024

Have you tested that it works with Mcore?

Yes, we also made some changes in Mcore, together with TE changes in this PR, to enable MoE cudagraph. You can refer to issue 193 in our Megatron-LM repo.

Signed-off-by: Xin Yao <[email protected]>
@yaox12
Copy link
Collaborator

yaox12 commented Oct 11, 2024

/te-ci pytorch

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants