Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

为 Paddle 支持 Zero-Bubble 并行编排 #62666

Closed
9 tasks done
AndSonder opened this issue Mar 12, 2024 · 0 comments
Closed
9 tasks done

为 Paddle 支持 Zero-Bubble 并行编排 #62666

AndSonder opened this issue Mar 12, 2024 · 0 comments
Labels
PFCC Paddle Framework Contributor Club,https://github.com/PaddlePaddle/community/tree/master/pfcc status/close 已关闭 type/feature-request 新需求申请

Comments

@AndSonder
Copy link
Contributor

AndSonder commented Mar 12, 2024

任务名称 为 Paddle 支持 Zero-Bubble 并行编排
提交作者 Sonder @AndSonder
提交时间 2024-03-12
版本号 v1.3
依赖飞桨版本 develop

*本 Issue 用于记录 为 Paddle 支持 Zero-Bubble 并行编排 任务实现进度

Milestones

一、任务背景

流水线并行 (Pipeline Parallelism) 是一种并行策略,它将模型的不同部分分配到不同的设备上,以便在不同设备上并行执行。这种并行策略的优势在于,它可以在不增加单个设备上的模型大小的情况下,提高模型的训练速度。之前的流水线并行策略, 如 1F1B 或者 FThenB 都会有 Bubble 的问题,即某些设备处于空闲状态。Bubble 一般认为是不可避免的,但是 Penghui Qi 等人在 Zero Bubble Pipeline Parallelism 这篇工作中提出了 Zero Bubble 的概念,即不产生 Bubble 的流水线并行策略。

Zero Bubble 是一种新的流水线并行策略,它是核心思想是将反向计算分为两部分,一部分计算输入的梯度,另一部分计算参数的梯度。Zero Bubble 还引入了一种新的技术,即在优化器步骤中绕过同步,以实现真正的 Zero Bubble。 为 Paddle 支持 Zero-Bubble 并行编排可以有效提升大模型场景下的训练速度。

二、Zero Bubble 并行编排

1. 将反向计算分为两部分

管道编排可以通过以更精细的粒度表示和调度计算图来进一步优化。经典的深度学习框架是以层为粒度设计的,而现代深度学习编译器使用不同的中间表示来在各个层面进行优化。尽管更细的粒度通常意味着更大的搜索空间,但由于缺乏导航该空间的优化工具,因此选择合适的粒度至关重要。

Zero-Bubble 中第一个优化点就是将反向计算分为两部分:

7e17024676f5a8edb5159339d0d9d70351e4b30443b447ea2f8f3df4e24a0db3

Zero Bubble 基于 1F1B 的思想进行改进, 将后向传递拆分为 B 和 W 传递,F 和 B 来自同一微批次的仍然必须在流水线阶段之间保持顺序依赖。然而,W可以在同一阶段的相应 B 之后的任何位置灵活安排。这允许策略性地安排 W 来填充流水线中的气泡。有许多可能的调度方案可以优于 1F1B。下图展示了手工调度方案,展示了更细粒度在减少流水线气泡方面的优势。

1e95c8c0bce08b4402c51de5d3e3032d97c57fa5a54b1185acb087c4da0c4672

图中顶部的是 ZB-H1,底部的是 ZB-H2。ZB-H1 确保所有工作节点的最大峰值内存使用量不超过 1F1B 的内存使用量。ZB-H1 通常遵循 1F1B 的调度,但根据预热微批次的数量调整 W 的起始点。这确保所有工作节点保持相同数量的在途微批次。Bubble 大小通常可以减小到 1F1B 大小的三分之一。这种减少是因为与1F1B 相比,B 在所有工作节点上更早启动,并且尾部 Bubble 由后启动的 W 传递填充。由于 W 通常使用的内存少于 B,第一个工作节点具有与 1F1B 一致的最大峰值内存使用量。

如果允许比 1F1B 更大的内存占用并且有足够数量的微批次时,就可以实现零气泡调度,我们将其标记为ZB-H2。在热身阶段引入更多的 F 传递来填充初始 B 之前的气泡。此外还重新排列尾部的 W 传递,将布局从梯形变为平行四边形,消除了管道中的所有气泡。

2. 移除优化器步骤之间的同步

在大多数流水线并行实践中,为了数值稳健性,通常会在优化器步骤中执行管道阶段上的同步。例如,需要计算全局梯度范数以进行梯度范数裁剪。在混合精度设置中执行 NAN 和 INF 值的全局检查。这两者都需要跨所有阶段进行全局归约通信。然而,在优化器步骤中的同步破坏了平行四边形),使零 Bubble 变得不可能。

