Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

【Hackathon No.56&38】deformable_conv_v1 算子实现 float16 数据类型支持&前向运行加速 #46111

Merged
merged 24 commits into from
Oct 10, 2022
Merged
Show file tree
Hide file tree
Changes from 14 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
82 changes: 43 additions & 39 deletions paddle/phi/kernels/cpu/deformable_conv_grad_kernel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ inline void ModulatedDeformableCol2imCPUKernel(
const int height_col,
const int width_col,
T* grad_im) {
using MT = typename phi::dtype::MPTypeTrait<T>::Type;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

不要求支持CPU的FP16 Kernel,因此这个文件暂时不要修改。

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

有的函数由于cc和cu文件共用,就在cpu的文件中增加了少量代码以适配。

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

猜测有没有一种可能,单测出错的原因是因为你这里使用了MT?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

猜测有没有一种可能,单测出错的原因是因为你这里使用了MT?

好像有可能🤔 我在cc文件里不用MT试试

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

猜测有没有一种可能,单测出错的原因是因为你这里使用了MT?

仔细看了下单测的过程,应该不是由于kernel中代码的问题。是因为op_test内check_grad使用的是(y_neg-y_pos) / delta /2的数值方式,而y_neg和y_pos本身是float16的小数位不够,精度不够高,导致分子无法表示真实差值。感觉在float16类型是可能用数值定义法来测试好像不太行😂。
目前我是通过人工核验float16和float32情况下的grad一致性,因为float32是可以利用数值法核验的。
image

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

目前我是通过人工核验float16和float32情况下的grad一致性,因为float32是可以利用数值法核验的。

「人工校验」的方式,是说和python实现的reference版本结果进行对比?确保如下2个方面:
(1)np数据需要显式定义数据类型,默认是double
(2)reference实现也务必使用float16作为输入输出、但是使用float类型进行计算,保证reference实现和cuda kernel中实现的方式是一致。

另外,conv类算子可能本身误差较大,我看fp32的单测,max_relative_error已经设置到0.05、0.1这么大了。

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

目前我是通过人工核验float16和float32情况下的grad一致性,因为float32是可以利用数值法核验的。

「人工校验」的方式,是说和python实现的reference版本结果进行对比?确保如下2个方面: (1)np数据需要显式定义数据类型,默认是double (2)reference实现也务必使用float16作为输入输出、但是使用float类型进行计算,保证reference实现和cuda kernel中实现的方式是一致。

另外,conv类算子可能本身误差较大,我看fp32的单测,max_relative_error已经设置到0.05、0.1这么大了。

不是的,是直接使用c++版本的float32和float16二者的grad打印出来进行比较,可以参考这里的截图。https://github.com/PaddlePaddle/Paddle/pull/46111#issuecomment-1253283724。
二者一致应该可以在人工上确认我fp16的实现上应该是没问题的。
ps:np目前python版只有前向infer的实现,fp32和fp16是没有问题的。目前单测的后向计算grad都是统一使用op_test.py中的数值定义法(本身存在精度不够和误差大的问题),这也是造成grad无法通过单测的原因。

for (int thread = 0; thread < num_kernels; thread++) {
const int j = (thread / width_col / height_col / batch_size) % kernel_w;
const int i =
Expand All @@ -67,17 +68,17 @@ inline void ModulatedDeformableCol2imCPUKernel(
((2 * (i * kernel_w + j) + 1) * height_col + h_out) * width_col + w_out;
const int data_mask_hw_ptr =
((i * kernel_w + j) * height_col + h_out) * width_col + w_out;
const T offset_h = data_offset_ptr[data_offset_h_ptr];
const T offset_w = data_offset_ptr[data_offset_w_ptr];
const T cur_inv_h_data = h_in + i * dilation_h + offset_h;
const T cur_inv_w_data = w_in + j * dilation_w + offset_w;
const MT offset_h = static_cast<MT>(data_offset_ptr[data_offset_h_ptr]);
const MT offset_w = static_cast<MT>(data_offset_ptr[data_offset_w_ptr]);
const MT cur_inv_h_data = h_in + i * dilation_h + offset_h;
const MT cur_inv_w_data = w_in + j * dilation_w + offset_w;

T cur_top_grad = data_col[thread];
MT cur_top_grad = static_cast<MT>(data_col[thread]);
if (data_mask) {
const T* data_mask_ptr =
data_mask + (b * deformable_group + deformable_group_index) *
kernel_h * kernel_w * height_col * width_col;
const T mask = data_mask_ptr[data_mask_hw_ptr];
const MT mask = static_cast<MT>(data_mask_ptr[data_mask_hw_ptr]);
cur_top_grad *= mask;
}
const int cur_h = static_cast<int>(cur_inv_h_data);
Expand All @@ -89,22 +90,23 @@ inline void ModulatedDeformableCol2imCPUKernel(
abs(cur_inv_w_data - (cur_w + dx)) < 1) {
int cur_bottom_grad_pos =
((b * channels + c) * height + cur_h + dy) * width + cur_w + dx;
T weight = DmcnGetGradientWeight(cur_inv_h_data,
cur_inv_w_data,
cur_h + dy,
cur_w + dx,
height,
width);
MT weight = DmcnGetGradientWeight(cur_inv_h_data,
cur_inv_w_data,
cur_h + dy,
cur_w + dx,
height,
width);

*(grad_im + cur_bottom_grad_pos) =
*(grad_im + cur_bottom_grad_pos) + weight * cur_top_grad;
*(grad_im + cur_bottom_grad_pos) +
static_cast<T>(weight * cur_top_grad);
}
}
}
}
}

template <typename T, typename Context>
template <typename T, typename MT, typename Context>
void ModulatedDeformableCol2im(const Context& dev_ctx,
const T* data_col,
const T* data_offset,
Expand All @@ -116,7 +118,7 @@ void ModulatedDeformableCol2im(const Context& dev_ctx,
const std::vector<int>& stride,
const std::vector<int>& dilation,
const int deformable_group,
T* grad_im) {
MT* grad_im) {
int channel_per_deformable_group = im_shape[0] / deformable_group;
int num_kernels = col_shape[0] * col_shape[1] * col_shape[2] * col_shape[3];

Expand Down Expand Up @@ -169,8 +171,9 @@ void ModulatedDeformableCol2imCoordCPUKernel(
const int width_col,
T* grad_offset,
T* grad_mask) {
using MT = typename phi::dtype::MPTypeTrait<T>::Type;
for (int i = 0; i < num_kernels; i++) {
T val = 0, mval = 0;
MT val = 0, mval = 0;
const int w = i % width_col;
const int h = (i / width_col) % height_col;
const int c = (i / width_col / height_col) % offset_channels;
Expand Down Expand Up @@ -215,40 +218,41 @@ void ModulatedDeformableCol2imCoordCPUKernel(
const int data_offset_w_ptr =
(((2 * (i * kernel_w + j) + 1) * height_col + h_out) * width_col +
w_out);
const T offset_h = data_offset_ptr[data_offset_h_ptr];
const T offset_w = data_offset_ptr[data_offset_w_ptr];
T inv_h = h_in + i * dilation_h + offset_h;
T inv_w = w_in + j * dilation_w + offset_w;
const MT offset_h = static_cast<MT>(data_offset_ptr[data_offset_h_ptr]);
const MT offset_w = static_cast<MT>(data_offset_ptr[data_offset_w_ptr]);
MT inv_h = h_in + i * dilation_h + offset_h;
MT inv_w = w_in + j * dilation_w + offset_w;
if (inv_h <= -1 || inv_w <= -1 || inv_h >= height || inv_w >= width) {
inv_h = inv_w = -2;
} else {
mval += data_col_ptr[col_pos] *
funcs::DmcnIm2colBilinear(data_im_ptr + cnt * height * width,
width,
height,
width,
inv_h,
inv_w);
mval +=
static_cast<MT>(data_col_ptr[col_pos]) *
funcs::DmcnIm2colBilinear<T, MT>(data_im_ptr + cnt * height * width,
width,
height,
width,
inv_h,
inv_w);
}
const T weight =
DmcnGetCoordinateWeight(inv_h,
inv_w,
height,
width,
data_im_ptr + cnt * height * width,
width,
bp_dir);
const MT weight =
DmcnGetCoordinateWeight<T, MT>(inv_h,
inv_w,
height,
width,
data_im_ptr + cnt * height * width,
width,
bp_dir);
if (data_mask_ptr) {
const int data_mask_hw_ptr =
(((i * kernel_w + j) * height_col + h_out) * width_col + w_out);
const T mask = data_mask_ptr[data_mask_hw_ptr];
val += weight * data_col_ptr[col_pos] * mask;
const MT mask = static_cast<MT>(data_mask_ptr[data_mask_hw_ptr]);
val += weight * static_cast<MT>(data_col_ptr[col_pos]) * mask;
} else {
val += weight * data_col_ptr[col_pos];
val += weight * static_cast<MT>(data_col_ptr[col_pos]);
}
cnt += 1;
}
grad_offset[i] = val;
grad_offset[i] = static_cast<T>(val);
if (grad_mask && offset_c % 2 == 0)
grad_mask[(((b * deformable_group + deformable_group_index) * kernel_h *
kernel_w +
Expand Down
25 changes: 14 additions & 11 deletions paddle/phi/kernels/funcs/deformable_conv_functor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,10 @@
// limitations under the License.

#include "paddle/phi/kernels/funcs/deformable_conv_functor.h"

#include "paddle/fluid/platform/device_context.h"
#include "paddle/fluid/platform/float16.h"
#include "paddle/phi/backends/cpu/cpu_context.h"
#include "paddle/phi/common/amp_type_traits.h"

namespace phi {
namespace funcs {
Expand Down Expand Up @@ -42,6 +44,7 @@ inline void ModulatedDeformableIm2colCPUKernel(
const int height_col,
const int width_col,
T* data_col) {
using MT = typename phi::dtype::MPTypeTrait<T>::Type;
for (int i = 0; i < num_kernels; i++) {
const int w_col = i % width_col;
const int h_col = (i / width_col) % height_col;
Expand Down Expand Up @@ -76,22 +79,22 @@ inline void ModulatedDeformableIm2colCPUKernel(
((2 * (i * kernel_w + j) + 1) * height_col + h_col) * width_col +
w_col;

const T offset_h = data_offset_ptr[data_offset_h_ptr];
const T offset_w = data_offset_ptr[data_offset_w_ptr];
T val = static_cast<T>(0);
const T h_im = h_in + i * dilation_h + offset_h;
const T w_im = w_in + j * dilation_w + offset_w;
const MT offset_h = static_cast<MT>(data_offset_ptr[data_offset_h_ptr]);
const MT offset_w = static_cast<MT>(data_offset_ptr[data_offset_w_ptr]);
MT val = static_cast<MT>(0);
const MT h_im = h_in + i * dilation_h + offset_h;
const MT w_im = w_in + j * dilation_w + offset_w;
if (h_im > -1 && w_im > -1 && h_im < height && w_im < width) {
val =
DmcnIm2colBilinear(data_im_ptr, width, height, width, h_im, w_im);
val = DmcnIm2colBilinear<T, MT>(
data_im_ptr, width, height, width, h_im, w_im);
}
*data_col_ptr = val;
if (data_mask_ptr) {
const int data_mask_hw_ptr =
((i * kernel_w + j) * height_col + h_col) * width_col + w_col;
const T mask = data_mask_ptr[data_mask_hw_ptr];
*data_col_ptr *= mask;
const MT mask = static_cast<MT>(data_mask_ptr[data_mask_hw_ptr]);
val *= mask;
}
*data_col_ptr = static_cast<T>(val);
data_col_ptr += batch_size * height_col * width_col;
}
}
Expand Down
41 changes: 30 additions & 11 deletions paddle/phi/kernels/funcs/deformable_conv_functor.cu
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,11 @@
// See the License for the specific language governing permissions and
// limitations under the License.

#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/kernels/funcs/deformable_conv_functor.h"
#include "paddle/fluid/platform/device_context.h"
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/common/amp_type_traits.h"
#include "paddle/phi/common/float16.h"

namespace phi {
namespace funcs {
Expand Down Expand Up @@ -51,6 +54,8 @@ __global__ void ModulatedDeformableIm2colGpuKernel(
T* data_col) {
int index = blockIdx.x * blockDim.x + threadIdx.x;
int offset = blockDim.x * gridDim.x;

using MT = typename phi::dtype::MPTypeTrait<T>::Type;
for (size_t i = index; i < nthreads; i += offset) {
const int w_col = i % width_col;
const int h_col = (i / width_col) % height_col;
Expand Down Expand Up @@ -85,22 +90,22 @@ __global__ void ModulatedDeformableIm2colGpuKernel(
((2 * (i * kernel_w + j) + 1) * height_col + h_col) * width_col +
w_col;

const T offset_h = data_offset_ptr[data_offset_h_ptr];
const T offset_w = data_offset_ptr[data_offset_w_ptr];
T val = static_cast<T>(0);
const T h_im = h_in + i * dilation_h + offset_h;
const T w_im = w_in + j * dilation_w + offset_w;
const MT offset_h = static_cast<MT>(data_offset_ptr[data_offset_h_ptr]);
const MT offset_w = static_cast<MT>(data_offset_ptr[data_offset_w_ptr]);
MT val = static_cast<MT>(0);
const MT h_im = h_in + i * dilation_h + offset_h;
const MT w_im = w_in + j * dilation_w + offset_w;
if (h_im > -1 && w_im > -1 && h_im < height && w_im < width) {
val =
DmcnIm2colBilinear(data_im_ptr, width, height, width, h_im, w_im);
val = DmcnIm2colBilinear<T, MT>(
data_im_ptr, width, height, width, h_im, w_im);
}
*data_col_ptr = val;
if (data_mask_ptr) {
const int data_mask_hw_ptr =
((i * kernel_w + j) * height_col + h_col) * width_col + w_col;
const T mask = data_mask_ptr[data_mask_hw_ptr];
*data_col_ptr *= mask;
const MT mask = static_cast<MT>(data_mask_ptr[data_mask_hw_ptr]);
val *= mask;
}
*data_col_ptr = static_cast<T>(val);
data_col_ptr += batch_size * height_col * width_col;
}
}
Expand Down Expand Up @@ -164,6 +169,20 @@ template void ModulatedDeformableIm2col(
const int deformable_groups,
float* data_col);

template void ModulatedDeformableIm2col(
const phi::GPUContext& dev_ctx,
const phi::dtype::float16* data_im,
const phi::dtype::float16* data_offset,
const phi::dtype::float16* data_mask,
const std::vector<int64_t>& im_shape,
const std::vector<int64_t>& col_shape,
const std::vector<int64_t>& filter_shape,
const std::vector<int>& paddings,
const std::vector<int>& strides,
const std::vector<int>& dilations,
const int deformable_groups,
phi::dtype::float16* data_col);

template void ModulatedDeformableIm2col(
const phi::GPUContext& dev_ctx,
const double* data_im,
Expand Down
55 changes: 29 additions & 26 deletions paddle/phi/kernels/funcs/deformable_conv_functor.h
Original file line number Diff line number Diff line change
Expand Up @@ -14,44 +14,47 @@

#pragma once

#include "paddle/phi/common/amp_type_traits.h"
#include "paddle/phi/common/float16.h"
#include "paddle/phi/core/dense_tensor.h"

namespace phi {
namespace funcs {

template <typename T>
HOSTDEVICE T DmcnIm2colBilinear(const T* bottom_data,
const int data_width,
const int height,
const int width,
T h,
T w) {
template <typename T, typename MT>
HOSTDEVICE MT DmcnIm2colBilinear(const T* bottom_data,
const int data_width,
const int height,
const int width,
MT h,
MT w) {
int h_low = floor(h);
int w_low = floor(w);
int h_high = h_low + 1;
int w_high = w_low + 1;

T lh = h - h_low;
T lw = w - w_low;
T hh = 1 - lh;
T hw = 1 - lw;
MT lh = h - h_low;
MT lw = w - w_low;
MT hh = 1 - lh;
MT hw = 1 - lw;

T v1 =
(h_low >= 0 && w_low >= 0) ? bottom_data[h_low * data_width + w_low] : 0;
T v2 = (h_low >= 0 && w_high <= width - 1)
? bottom_data[h_low * data_width + w_high]
: 0;
T v3 = (h_high <= height - 1 && w_low >= 0)
? bottom_data[h_high * data_width + w_low]
: 0;
T v4 = (h_high <= height - 1 && w_high <= width - 1)
? bottom_data[h_high * data_width + w_high]
: 0;
MT v1 = (h_low >= 0 && w_low >= 0)
? static_cast<MT>(bottom_data[h_low * data_width + w_low])
: 0;
MT v2 = (h_low >= 0 && w_high <= width - 1)
? static_cast<MT>(bottom_data[h_low * data_width + w_high])
: 0;
MT v3 = (h_high <= height - 1 && w_low >= 0)
? static_cast<MT>(bottom_data[h_high * data_width + w_low])
: 0;
MT v4 = (h_high <= height - 1 && w_high <= width - 1)
? static_cast<MT>(bottom_data[h_high * data_width + w_high])
: 0;

T w1 = hh * hw;
T w2 = hh * lw;
T w3 = lh * hw;
T w4 = lh * lw;
MT w1 = hh * hw;
MT w2 = hh * lw;
MT w3 = lh * hw;
MT w4 = lh * lw;

return w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4;
}
Expand Down
Loading