Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Optimization of pool2d grad #35389

Merged
merged 19 commits into from
Sep 19, 2021
Merged
Changes from 10 commits
Commits
Show all changes
19 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
299 changes: 209 additions & 90 deletions paddle/fluid/operators/math/pooling.cu
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,104 @@ limitations under the License. */
#include <vector>

#include "paddle/fluid/operators/math/pooling.h"
#include "paddle/fluid/platform/cuda_device_function.h"
#include "paddle/fluid/platform/cuda_primitives.h"
#include "paddle/fluid/platform/fast_divmod.h"
#include "paddle/fluid/platform/gpu_launch_config.h"

#ifdef __HIPCC__
#define BLOCK_SIZE 256
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

宏的名字加一下限定,POOL_BLOCK_SIZE之类的

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

根据建议修改。

#else
#define BLOCK_SIZE 512
#endif

namespace paddle {
namespace operators {
namespace math {

struct FastDivModOfPool {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

一般来说,先线程池这样的概念才叫XxxPool,这个命名不太合适。

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

根据建议修改成 FastDivModForPoolingGrad

public:
platform::FastDivMod channel;
platform::FastDivMod input_w;
platform::FastDivMod input_h;
platform::FastDivMod ksize_w;
platform::FastDivMod ksize_h;
platform::FastDivMod stride_w;
platform::FastDivMod stride_h;

HOSTDEVICE FastDivModOfPool(const int channels, const int input_width,
const int input_height, const int ksize_width,
const int ksize_height, const int stride_width,
const int stride_height) {
channel = platform::FastDivMod(channels);
input_w = platform::FastDivMod(input_width);
input_h = platform::FastDivMod(input_height);
ksize_w = platform::FastDivMod(ksize_width);
ksize_h = platform::FastDivMod(ksize_height);
stride_w = platform::FastDivMod(stride_width);
stride_h = platform::FastDivMod(stride_height);
}
};

template <typename T, typename PoolProcess, typename Enable = void>
struct PoolingFunctor {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个Functor看起来是用于反向计算,命名不够代表实际含义,以及每个函数是用来干什么的?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

根据建议修改成 PoolingGradProcess,并对内部成员方法添加注释。

const T* __restrict__ input_data;
const T* __restrict__ output_data;
T input;

explicit PoolingFunctor(const T* __restrict__ _input_data,
const T* __restrict__ _output_data)
: input_data(_input_data), output_data(_output_data) {}

inline DEVICE void ParameterUpdate(int tid, int output_stride) {
input = input_data[tid];
output_data += output_stride;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

output_data不是定义成了const吗,还可以+=?以及这种方式,感觉不太安全啊。

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

+= 在这里的目的是移动 output_data 所代表的指针地址。

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

感觉这个类型,是为了避免input_dataoutput_data的访存,强行做了一些封装,本身不具备完备的语义跟可解释性。

}

inline HOSTDEVICE void operator()(const T* __restrict__ output_grad,
T* __restrict__ gradient, int pool_size,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

gradient是谁的梯度?这个函数不太像常规的operator(),用具体的函数名代替合适些。

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

gradient会修改成input_grad_dataoperator 会修改成Compute

int index) const {
*gradient +=
output_grad[index] * static_cast<T>(input == output_data[index]);
}
};

/*
Different from MaxPoolGrad, parameters like input_data and
output_data is unnecessary in AvgPoolGrad, individual template
specialization of AvgPoolGrad can gain more kernel performance.
*/
template <typename T, typename PoolProcess>
struct PoolingFunctor<T, PoolProcess,
typename std::enable_if<std::is_same<
PoolProcess, math::AvgPoolGrad<T>>::value>::type> {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Kernel中PoolProcess不会再实际用到了吧?这个PoolProcess只是用来区分Avg和Max的Pool定义,没有实际用于计算,感觉没有必要。

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

嗯,Kenel内部的 PoolProcess 可以消除掉了。

explicit PoolingFunctor(const T* __restrict__ _input_data,
const T* __restrict__ _output_data) {}
inline DEVICE void ParameterUpdate(int tid, int output_stride) {}

inline HOSTDEVICE void operator()(const T* __restrict__ output_grad,
T* __restrict__ gradient, int pool_size,
int index) const {
*gradient += output_grad[index] / static_cast<T>(pool_size);
}
};

int GetThreadsPerBlock(const platform::CUDADeviceContext& ctx,
int threads_per_block, int64_t numel) {
int sm_count = ctx.GetSMCount();
if (numel / (sm_count << 1) < threads_per_block) {
// Round up threads number into an exponential multiple of 2, while number
// of acitve blocks is about twice of SM, to acquire better performance.
threads_per_block = platform::RoundToPowerOfTwo(numel / (sm_count << 1));
} else if (numel / (sm_count << 2) < threads_per_block) {
// Round up threads number into an exponential multiple of 2, while number
// of acitve blocks is about 4 times of SM, to acquire better performance.
threads_per_block = platform::RoundToPowerOfTwo(numel / (sm_count << 2));
}
// Number of threads per block shall be larger than 64.
return std::max(64, threads_per_block);
}

template <typename PoolProcess, typename T>
__global__ void KernelPool2D(const int nthreads, const T* input_data,
const int channels, const int input_height,
Expand Down Expand Up @@ -85,88 +176,110 @@ __global__ void KernelPool2D(const int nthreads, const T* input_data,
output_data[index] = ele;
}
}

template <typename PoolProcess, typename T>
__global__ void KernelPool2DGrad(
const int nthreads, const T* input_data, const T* output_data,
const T* output_grad, const int channels, const int input_height,
const int input_width, const int output_height, const int output_width,
const int ksize_height, const int ksize_width, const int stride_height,
const int stride_width, const int padding_height, const int padding_width,
PoolProcess pool_process, bool exclusive, bool adaptive, T* input_grad,
bool channel_last = false) {
const int nthreads, const T* __restrict__ output_grad,
const int output_height, const int output_width, const int input_width,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

参数列表中统一一下height、width的顺序。

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

根据建议修改。

const int input_height, const int ksize_width, const int ksize_height,
const int stride_width, const int stride_height, FastDivModOfPool divmods,
const int padding_height, const int padding_width,
PoolingFunctor<T, PoolProcess> functor, bool exclusive, bool adaptive,
T* __restrict__ input_grad, bool channel_last = false) {
for (int index = blockIdx.x * blockDim.x + threadIdx.x; index < nthreads;
index += blockDim.x * gridDim.x) {
int w_offset, h_offset, offsetC, batch_idx;
T gradient = static_cast<T>(0);
int w_offset, h_offset, offsetC;
int phstart, phend, pwstart, pwend;
int output_stride;

if (!channel_last) { /* NCHW */
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

NHWC或NCHW这些索引的计算,感觉可能是比较常见的?可以进行封装一下,比如可以定义一个IndexCalculator4d`,并且针对NHWC、NCHW提供一些基础计算函数?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

根据建议会进行封装化修改。

w_offset = index % input_width + padding_width;
h_offset = (index / input_width) % input_height + padding_height;
offsetC = (index / input_width / input_height) % channels;
batch_idx = index / input_width / input_height / channels;
auto input_width_divmod = divmods.input_w.Divmod(index);
auto input_height_divmod =
divmods.input_h.Divmod(input_width_divmod.val[0]);
auto channel_divmod = divmods.channel.Divmod(input_height_divmod.val[0]);
w_offset = input_width_divmod.val[1] + padding_width;
h_offset = input_height_divmod.val[1] + padding_height;
offsetC = channel_divmod.val[1];
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个命名跟其他的变量不太一致

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这块的命名会修改成 channel_offset

output_stride =
(channel_divmod.val[0] * divmods.channel.divisor + offsetC) *
output_height * output_width;
} else { /* NHWC */
offsetC = index % channels;
w_offset = (index / channels) % input_width + padding_width;
h_offset =
(index / channels / input_width) % input_height + padding_height;
batch_idx = index / channels / input_width / input_height;
auto c_divmod = divmods.channel.Divmod(index);
auto input_width_divmod = divmods.input_w.Divmod(c_divmod.val[0]);
auto input_height_divmod =
divmods.input_h.Divmod(input_width_divmod.val[0]);
offsetC = c_divmod.val[1];
w_offset = input_width_divmod.val[1] + padding_width;
h_offset = input_height_divmod.val[1] + padding_height;
output_stride = input_height_divmod.val[0] * output_height *
output_width * divmods.channel.divisor;
}
functor.ParameterUpdate(index, output_stride);
output_grad += output_stride;

int phstart, phend;
int pwstart, pwend;
if (adaptive) {
phstart = AdaptStartIndex(h_offset, output_height, input_height);
phend = AdaptEndIndex(h_offset, output_height, input_height);
auto tmp_phend = divmods.input_h.Divmod((h_offset + 1) * output_height);
auto tmp_pwend = divmods.input_w.Divmod((w_offset + 1) * output_width);
phstart = divmods.input_h.Div(h_offset * output_height);
pwstart = divmods.input_w.Div(w_offset * output_width);
phend = tmp_phend.val[1] > 0 ? tmp_phend.val[0] + 1 : tmp_phend.val[0];
pwend = tmp_pwend.val[1] > 0 ? tmp_pwend.val[0] + 1 : tmp_pwend.val[0];

pwstart = AdaptStartIndex(w_offset, output_width, input_width);
pwend = AdaptEndIndex(w_offset, output_width, input_width);
} else {
phstart = (h_offset < ksize_height)
? 0
: (h_offset - ksize_height) / stride_height + 1;
pwstart = (w_offset < ksize_width)
? 0
: (w_offset - ksize_width) / stride_width + 1;
phend = min(h_offset / stride_height + 1, output_height);
pwend = min(w_offset / stride_width + 1, output_width);
}
T gradient = static_cast<T>(0.0);
T input = input_data[index];

int output_stride;
if (!channel_last) {
output_stride =
(batch_idx * channels + offsetC) * output_height * output_width;
for (int ph = phstart; ph < phend; ++ph) {
for (int pw = pwstart; pw < pwend; ++pw) {
auto ksize_w_divmod = divmods.ksize_w.Divmod(input_width);
auto ksize_h_divmod = divmods.ksize_h.Divmod(input_height);
auto tmp_width = ksize_w_divmod.val[1] > 0 ? ksize_w_divmod.val[0] + 1
: ksize_w_divmod.val[0];
auto tmp_height = ksize_h_divmod.val[1] > 0
? ksize_h_divmod.val[0] + 1
: ksize_h_divmod.val[0];
int pool_size = tmp_height * tmp_width;
int tmp_idx = ph * output_width + pw;
int output_sub_idx = channel_last
? tmp_idx * divmods.channel.divisor + offsetC
: tmp_idx;
functor(output_grad, &gradient, pool_size, output_sub_idx);
}
}
} else {
output_stride = batch_idx * output_height * output_width * channels;
}

output_data += output_stride;
output_grad += output_stride;

for (int ph = phstart; ph < phend; ++ph) {
for (int pw = pwstart; pw < pwend; ++pw) {
int pool_size;
if (adaptive) {
pool_size = static_cast<int>(ceil(static_cast<double>(input_height) /
ksize_height)) *
static_cast<int>(
ceil(static_cast<double>(input_width) / ksize_width));
} else {
int hstart = ph * stride_height - padding_height;
int wstart = pw * stride_width - padding_width;
int hend = min(hstart + ksize_height, input_height);
int wend = min(wstart + ksize_width, input_width);
hstart = max(hstart, 0);
wstart = max(wstart, 0);
pool_size = exclusive ? (hend - hstart) * (wend - wstart)
: ksize_height * ksize_width;
auto stride_height_div = divmods.stride_h.Div(h_offset - ksize_height);
auto stride_width_div = divmods.stride_w.Div(w_offset - ksize_width);
phstart = (h_offset < ksize_height) ? 0 : stride_height_div + 1;
pwstart = (w_offset < ksize_width) ? 0 : stride_width_div + 1;
phend = min(divmods.stride_h.Div(h_offset) + 1, output_height);
pwend = min(divmods.stride_w.Div(w_offset) + 1, output_width);

if (exclusive) {
for (int ph = phstart; ph < phend; ++ph) {
for (int pw = pwstart; pw < pwend; ++pw) {
int hstart = ph * stride_height - padding_height;
int wstart = pw * stride_width - padding_width;
int hend = min(hstart + ksize_height, input_height);
int wend = min(wstart + ksize_width, input_width);
hstart = max(hstart, 0);
wstart = max(wstart, 0);
int pool_size = exclusive ? (hend - hstart) * (wend - wstart)
: ksize_height * ksize_width;
int tmp_idx = ph * output_width + pw;
int output_sub_idx =
channel_last ? tmp_idx * divmods.channel.divisor + offsetC
: tmp_idx;
functor(output_grad, &gradient, pool_size, output_sub_idx);
}
}
} else {
for (int ph = phstart; ph < phend; ++ph) {
for (int pw = pwstart; pw < pwend; ++pw) {
int pool_size = ksize_height * ksize_width;
int tmp_idx = ph * output_width + pw;
int output_sub_idx =
channel_last ? tmp_idx * divmods.channel.divisor + offsetC
: tmp_idx;
functor(output_grad, &gradient, pool_size, output_sub_idx);
}
}

int output_sub_idx = channel_last
? (ph * output_width + pw) * channels + offsetC
: ph * output_width + pw;
pool_process.compute(input, output_data[output_sub_idx],
output_grad[output_sub_idx],
static_cast<T>(1.0 / pool_size), &gradient);
}
}
input_grad[index] = gradient;
Expand Down Expand Up @@ -402,15 +515,19 @@ class Pool2dGradFunctor<platform::CUDADeviceContext, PoolProcess, T> {
T* input_grad_data = input_grad->mutable_data<T>(context.GetPlace());

int nthreads = batch_size * input_channels * input_height * input_width;
int blocks = (nthreads + 1024 - 1) / 1024;
dim3 threads(1024, 1);
dim3 grid(blocks, 1);

KernelPool2DGrad<PoolProcess, T><<<grid, threads, 0, context.stream()>>>(
nthreads, input_data, output_data, output_grad_data, input_channels,
input_height, input_width, output_height, output_width, ksize_height,
ksize_width, stride_height, stride_width, padding_height, padding_width,
pool_process, exclusive, adaptive, input_grad_data);
int blocks = GetThreadsPerBlock(context, BLOCK_SIZE, nthreads);
int grids = (nthreads + blocks - 1) / blocks;

auto pool_divmod =
FastDivModOfPool(input_channels, input_width, input_height, ksize_width,
ksize_height, stride_width, stride_height);
auto pool_functor = PoolingFunctor<T, PoolProcess>(input_data, output_data);
Copy link
Contributor

@Xreki Xreki Sep 12, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个Functor的引入,是为了减少IO?感觉可以基于原PoolProcess改造一下。

Copy link
Contributor Author

@JamesLim-sy JamesLim-sy Sep 12, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

尝试过对于原始PoolProcess的修改,但是实现起来感觉需要在类对象中加入太多成员,就切换成用对CUDA计算进行特化的实现方式了。
原始PoolProcess是链接中的类对象,

template <class T>
class MaxPoolGrad {
public:
DEVICE inline void compute(const T& x, const T& y, const T& dy, T scale,
T* dx) {
*dx += dy * static_cast<T>(x == y);
}
};
template <class T>
class AvgPoolGrad {
public:
DEVICE inline void compute(const T& x, const T& y, const T& dy, T scale,
T* dx) {
*dx += (scale * dy);
}
};

类对象需要同时支持CPU和CUDA计算,其中CPU的计算逻辑和CUDA的计算逻辑差异比较大,CPU的计算逻辑是,input_data 赋值在前,output_data指针地址的偏移滞后,且CPU计算中output_data的指针地址需要伴随循环不断偏移 ,如下:
float scale = 1.0 / pool_size;
for (int d = dstart; d < dend; ++d) {
for (int h = hstart; h < hend; ++h) {
for (int w = wstart; w < wend; ++w) {
int input_idx = (d * input_height + h) * input_width + w;
int output_idx =
(pd * output_height + ph) * output_width + pw;
pool_grad_process.compute(
input_data[input_idx], output_data[output_idx],
output_grad_data[output_idx], static_cast<T>(scale),
input_grad_data + input_idx);
}
}
}
}
}
}
input_data += input_stride;
output_data += output_stride;
input_grad_data += input_stride;
output_grad_data += output_stride;

和CUDA的计算逻辑不太相同,CUDA 的计算逻辑中数据读取和指针偏移都是一次性的,如下:
output_data += output_stride;
output_grad += output_stride;
for (int ph = phstart; ph < phend; ++ph) {
for (int pw = pwstart; pw < pwend; ++pw) {
int pool_size;
if (adaptive) {
pool_size = static_cast<int>(ceil(static_cast<double>(input_height) /
ksize_height)) *
static_cast<int>(
ceil(static_cast<double>(input_width) / ksize_width));
} else {
int hstart = ph * stride_height - padding_height;
int wstart = pw * stride_width - padding_width;
int hend = min(hstart + ksize_height, input_height);
int wend = min(wstart + ksize_width, input_width);
hstart = max(hstart, 0);
wstart = max(wstart, 0);
pool_size = exclusive ? (hend - hstart) * (wend - wstart)
: ksize_height * ksize_width;
}
int output_sub_idx = channel_last
? (ph * output_width + pw) * channels + offsetC
: ph * output_width + pw;
pool_process.compute(input, output_data[output_sub_idx],
output_grad[output_sub_idx],
static_cast<T>(1.0 / pool_size), &gradient);


KernelPool2DGrad<PoolProcess, T><<<grids, blocks, 0, context.stream()>>>(
nthreads, output_grad_data, output_height, output_width, input_width,
input_height, ksize_width, ksize_height, stride_width, stride_height,
pool_divmod, padding_height, padding_width, pool_functor, exclusive,
adaptive, input_grad_data);
}
void operator()(const platform::CUDADeviceContext& context,
const framework::Tensor& input,
Expand All @@ -424,7 +541,6 @@ class Pool2dGradFunctor<platform::CUDADeviceContext, PoolProcess, T> {
bool channel_last = (data_format == "NHWC");

const int batch_size = input.dims()[0];

const int input_channels = channel_last ? input.dims()[3] : input.dims()[1];
const int input_height = channel_last ? input.dims()[1] : input.dims()[2];
const int input_width = channel_last ? input.dims()[2] : input.dims()[3];
Expand All @@ -447,19 +563,22 @@ class Pool2dGradFunctor<platform::CUDADeviceContext, PoolProcess, T> {
const T* input_data = input.data<T>();
const T* output_data = output.data<T>();
const T* output_grad_data = output_grad.data<T>();

T* input_grad_data = input_grad->mutable_data<T>(context.GetPlace());

int nthreads = batch_size * input_channels * input_height * input_width;
int blocks = (nthreads + 1024 - 1) / 1024;
dim3 threads(1024, 1);
dim3 grid(blocks, 1);

KernelPool2DGrad<PoolProcess, T><<<grid, threads, 0, context.stream()>>>(
nthreads, input_data, output_data, output_grad_data, input_channels,
input_height, input_width, output_height, output_width, ksize_height,
ksize_width, stride_height, stride_width, padding_height, padding_width,
pool_process, exclusive, adaptive, input_grad_data, channel_last);
int blocks = GetThreadsPerBlock(context, BLOCK_SIZE, nthreads);
int grids = (nthreads + blocks - 1) / blocks;

auto pool_divmod =
FastDivModOfPool(input_channels, input_width, input_height, ksize_width,
ksize_height, stride_width, stride_height);
auto pool_functor = PoolingFunctor<T, PoolProcess>(input_data, output_data);

KernelPool2DGrad<PoolProcess, T><<<grids, blocks, 0, context.stream()>>>(
nthreads, output_grad_data, output_height, output_width, input_width,
input_height, ksize_width, ksize_height, stride_width, stride_height,
pool_divmod, padding_height, padding_width, pool_functor, exclusive,
adaptive, input_grad_data, channel_last);
}
};

Expand Down