由于大多数情况下全局状态没有影响,例如,全局检查 NAN 和 INF 很少触发。Zero-Bubble 采取了后验证的方式。在每个阶段在优化器步骤之前,从前一阶段接收到一个部分归约的全局状态,与当前阶段的局部状态结合,然后传递到下一个阶段。每个阶段的优化器步骤由部分归约状态控制,例如,当发现 NAN 或部分归约梯度范数超过裁剪阈值时跳过更新。在下一次迭代的热身阶段,完全归约的全局状态然后从最后一个阶段传播回第一个阶段。收到全局状态后,每个阶段执行验证以决定前一个优化器步骤是否合法。如果需要对梯度进行修正,则会发出回滚,然后根据完全归约的全局状态重新执行优化器步骤。整体流程如下图所示:

a6f81040600bc0946dc0343e2dccb70588ff712753f25979110f38aca1643df7

Zero Buffle 中针对回滚操作也进行了优化,其利用优化器可逆的特性,快速恢复到上一个状态。这一点在后续实现中将会详细介绍。

3. 自动搜索最优调度策略

尽管手工制定的调度方案提供了简单性和更好的可理解性, 但在实际应用中面临几个问题。首先, 在假设 F、B、W 所用的时间相同即 $T_F=T_B=T_W$ 的情况下进行调度会引入不必要的气泡, 特别是对于这些值差异显著的模型。此外, 在手工制定的调度中通常忽略了在阶段之间传输激活/梯度所需的通信时间(表示为 $T_{\mathrm{comm}}$ ), 导致管道流中出现明显的延迟。最后, 在可用内存不足以容纳足够的微批次以实现无气泡调度的情况下, 平衡减小气泡大小和遵守内存限制变得特别具有挑战性。

为了解决这些问题并适应实际场景,Zero-Bubble 提出一种启发式策略,可在微批次数足够大时生成接近最优解。启发式算法有以下步骤:

  1. 热身阶段: 在内存允许的情况下,尽可能安排更多的 $F$,以减少第一个 $B$ 前的等待时间。如果内存还有余量,可以安排额外的 $F$,但可能会延迟后续的 $B$

  2. 稳定阶段: 在热身阶段后,我们交替安排 $F$$B$。当有空闲时间超过 $T_W$ 时,插入 $W$ 填充等待时间。即使等待时间不足 $T_W$,但当前等待时间会增加所有阶段中最大的等待时间时,我们也会插入 $W$。当内存接近饱和时,也会插入 $W$ 释放一些内存。

  3. 阶段间调度: 确保每个阶段在用尽 $F$ 之前至少安排一个比下一个阶段更多的 $F$。当差异超过一定阈值时,考虑跳过某些阶段中的 $F$

  4. 资源用尽: 在每个阶段,当 $F$$B$ 任务完成时,按顺序安排所有剩余的 $W$ 任务。

三、开关接口设计

Zero Bubble 作为一个新的并行编排策略,需要提供一个开关接口来控制是否使用 Zero Bubble。该接口与 1F1B 等编排方式的接口保持一致。具体来说,我们新注册一个 Pass pipeline_scheduler_zero_bubble。在 PaddleNLP Llama 模型中我们可以通过 --pipeline_schedule_mode "ZBH1" 来开启 Zero Bubble。这与 --pipeline_schedule_mode "1F1B" 的使用方式保持一致。

该 Pass 中有参数 enable_optimizer_post_validation: 是否启用优化器后验证

四、实现方案

1. 为 Paddle 适配 ZB-H1 编排策略

1.1 将反向计算分为两部分

Paddle 中在流水并行的编排的时候会将 op 分为 ForwardBackwardOptimize 三种。 在 Zero Bubble 中,我们需要将 Backward 进一步分为 BW 两部分。具体来说我们可以根据算子输入参数名字来进行判断。一个变量去掉@Grad,就可以得到它的前向变量,如果前向变量是参数,那这个梯度就是参数的梯度。一个变量是否是参数,可以通过 is_parameter 接口来判断。

$^*$将 matmul_v2_grad 进行拆分

由于 matmul_v2_grad 这个算子会同时输出 $dX$$dW$ ,所以我们需要对这个算子拆分成两个 matmul 算子。算子的拆分的实现可以复用 allreduce_matmul_grad_overlap Pass 中拆分 matmul_v2_grad 的实现。在进行算子拆分的时候需要考虑如下几点:

  • 拆分的操作应该放在 allreduce_matmul_grad_overlap 之后
  • 拆分后是否会有类似 PR #61865 中出现的问题

由于 allreduce_matmul_grad_overlap Pass 中已经实现了 matmul_v2_grad 的拆分,我们可以将其提取出来作为一个公用的函数。在 Zero Bubble 中我们可以直接调用这个函数来进行 matmul_v2_grad 的拆分。拆分出来的函数 split_matmul_v2_grad 放在 pass_utils.py 当中。由于之前 matmul_v2_grad 的拆分是在 allreduce_matmul_grad_overlap Pass 中的实现依赖 allreduce op 的位置,所以我们需要对 allreduce_matmul_grad_overlap Pass 进行一定的改造。具体来说针对每一对的 matmul_v2_grad allreduce 首先调用 split_matmul_v2_grad 函数将 matmul_v2_grad 拆分成两个 matmul 算子,然后再移动 allreduce 算子到第二个 matmul 算子的前面。相关 PR:

1.2 并行化改造

Zero-Bubble 是基于 1F1B 的思想进行改进,整体 Pass 的实现我们可以参考 1F1B Pass 的实现。在 1F1B Pass 的基础上,我们需要对 Backward 进一步分为 BW 两部分。具体来说我们可以根据算子输入参数名字来进行判断。在 _split_ops 的时候将 Backward 分类俩个类型 BackwardBBackwardW。在 _split_ops 之后,需要将 BackwardBBackwardW 作为不同类型的 Job。相关代码如下:

def _create_job_list(self):
    ...
    for micro_step in range(warmup_steps):
        ... # 创建 Forward Job 

    for micro_step in range(steady_steps):
        ...
        bwd_job_type = ... # 'B' or 'W' or ''
        bwd_job = core.Job(BACKWARD + bwd_job_type) # Backward, BackwardB, BackwardW
        bwd_job.set_micro_batch_id(bwd_micro_batch_id)
        job_list.append(bwd_job)

在 rank 0 和 BWF 中,我们不对 B 和 W 进行分割的原因如下:

  1. 为了利用批量点对点操作(send_backward_recv_forward)
  2. 为了同时进行梯度全局归约以支持张量并行
  3. 为了避免重复执行序列并行的梯度全聚合操作

并行化改造后一共有 5 种 Job,分别是 ForwardBackwardBackwardBBackwardWOptimize

在部分在实现时有以下几个注意点:

  1. 这种做法改变了梯度累积的顺序,所以可能会导致与1F1B模式相比的轻微精度误差(但是在数学上是正确的)
  2. 需要考虑到改操作是否会影响 gradient_merge 等 Pass 的正确性

Llama2 下 4卡实际调度结果如下:

d378681e1c2a9cdb1d47f971c9836ea7

相关 PR:

1.3 Llama2 下性能测试分析

在 PaddleNLP Llama2 模型上进行测试结果如下(pp4, batch 1, hidden_layer=4):

1.精度测试

精度可以对齐,有时候小数点后2位会有误查(符合论文的描述)

Llama2 下 10000 步 Loss 对比:

  • ZBH1: 2.67412233
  • 1F1B: 2.65833998

以下为前1000步,loss 曲线图

image

2.速度测试

测试机器: 4卡 3090,去

调度方案 interval_runtime interval_samples_per_second interval_steps_per_second
1F1B 3.17 5.1 0.3
ZBH1 2.75 5.8 0.4

3.显存占用

调度方案 卡号 max_memory_allocated max_memory_reserved
1F1B 0 12605.69 MB 13405.76 MB
1F1B 1 8809.68 MB 9611.76 MB
1F1B 2 7013.66 MB 7785.76 MB
1F1B 3 7806.72 MB 8561.76 MB
ZBH1 0 12921.69 MB (↑ 316 ) 13831.76 MB (↑ 426 )
ZBH1 1 9639.7 MB (↑ 830 ) 10463.76 MB (↑ 852 )
ZBH1 2 8357.72 MB (↑ 1344 ) 9149.76 MB (↑ 1364 )
ZBH1 3 10597.38 MB (↑ 1790 ) 11219.76 MB (↑ 1658 )
  • 1F1B 总 max_memory_allocated: 36035.75 MB
  • ZBH1 总 max_memory_allocated: 41516.49 MB
  • 1F1B 总 max_memory_reserved: 35064.04 MB
  • ZBH1 总 max_memory_reserved: 44650.04 MB

2. 为 Paddle 适配 ZB-VPP 编排策略

ZB VPP 是一种根据计算图自动调度任务的并行训练方案,反向计算分为两部分 b 和 w。w 可以用于填充计算图中的空洞,以此来降低 Bubble 率。ZB VPP 会把 Forward 和 Backward 拆分为多个 chunk,然后根据显存占用情况来进行任务调度。

方案设计文档:

ZB VPP 是一种根据计算图自动调度任务的并行训练方案,反向计算分为两部分 b 和 w。w 可以用于填充计算图中的空洞,以此来降低 Bubble 率。ZB VPP 的手动模拟结果如下图所示:

image

每个设备被分配到正好2个块,其中白色文本颜色代表第一个块,黑色文本颜色代表第二个块。模型块之间的依赖顺序在前向和后向传递中都遵循“V”形状模式。

2.1 ZB-VPP 模块设计

image

ZB-VPP 编排主要由两个模块组成,分别是显存估计模块和自动编排模块。

显存估计模块用于估计子图运行中的显存信息。该模块会统计每个变量的显存使用情况,并在估计时考虑不同子图间变量的依赖关系。通过模拟实际运行时变量的申请和释放情况,我们可以获得程序运行后的显存变化及运行时的峰值显存。这些显存信息在自动编排阶段用于控制最大显存使用。

自动编排模块会对任务进行自动编排,获取最优的编排策略,实现 V 型编排并对 V 型之间的空白进行任务填充。为了实现更低的 Bubble 率,自动编排模块中有多种 W 填充策略。在编排时,这些策略会进行排列组合,以选取最优策略。此外,对于较小的 Bubble,算法会尝试将 W 任务挤进这些小 Bubble 内,最终搜索出整体耗时最短的编排方案。

2.2 显存估计模块

2.2.1 显存估计模块实现

显存估计模块通过 PipelineMemoryEstimator 类实现,主要用于估计子图在运行时的显存使用情况,并将这些信息用于编排过程中的显存控制。

显存估计前,需要设置每种类型子图需要跳过垃圾回收的变量。通过提取子图所需的变量,并按照子图类型的顺序处理,确定哪些变量需要跳过垃圾回收。

显存估计的主要流程是首先获取子图中的所有操作,并按照执行顺序进行排序。然后,通过分析这些操作,获取每个变量的显存使用信息,包括大小和是否持久化等属性。根据这些信息,更新之前设置的跳过垃圾回收的变量的显存大小。接下来,记录前一子图类型中已访问的变量,以避免重复计算显存。

在最大显存使用估计过程中,模块会遍历子图中的每个操作,并根据操作的输入和输出变量,更新显存使用量。对于未被访问且非持久化的变量,会增加其显存使用量,并更新最大显存使用量。对于不再使用的变量,则释放其显存,并更新当前显存使用量。遍历完成后,计算出子图执行过程中最大显存使用量。最后,显存估计模块返回子图的总显存使用量和最大显存使用量,这些信息将在自动编排阶段用于控制最大显存。

预估结果与模型实际运行结果如下(单位 MB)运行后显存变化表示以 program 开始运行时候的显存作为基准,这个 program 运行完之后显存的变化。运行中 max 值的意思是以 program 开始运行时候的显存为基准,program 过程中的最大显存占用。

2.2.2 显存估计模块测试

pp4, gradient accumulation 8, 开启 recompute,batch1, num_hidden_layers 4

image

pp4, gradient accumulation 8, 不开启 recompute,batch1, num_hidden_layers 4

image

pp2, mp2, gradient accumulation 8, 不开启 recompute,batch2, num_hidden_layers 4

image

2.2.3 相比源码的改进

图中为实际运行时的显存变化,红色虚线框内的部分显示了一个显存使用的高峰。从这张图中可以看出,子图运行时的显存使用情况有明显的波动和峰值。我们在优化 ZB-VPP 编排时,特别关注了这些波动和峰值显存的控制,以实现更加精细的显存管理。

image

相关 PR:

2.3 ZB-VPP 自动编排模块

2.3.1 ZB-VPP 自动编排模块实现

自动编排模块通过一系列智能策略,确保任务在运行过程中高效利用显存和计算资源。其核心目的是通过优化任务调度,减少显存波动和 Bubble 率,提升整体系统性能。

任务自动编排

