Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Optimize the forward of log_softmax for the case when axis is not the last dimention. #32396

Merged
merged 20 commits into from
Jul 6, 2021

Conversation

AshburnLee
Copy link
Contributor

@AshburnLee AshburnLee commented Apr 20, 2021

PR types

Performance optimization

PR changes

OPs

Describe

PR功能

当axis为-1且对应元素个数不超过1024,log_softmax的前向计算执行paddle#31630,后向计算执行paddle#32180

此PR功能是当axis不为-1、或axis为-1但是axis对应的元素个数大于1024时,log_softmax执行的前向逻辑。

PR性能表现

以下两个配置来自于op benchmark,测试时repeat为1000, 数值表示执行op时间。

配置 axis 一致性diff PR之前gpu time 此PR gpu time
[128L, 128L, 16L, 16L] fp32 0 0.0 1.072 0.209
[512L, 896L, 4L, 12L] fp32 1 1.9e-6 14.842 0.787

PR的方法和逻辑

方法分为两步,第一步计算核函数的执行配置,即得到block尺寸、shared memory大小和grid尺寸。第二步将网格映射到数据上执行核函数逻辑。

1. 如何得到grid和block

计算执行配置是启动核函数的前提。一般上讲,先考虑block的尺寸,grid在受限于GPU总的active block数时尽可能大。这里的计算配置不仅根据输入shape动态变化,并且可以最大化硬件资源利用率。实现在ComputeLaunchConfigure()中,分为以下四个步骤:

  • 步骤一:函数GetBlockSize()计算block
    block的得到仅与dim_size和inner_size有关。解释下,若shape=【2,3,4,1】axis=1,那么outer纬度对应第0个纬度,即outer_size=2;axis对应第一个纬度,即dim_size=3;inner纬度对应剩下的第二、三纬度,即inner_size=4*1。

    • 首先,让block的y方向的threads覆盖inner,并行处理inner纬度,并且threads数不超过1024(参考图一)。
    • 然后,block的x方向的threads处理axis对应的纬度(即沿着block的x方向执行log_softmax的数学计算),这里限制block.x不超过dim_size,并且block.x * block.y 不超过1024。
  • 步骤二:计算shared mem大小

    • 上一步中获取了block的配置。当blockDim.x为1时, 设shared mem为0(为什么设为0:之后执行逻辑中的reduce操作是在block内沿着x方向执行的,既然x方向thread个数只有1,则不需要block reduce 操作,就用不到shared mem);否则shared mem设为block线程总数*数据类型大小
  • 步骤三:计算max_active_blocks

    • 调用APIcudaOccupancyMaxActiveBlocksPerMultiprocessor()计算得到每个SM的最大block数。
    • 该API计算资源占用率,根据这篇blog描述:「该API根据block大小(步骤一得到)和shared mem大小(步骤二得到),预测一个kernel的occupancy」。根据API文档,此API返回每个SM最大block数:blocks_per_SM
    • 硬件SM数num_sm易获得,blocks_per_SM*num_sm 这样就有了GPU总的active block数
    • 实际上,就op benchmark的config#8(改配置可以体现active block数对grid配置的影响),使用该函数前后对性能并没有影响。使用函数的效果和经验(num_sm*2)一致。故在代码中改用经验值active_blocks = num_sm*2。
      设使用该函数为对照组,那么实验组中active block的个数设置为SM的1倍、2倍、4倍、8倍。由于函数得到active block数,而active block数只影响grid配置,所以统计grid配置和gpu时间如下:
方法 grid配置 时间
调用函数 (160, 1) 14.9255
num_sm (80, 1) 14.9438
num_sm*2 (160, 1) 14.9278
num_sm*4 (320, 1) 14.9239
num_sm*8 (512, 1) 14.9447
  • 步骤四:函数GetGridSize()计算grid
    • 以上述3步骤的结果,即GPU总active block数,作为约束,就可以计算使得GPU占用率最高的grid配置。
    • 如何配置grid。步骤一中解释了block的任务分配,整体设计是,block 的x方向处理axis对应的纬度元素:即计算Max、累加值及最终值,y方向线程覆盖inner。那么grid的目的就是尽量让其中的block数覆盖最多的这样一个dim_size*inner_size个元素。所以,grid 的y方向铺满尽可能多的block,x方向也是,具体计算根据公式grid在x轴上的block数量=(问题在x轴上的尺寸 + 每个block在x轴上的尺寸-1)/每个block在x轴上的尺寸。同时保证总的block数不超过GPU总的active block数
      示例如下,shape=[2,3,4] axis=1,得到grid(2, 1)和block(1, 4)。
      截屏2021-04-30 15 20 30
      图一:block执行示意图
    • 上图中每个线程处理图中一列(即axis=1对应的纬度),因x方向只有一个线程,所以3个元素被循环处理。假如x方向上有3个线程,即blockDim.x=3,则不需要循环,只需在block的x方向reduce即可。当blockDim.x=2,0号线程处理两个元素,1号线程处理1个,之后block的x方向reduce。
      所以,当blockDim.x=1时,不需要block reduce,此情况单独写出来 已合并两种情况。
    • grid的gridDim.x 与gridDim.y 的位置会影响性能。纬度互换后,配置[128L, 128L, 16L, 16L]的耗时从0.2上升到3。

