Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Feature Description
The original PR was incorrectly implementing tensor parallelism for the T5 model. This PR adds the missing feature.
Details
The goal is to ensure that the "model parallel" wrapper classes (
ColumnParallelLinear
,RowParallelLinear
) around typical transformers operation work correctly, as described in the original paper: https://arxiv.org/pdf/1909.08053.pdf.This PR makes sure that we are:
MLP
andAttention
blocks are distributed across GPUs)ColumnParallelLinear
andRowParallelLinear
modules in theAttention
module (the modules take care for us of calculating the correct sharded tensor dimensions - this PR takes advantage of it)nn.Embedding
(non-parallelizable) toVocabParallelEmbedding
. This allows for the parallelization of the embedding GEMM.a) taking an original checkpoint weight matrix, splitting it into
num_shards
parts, and then sending it off to the appropriate devices. E.g. assuming four shards and column-wise sharding, if the matrix is[N, M
], we create four[N/4, M
] shards and send them to four model copies.b) for the relative positional embeddings, if sharding is enabled, we need to only load the small chunk of the original matrix, such that its dimensions match the dimensions of the sharded hidden dimension.
Testing
Made sure that the output of the model is correct not only for an unsharded model but also for the models distributed across two and four shards (see
examples/offline_inference_enc_dec.py
)