-
Notifications
You must be signed in to change notification settings - Fork 27.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[RFC] adding Tensor and Pipeline Parallelism to transformers #13690
Comments
@stas00 - I like this a lot! And as we're been dragging our feet with implementing some of the Megatron 3D parallelism into I think my (and the Mistral team's) addition the next few weeks will be trying to do some benchmarking of Megatron and existing gains with various subsets of parallelism (at a very fundamental level - profiling which kernels are being called, etc.) and maybe creating a set of unit tests to verify correctness? Separately - might be worth keeping logs of how to "3D-ify" new models, and ways we might make that procedure even easier moving forward. Let me know if this makes sense! |
Thanks for the feedback, Sidd. The reason @hyunwoongko thought of starting with GPTNeo was because GPT2 already has the naive PP Note that the intention is to do simple things first and not do too many things at once. So I think starting with GPTNeo on a clean slate is a better idea. Once it's happy it'd be trivial to replicate that to GPT2. And it's already done as you can see from the link in OP. Here is my vision of 3Difying step 1. implement TP in one model note how step 2 can be done in parallel by different people. So I can see that Mistral's team efforts would be parallel work and not sequential. So for example: step 3b. implement Mistral's GPT2 improvements to GPT2 If were were to start with GPT2 we would interfere with your work, Sidd, so I think it's actually best if we pick 2 different starting models. But let's stay focused in this discussion on TP+PP, otherwise it'd be too easy to get side-tracked. We already spent too much time talking - let's see some code going into wrt trainers, it'll be a natural part of the work - I'm not worried too much about it. I don't know much about accelerate yet, but HF Trainer should be relatively easy. |
This makes a lot of sense to me - thanks @stas00 and @hyunwoongko for the clarifications! The steps above form a pretty good concrete plan - but if you both are already planning on tackling it, maybe it makes sense for us to tackle some of the other Megatron-LM improvements first, like the custom loss scaling/kernels/etc. (in mistral, so we can break things 😅)? And as y'all build the "main API" for 3D parallelism, we can just drop that in, and train larger models! The PR with the mistral's first set GPT-2 improvements is waiting on approval right now - once that's in we can move a bit faster as well. |
That sounds like a perfect plan to me, Sidd. |
@stas00 I think the following method is not good for megatron-friendly method.
Ultimately, implementing PP requires rewriting all modeling code. (including |
The transformers-friendly method (=parallelformers) has the advantage of being able to extend the model quickly because it does not need to rewrite the modeling code (it uses the existing transformers code), but it is not compatible with PP. So we have to remove all the transformers-friendly TP when implementing PP. Which strategy we take is a matter of choice. We can quickly expand them in a transformers friendly way, and then change them one by one to be megatron friendly like
Or there is a way to not implement transformers-friendly methods because they will be removed anyway. But, since there are thousands of lines of code to write for megatron-friendly and tens of lines of code for transformers-friendly, the megatron-friendly approach will scale very slowly.
One thing to note is that the transformers-friendly TP implementation is completely eliminated when implementing the megatron-friendly TP. A megatron-friendly TP is implemented differently from a transformers-friendly TP. |
Adding a GPTNeo3D to experiment seems like a good idea to me. At the end of the day, that modeling file can leave in the same folder as Note that while you experiment, you can leverage #13467 to share models on the Hub that have no implementation in Transformers and still work with the auto-model API. |
Great!
The 3D GPTNeo model's weights are the same as a normal GPTNeo model's - i.e. it can be used w/ or w/ PP/TP, so I'm not sure why we need a special API? And I guess we won't be able to use |
@hyunwoongko, you're bringing up excellent points. I suppose the main question is how much of a benefit we can give to users by having just TP. My thinking is that if it's easy to add TP to all models and since you have already done this, let's do it. I'm concerned that adding PP will be a very slow process because as you said it requires massive rewrites to the model's code, and meanwhile those models that are waiting their turn won't be very scalable (except with Deepspeed ZeRO). Besides we can delegate the TP adding to the rest of the models to others (other developers and even community) since it's mostly just replaying the code you have already written. But it still requires work, at least in adding tests and documentation, and then PRs. The only concern with adding the transformers-friendly way is that the external API remains the same when we add PP. How does that sound? |
@stas00 But anyway, I don't prefer PP. As you know, PP is memory inefficient because it is not compatible with ZeRO 2, 3. In fact, we also decided not to use PP when developing language models. So adding just TP would be helpful for many people. So let's go with the following strategy. but, as you said, the API for both methods should remain the same.
But transformers-friendly TPs have no reason to rewrite their modeling code. What should we do? |
That's great, @hyunwoongko! And once we complete |
In my opinion, transformers-friendly TP have no reason to write their own modeling code like GPTNeo3D.
|
I'm thinking of an API like this. from transformers import GPTNeoModel
model = GPTNeoModel.from_pretrained("elutherai/gpt-neo-1.3B", tensor_model_parallel_size=4)
or
model = GPTNeoModel.from_pretrained("elutherai/gpt-neo-1.3B", tp=4) I implemented megatron friendly model internally like @classmethod
def from_yaml(
cls,
cfg_path: str,
tensor_model_parallel_size: int = 1,
pipeline_model_parallel_size: int = 1,
tp: int = None,
pp: int = None,
):
"""
Create model from yaml config file
Args:
cfg_path: path of configurations
tensor_model_parallel_size: tensor model parallel world size
pipeline_model_parallel_size: pipeline model parallel world size
tp (int): equivalent with `tensor_model_parallel_size`
pp (int): equivalent with `pipeline_model_parallel_size`
"""
if tp is not None:
assert tensor_model_parallel_size == 1, (
"you can't use param `tensor_model_parallel_size` and `tp` at the same time. "
"they are equivalent. so please use one of them."
)
tensor_model_parallel_size = tp
if pp is not None:
assert pipeline_model_parallel_size == 1, (
"you can't use param `pipeline_model_parallel_size` and `pp` at the same time. "
"they are equivalent. so please use one of them."
)
pipeline_model_parallel_size = pp |
I totally agree, that this is a much better way to proceed. @sgugger, is it ok if we change the initial proposal and add TP to the normal model classes? As we continued discussing this and based on my experience with trying to add PP to transformers it'll be a huge amount of work to do it for all models, and so it's very likely many models will never get it. And since TP requires no changes to the models then there is no reason to make it difficult on users and maintainers to fork the model for that feature to work. And we believe just having TP+DP will already be a great boon to the scalability of the models (if Deepspeed ZeRO doesn't already address this for whatever reason). For PP new classes will be needed 100%. Thank you. |
As long as the changes are minimal, no objection from my side. I agree it makes much more sense to get that out if it's faster and deliver the PP later on. |
the problem is the 'parallelize()' method, the API for layerwise naive parallelism in GPT2 and T5. Do you agree to remove this method? The megatron-friendly TP + PP cannot handle it that way. This is because in the case of PP, parallelization occurs at the time of model creation. That's why I let |
I think
The naive PP is experimental: transformers/src/transformers/models/gpt2/modeling_gpt2.py Lines 527 to 529 in 50c746e
but we shouldn't remove it until we replace it with real PP, because users actively use the naive PP at the moment. That's why we proposed to work on NeoGPT first so that it's easier to take time and not need to have that older code interfere. |
So I made it support both variables (long name and short name). not good?
I totally agree with you. Let's start from GPTNeo. The second thing to discuss is the embedding layer. When I implemented parallelformers, I didn't actually parallelize the embedding layer. In this case, the embedding layer is copied to all GPUs. Therefore, it is memory inefficient. But in fact we can apply I didn't tell you guys, but I actually experimented little by little. I already figured out that I can do |
At the moment I don't recall
Was CrossEntropy the reason for not doing it in the first place in parallelformers? I guess the integration will allow to overcome this then if I understood your comment correctly. But otherwise by all means let's make TP as efficient as possible. |
How about implementing it with options first?
|
Ah, ok, we can use
Ah, right, I forgot that |
Is there a technical reason for not always doing the latter? |
Because of So I thought the default value of embedding_parallelism as false and turning it on when the user wants to. |
Thank you for the explanation, Hyunwooongko. Then yes we need that arg. Should the default be Further, |
Oh I was wrong. Only tying the output embeddings have problems with the loss function. I checked and it doesn't matter since neither gpt2 nor gpt neo are tying output embeddings. In most cases, we don't need to worry about the loss function. Therefore, I will implement embedding parallelism works everytime so this option is unnecessary. and users do not need to worry about it. If I find a model that tying input and output embeddings without an lm head later, I will think about it then. |
But maybe Meg-DS and GPT NeoX use embedsing tying. So this option will be needed in the future. |
If I'm not mistaken many models have input and output embeddings tied. |
Hi all, I helped implement pipeline parallelism in Megatron (and was also one of the lead authors on the PipeDream project). Happy to answer any questions. I had a question too: what is the current plan for the new PP-friendly model classes? What is going into these, and how will they be different from the vanilla model classes? Thanks! |
Hi @hyunwoongko. Thank you for your reply. I got your idea about linear or conv1d layer in the Attention module or MLP module and I think it is great. |
@lucasleesw I totally agree with you. Defining both columns and rows is probably the safest and most extensible way. |
@stas00 I initially used the method of defining both column and row parallel parameters, but since the process of defining them is quite difficult, I experimented with many ways to create a simple tensor parallelization map. But the simplification got the more possibility that can makes exceptions. So, like @lucasleesw's method, it would be best to use all three pieces of information: column, row, and mp_param. We all parallelize in a similar way, and so does sagemaker too. Therefore, it would be convenient if we unify and manage this inside the transformers. |
@lucasleesw Omitting column parallel linear and tracining method won't cause any problems in vit. Because embedding is not in the module list. https://github.com/huggingface/transformers/blob/master/src/transformers/models/vit/modeling_vit.py#L146 They parallelize only the layers inside the base layer module (like BertLayer), not all existing layers. Even so, these simplifications can always make exceptions. |
@lucasleesw I'm also wondering about your pp implementation. could you let me know? I used deepspeed pp in the beginning, but now we are implementing the same method with sagemaker pp. |
@hyunwoongko You are right, thanks again for your inspiration. |
@lucasleesw I will upload a PR for tensor parallel mapping today. It would be great if you could reply to make a more general PR. How did you deal with the fused attention module (in the gpt2, transfo_xl)? it means attention layer that has the size like |
So we have at least 3 possible "consumers" of additional model metadata at the moment: @hyunwoongko, @lucasleesw and @RezaYazdaniAminabadi - so perhaps instead of maintaining 3 different tables, would you agree on having one that contains all the fields that you need? and you can discuss between yourselves how you prefer to call those. We can give the new names a period of "experimental and a subject to change" until the dust settles and then they will get carved in stone at a later date to support backward compatibility. And I'm sure there will be other consumers for that type of metadata. I don't see any reason not to have all the desired components written out explicitly, instead of being derived automatically. There is absolutely no reason to take a risque here, this is software engineering and not a stock market. I propose to start with having a dedicated file for it with the first few models and then down the road we can see if it makes sense to move these into their model files. I just want to create minimal disturbance to the models code until we are ready to do so. |
I opened a PR about tensor parallel mappings ! |
@stas00 For TP, it will go with the megatron-lm's way to do the parallelism in a python way if I understood correctly. If that's the case, it leaves us the opportunity to support the tpu and etc since the only question is about the allgather&allreduce API. |
At the moment we have 2 projects that support TP (tensor parallelism):
Both are not yet integrated into transformers. Oslo we are just slow to integrate since I'm busy with BigScience and @jaketae is backing me up and has started to work on the integration. Deepspeed-Inference is still a work in progress on the core, and I have some initial PR that integrates it but there are some hanging issues as HF Trainer is not MPU-aware yet. So at the moment Deepspeed-ZeRO is the only solid and working solution for scalability on the free side, and Sagemaker on the paid side (though I have never tried the latter myself). PP is much more difficult, and we are leaving it to the end, in hope that pytorch will provide us a new much easier to use PP-api that is somewhat similar to sagemaker's paper https://arxiv.org/abs/2111.05972 |
|
I was just saying that it doesn't have MPU at the moment ;) And it's needed to sync the ds-inference tp-processes I think. But my mind is the BigScience at the moment so I don't have the brain cycles for indepth analysis at the moment. Making ds-inference integration depend on oslo would be odd, but it could be ok at the beginning and eventually have an internal one - it's just one module that's already written. why integrate ds-inference w/o Trainer? Trainer is just a name for both inference and training. |
Any progress now? I think to integrate the TP and PP with transformers and deepspeed or megatron, offering a easier access to users today will be a great contribute. |
Given the complexity of TP/PP and which requires massive changes to the modeling code and having hundreds of architectures it's hard to tell if this will ever happen in transformers. You already have Deepspeed integration in transformers, so it's trivial to scale pretty much any model in transformers' arsenal to any number of GPUs. And you should get about the same throughput with DeepSpeed ZeRO-3 as you'd with TP/PP/DP as long as you are on a fast internode connection. |
I'm actually going to close this, since it's too old and clearly isn't happening. |
Hey @stas00 |
That's an excellent question, @marianna13 First, Megatron-LM does more than just 3D parallelism - it has a superb set of various other features, so we can't compare projects just on the base of how well they solve one specific problem. The main cons of 3D parallelism is that it requires the modeling code to be modified, which is far from trivial and it's very easy to introduce bugs and unfortunately we have seen some of those bugs being invisible until it's too late to fix them. The main pros of ZeRO (Pytorch FSDP or Deepspeed ZeRO-DP) is that modeling code and scalability code are separate and the user only needs to get the modeling code right for a successful training. You just write the code as if you were to train on a single gpu and ZeRO can then scale it to any number of gpus w/o you needing to do anything about it. Now specifically to the claim "you should get about the same throughput with DeepSpeed ZeRO-3 as you'd with TP/PP/DP as long as you are on a fast internode connection" - I personally haven't seen that yet because I'm yet to be given a chance to run on a cluster where one gets high inter-node speed. The fastest I have used so far was 340Gbps w/ A100 which is very slow. Given that ZeRO implementations prefetch the sharded data, this comms overhead is overlapped with compute of the previous stage and if the inter-node speed is fast enough it'll be mostly hidden and not contributing to an additional overhead. I walked through the math here: Additionally, recently pytorch and Deepspeed released a hybrid version of ZeRO (Hybrid FSDP and ZeRO++) which if you can fit the sharded model on a single node it will use the super-fast speed of NVLink for all comms and only do grads reduction DDP-stype over the slower inter-node if multiple nodes are used. (except it uses a faster version of comms than DDP since each gpu needs only its gradient shard - so it's 1/2 of the DDP traffic) This should lead to much higher TFLOPS, for models up to a certain size. e.g. on a 8x80GB node can at most fit a 30B-param model for mixed half-precision training. For larger models this won't help and the slower inter-node link becomes the defining speed for all comms :( And even if the inter-node speed is slow and one gets less TFLOPS using ZeRO the important consideration to make is whether your team will finish the training sooner or later using 3D parallelism because it'll certainly take longer development/testing time than the ZeRO equivalent. Because if you just need to train something that Megatron-LM has implemented fully you should have close to 0 dev overhead. But if you need to introduce changes, you might spend weeks and months of dev time depending on the complexity of the changes and the skills of your engineers. So it's possible that the slower ZeRO training will still deliver a faster outcome and less hair will be lost in the process. Also did you know Megatron-LM implemented ZeRO-1 and ZeRO-2 as well? And there is https://github.com/microsoft/Megatron-DeepSpeed/ - in other words, the awesome developers who give us these incredibly useful tools work on combining the best features of both approaches. |
Thank you very much for this detailed answer, @stas00 ! |
Following up on this proposal #12772 I just had a discussion with @hyunwoongko (with great help from @jaketae who patiently translated for us), and we tried to discuss a strategy of how to best integrate Tensor Parallelism (TP) and Pipeline Parallelism (PP) into
transformers
, making it easy for reviewers and the contributors. Note thatparallelformers currently implements only TP.
So here is a great example of how the TP can be added, as @hyunwoongko already implemented it in his fork for
GPTNeo
tunib-ai@5bf8655 (he didn't use
GPT2
since it already has the naive PP implemented). So you can see exactly what we want to merge. It's a very thin layer to the model and most of the functionality is in the helper parallel utils. The end of the change is multiple tests/examples that need to be converted to our test framework.Now, while adding TP is relatively easy, adding PP is very complex in the current state of HF models because they include many features that interfere with implementing PP - due to the requirements:
nn.Sequential
andSo to implement PP we will most likely have to fork each model, strip the unnecessary for scalability features and only then be able to implement PP.
So my thinking is that perhaps we do it from the get the going? Instead of integrating TP into the normal model - say
GPTNeo
, we fork it to sayGTPNeo3D
from the get going and do all the work including TP and PP on that new model. Once everybody is happy we can rinse and repeat for other models.I added 3D to
GPTNeo
to makeGTPNeo3D
- 3D = DP/TP/PP - not exactly sure about this particular name or attached to it, this is just something to start with.Also once TP is implemented in say
GTPNeo3D
we can start replicating it to other models. Because parallelformers has them all covered already. PP will be much harder and we can do this in parallel.I wanted to check in with the team to see if this approach resonates better, rather than modifying the existing models.
Thank you!
Also see this blog post explaining parallelforms.
Additionally see the main pytorch Parallelism discussion at pytorch/rfcs#32
@LysandreJik, @sgugger, @patrickvonplaten
The text was updated successfully, but these errors were encountered: