Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Questions about implementing model parallelism in the inference engine #1161

Closed
hyunwoongko opened this issue Jun 15, 2021 · 30 comments
Closed

Comments

@hyunwoongko
Copy link
Contributor

hyunwoongko commented Jun 15, 2021

Hello. I would like to ask about the model parallelism feature in the inference engine. In general, the model parallelism that I can think of is inter-layer model parallelism like GPipe (only partitioning part, not pipelining) and intra-layer model parallelism like Megatron-LM.

dist.broadcast(p, 0)

But, in the case of the current implementation of inference engine, it seems to broadcast all parameters from device 0 if mpu is not input. But, as far as I know, these operations don't seem parallelizing way that models are sliced.

dist.broadcast(input, 0)

In this part of the code, it seems that input is also broadcasting to all devices. So, I wonder why deepspeed is broadcasting in all parts when mpu isn't inputted.

@hyunwoongko
Copy link
Contributor Author

hyunwoongko commented Jun 15, 2021

  • --num_gpus=2 and mp_size=2
    image

  • --num_gpus=1 and mp_size=1
    image

@RezaYazdaniAminabadi
Copy link
Contributor

RezaYazdaniAminabadi commented Jun 15, 2021

Hi @hyunwoongko

The tensor slicing for Inference is not happening on the engine class, whereas it is done in the replace_module utility and in ReplaceWithTensorSlicing class. The model parallelism is based on Megatron tensor slicing. This model parallelism works for various model architectures and not just GPT-Neo.

Thanks,
Reza

@RezaYazdaniAminabadi
Copy link
Contributor

For reducing the results across different GPUs, we use the all_reduce in the inference-api. The model-parallel group is also created in the inference-engine and passed to the inference-module.

@hyunwoongko
Copy link
Contributor Author

Thanks for your kind reply. :)

My question is, why doesn't the memory of each GPU area decrease after slicing the model in two? For example, if we slice a 30GB model into two, the amount of memory allocated to at least one device will have to be less than 30GB, even if it is not completely reduced to 15GB.

Thanks,
Hyunwoong Ko

@RezaYazdaniAminabadi
Copy link
Contributor

The problem is that the model gets created on both GPUs by HunggingFace initially and that takes the model's total memory which is in your case 30GB. However, in deepspeed we partition the parameters and send the corresponding parts to each GPU. So, some of the initially allocated memory is just cached on the GPU and never used. I have not spend time to release that memory in HuggingFace way of initializing a model. By the way, I think you can still use those allocated memory, the issue is that nvidia-smi is not showing the amount of free memory in a precise way.

@RezaYazdaniAminabadi
Copy link
Contributor

I have used the torch's memory management (https://pytorch.org/docs/stable/cuda.html#memory-management) and it shows the memory reduction when using model-parallel 2:

Before deepspeed initialize: 14042529792 bytes allocated
After deepspeed initialize: 7719548928 bytes allocated

I think nvidia-smi is just showing all the cached memory!

Could you please print {torch.cuda.memory_allocated()}, {torch.cuda.memory_cached()} before and after deepspeed inference intialize on your side to verify it?
Thanks,
Reza

@hyunwoongko
Copy link
Contributor Author

hyunwoongko commented Jun 16, 2021

When I call torch.cuda.empty_cache(), allocated memory is reduced as your words.

image

Thanks !!
Hyunwoong Ko

@hyunwoongko
Copy link
Contributor Author

hyunwoongko commented Jun 16, 2021

Thanks for the very kind reply. Can't I just load the model from the CPU? If I deploy a large model like Blender (9B+) or T5 (10B+), if the hugging face model is loaded on the GPU first, memory allocation will fail.

Thanks
Hyunwoong Ko

@hyunwoongko
Copy link
Contributor Author

hyunwoongko commented Jun 16, 2021

I successed CPU to GPU parallel. Thanks !

@andreamad8
Copy link

@hyunwoongko any insight on how did you do it?

@hyunwoongko
Copy link
Contributor Author

I developed totally new tools for model parallelism.
My tool can parallelize all the huggingface models. I will open this tool asap.

@andreamad8
Copy link

ok, thanks!

@RezaYazdaniAminabadi
Copy link
Contributor

@hyunwoongko, thanks for pushing to solve all these issues. Please let me know when you are finished with opening this as I am also eager to see your approach.

Best,
Reza

@hyunwoongko
Copy link
Contributor Author

hyunwoongko commented Jul 18, 2021

@RezaYazdaniAminabadi

https://github.com/tunib-ai/parallelformers

We present parallelformers, a novel framework for model parallelization 🎉

@hyunwoongko hyunwoongko reopened this Jul 18, 2021
@RezaYazdaniAminabadi
Copy link
Contributor

@hyunwoongko

Thanks for sharing this. I will look into it and let you know if I have questions.

@RezaYazdaniAminabadi
Copy link
Contributor

Hi @hyunwoongko,

Thanks for sharing the great work you did on parallelizing the transformers models. I think we can use a lot of this in deepspeed to parallelize the different models.
I went through your description here. I agree with the limitation that our approach has, mostly due to using the inference-optimized kernels. However, I think on the parallelism part I see you take more or less the same approach as us.
I see that you are mainly concerned about the way that we move the model to the different devices (in the inference engine). I think solving this is pretty easy by moving this part after module_inject function which takes care of the model-parallel partitioning.
However, I still see the issue of creating the model on CPU through the HuggingFace Transformers pipeline. That is due to the fact the some of the tensors on the pipeline side (which are not part of the model) will be on CPU whereas the computation of the model is happening on the device side and generates tensors on different devices.
Please note that using the pipeline way of doing inference is giving a pretty easy of way testing different models that users are normally eager to use.
I have made a PR to fix this issue on the deepspeed-inference engine. So, we can now create the model on the CPU side, partition it based on the number of GPUs, and then move it to the devices. As you can see the changes are pretty minimal to resolve this limitation!

I would really appreciate if we can work together to bring the parallel implementation design on your side to deepspeed and we can merge it with the high-performance kernels for the different models.

Thanks,
Reza

@RezaYazdaniAminabadi
Copy link
Contributor

By the way, you will realize that with this PR the memory of the GPUs won't increase (even the cached one) the same as before.

@hyunwoongko
Copy link
Contributor Author

Thanks for the positive reply. @stas00 and I were discussing the integration of parallelformers and transformers here. And maybe we're thinking of this as part of the DeepSpeed and Transformers integration, and we want to work on this with the DeepSpeed team. For example, I and you implement and commit in DeepSpeed, and I and Stas utilize this in Huggingface Transformers.

Also, as you said, I'm thinking deeply about how to use fused kernel with the mechanism I'm currently implementing on my side. If this works, we can take the speed of fused kernel and the scalability of my implementation at the same time. Anyway, we need to discuss how to collaborate. How would you like to work?

Thanks,
Hyunwoong Ko

@hyunwoongko
Copy link
Contributor Author

In addition, I also want to implement a training feature, what do you think of this? I think many people want to training Transformers models with the Tensor MP method. I hope ultimately all the models in Transformers support 3D parallelization through ZeRO + Pipeline with Tensor MP.

@hyunwoongko
Copy link
Contributor Author

hyunwoongko commented Jul 23, 2021

@RezaYazdaniAminabadi
I'm looking forward to your answer :)
Are you interested in this collaborative process?

@hyunwoongko
Copy link
Contributor Author

hyunwoongko commented Jul 23, 2021

I'm closing this issue because it's too old and I'll discuss it in a new issue. (#1248)

@RezaYazdaniAminabadi
Copy link
Contributor

Hi @hyunwoongko

Thanks for the great discussion. I am certainly interested. Let me also discuss this internally, and we can go on with the collaboration soon :)

Thanks,
Reza

@hyunwoongko
Copy link
Contributor Author

hyunwoongko commented Jul 27, 2021

@RezaYazdaniAminabadi
Let me know when the discussion is over. I'll start working on it right away :)

@RezaYazdaniAminabadi
Copy link
Contributor

Hi @hyunwoongko

Thanks for checking in :)
I wonder if you will be free for a quick offline chat to go through this?

Thanks,
Reza

@hyunwoongko
Copy link
Contributor Author

hyunwoongko commented Jul 28, 2021

Sure! Why don't we arrange the meeting via email?
My one is [email protected].

@RezaYazdaniAminabadi
Copy link
Contributor

Just sent an invite through email. Thanks

@hyunwoongko
Copy link
Contributor Author

Could you summarize what you said at the meeting?

@hyunwoongko
Copy link
Contributor Author

hyunwoongko commented Jul 30, 2021

And since this issue is closed, how about discussing it in new issue?

@RezaYazdaniAminabadi
Copy link
Contributor

Yes, better to open another issue, I will send the summary in email.

Thanks,
Reza

@hyunwoongko
Copy link
Contributor Author

#1248

I already opened new issue !

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants