Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP][Need help and discussion] : basic llama tensor parallel #32597

Closed
wants to merge 0 commits into from

Conversation

SeungyounShin
Copy link

@SeungyounShin SeungyounShin commented Aug 11, 2024

What does this PR do?

This PR addresses an issue encountered when running the following command:

CUDA_VISIBLE_DEVICES=0,1 python3 examples/pytorch/tensor-parallel/run_tp_llama.py

The current implementation results in the following error:

    query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
  File "./transformers/src/transformers/models/llama/modeling_llama.py", line 276, in apply_rotary_pos_emb
    q_embed = (q * cos) + (rotate_half(q) * sin)
RuntimeError: The size of tensor a (512) must match the size of tensor b (256) at non-singleton dimension 2

Problem Description and Discussion Points

  1. Sequence Length:
    It appears that the Tensor Parallel approach requires the sequence length to be evenly divisible, which is not currently handled in the existing implementation (though this doesn't apply in inference mode).

  2. Potential Solution - Accelerate:
    Given the benefits of Tensor Parallel in training, especially when compared to other Distributed Data Parallel methods like DeepSpeed and FSDP, I'm considering submitting a PR to the accelerate library. However, it’s important to note that the current structure of the transformers model may need to be adjusted to fully realize these benefits.

  3. Root Cause of the Error:
    The error seems to stem from the fact that positional embeddings are applied immediately after token embeddings. This results in an incompatibility with Tensor Parallel, causing the misalignment seen in the error.

  4. Request for Assistance:
    Addressing this issue might require significant changes to the codebase. As such, I would greatly appreciate any feedback, guidance, or assistance in this matter.
    As a recent graduate, I've observed that many are now using 2-4 nodes with 8-way GPUs. In these setups, Data Parallel (DP) methods like DeepSpeed and FSDP often suffer from high ring latency(many limited to memory constraint of the device). I believe Tensor Parallel, coupled with DP across nodes, could become a dominant approach in the near future. I’m eager to discuss and collaborate on making this approach compatible with the Transformers library or a similar framework.

Before submitting

  • This PR is not just a typo fix or documentation improvement.
  • I have read the contributor guidelines.
  • This has been discussed and/or approved via a GitHub issue or the Hugging Face forum.
  • The documentation has been updated with any necessary changes. Here are the documentation guidelines.
  • I have written any new necessary tests.

Thank you in advance for your time and consideration. I look forward to any suggestions or feedback.

cc. @amyeroberts

@@ -732,7 +731,7 @@ def forward(

# Self Attention
hidden_states, self_attn_weights, present_key_value = self.self_attn(
hidden_states=hidden_states,
hidden_states,
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

pytorch tensor parallel can't recognize keyword args.

@winglian
Copy link
Contributor

I think you need to scale the Rotary Embeddings by the tp_mesh.size(). Since you have a tp size of 2, you're seeing the scale off by that factor, since the self_attn.rotary_emb has no parallelize plan, it's not accounting for the changes.

@winglian
Copy link
Contributor

if you look at torch titan, they do some reshaping for broadcast for rope https://github.com/pytorch/torchtitan/blob/main/torchtitan/models/llama/model.py#L112

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants