Skip to content

Commit

Permalink
Add Varuna details to the RFC
Browse files Browse the repository at this point in the history
  • Loading branch information
James Reed committed Dec 6, 2021
1 parent ecfd1f8 commit 29126ef
Show file tree
Hide file tree
Showing 2 changed files with 51 additions and 18 deletions.
6 changes: 4 additions & 2 deletions RFC-0020-Distributed-Pipeline-Parallelism.md
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@ Existing approaches that support this (in no particular order):
* Sagemaker [model parallelism](https://arxiv.org/abs/2111.05972)
* [DeepSpeed pipeline parallelism](https://www.deepspeed.ai/tutorials/pipeline/)
* [OneFlow](https://github.com/Oneflow-Inc/oneflow)
* [Varuna](https://github.com/microsoft/varuna)[13]

Proposed approach short-list: (all approaches can be seen in [[RFC] Distributed Pipeline Parallel Training Technical Approach](https://github.com/pytorch/rfcs/blob/master/RFC-0021-Distributed-Pipeline-Parallel-Technical.md)

Expand Down Expand Up @@ -155,7 +156,7 @@ Proposed approach short-list:

These approaches can be composed on top of an existing API that takes an `nn.Sequential`. We may consider in the future to develop a "v2" API that is centered more natively around non-`nn.Sequential` models using technologies from Sagemaker, OneFlow, or other research developments.

### P1: Support arbitrary programmable schedules (e.g. fill-drain, 1F1B, interleaved 1F1B)
### P1: Support arbitrary programmable schedules (e.g. fill-drain, 1F1B, interleaved 1F1B)

Existing approaches that support this (in no particular order):

Expand Down Expand Up @@ -208,4 +209,5 @@ Going into the future, we would like to develop theory and implementation for a
10. Performance analysis of a pipelined backpropagation parallel algorithm https://ieeexplore.ieee.org/document/286892
11. PipeMare: Asynchronous Pipeline Parallel DNN Training https://arxiv.org/abs/1910.05124
12. Scaling Language Model Training to a Trillion Parameters Using Megatron
https://developer.nvidia.com/blog/scaling-language-model-training-to-a-trillion-parameters-using-megatron/
https://developer.nvidia.com/blog/scaling-language-model-training-to-a-trillion-parameters-using-megatron/
13. Varuna: Scalable, Low-cost Training of Massive Deep Learning Models https://arxiv.org/abs/2111.04007
63 changes: 47 additions & 16 deletions RFC-0021-Distributed-Pipeline-Parallel-Technical.md
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ def train_model(model, criterion, optimizer, scheduler, num_epochs=25):
# load best model weights
model.load_state_dict(best_model_wts)
return model
model_ft = models.resnet18(pretrained=True)
num_ftrs = model_ft.fc.in_features
# Here the size of each output sample is set to 2.
Expand Down Expand Up @@ -189,7 +189,7 @@ This scheme of zeroing grads on the first micro-batch can be trivially implement

Predication of forward-propagation can be done as in Zach’s [proposal](https://colab.research.google.com/drive/1lGg2NqlvDwVmvBqejzni2yTmYE9rxfdr?usp=sharing).

Note that we may extend the predicated training loop scheme to include schedules such as 1F1B or interleaved 1F1B, discussed later.
Note that we may extend the predicated training loop scheme to include schedules such as 1F1B or interleaved 1F1B, discussed later.

**Loss Calculation**

Expand Down Expand Up @@ -237,22 +237,22 @@ Note that forward and backward stages do not necessarily always run in a given r

## Approach 2 - RPC with RemoteModule and torchgpipe-style single coordinator (@pritamdamania87 RFC)

One proposal for an API for pipeline parallelism is the `pipeline_sync` API proposed in @pritamdamania87’s [RFC](https://github.com/pytorch/pytorch/issues/44827) (Certain lines are called out with end-of-line comments containing an alphabetical identifier):
One proposal for an API for pipeline parallelism is the `pipeline_sync` API proposed in @pritamdamania87’s [RFC](https://github.com/pytorch/pytorch/issues/44827) (Certain lines are called out with end-of-line comments containing an alphabetical identifier):

```
# Note: This API is very similar to torchgpipe and inspired from it.
# torchgpipe API for reference: https://torchgpipe.readthedocs.io/en/stable/api.html
torch.distributed.pipeline_sync(
pipeline: nn.Sequential,
checkpoint: CheckpointEnum = EXCEPT_LAST, # ALWAYS, EXCEPT_LAST, NEVER
checkpoint: CheckpointEnum = EXCEPT_LAST, # ALWAYS, EXCEPT_LAST, NEVER
chunks: int = 1) -> PipelineSyncModel
Arguments:
pipeline: nn.Sequential (https://pytorch.org/docs/stable/generated/torch.nn.Sequential.html) where each nn.Module in the list is placed on the
appropriate device(CPU or GPU)/machine by the user. Note that
nn.Sequential could also consist of RemoteModule (https://github.com/pytorch/pytorch/blob/master/torch/distributed/nn/api/remote_module.py#L147) for cross host
pipeline: nn.Sequential (https://pytorch.org/docs/stable/generated/torch.nn.Sequential.html) where each nn.Module in the list is placed on the
appropriate device(CPU or GPU)/machine by the user. Note that
nn.Sequential could also consist of RemoteModule (https://github.com/pytorch/pytorch/blob/master/torch/distributed/nn/api/remote_module.py#L147) for cross host
pipelining.
checkpoint: Enum that determines which checkpointing mode to use.
chunks: Number of micro-batches.
Expand All @@ -265,17 +265,17 @@ Forward Method
PipelineSyncModel.forward(self, *input, **kwargs) -> RRef
Returns:
RRef to output corresponding to the result of the minibatch.
Since we plan to support cross host pipelining, the RRef could be on a
RRef to output corresponding to the result of the minibatch.
Since we plan to support cross host pipelining, the RRef could be on a
device on a different host.
Example:
# This is an example of a pipeline across two machines each using one GPU.
# On worker 0
layer1 = nn.Linear(10, 5).cuda(0)
# Need to enhance RemoteModule to include device for this purposes.
layer2 = RemoteModule("worker1", device="cuda:0", nn.Linear, 5, 1)
layer2 = RemoteModule("worker1", device="cuda:0", nn.Linear, 5, 1)
pipeline = nn.Sequential(layer1, layer2)
model = torch.distributed.pipeline_sync(pipeline, chunks = 4) # A
Expand All @@ -300,7 +300,7 @@ for epoch in range(epochs):
target_rref = rpc.remote("worker1", identity_fn, target) # C
output_rref = model(minibatch) # D
loss_rref = rpc.remote("worker1", compute_loss, output_rref, target_rref) # E
# Can enhance RRef to ensure this calls "dist_autograd.backward" on the last
# Can enhance RRef to ensure this calls "dist_autograd.backward" on the last
# node in the pipeline.
loss_rref.backward(context_id) # F
dist_optim****.step() # G
Expand Down Expand Up @@ -375,7 +375,7 @@ class Loss(nn.Module):
for epoch in range(epochs):
loss_module = DistributedLoss(Loss, criterion, ntokens)
for minibatch, targets in dataloader:
with dist_autograd.context() as context_id:
minibatch = minibatch.transpose(0, 1)
Expand All @@ -389,7 +389,7 @@ This proposal has the training loop running on a single machine and makes copiou

**A - Model Pipelining**

As opposed to the torchgpipe-based Approach 2, this approach instantiates actors (specifically [PartitionHandler](https://github.com/facebookresearch/fairscale/blob/6f3931a4d3464056231f1bbb92a67fbd14f30d96/fairscale/experimental/nn/distributed_pipeline/partition_handler.py#L140) instances) that execute the pipeline in an event-driven manner. PartitionHandler instances own a [DistributedPipelineRecord](https://github.com/facebookresearch/fairscale/blob/6f3931a4d3464056231f1bbb92a67fbd14f30d96/fairscale/experimental/nn/distributed_pipeline/partition_handler.py#L27) instance, which has a “feed” method to be called via RPC to add a data item for processing.
As opposed to the torchgpipe-based Approach 2, this approach instantiates actors (specifically [PartitionHandler](https://github.com/facebookresearch/fairscale/blob/6f3931a4d3464056231f1bbb92a67fbd14f30d96/fairscale/experimental/nn/distributed_pipeline/partition_handler.py#L140) instances) that execute the pipeline in an event-driven manner. PartitionHandler instances own a [DistributedPipelineRecord](https://github.com/facebookresearch/fairscale/blob/6f3931a4d3464056231f1bbb92a67fbd14f30d96/fairscale/experimental/nn/distributed_pipeline/partition_handler.py#L27) instance, which has a “feed” method to be called via RPC to add a data item for processing.

**B - Distributed Optimizer**

Expand All @@ -405,7 +405,7 @@ Loss calculation happens similarly to in Approach 2, the single driver calls int

**E - Backprop**

Backpropagation through the pipeline is similarly implemented via distributed autograd, as in Approach 2. Note that the same fork/join barrier approach is used to [serialize](https://github.com/facebookresearch/fairscale/blob/6f3931a4d3464056231f1bbb92a67fbd14f30d96/fairscale/experimental/nn/distributed_pipeline/partition_handler.py#L103) execution of micro-batches on the backward pass.
Backpropagation through the pipeline is similarly implemented via distributed autograd, as in Approach 2. Note that the same fork/join barrier approach is used to [serialize](https://github.com/facebookresearch/fairscale/blob/6f3931a4d3464056231f1bbb92a67fbd14f30d96/fairscale/experimental/nn/distributed_pipeline/partition_handler.py#L103) execution of micro-batches on the backward pass.

**NOTE**: I don’t believe that forward and backward jobs are serialized; they may run concurrently. Is this true?

Expand Down Expand Up @@ -482,7 +482,7 @@ The implementations for each of these instructions can be referenced from this [

* (**D2**) (hypothetically) supports arbitrary schedules through the [PipeSchedule](https://github.com/microsoft/DeepSpeed/blob/488105ebd200bbd1f6d7cbe863412e41d9ab4221/deepspeed/runtime/pipe/schedule.py#L6) abstraction. However, there don’t seem to be any schedules implemented beyond the default
* (**D3, D4?**) Usable in 3d parallelism, as detailed by the [blog post](https://www.deepspeed.ai/tutorials/pipeline/).
* (**D6**) Since data is pulled from the data loader rather than being pushed by a synchronous call in the training loop, this approach could *hypothetically* support async PP.
* (**D6**) Since data is pulled from the data loader rather than being pushed by a synchronous call in the training loop, this approach could *hypothetically* support async PP.
* (**D7**) The approach seems to account for many different types of parallelism.

**Con**
Expand Down Expand Up @@ -545,6 +545,37 @@ An example of using OneFlow for pipeline parallelism can be seen in this [tutori
* (split pro/con) (**D2**) Not clear if `1f1b` or other schedules are implemented/implementable?
* (**D8**) Not clear what the training loop abstraction looks like. The optimizer is installed via an `nn.Graph` API. Loss calculation is created in the `nn.Graph.build()` method.

## Approach 7: Varuna

Varuna (https://arxiv.org/abs/2111.04007) proposes a system for large-scale training that focuses on training on commodity hardware. In particular, Varuna focuses on pipeline parallelism (to work on commodity interconnects), tuning the pipeline to optimally trade-off pipeline bubble size vs. allreduce bandwidth, dynamic scheduling of pipeline stages to account for network latency and jitter, and elastic scheduling.

The workflow of Varuna looks like the following:
* The user manually annotates their model with [CutPoints](https://github.com/microsoft/varuna/blob/ea13ce6a8934bfe829b662a56af4dc29044fcb34/docs/cutpoint.rst). These are points in the program where the system _may_ place a pipeline stage
* The user wraps their model in the [Varuna](https://github.com/microsoft/varuna/blob/ea13ce6a8934bfe829b662a56af4dc29044fcb34/docs/varuna.rst#id9) class and configures pipeline parallelism via this interface. This includes parameters like chunk size, installing the optimizer, and informing the system of the rank of the current running instance.
* Internally, the system is going to do a roudn of profiling to determine the optimal pipeline balance and choose a subset of the cut-points together to represent the code for different pipeline stages.
* User calls the `Varuna.step()` method to run the program in pipeline parallel execution (forward, loss, backward)
* Varuna uses an opportunistic scheduling policy (described in section 3.2 of the paper), which will run ahead and run `forward()` micro-batches if `backward()` micro-batches are not available
* User applies the optimizer's `step()` method to finally update the parameters give accumulated gradients.
* Outside of the Python script, the user uses the `run_varuna` launcher script to orchestrate (elastic) job scheduling

### Pros and Cons of the Approach

**Pro**

* (**D2**) Varuna implements scheduling, particularly their [opportunistic scheduling](https://github.com/microsoft/varuna/blob/79aaf45995a2b06bf5186e825b7baf81b9145837/varuna/pipeline.py#L280) policy. See in the "con", I'm not super convinced by this scheme, but the system supports scheduling (and can probably be extended or hacked to support more traditional schedules)
* (**D5**) The system nominally supports pipeline partitioning without wrapping into a `Sequential`, but see "con" for commentary about the soundness of this approach.


**Con**

* (**D5**) Varuna's approach to partitioning models does not seem sound.
* It assumes that the order of invocation of modules matches the order of the modules as [enumerated](https://github.com/microsoft/varuna/blob/ea13ce6a8934bfe829b662a56af4dc29044fcb34/varuna/partitioned_model.py#L396) in `nn.Module`. This is not necessarily the case, as there can be any arbitrary set of `use-def` relationships between modules and call-sites in a PyTorch module
* It [nulls out](https://github.com/microsoft/varuna/blob/ea13ce6a8934bfe829b662a56af4dc29044fcb34/varuna/partitioned_model.py#L434) modules that are not relevant to the current rank's computation by replacing them with [PassThroughModule](https://github.com/microsoft/varuna/blob/ea13ce6a8934bfe829b662a56af4dc29044fcb34/varuna/partitioned_model.py#L635), which is implemented to simply return `None` from `forward()`. This works when the model is composed purely of module calls and the data dependencies between them, but if at any point there is a non-trivial construct in the code (e.g. any operations on the output of the Module), this will break.
* (**D2**) I'm not convinced by [opportunistic scheduling](https://github.com/microsoft/varuna/blob/79aaf45995a2b06bf5186e825b7baf81b9145837/varuna/pipeline.py#L280). The concept is sound, and reminds me of Tomasulo algorithm-style out-of-order execution for dealing with stochastic latencies (e.g. from memory latencies in a processor), but an extra constraint in pipeline parallel execution for deep learning is: value lifetimes. Activations from forward jobs must be saved for use in the backward job, meaning the memory high-watermark of the pipeline stage is increased for every forward() job that is admitted without a corresponding backward() job to release those values. The literature addresses this with static execution schedules such as the [1F1B strategy](https://arxiv.org/abs/1806.03377) and OneFlow solves this by implementing [Registers and back pressure](https://oneflow2020.medium.com/runtime-of-oneflow-based-on-boxing-and-actor-model-part-3-f2b786dc14a0) in the pipeline. As far as I can tell, Varuna will run ahead indiscriminately until OOM
* (**D8**) The extent to which the training loop must be modified ([BERT](https://github.com/microsoft/varuna/blob/ea13ce6a8934bfe829b662a56af4dc29044fcb34/examples/BERT/bert.patch), [megatron](https://github.com/microsoft/varuna/blob/ea13ce6a8934bfe829b662a56af4dc29044fcb34/examples/Megatron-LM/megatron.patch)) to work with Varuna is pretty extreme
* (**D3/D4/D7**) The system does not seem to compose with tensor parallelism, at least that is not described in the paper or the README.
* (**D6**) Does not support async

## Final Analysis

### General Design Axes
Expand Down

0 comments on commit 29126ef

Please sign in to comment.