Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

【Hackathon No.34】优化 poisson op #45160

Merged
merged 9 commits into from
Aug 24, 2022

Conversation

Rayman96
Copy link
Contributor

@Rayman96 Rayman96 commented Aug 15, 2022

PR types

Performance optimization

PR changes

OPs

Describe

实验开发环境:
硬件:Tesla-P4
软件环境:CUDA 11.2,CuDNN 8

对于Poisson算子优化方案有以下两种
方案一:通过paddle已实现的gpu_launch_config.h中GetGpuLaunchConfig1D方法获得较优的参数配置。
性能效果:
该方案经过测试在float32数据上有5%左右的性能提升,float64数据上有10%左右的性能下降。故不作为首选方案。
Tesla-P4:

Case No. input_shape data_type Paddle_modify Perf(s) Perf_over_paddle_origin(%) Perf_over_pytorch(%)
1 [16, 16, 16, 16] float32 0.3190 +8.00 -2.00
2 [16, 35, 1500] float32 2.7361 +8.17 -2.34
3 [16, 16, 16, 16] float64 0.3261 -13.78 -1.86

Tesla-P40:

Case No. input_shape data_type Paddle_modify Perf(s) Perf_over_paddle_origin(%) Perf_over_pytorch(%)
1 [16, 16, 16, 16] float32 0.1684 -5.00 +1.00
2 [16, 35, 1500] float32 1.3766 +8.74 -1.28
3 [16, 16, 16, 16] float64 0.1743 +5.27 -19.1

方案二:通过手动测试在该场景下更优的配置参数,BlockSize性能较优的取值通常为[128, 256,512]。对这三者进行实验并测试性能,结果显示是用一维Grid,且BlockSize=256时,均有大幅性能提升。
性能效果:

Tesla-P4:

Case No. input_shape data_type Paddle_modify Perf(s) Perf_over_paddle_origin(%) Perf_over_pytorch(%)
1 [16, 16, 16, 16] float32 0.2205 +36.62 +29.27
2 [16, 35, 1500] float32 2.044 +31.40 +23.54
3 [16, 16, 16, 16] float64 0.2159 +24.68 +32.57

Tesla-P40:

Case No. input_shape data_type Paddle_modify Perf(s) Perf_over_paddle_origin(%) Perf_over_pytorch(%)
1 [16, 16, 16, 16] float32 0.1323 +17.29 +20.94
2 [16, 35, 1500] float32 1.0011 +33.63 +26.34
3 [16, 16, 16, 16] float64 0.1324 +28.00 +9.49

根据比较,最终选择方案二。

@CLAassistant
Copy link

CLAassistant commented Aug 15, 2022

CLA assistant check
Thank you for your submission! We really appreciate it. Like many open source projects, we ask that you sign our Contributor License Agreement before we can accept your contribution.
You have signed the CLA already but the status is still pending? Let us recheck it.

1 similar comment
@CLAassistant
Copy link

CLA assistant check
Thank you for your submission! We really appreciate it. Like many open source projects, we ask that you sign our Contributor License Agreement before we can accept your contribution.
You have signed the CLA already but the status is still pending? Let us recheck it.

@Rayman96
Copy link
Contributor Author

CI已全部通过,辛苦审核🙏

paddle/phi/kernels/gpu/poisson_kernel.cu Outdated Show resolved Hide resolved
paddle/phi/kernels/gpu/poisson_kernel.cu Outdated Show resolved Hide resolved
paddle/phi/kernels/gpu/poisson_kernel.cu Outdated Show resolved Hide resolved
paddle/phi/kernels/gpu/poisson_kernel.cu Outdated Show resolved Hide resolved
@ZzSean
Copy link
Contributor

ZzSean commented Aug 17, 2022

建议可以试试ElementwiseKernel那套模版,如果性能差别不大的话,建议改成那种形式,可以让代码风格更统一一些

@Rayman96
Copy link
Contributor Author

ElementwiseKernel的方法我尝试后进行回复

@Rayman96
Copy link
Contributor Author

Rayman96 commented Aug 19, 2022

建议可以试试ElementwiseKernel那套模版,如果性能差别不大的话,建议改成那种形式,可以让代码风格更统一一些

ElementwiseKernel测试结果无法达到性能。观测发现ElementwiseKernel在重复测试过程中速度不稳定,且最快和最慢的差距较大。平均下来Float32的测试用例能达到预期但不及目前实现,但Float64相较原始情况提升达不到预期。
实验环境1:Tesla P4
较好情况

Case No. input_shape data_type Paddle_ElementWise Perf(s) Perf_over_paddle_origin(%)
1 [16, 16, 16, 16] float32 0.25622 +26(较目前实现差10%)
2 [16, 35, 1500] float32 2.24506 +25 (较目前实现差6%)
3 [16, 16, 16, 16] float64 0.27475 +4 (较目前实现差20%)

较差情况

Case No. input_shape data_type Paddle_ElementWise Perf(s) Perf_over_paddle_origin(%)
1 [16, 16, 16, 16] float32 0.26173 +25(和目前实现差11%)
2 [16, 35, 1500] float32 2.2529 +24 (和目前实现差7%)
3 [16, 16, 16, 16] float64 0.28583 0 (和目前实现差24%)

故还是需要保持现用实现方式。

__global__ void GetPoisson(
const T* in, T* out, const int N, unsigned int seed, unsigned int offset) {
int idx = threadIdx.x + blockIdx.x * blockDim.x;
if (idx < N) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里需要写成循环,防止numel过大时,无法完成所有的计算,用CUDA_KERNEL_LOOP_TYPE包一下就可以

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

好的 我修改下


int block_size = std::min(kMaxBlockDim, ctx.GetMaxThreadsPerBlock());
dim3 dim_block(block_size);
dim3 dim_grid((size + block_size - 1) / block_size);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

grid的设置也需要进行最大值的限制,可以参考

paddle::platform::LimitGridDim(ctx, &grid_dim);

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

grid的设置也需要进行最大值的限制,可以参考

paddle::platform::LimitGridDim(ctx, &grid_dim);

好的 我修改下

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

已修改

@Rayman96
Copy link
Contributor Author

CI已通过劳烦review

@@ -19,6 +19,7 @@ limitations under the License. */
#include <hiprand_kernel.h>
#endif

#include "paddle/fluid/platform/device/gpu/gpu_launch_config.h"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In principle, phi does not allow include fluid header files, try to replace by #include "paddle/phi/backends/gpu/gpu_launch_config.h"?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

paddle::platform::LimitGridDim 这个函数还在fluid下的头文件中,需要把这个函数先放过来一份,并且把namespace更新一下

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

好的

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

paddle::platform::LimitGridDim 这个函数还在fluid下的头文件中,需要把这个函数先放过来一份,并且把namespace更新一下

已修改

@Rayman96
Copy link
Contributor Author

@chenwhql 修改后CI已全部通过

Copy link
Contributor

@ZzSean ZzSean left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@ZzSean ZzSean merged commit 3c14b09 into PaddlePaddle:develop Aug 24, 2022
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

7 participants