自动编排模块的核心是任务自动编排。系统会根据子图的计算需求,自动生成任务列表并排列这些任务。任务列表按照前向、后向及优化任务的顺序生成,确保每个阶段的任务能够有序执行,并充分考虑任务间的依赖关系,避免执行问题。

在任务自动编排过程中,系统首先插入所有微批次的前向任务,确保前向计算任务能够按计划执行。前向任务的插入顺序根据不同阶段的需求进行调整,以保证资源利用最大化。接着,系统会插入后向任务,并动态调整前向和后向任务的比例,确保高效利用显存,同时减少 Bubble 率。

多种填充策略组合

为了进一步优化任务调度,自动编排模块采用了多种填充策略组合。这些策略通过排列组合,选出最优的任务执行顺序。具体包括以下几种填充策略:

  1. Forward 后填充 W 任务
  2. Backward 后填充 W 任务
  3. 填充损失计算阶段

通过组合多种填充策略,系统能够生成最优的调度方案。这些策略通过排列组合,系统会尝试不同的任务顺序,最终选出能够最小化执行时间和显存使用的最优策略。

自搜索小 Bubble 填充

自动编排模块还具备自搜索小 Bubble 填充的功能。系统在执行任务时,会自动检测显存使用情况,并搜索小的 Bubble 区域。对于这些小 Bubble 区域,系统会尝试将任务插入其中,以最大化显存利用率。

2.3.2 ZB-VPP 自动编排模块测试

自动编排模块在实际应用中的性能通过一系列测试进行验证,测试结果如下:

编排方式 vpp_degree 性能
zbv 3 较vpp3提升1.74%
zbv 4 较vpp4提升4.6%
zbv 5 较vpp5提升4.24%

性能测试结果展示了不同 vpp_degree 下的性能提升情况。vpp_degree 是指虚拟流水线的并行度。使用 zbv 编排策略可以显著提升系统的整体性能,特别是在高并行度配置下,提升效果更加明显。

测试配置为 Llama2模型,4卡A100 80GB,pp4,dp1,mp1。测试的性能指标是每秒处理的样本数(interval_samples_per_second)。

*注:测试中,由于显存问题以及 hidden_layers 必须为 pp_degree * vpp_degree 的倍数,上述数据的 hidden_layers 分别为 24、16 和 20。

2.3.3 相比源码的改进

改进点1: 满足显存限制的情况下在 b 之前插入更多 f

在满足显存限制的情况下,我们的策略允许在第一次后向计算 B 之前插入更多的前向计算 F。这种策略在某些情况下有助于降低 Bubble 率。例如,下图展示了一个案例,上半部分是使用原始 ZBV 源码实现的调度效果,下半部分是我们的优化实现。

image

改进点2: 调整原论文中的 w 插入逻辑,优化较小 acc_step 下的编排

在实际业务中,有时我们需要限制全局批处理大小(global batch size),这会导致我们只能使用较小的累积步数(acc_step)。在源码中的填充策略下,当累积步数较小时,有时会错误地延迟后向计算(b)的插入。经过调整后,我们实现了更快的训练速度和更低的 Bubble 率。训练速度从与 vpp5 大致持平提升到了比 vpp5 快 3.65%。

image

改进点3: 解决由于 loss 计算时间引起的 “计算时间不均衡问题”

在实际运行 Llama2 时,我们发现 ZB-VPP 编排存在较为严重的“计算时间不均衡”问题。为了解决这一问题,我们首先将损失计算时间纳入编排方案中。然后,我们引入了 fill loss stage 策略,以解决由损失计算时间引起的不均衡。这一策略通过在计算损失的阶段,用多余的 W 任务填充中间的小 Bubble,从而优化计算时间。

算法会自动对比 fill loss stage 和不使用该策略的方案,最终选择耗时最小的方案。通过观察编排图的变化,我们可以发现,后置的 W 任务减少了,stage 0 中间的 Bubble 也变小了,从而实现了更均衡的计算时间和更高效的资源利用。

image

优化前后 Bubble rate 对比如下

image

相关 PR:

3. 移除优化器步骤之间的同步

在大多数流水线并行实践中,为了数值稳健性,通常会在优化器步骤中执行管道阶段上的同步。

a6f81040600bc0946dc0343e2dccb70588ff712753f25979110f38aca1643df7

可参考实现:

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
PFCC Paddle Framework Contributor Club,https://github.com/PaddlePaddle/community/tree/master/pfcc status/close 已关闭 type/feature-request 新需求申请
Projects
None yet
Development

No branches or pull requests

2 participants