2. 映射及执行逻辑

  • 处理单元如何映射到数据及block如何执行
    见"如何配置grid"处描述。

  • 核函数数学逻辑
    数学逻辑与这种case的前向计算相同paddle#31630。沿着axis方向,分三个步骤计算,不再赘述。

其他

在functors.h中添加MaxFunctor。

@paddle-bot-old
Copy link

Thanks for your contribution!
Please wait for the result of CI firstly. See Paddle CI Manual for details.

@xingfeng01
Copy link
Contributor

Review 意见:

  • 实现逻辑ok
  • 建议修改个别代码或注释,可以另做PR
  • CI op benchmark 个别case性能变差,需注意

Copy link
Contributor

@Xreki Xreki left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

请在PR描述中说清楚这个PR的工作、优化方法和效果,参考#30601

paddle/fluid/operators/log_softmax_op.cu Outdated Show resolved Hide resolved
paddle/fluid/operators/log_softmax_op.cu Outdated Show resolved Hide resolved
}

template <typename T>
__forceinline__ __device__ T BlockReduceAdd(T *shared, T val) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

设计Max、Add两个functor的目的,应该是为了统一BlockReduceMaxBlockReduceAdd的实现。

paddle/fluid/operators/log_softmax_op.cu Outdated Show resolved Hide resolved
paddle/fluid/operators/log_softmax_op.cu Outdated Show resolved Hide resolved
paddle/fluid/operators/log_softmax_op.cu Outdated Show resolved Hide resolved
paddle/fluid/operators/log_softmax_op.cu Outdated Show resolved Hide resolved
// 3. input-max-log_sum and store
for (uint32_t d = threadIdx.x; d < dim_size; d += blockDim.x) {
output[data_offset + d * dim_stride] = static_cast<T>(
static_cast<AccT>(input[data_offset + d * dim_stride]) -
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

L229、L238、L245读了input 3次,感觉效率上不是最优。

paddle/fluid/operators/log_softmax_op.cu Outdated Show resolved Hide resolved
paddle/fluid/operators/log_softmax_op.cu Outdated Show resolved Hide resolved
@paddle-bot-old
Copy link

Sorry to inform you that 91dfddf's CIs have passed for more than 7 days. To prevent PR conflicts, you need to re-run all CIs manually.

@paddle-bot-old
Copy link

paddle-bot-old bot commented May 7, 2021

Sorry to inform you that 42fd6f9's CIs have passed for more than 7 days. To prevent PR conflicts, you need to re-run all CIs manually.

paddle/fluid/operators/log_softmax_op.cu Outdated Show resolved Hide resolved
paddle/fluid/operators/log_softmax_op.cu Outdated Show resolved Hide resolved
paddle/fluid/operators/log_softmax_op.cu Outdated Show resolved Hide resolved
paddle/fluid/operators/log_softmax_op.cu Outdated Show resolved Hide resolved
paddle/fluid/operators/log_softmax_op.cu Show resolved Hide resolved
@Xreki Xreki requested a review from qili93 June 1, 2021 01:50
@qili93
Copy link
Contributor

qili93 commented Jun 2, 2021

手动跑了下这个PR的ROCM单测,可以通过

image

qili93
qili93 previously approved these changes Jun 2, 2021
Copy link
Contributor

@qili93 qili93 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

Copy link
Contributor

@Xreki Xreki left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM. PR先合入了,一些review建议可以下个PR中一起修复。

paddle/fluid/operators/log_softmax_op.cu Show resolved Hide resolved
paddle/fluid/operators/log_softmax_op.cu Show resolved Hide resolved
paddle/fluid/operators/log_softmax_op.cu Show resolved Hide resolved
paddle/fluid/operators/log_softmax_op.cu Show resolved Hide resolved
paddle/fluid/operators/log_softmax_op.cu Show resolved Hide resolved
paddle/fluid/operators/log_softmax_op.cu Show resolved Hide resolved
@Xreki Xreki changed the title Log_softmax forwardward case: axis != -1 Optimize the forward of log_softmax for the case when axis is not the last dimention. Jul 6, 2021
@Xreki Xreki merged commit 69ffb38 into PaddlePaddle:develop Jul 6, 2021
@AshburnLee AshburnLee deleted the log-sftmx-case3 branch August 5, 2021 08:32
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants