Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[wip] [doc] Parallelism notes #9766

Closed
stas00 opened this issue Jan 24, 2021 · 0 comments · Fixed by #12524
Closed

[wip] [doc] Parallelism notes #9766

stas00 opened this issue Jan 24, 2021 · 0 comments · Fixed by #12524
Assignees
Labels
DeepSpeed Model Parallel Model Parallelilsm Implementations Pipeline Parallel WIP Label your PR/Issue with WIP for some long outstanding Issues/PRs that are work in progress

Comments

@stas00
Copy link
Contributor

stas00 commented Jan 24, 2021

Perhaps this will end up in a blog post and/or a new document, for now collecting notes. This is a work in progress. Please give me some time to write the bulk of it and then you'll be welcome to ask questions, add contributions, etc.


Parallelism overview

In the modern machine learning the various approaches to Parallelism are used to:

  1. fit very large models onto limited hardware - e.g. t5-11b is 45GB in just model params
  2. significantly speed up training - finish training that would take a year in hours

We will first discuss in depth various 1D parallelism techniques and their pros and cons and then look at how they can be combined into 2D and 3D parallelism to enable an even faster training and to support even bigger models.

While the main concepts most likely will apply to any other framework, this article is focused in pytorch-based implementations.

Data Parallel

Most users with just 2 GPUs already enjoy the increased training speed up thanks to DataParallel (DP) and DistributedDataParallel (DDP) that are almost trivial to use.

ZeRO Data Parallel

ZeRO-powered data parallelism (ZeRO-DP) is described on the following diagram from this blog post
DeepSpeed-Image-1

It can be difficult to wrap one's head around it, but in reality the concept is quite simple. This is just the usual DataParallel (DP), except, instead of replicating the full model params, gradients and optimizer states, each GPU stores only a slice of it. And then at run-time when the full layer params are needed just for the given layer, all GPUs synchronize to give each other parts that they miss - this is it.

Consider this simple model with 3 layers, where each layer has 3 params:

La | Lb | Lc
---|----|---
a0 | b0 | c0
a1 | b1 | c1
a2 | b2 | c2

Lx being the layer and we have 3 layers, and ax being the weights - 3 weights

If we have 3 GPUs, the Sharded DDP (= Zero DP) splits the model onto 3 GPUs like so:

GPU0:
La | Lb | Lc
---|----|---
a0 | b0 | c0

GPU1:
La | Lb | Lc
---|----|---
a1 | b1 | c1

GPU2:
La | Lb | Lc
---|----|---
a2 | b2 | c2

In a way this is horizontal slicing, if you imagine the typical DNN diagram. Vertical slicing is where one puts whole layer-groups on different GPUs. But it's just the starting point.

Now each of these GPUs will get the usual mini-batch as it works in DP:

x0 => GPU0
x1 => GPU1
x2 => GPU2

The inputs are unmodified - they think they are going to be processed by the normal model.

So the inputs first hit the first layer La.

Let's focus just on GPU0: x0 needs a0, a1, a2 params to do its forward path, but GPU0 has only a0 - so what it does is it gets sent a1 from GPU1 and a2 from GPU2. Now the forward step can happen.

In parallel GPU1 gets mini-batch x1 and it only has a1, but needs a0 and a2 params, so it gets those from GPU0 and GPU2.

Same happens to GPU2 that gets input x2. It gets a0 and a1 from GPU0 and GPU1.

As soon as the calculation is done, the data that is no longer needed gets dropped - it's only used during the calculation.

The same is repeated at every other stage.

And the whole larger thing is repeated for layer Lb, then Lc forward-wise, and then backward Lc -> Lb -> La.

To me this sounds like an efficient group backpacking weight distribution strategy:

  1. person A carries the tent
  2. person B carries the stove
  3. person C carries the entertainment system

Now each night they all share what they have with others and get from others what the don't have, and in the morning they pack up their allocated type of gear and continue on their way. This is Sharded DDP / Zero DP.

Compare this strategy to the simple one where each person has to carry their own tent, stove and entertainment system, which would be far more inefficient. This is DataParallel in pytorch.

And I think pretty much everywhere I read Sharded == Partitioned, so I think those are synonyms in the context of distributed models.

If you pay close attention the way ZeRO partitions the model's data - it looks very similar to horizontal model parallelism which will be discussed later. This is because it partitions/shards each layer's data unlike vertical model parallelism which is discussed next.

Implementations:

Naive Model Parallel (Vertical) and Pipeline Parallel

Naive Model Parallel (MP) is where one spreads groups of model layers across multiple GPUs. The mechanism is relatively simple - switch the desired layers .to() the desired devices and now whenever the data goes in and out those layers switch the data to the same device as the layer and leave the rest unmodified.

We refer to it as Vertical MP, because if you remember how most models are drawn, we slice the layers vertically. For example, if the following diagram shows an 8-layer model:

===================  ===================
|  0 | 1 | 2 | 3  |  |  4 | 5 | 6 | 7  | 
===================  ===================
        gpu0                 gpu1

we just sliced it in 2 vertically, placing layers 0-3 onto gpu 0 and 4-7 to gpu 1.

Now while data travels from layer 0 to 1, 1 to 2 and 2 to 3 this is just the normal model. But when data needs to pass from layer 3 to layer 4 it needs to travel from gpu0 to gpu1 which introduces a communication overhead. If the participating GPUs are on the same node (e.g. same PC) this copying is pretty fast, but if the other gpus are on different nodes (e.g. another PC) the communication overhead could be significantly larger.

Then layers 4 to 5 to 6 to 7 are as a normal model would have and when the 7th layer completes we often need to send the data back to layer 0 where the labels are (or alternatively send the labels to the the last layer).

