Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

T5 Tensor Parallelism #2

Closed

Conversation

dbogunowicz
Copy link
Collaborator

@dbogunowicz dbogunowicz commented Mar 22, 2024

Feature Description

The original PR was incorrectly implementing tensor parallelism for the T5 model. This PR adds the missing feature.

Details

The goal is to ensure that the "model parallel" wrapper classes (ColumnParallelLinear, RowParallelLinear) around typical transformers operation work correctly, as described in the original paper: https://arxiv.org/pdf/1909.08053.pdf.

This PR makes sure that we are:

  • parallelizing the atomic operations of the transformer architecture (MLP and Attention blocks are distributed across GPUs)
  • correctly initializing ColumnParallelLinear and RowParallelLinear modules in the Attention module (the modules take care for us of calculating the correct sharded tensor dimensions - this PR takes advantage of it)
  • swapping the token embedding layer module from nn.Embedding (non-parallelizable) to VocabParallelEmbedding. This allows for the parallelization of the embedding GEMM.
  • correctly loaded the weights for the layers that need to be sharded across multiple GPUs. This means either:
    a) taking an original checkpoint weight matrix, splitting it into num_shards parts, and then sending it off to the appropriate devices. E.g. assuming four shards and column-wise sharding, if the matrix is [N, M], we create four [N/4, M] shards and send them to four model copies.
    b) for the relative positional embeddings, if sharding is enabled, we need to only load the small chunk of the original matrix, such that its dimensions match the dimensions of the sharded hidden dimension.

Testing

Made sure that the output of the model is correct not only for an unsharded model but also for the models distributed across two and four shards (see examples/offline_inference_enc_dec.py)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

1 participant