Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

RFC-0020/0021/0022 RFCs for Pipeline Parallelism #32

Open
wants to merge 9 commits into
base: master
Choose a base branch
from

Conversation

jamesr66a
Copy link

@jamesr66a jamesr66a commented Nov 20, 2021

This PR consists of three RFCs:

  • RFC-0020 Pipeline Parallelism Strategic Plan
  • RFC-0021 Pipeline Parallelism Technical Approach Proposal
  • RFC-0022 Model Partitioning in Pipeline Parallelism Proposal

Please note that the details of these proposals are subject to revision given feedback from users and partners. Please feel free to comment on the RFCs with your feedback

* Approach 3: MPMD with RemoteModule and message passing (fairscale experimental)
* Approach 4: MPMD with a custom interpreter/instruction format and message passing (DeepSpeed)
* Approach 5: RPC with remote modules and generalized Module-server architecture (SageMaker)
* Approach 6: SPMD with Program capture/JIT compilation and message passing (OneFlow)
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am curious how quickly we can prototype some of the approaches on typical transformer models such as GPT3-13B or GPT3-175B, and then evaluate them on both UX and performance aspects.

IMO from pure UX point of view we might want to preserve Python training loop and free-form (no sequential) authoring, otherwise the adoption barrier might be so high that users would just not try it.

I even feel that maximizing UX at a small cost of perf is fine, because user can always find a way (and hire help) to hyper-optimize their training pipeline later if needed.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @yf225! Indeed, we are also hearing feedback from other users that the front-end restrictions will be prohibitive. I think this week I will likely keep iterating on both strategy (prioritize non-Sequential models) and technical approaches that address the front-end issue and update this RFC.


![Module-server request-response execution in SageMaker pipeline parallelism](https://i.imgur.com/y9MZJ3b.png)

**NOTE**: The schedules here do not necessarily run stage in a given order on each stage. Network latency and other affects may change the order of when micro-batches are executed.
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I feel that non-determinism in scheduling is probably okay as long as we can prove it's correct and we can reduce bubble and memory usage as much as possible.

Particularly for memory, one thing that's tricky these days is to reason through the memory usage behavior and to optimize it (e.g. to fit a bigger batch size). Since in general forward accumulates activation tensor memory and backward releases activation tensor memory, iiuc under this scheme, doing forward twice might cause OOM while using a strict 1F1B schedule wouldn't. In that case I think having a deterministic (or at least a formally reasonable) schedule is still better.

@deepakn94
Copy link

Thanks for this, @jamesr66a! It is quite detailed; it will take me a few more days to give it the full attention it deserves.

Some very quick high-level thoughts:

  • I prefer figuring out a clean API to support pipeline parallelism for any model, rather than first trying to support torch.nn.Sequential. Things like skip connections are pretty common, so figuring out how to support these now rather than later seems like time well spent.
  • I also think supporting this with as few code changes as possible to the main training loop is good, similar to torch.distributed.DistributedDataParallel. Of course, some things will probably need to change given that every rank is not loading input data, etc.
  • I like the idea of hiding the specifics of what happens in a given batch (e.g., how many microbatches are in a batch, how these microbatches are scheduled, etc.) behind some API. IMO this achieves a clean separation of concerns: users can use free-form Python for the per-batch processing (as before), while hiding parallelization-strategy-specific implementation details behind an API. And the API can provide some way for users to override the schedule if they like. I wonder if it's possible to support asynchronous pipelining schemes using this approach.

Given the existing work and their limitations, we introduce a set of APIs that’s flexible enough for future improvements and intuitive to use. We expose a single API `create_balanced_partition`, to take the model and do the partition under the hood.

```
def create_balanced_partition(model: nn.Module,

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This could perhaps be a bit more expressive if you also considered data parallelism degree. So given a total number of devices, you figure out the data parallelism and pipeline parallelism degrees as well.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@deepakn94 Do you mean if one were to feed too much devices where one would be able to leverage data parallelism? If yes then this makes sense to me.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, the idea would be to give the partitioning function all devices that you have available (let's say 128 GPUs), and the system automatically figures out how to split GPUs between data and pipeline parallelism (and perhaps tensor parallelism at some point as well).

#device = torch.device("cuda")):
# if model itself if larger than a single cuda device memory,
# should we allow the user to profile the model on cpu?
# Decision: probably no, as the characteristics are different

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, I think you still want to do profiling on the same target GPU type. Perhaps you can stream execution per operator / inner module in order to keep memory footprint low?


Pipeline parallelism auto partition capability needs more advanced knowledge in order to accurately divide a model into a balanced partition set. In order to achieve auto partitioning with the most balanced approach, we need to get the model execution graph and try to split the model base on the graph. Using torch.fx can give us a helpful graph representation in order for us to do more advanced partition with extensive analysis with tracing and profiling.

The model that passed in should be fx compatible in order to generate the partitions with the most accurate estimation. What models does not have fx compatibility currently:

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Probably want to collect memory profiles as well (especially for large models that don't fit in a single accelerator).


where `memory_weight + time_weight = 1`. How to weigh between time and memory cost? Undecided, a heuristic number `time_weight=0.6`

* @mrshenli: regarding balancing time between memory, I kind of feel we should prioritize time over memory. Because people usually use Pipeline parallel to accelerate training, and the training speed the ultimate goal. Balanced execution time has a directly impact on total pipeline makespan (suppose one phase is slower than others, then all other devices will need to wait for that phase). If the above assumption is correct, it looks like we should first try to balance time, and only try to balance memory when a shard of the optimally time-balanced model cannot fit in some devices.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

People use pipeline parallelism often to train large models like GPT-3, which don't fit on a single accelerator. I think of memory_cost as a constraint rather than an objective (minimize time-per-iteration while ensuring that any worker's memory footprint does not exceed its memory capacity).

```


Since we assigned the cost to each top-level module, we can do the partition, basic partition function will be like the below:

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Instead of hardcoding these constants, can we not run a few training iterations with the relevant optimizer?

After deliberation, we want to build the API with the least complexity, at least initially. We will modify/build the API in FairScale experimental with a few modifications:


* (**DA5**) Rather than using torchgpipe-style virtual dependencies in the distributed autograd graph, we want each stage to manually handle running `forward` and `backward` stages (the latter by explicitly calling `torch.autograd.backward()`). This will give easier and finer-grained control over the execution schedule of the pipeline

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can an example of what the main training loop looks like with these modifications be added?

And also, how is training launched? Some example commands for each required rank / host would be useful.

Copy link

@stas00 stas00 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you so much for working on this!!! So happy to see this RFC.

I left a few small suggestions if you will find those useful.

RFC-0020-Distributed-Pipeline-Parallelism.md Show resolved Hide resolved
RFC-0020-Distributed-Pipeline-Parallelism.md Show resolved Hide resolved
* Non-requirements:
* Composability with ZeRO-2/3 is not required. Theoretically possible, but reportedly will not give any perf gain.
* Success Criteria:
* **to be determined**: Feedback on this would be appreciated
Copy link

@stas00 stas00 Dec 1, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Well, the biggest issue with integrating PP into Transformers is the cheer number of models to support - 70+ and growing (albeit there aren't all unique models).

So the key measure of success would be either passing a model as is and having it just work - similar to how it works with Deepspeed ZeRO, or having a sort of table of policies for each model architecture, that can be easily defined which when applied to the model will restructure it on the fly to do the right thing.

(edited to add ZeRO above)

One of the things we have been looking into is having a magical torch.fx rewrite of the graph to make the model work with PP on the fly. We haven't made much progress here yet.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So the key measure of success would be either passing a model as is and having it just work - similar to how it works with Deepspeed

So for my understanding, DeepSpeed requires you to still pass in a linear sequence of layers, and that would require some modification, right? The prior sentence seems to suggest that having "out-of-the-box" pipelining without having to do such modifications. I'm happy to set that as a goal, just want to clarify.

Copy link

@stas00 stas00 Dec 1, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My apologies. I meant Deepspeed ZeRO - which requires no mods to the model (most of the time) - I edited my comment.

Deepspeed PP is not user friendly at all compared to that as it too wants a sequential model.

But Sagemaker PP offers that of out the box. But it's a proprietary solution.

Copy link

@stas00 stas00 Dec 1, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The only reason not to use ZeRO to scale a huge model - is when interconnects are slow. When interconnects are fast there is no real need for PP+TP, ZeRO-DP and PP+TP+DP perform on par.

Copy link

@stas00 stas00 Dec 1, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is why we ended up using Megatron-Deepspeed for the BigScience project instead of using Transformers+Deepspeed ZeRO. The network on JeanZay is slow, so PP was the only way to keep TFLOPs high.

We are now dealing with 100-200B models.

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hey @jamesr66a,

I agree with @stas00 on this one. I believe adding PP support for more models would definitely be a criterion of success and unblock an integration within PyTorch Lightning as we used to support the experimental Pipe version from FairScale, but it was too limited and unstable.

RFC-0020-Distributed-Pipeline-Parallelism.md Show resolved Hide resolved

DAPPLE combines DDP with pipeline parallelism to form a bigger search space and use a more efficient schedule if possible. Their partition approach is a DP-based algorithm, it first tries to find the “pivotal” stage, then optimize the overall latency, the “overall latency” optimization here tries to reduce the bubbles in the pipeline as small as possible.


Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here too Varuna that I linked too earlier provides a partitioning mechanism using cutPoints. see their paper : https://arxiv.org/abs/2111.04007 for details.

RFC-0020-Distributed-Pipeline-Parallelism.md Show resolved Hide resolved

**Data Loader**

The data loader should only load input data on rank 0. We can view this as predicating `true` on rank 0 and false on all other ranks. We can also commingle this with `recv` for stages != 0. i.e. under the hood the dataloader object will return an input micro-batch on rank 0, but will return the intermediate value received from `rank - 1` on all ranks != 0.
Copy link

@thomasw21 thomasw21 Dec 1, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In a non sequential model, the model graph becomes a directed acyclic graph, where parents of rank are in [0, rank-1], not just rank - 1. A typical example would be:

def forward(x):
  y = self.layer_0(x) # rank 0
  z = self.layer_1(x) # rank 1
  a = self.layer_2(z + y) # rank 2 (receiving value from both rank 0 and rank 1) 
  return a

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@thomasw21 if this is something that is an expected use case, we should discuss this further as it has broader implications on the rest of the RFC and we should likely re-do all of the designs.

To help me understand: what would the execution pattern look like when pipelining a non-sequential (arbitrary DAG) model? In particular, can there exist disconnected nodes which must be scheduled somehow? For example:

def forward(self, x):
  y = torch.zeros(3, 4) # On forward, must be scheduled for execution independently of the function inputs
  z = self.layer_1(x) # rank 0
  a = self.layer_2(z + y) # rank 1
  return a

(Hypothetically we could also have some sort of sink node in the forward that results in a disconnected edge in the autograd graph, but I can't think of a concrete scenario where that would happen)

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would expect that disconnected graphs should work. Typically when using absolute positional embeddings, we obtain the position embeddings like the following:

def forward(self, x):
   pos_emb = self.get_position_embeddings() # returns nn.Parameter
   token_emb = self.compute_token_embeddings(x)
   return pos_emb + token_emb

However now that I think about it, it's highly unlikely that there exist a rank that holds a parentless node (otherwise you would just fuse it with one of its children, the first one in the rank ordering, and use that as a proxy parent):

pos_emb -------> output
             /
token_emb --

to

token_emb ----> pos_emb + token_emb ----> output

The reason I think of DAG, is because let's say you have ensembling techniques, where you have 5 experts, which are then aggregated. I'd expect the 5 experts to be able to run in parallel. Now imagine each expert is held inside one rank each. You'd have something like 6 ranks (5 + 1 to run experts and then aggregate).

Current framework would work like this in my understanding:

expert 0 ----> expert 1 + expert 0 ----> expert 2 + expert 1 + expert 0 ---> etc 

Each rank would actually hold its own expert, run that expert, but actually keep forwarding previous expert results to next rank no? This means that fundamentally we should only be able to run that in a sequential fashion? I would have expected all experts to be able to run in parallel. The agregator can still require to receive values from its experts in an arbitrary order (but that means we're bottlenecked by the slowest expert, instead of the sum all experts).

Does that make sense?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@thomasw21 I think modeling arbitrary graph programs in a pipeline parallel API is certainly possible (e.g. it was done in the fairscale experimental distributed pipeline prototype https://github.com/facebookresearch/fairscale/blob/7c2c3e0046eb3a1b1c98fd24f61fb09f90f6817f/fairscale/experimental/nn/distributed_pipeline/graph.py#L56).

OTOH, I would ask whether we can decompose this problem into simpler solutions that can be composed with each other. For example, in the first example, could we group the get_position_embeddings() call and the compute_token_embeddings call into a single pipeline stage and use RPCs to run those in parallel on remote hosts. Similarly, for an MoE architecture, can we commingle each expert into a single pipeline stage, but within that pipeline stage use remote RPCs to run the experts on remote hosts in parallel?

Copy link

@stas00 stas00 Dec 2, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Won't that be incredibly inefficient speed-wise comparatively to other stages with a lot more logic in them? Also wrt to memory usage - as you will want to partition each stage so that memory usage will be balanced.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah sorry, when mentioning embeddings I was trying to think of examples of parentless nodes. I think that it's highly unlikely that you would create a pipeline stage that doesn't require a previous stage to be ran before or depend on any inputs.

I think the part I'm perhaps a bit unfamiliar is the concept of having multiple hosts on a single pipeline stage. I'm mostly familiar with deepspeed's version. Stages are assigned a single device. We mainly get parallel execution thanks to microbatching/scheduling.

Taking the multi experts again. If I understand correctly there'd be two stages, one to run all the experts, and another one to run the aggregator? If so that makes sense to me. The rules should become like:

  • Within a single pipeline stage, multiple hosts run in parallel.
  • Ranks feed their results sequentially to next ranks.

But then if you manually have to setup RPCs, aren't you manually doing Pipeline parallelism? ie the whole model is a single stage and use RPCs to schedule calls on different devices?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@thomasw21 here's what i'm describing, maybe you have something else in mind:

Given the definitions of pipelining used in the literature, pipeline parallelism looks like turning the following:

def my_model(x):
  x = op1(x)
  x = op2(x)
  x = op3(x)
  return x

Into something like (suppose we cut between op2 and op3 into 2 stages):

def my_model0(x):
  x = op1(x)
  x = op2(x)
  return x

def my_model1(x):
  x = op3(x)
  return x

And then the composition my_model1(my_model0(x)) would run in pipelined execution. I think in the MoE example, the original source would look something like:

def my_model(x):
  x = op1(x)
  x = op2(x)
  x = op3(x)

  gates = gate_fn(x)
  moe_results = []
  for i, expert in enumerate(experts):
    moe_results.append(expert(x[i]))
  x = combine(moe_experts)
  return x

Suppose we apply the asme pipelining policy, i.e.:

def my_model0(x):
  x = op1(x)
  x = op2(x)
  return x

def my_model1(x):
  x = op3(x)

  gates = gate_fn(x)
  moe_results = []
  for i, expert in enumerate(experts):
    moe_results.append(expert(x[i]))
  x = combine(moe_results)
  return x

As you point out, there is inherent parallelism in the parallel application of the experts, i.e. in the enumerate(experts) loop. However, with respect to how the code appears on the screen, this is "horizontal" parallelism vs. the "vertical" parallelism of pipelining. So my proposal is to basically implement horizontal parallelism in a way that is orthogonal to pipelining, i.e.:

def my_model0(x):
  x = op1(x)
  x = op2(x)
  return x

def my_model1(x):
  x = op3(x)

  gates = gate_fn(x)
  futures = []
  for i, expert in enumerate(experts):
    futures.append(rpc(expert(x[i])))

  moe_results = []
  for fut in futures:
    moe_results = fut.wait()

  x = combine(moe_results)
  return x

In this way, horizontal parallelism is achieved while keeping the pipeline parallel system simple and focused on vertical parallelism. The same argument can be made for model parallelism, i.e. pipeline parallelism can operate on ShardedTensor instances that dispatch to collectives, but pipeline parallelism itself does not need to be aware of this scheme

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@thomasw21 Put another way, pipeline parallelism is about parallelizing code across micro-batches, whereas the MoE parallelism is within a micro-batch, thus I think these should be orthogonal to each other. Does that make sense?

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see, this makes sense. Would the partitioning algorithm be able to work orthogonally with having these kind of systems but in place?

Copy link

@stas00 stas00 Dec 10, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To support the discussion of using MoE in the pipeline

Microsoft and Google have just published architectures that use MoE:

https://ai.googleblog.com/2021/12/more-efficient-in-context-learning-with.html
https://www.deepspeed.ai/news/2021/12/09/deepspeed-moe-nlg.html

Deepspeed's version uses Megatron-Deepspeed so it's already based on Deepspeed PP implementation - so you could probably check how they dealt with MoE - the Moe Branch is here: https://github.com/microsoft/Megatron-DeepSpeed/tree/moe-training

But reading closer Deepspeed's version activates only top-1 expert, so it's then irrelevant to this discussion, as it runs only 1 expert concurrently.

Google's arch runs 2 experts concurrently and thus is relevant as it needs to parallelize the processing of these 2.

Given the existing work and their limitations, we introduce a set of APIs that’s flexible enough for future improvements and intuitive to use. We expose a single API `create_balanced_partition`, to take the model and do the partition under the hood.

```
def create_balanced_partition(model: nn.Module,

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@deepakn94 Do you mean if one were to feed too much devices where one would be able to leverage data parallelism? If yes then this makes sense to me.

Given the existing work and their limitations, we introduce a set of APIs that’s flexible enough for future improvements and intuitive to use. We expose a single API `create_balanced_partition`, to take the model and do the partition under the hood.

```
def create_balanced_partition(model: nn.Module,

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does that mean that one needs to be able to load a model in memory? This might be problematic when trying to load huge model on a single host?


## Proposing: Automated Model Partitioning in PyTorch Pipeline Parallelism

Given the existing work and their limitations, we introduce a set of APIs that’s flexible enough for future improvements and intuitive to use. We expose a single API `create_balanced_partition`, to take the model and do the partition under the hood.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we be able to "override" the partitioning, ie similarly to feed to a regex like rule?


* Use fairscale/gpipe's [Block Partition of Sequences](https://arxiv.org/pdf/1308.2452.pdf) as our first step
* See if we can apply [Fiduccia-Mattheyses Heuristic](https://en.wikipedia.org/wiki/Fiduccia%E2%80%93Mattheyses_algorithm) to partition the graph and how the performance compare with the default one
* See if we can do operation-level tracing, and partition the model into several balanced `torch.fx.GraphModule` instead of using the original module architecture

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would that mean that we lose all methods linked to the original class? For example transformers fit a certain interface which allows them to use a generate method. If one uses torch.fx.GraphModule it loses all methods linked to it's original module?


## FX compatibility

Pipeline parallelism auto partition capability needs more advanced knowledge in order to accurately divide a model into a balanced partition set. In order to achieve auto partitioning with the most balanced approach, we need to get the model execution graph and try to split the model base on the graph. Using torch.fx can give us a helpful graph representation in order for us to do more advanced partition with extensive analysis with tracing and profiling.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

torch.fx gives a representation of a function no? Typically in case of a module we'd trace forward. However a lot of model have "secondary" forward function, typically infer which might be slightly different from the forward function. How does optimizing on the forward function, generalises to the infer function.


```
cost(m) = time_weight * execution_time(m)
+ memory_weight * memory_cost(m)

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why would memory weight be used? Isn't it binary as in "does it fit in memory"?

Can we use ShardedTensor to shard the module parameters? or we could do operation level tracing, partition this module into 2 submodules, then assign to different devices.

**What if constructing the model on CPU itself is hard?**
User created the model first on meta device, we use the model that’s not materializing to do symbolic tracing, but we couldn’t do profiling since it’s not materialized yet, after we do symbolic tracing, we should use a simple partition algorithm (i.e. only count the param sizes) to do the partition, then materialize afterwards.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there a documentation on "meta" device? I don't seem to find any.

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@deepakn94
Copy link

deepakn94 commented Dec 2, 2021

Oh, another thing that I don't see mentioned anywhere: it is common for language models to have layers that "share" parameters. For example, the embedding layer at the front of the model and the LM head often share weights. There are a couple of different ways of handling this: a) ensure that all layers sharing parameters are on the same device, and let autograd do the right thing, b) allow layers sharing parameters to reside on different devices, but then synchronize gradients at the end of a step (and also initialize parameters the same way on all devices with the same "shared" parameters).

I am sure there are other ways of handling this as well.

* (split pro/con) (**D2**) Not clear if `1f1b` or other schedules are implemented/implementable?
* (**D8**) Not clear what the training loop abstraction looks like. The optimizer is installed via an `nn.Graph` API. Loss calculation is created in the `nn.Graph.build()` method.

## Approach 7: Varuna
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@stas00 I've added details about Varuna in this section, please have a look and let me know if I've missed anything

Copy link

@stas00 stas00 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for adding Varuna, @jamesr66a!

The notes look good

RFC-0021-Distributed-Pipeline-Parallel-Technical.md Outdated Show resolved Hide resolved
@SeanNaren
Copy link

This is a fantastic summary of attempts for PP, great work @jamesr66a :)

Within Lightning we experimented with the initial RemoteModule from Fairscale which relied on RPC, however as discussed the requirement of everything being sequential made it very restrictive.

There is another example (albeit very intrusive in the user's code) by the GraphCore team where PP is a requirement for many models when using IPUs https://docs.graphcore.ai/projects/poptorch-user-guide/en/1.0.0/overview.html#poptorch-block-and-poptorch-beginblock

Even though this approach is extremely intrusive, it does have its merits of support a wider range of models + being a bit more expressive. I would relate this to the idea behind FlexFlow's torch.fx support as we would need to traverse the graph.

@stas00
Copy link

stas00 commented Dec 7, 2021

Even though this approach is extremely intrusive, it does have its merits of support a wider range of models + being a bit more expressive. I would relate this to the idea behind FlexFlow's torch.fx support as we would need to traverse the graph.

Indeed, I was considering mentioning https://github.com/flexflow/flexflow as another framework to consider but last we checked it still had the PP-support only planned.

The design paper is interesting and it too uses a simulation to automatically partition the graph.

@@ -0,0 +1,211 @@
# [RFC] Ceci n'est pas pipeline parallelism (Pipeline Parallelism 2021Q4/2022 Plan)
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

French ?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

RFC-0020-Distributed-Pipeline-Parallelism.md Outdated Show resolved Hide resolved
* Non-requirements:
* Composability with ZeRO-2/3 is not required. Theoretically possible, but reportedly will not give any perf gain.
* Success Criteria:
* **to be determined**: Feedback on this would be appreciated
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hey @jamesr66a,

I agree with @stas00 on this one. I believe adding PP support for more models would definitely be a criterion of success and unblock an integration within PyTorch Lightning as we used to support the experimental Pipe version from FairScale, but it was too limited and unstable.

@rohan-varma
Copy link
Member

cc @pritamdamania87 @pbelevich @mrshenli @zhaojuanmao

* [DeepSpeed pipeline parallelism](https://www.deepspeed.ai/tutorials/pipeline/)
* [OneFlow](https://github.com/Oneflow-Inc/oneflow)

Proposed approach short-list: (all approaches can be seen in [[RFC] Distributed Pipeline Parallel Training Technical Approach](https://github.com/pytorch/rfcs/blob/master/RFC-0021-Distributed-Pipeline-Parallel-Technical.md)
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry, the link is now broken.

@facebook-github-bot
Copy link
Contributor

Hi @jamesr66a!

Thank you for your pull request.

We require contributors to sign our Contributor License Agreement, and yours needs attention.

You currently have a record in our system, but the CLA is no longer valid, and will need to be resubmitted.

Process

In order for us to review and merge your suggested changes, please sign at https://code.facebook.com/cla. If you are contributing on behalf of someone else (eg your employer), the individual CLA may not be sufficient and your employer may need to sign the corporate CLA.

Once the CLA is signed, our tooling will perform checks and validations. Afterwards, the pull request will be tagged with CLA signed. The tagging process may take up to 1 hour after signing. Please give it that time before contacting us about it.

If you have received this in error or have any questions, please contact us at [email protected]. Thanks!

@facebook-github-bot
Copy link
Contributor

Thank you for signing our Contributor License Agreement. We can now accept your code for this (and any) Meta Open Source project. Thanks!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.