Problems:

  • the main deficiency and why this one is called "naive", is that all but one GPU is idle at any given moment. So if 4 gpus are used - it's almost identical to quadrupling the amount of memory of a single GPU, and ignoring the rest of the hardware. Plus there is the overhead of copying the data between devices. So 4x 6GB cards will be able to accommodate the same size as 1x 24GB card using naive MP, except the latter will complete the training faster, since it doesn't have the data copying overhead. But, say, if you have 40GB cards and need to fit a 45GB model you can with 4x 40GB cards (barely because of the scheduler and optimizer data)
  • shared embeddings may need to get copied back and forth between GPUs.

Pipeline Parallel (PP) is almost identical to a naive MP, but it solves the idling problem to a degree, by chunking the incoming batch into micro-batches and artificially creating a pipeline, which allows different GPUs to concurrently participate in the computation process.

The following illustration from the GPipe paper shows first the naive MP, the PP:

mp-pp

It's easy to see how PP has less dead zones where GPUs are idle.

PP introduces a new hyper-parameter to tune and it's chunks which defines how many pipeline stages are to be used. e.g. in the 2nd diagram of the image above you can see that chunks=4.

With chunks=1 you end up with the naive MP. with a very large value you will find that the overhead of slicing the tensors will slow everything down. So one has to experiment to find the best value. It's also important to remember that to take advantage of the GPU, you need largish batches and ideally in multiples of 8.

So if the normal batch size bs=64 and chunks=8, the each stage will receive a micro-batch of 8. However if you're tight on memory in first place you may end up with a the normal bs=8, and then if you choose chunks=4, you will end up with 4 pipeline segments with a micro-batch of just 2 - which would be very inefficient. Also bs=8 and chunks=3 won't go too well together either, as you will end up with uneven micro-batches of [3,3,2].

While the diagram shows that there is a bubble of "dead" time that can't be parallelized because the last forward stage has to wait for backward to complete the pipeline, the purpose of finding the best value for chunks is to enable a high concurrent GPU utilization across all participating GPUs.

Problems:

  • have to modify the model quite heavily, because Pipeline requires one to rewrite the normal flow of modules into a nn.Sequential sequence of the same, which may require changes to the design of the model.
  • currently the Pipeline API is very restricted. If you had a bunch of python variables being passed in the very first stage of the Pipeline, you will have to find a way around it. Currently, the pipeline interface requires either a single Tensor or a tuple of Tensors as the only input and output. These tensors must have batch size as the very first dimension, since pipeline is going to chunk the normal batch into micro-batches. Possible improvements are being discussed here [wip] [Pipe] supporting None and non-Tensors in forward's input/output pytorch/pytorch#50693
  • have to arrange each layer so that the output of one model becomes an input to the other model

Implementations:

Other approaches:

SageMaker introduces the concept of an Interleaved Pipeline
interleaved-pipeline-execution

Here the bubble (idle time) is further minimized by prioritizing backward passes.

According to the same document, it might be able to automate the conversion to pipeline.

The only problem is that this is currently only available at AWS, so you can't run it on your own hardware.

Model Parallel (Horizontal)

Megatron-LM

2D Parallelism

The following diagram from the DeepSpeed pipeline tutorial demonstrates how one combines DP with PP.

dp-pp-2d

Here it's important to see how DP rank 0 doesn't see gpu2 and DP rank 1 doesn't see gpu3. To DP there is just gpus 0 and 1 where it feeds data as if there were just 2 gpus. gpu 0 "secretly" offloads some of its load to gpu 2 using PP. and gpu 1 does the same by enlisting gpu 3 to its aid.

XXX: will update this section once I get it working

3D Parallelism

FlexFlow

FlexFlow is also solving the parallelization problem in a slightly different approach.

Paper: "Beyond Data and Model Parallelism for Deep Neural Networks" by Zhihao Jia, Matei Zaharia, Alex Aiken

It performs a sort of 4D Parallelism over Sample-Operator-Attribute-Parameter.

  1. Sample = Data Parallelism
  2. Operator = part vertical Layer Parallelism, but it can split the layer too - more refined level
  3. Attribute = horizontal Model Parallelism (Megatron-LM style)
  4. Parameter = Sharded model params

and they are working on Pipeline Parallelism. I guess ZeRO-DP is Sample+Parameter in this context.

flex-flow-soap

The significance of this framework is that it takes resources like (1) GPU/TPU/CPU vs. (2) RAM/DRAM vs. (3) fast-intra-connect/slow-inter-connect and it automatically optimizes all these algorithmically deciding which parallelisation to use where.

On very important aspect is that FlexFlow is designed for optimizing DNN parallelizations for models with static and fixed workload, since models with dynamic behavior may prefer different parallelization strategies across iterations.

So the promise is very attractive - it runs say a 30min simulation on the cluster of choice and it comes up with the best strategy to utilise this specific environment. If you add/remove/replace any parts it'll run and re-optimize the plan for that. And then you can train. A different setup will have its own custom optimization.

@stas00 stas00 added Model Parallel Model Parallelilsm Implementations DeepSpeed Pipeline Parallel labels Jan 24, 2021
@stas00 stas00 self-assigned this Jan 24, 2021
@stas00 stas00 changed the title [wip] Parallelism notes [wip] [doc] Parallelism notes Jan 24, 2021
@huggingface huggingface deleted a comment from github-actions bot Apr 14, 2021
@stas00 stas00 added the WIP Label your PR/Issue with WIP for some long outstanding Issues/PRs that are work in progress label Apr 14, 2021
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
DeepSpeed Model Parallel Model Parallelilsm Implementations Pipeline Parallel WIP Label your PR/Issue with WIP for some long outstanding Issues/PRs that are work in progress
Projects
None yet
Development

Successfully merging a pull request may close this issue.

1 participant