diff --git a/RFC-0020-Distributed-Pipeline-Parallelism.md b/RFC-0020-Distributed-Pipeline-Parallelism.md index 4452609..1f6a2bc 100644 --- a/RFC-0020-Distributed-Pipeline-Parallelism.md +++ b/RFC-0020-Distributed-Pipeline-Parallelism.md @@ -63,6 +63,7 @@ Existing approaches that support this (in no particular order): * Sagemaker [model parallelism](https://arxiv.org/abs/2111.05972) * [DeepSpeed pipeline parallelism](https://www.deepspeed.ai/tutorials/pipeline/) * [OneFlow](https://github.com/Oneflow-Inc/oneflow) +* [Varuna](https://github.com/microsoft/varuna)[13] Proposed approach short-list: (all approaches can be seen in [[RFC] Distributed Pipeline Parallel Training Technical Approach](https://github.com/pytorch/rfcs/blob/master/RFC-0021-Distributed-Pipeline-Parallel-Technical.md) @@ -155,7 +156,7 @@ Proposed approach short-list: These approaches can be composed on top of an existing API that takes an `nn.Sequential`. We may consider in the future to develop a "v2" API that is centered more natively around non-`nn.Sequential` models using technologies from Sagemaker, OneFlow, or other research developments. -### P1: Support arbitrary programmable schedules (e.g. fill-drain, 1F1B, interleaved 1F1B) +### P1: Support arbitrary programmable schedules (e.g. fill-drain, 1F1B, interleaved 1F1B) Existing approaches that support this (in no particular order): @@ -208,4 +209,5 @@ Going into the future, we would like to develop theory and implementation for a 10. Performance analysis of a pipelined backpropagation parallel algorithm https://ieeexplore.ieee.org/document/286892 11. PipeMare: Asynchronous Pipeline Parallel DNN Training https://arxiv.org/abs/1910.05124 12. Scaling Language Model Training to a Trillion Parameters Using Megatron - https://developer.nvidia.com/blog/scaling-language-model-training-to-a-trillion-parameters-using-megatron/ \ No newline at end of file + https://developer.nvidia.com/blog/scaling-language-model-training-to-a-trillion-parameters-using-megatron/ +13. Varuna: Scalable, Low-cost Training of Massive Deep Learning Models https://arxiv.org/abs/2111.04007 diff --git a/RFC-0021-Distributed-Pipeline-Parallel-Technical.md b/RFC-0021-Distributed-Pipeline-Parallel-Technical.md index c4523a0..1a9b221 100644 --- a/RFC-0021-Distributed-Pipeline-Parallel-Technical.md +++ b/RFC-0021-Distributed-Pipeline-Parallel-Technical.md @@ -72,7 +72,7 @@ def train_model(model, criterion, optimizer, scheduler, num_epochs=25): # load best model weights model.load_state_dict(best_model_wts) return model - + model_ft = models.resnet18(pretrained=True) num_ftrs = model_ft.fc.in_features # Here the size of each output sample is set to 2. @@ -189,7 +189,7 @@ This scheme of zeroing grads on the first micro-batch can be trivially implement Predication of forward-propagation can be done as in Zach’s [proposal](https://colab.research.google.com/drive/1lGg2NqlvDwVmvBqejzni2yTmYE9rxfdr?usp=sharing). -Note that we may extend the predicated training loop scheme to include schedules such as 1F1B or interleaved 1F1B, discussed later. +Note that we may extend the predicated training loop scheme to include schedules such as 1F1B or interleaved 1F1B, discussed later. **Loss Calculation** @@ -237,7 +237,7 @@ Note that forward and backward stages do not necessarily always run in a given r ## Approach 2 - RPC with RemoteModule and torchgpipe-style single coordinator (@pritamdamania87 RFC) -One proposal for an API for pipeline parallelism is the `pipeline_sync` API proposed in @pritamdamania87’s [RFC](https://github.com/pytorch/pytorch/issues/44827) (Certain lines are called out with end-of-line comments containing an alphabetical identifier): +One proposal for an API for pipeline parallelism is the `pipeline_sync` API proposed in @pritamdamania87’s [RFC](https://github.com/pytorch/pytorch/issues/44827) (Certain lines are called out with end-of-line comments containing an alphabetical identifier): ``` # Note: This API is very similar to torchgpipe and inspired from it. @@ -245,14 +245,14 @@ One proposal for an API for pipeline parallelism is the `pipeline_sync` API prop torch.distributed.pipeline_sync( pipeline: nn.Sequential, - checkpoint: CheckpointEnum = EXCEPT_LAST, # ALWAYS, EXCEPT_LAST, NEVER + checkpoint: CheckpointEnum = EXCEPT_LAST, # ALWAYS, EXCEPT_LAST, NEVER chunks: int = 1) -> PipelineSyncModel Arguments: -pipeline: nn.Sequential (https://pytorch.org/docs/stable/generated/torch.nn.Sequential.html) where each nn.Module in the list is placed on the - appropriate device(CPU or GPU)/machine by the user. Note that - nn.Sequential could also consist of RemoteModule (https://github.com/pytorch/pytorch/blob/master/torch/distributed/nn/api/remote_module.py#L147) for cross host +pipeline: nn.Sequential (https://pytorch.org/docs/stable/generated/torch.nn.Sequential.html) where each nn.Module in the list is placed on the + appropriate device(CPU or GPU)/machine by the user. Note that + nn.Sequential could also consist of RemoteModule (https://github.com/pytorch/pytorch/blob/master/torch/distributed/nn/api/remote_module.py#L147) for cross host pipelining. checkpoint: Enum that determines which checkpointing mode to use. chunks: Number of micro-batches. @@ -265,17 +265,17 @@ Forward Method PipelineSyncModel.forward(self, *input, **kwargs) -> RRef Returns: - RRef to output corresponding to the result of the minibatch. - Since we plan to support cross host pipelining, the RRef could be on a + RRef to output corresponding to the result of the minibatch. + Since we plan to support cross host pipelining, the RRef could be on a device on a different host. - + Example: # This is an example of a pipeline across two machines each using one GPU. # On worker 0 layer1 = nn.Linear(10, 5).cuda(0) # Need to enhance RemoteModule to include device for this purposes. -layer2 = RemoteModule("worker1", device="cuda:0", nn.Linear, 5, 1) +layer2 = RemoteModule("worker1", device="cuda:0", nn.Linear, 5, 1) pipeline = nn.Sequential(layer1, layer2) model = torch.distributed.pipeline_sync(pipeline, chunks = 4) # A @@ -300,7 +300,7 @@ for epoch in range(epochs): target_rref = rpc.remote("worker1", identity_fn, target) # C output_rref = model(minibatch) # D loss_rref = rpc.remote("worker1", compute_loss, output_rref, target_rref) # E - # Can enhance RRef to ensure this calls "dist_autograd.backward" on the last + # Can enhance RRef to ensure this calls "dist_autograd.backward" on the last # node in the pipeline. loss_rref.backward(context_id) # F dist_optim****.step() # G @@ -375,7 +375,7 @@ class Loss(nn.Module): for epoch in range(epochs): loss_module = DistributedLoss(Loss, criterion, ntokens) - + for minibatch, targets in dataloader: with dist_autograd.context() as context_id: minibatch = minibatch.transpose(0, 1) @@ -389,7 +389,7 @@ This proposal has the training loop running on a single machine and makes copiou **A - Model Pipelining** -As opposed to the torchgpipe-based Approach 2, this approach instantiates actors (specifically [PartitionHandler](https://github.com/facebookresearch/fairscale/blob/6f3931a4d3464056231f1bbb92a67fbd14f30d96/fairscale/experimental/nn/distributed_pipeline/partition_handler.py#L140) instances) that execute the pipeline in an event-driven manner. PartitionHandler instances own a [DistributedPipelineRecord](https://github.com/facebookresearch/fairscale/blob/6f3931a4d3464056231f1bbb92a67fbd14f30d96/fairscale/experimental/nn/distributed_pipeline/partition_handler.py#L27) instance, which has a “feed” method to be called via RPC to add a data item for processing. +As opposed to the torchgpipe-based Approach 2, this approach instantiates actors (specifically [PartitionHandler](https://github.com/facebookresearch/fairscale/blob/6f3931a4d3464056231f1bbb92a67fbd14f30d96/fairscale/experimental/nn/distributed_pipeline/partition_handler.py#L140) instances) that execute the pipeline in an event-driven manner. PartitionHandler instances own a [DistributedPipelineRecord](https://github.com/facebookresearch/fairscale/blob/6f3931a4d3464056231f1bbb92a67fbd14f30d96/fairscale/experimental/nn/distributed_pipeline/partition_handler.py#L27) instance, which has a “feed” method to be called via RPC to add a data item for processing. **B - Distributed Optimizer** @@ -405,7 +405,7 @@ Loss calculation happens similarly to in Approach 2, the single driver calls int **E - Backprop** -Backpropagation through the pipeline is similarly implemented via distributed autograd, as in Approach 2. Note that the same fork/join barrier approach is used to [serialize](https://github.com/facebookresearch/fairscale/blob/6f3931a4d3464056231f1bbb92a67fbd14f30d96/fairscale/experimental/nn/distributed_pipeline/partition_handler.py#L103) execution of micro-batches on the backward pass. +Backpropagation through the pipeline is similarly implemented via distributed autograd, as in Approach 2. Note that the same fork/join barrier approach is used to [serialize](https://github.com/facebookresearch/fairscale/blob/6f3931a4d3464056231f1bbb92a67fbd14f30d96/fairscale/experimental/nn/distributed_pipeline/partition_handler.py#L103) execution of micro-batches on the backward pass. **NOTE**: I don’t believe that forward and backward jobs are serialized; they may run concurrently. Is this true? @@ -482,7 +482,7 @@ The implementations for each of these instructions can be referenced from this [ * (**D2**) (hypothetically) supports arbitrary schedules through the [PipeSchedule](https://github.com/microsoft/DeepSpeed/blob/488105ebd200bbd1f6d7cbe863412e41d9ab4221/deepspeed/runtime/pipe/schedule.py#L6) abstraction. However, there don’t seem to be any schedules implemented beyond the default * (**D3, D4?**) Usable in 3d parallelism, as detailed by the [blog post](https://www.deepspeed.ai/tutorials/pipeline/). -* (**D6**) Since data is pulled from the data loader rather than being pushed by a synchronous call in the training loop, this approach could *hypothetically* support async PP. +* (**D6**) Since data is pulled from the data loader rather than being pushed by a synchronous call in the training loop, this approach could *hypothetically* support async PP. * (**D7**) The approach seems to account for many different types of parallelism. **Con** @@ -545,6 +545,37 @@ An example of using OneFlow for pipeline parallelism can be seen in this [tutori * (split pro/con) (**D2**) Not clear if `1f1b` or other schedules are implemented/implementable? * (**D8**) Not clear what the training loop abstraction looks like. The optimizer is installed via an `nn.Graph` API. Loss calculation is created in the `nn.Graph.build()` method. +## Approach 7: Varuna + +Varuna (https://arxiv.org/abs/2111.04007) proposes a system for large-scale training that focuses on training on commodity hardware. In particular, Varuna focuses on pipeline parallelism (to work on commodity interconnects), tuning the pipeline to optimally trade-off pipeline bubble size vs. allreduce bandwidth, dynamic scheduling of pipeline stages to account for network latency and jitter, and elastic scheduling. + +The workflow of Varuna looks like the following: +* The user manually annotates their model with [CutPoints](https://github.com/microsoft/varuna/blob/ea13ce6a8934bfe829b662a56af4dc29044fcb34/docs/cutpoint.rst). These are points in the program where the system _may_ place a pipeline stage +* The user wraps their model in the [Varuna](https://github.com/microsoft/varuna/blob/ea13ce6a8934bfe829b662a56af4dc29044fcb34/docs/varuna.rst#id9) class and configures pipeline parallelism via this interface. This includes parameters like chunk size, installing the optimizer, and informing the system of the rank of the current running instance. + * Internally, the system is going to do a roudn of profiling to determine the optimal pipeline balance and choose a subset of the cut-points together to represent the code for different pipeline stages. +* User calls the `Varuna.step()` method to run the program in pipeline parallel execution (forward, loss, backward) + * Varuna uses an opportunistic scheduling policy (described in section 3.2 of the paper), which will run ahead and run `forward()` micro-batches if `backward()` micro-batches are not available +* User applies the optimizer's `step()` method to finally update the parameters give accumulated gradients. +* Outside of the Python script, the user uses the `run_varuna` launcher script to orchestrate (elastic) job scheduling + +### Pros and Cons of the Approach + +**Pro** + +* (**D2**) Varuna implements scheduling, particularly their [opportunistic scheduling](https://github.com/microsoft/varuna/blob/79aaf45995a2b06bf5186e825b7baf81b9145837/varuna/pipeline.py#L280) policy. See in the "con", I'm not super convinced by this scheme, but the system supports scheduling (and can probably be extended or hacked to support more traditional schedules) +* (**D5**) The system nominally supports pipeline partitioning without wrapping into a `Sequential`, but see "con" for commentary about the soundness of this approach. + + +**Con** + +* (**D5**) Varuna's approach to partitioning models does not seem sound. + * It assumes that the order of invocation of modules matches the order of the modules as [enumerated](https://github.com/microsoft/varuna/blob/ea13ce6a8934bfe829b662a56af4dc29044fcb34/varuna/partitioned_model.py#L396) in `nn.Module`. This is not necessarily the case, as there can be any arbitrary set of `use-def` relationships between modules and call-sites in a PyTorch module + * It [nulls out](https://github.com/microsoft/varuna/blob/ea13ce6a8934bfe829b662a56af4dc29044fcb34/varuna/partitioned_model.py#L434) modules that are not relevant to the current rank's computation by replacing them with [PassThroughModule](https://github.com/microsoft/varuna/blob/ea13ce6a8934bfe829b662a56af4dc29044fcb34/varuna/partitioned_model.py#L635), which is implemented to simply return `None` from `forward()`. This works when the model is composed purely of module calls and the data dependencies between them, but if at any point there is a non-trivial construct in the code (e.g. any operations on the output of the Module), this will break. +* (**D2**) I'm not convinced by [opportunistic scheduling](https://github.com/microsoft/varuna/blob/79aaf45995a2b06bf5186e825b7baf81b9145837/varuna/pipeline.py#L280). The concept is sound, and reminds me of Tomasulo algorithm-style out-of-order execution for dealing with stochastic latencies (e.g. from memory latencies in a processor), but an extra constraint in pipeline parallel execution for deep learning is: value lifetimes. Activations from forward jobs must be saved for use in the backward job, meaning the memory high-watermark of the pipeline stage is increased for every forward() job that is admitted without a corresponding backward() job to release those values. The literature addresses this with static execution schedules such as the [1F1B strategy](https://arxiv.org/abs/1806.03377) and OneFlow solves this by implementing [Registers and back pressure](https://oneflow2020.medium.com/runtime-of-oneflow-based-on-boxing-and-actor-model-part-3-f2b786dc14a0) in the pipeline. As far as I can tell, Varuna will run ahead indiscriminately until OOM +* (**D8**) The extent to which the training loop must be modified ([BERT](https://github.com/microsoft/varuna/blob/ea13ce6a8934bfe829b662a56af4dc29044fcb34/examples/BERT/bert.patch), [megatron](https://github.com/microsoft/varuna/blob/ea13ce6a8934bfe829b662a56af4dc29044fcb34/examples/Megatron-LM/megatron.patch)) to work with Varuna is pretty extreme +* (**D3/D4/D7**) The system does not seem to compose with tensor parallelism, at least that is not described in the paper or the README. +* (**D6**) Does not support async + ## Final Analysis ### General Design Axes