Skip to content

Commit

Permalink
fix some potential bugs
Browse files Browse the repository at this point in the history
  • Loading branch information
ZzSean committed Oct 13, 2021
1 parent 8558ff3 commit ba751e4
Show file tree
Hide file tree
Showing 3 changed files with 17 additions and 18 deletions.
6 changes: 3 additions & 3 deletions paddle/fluid/operators/fused/cudnn_bn_add_relu_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -631,8 +631,8 @@ class CudnnBNAddReluTester {
op::CudnnScaleBiasAddRelu<T> sbar_op(ctx, act_type_, fuse_add_,
has_shortcut_, data_shape, param_shape,
bitmask_shape);
sbar_op.Forward(ctx, x, equiv_scale_x, equiv_bias_x, z, equiv_scale_z,
equiv_bias_z, &y, &bitmask);
sbar_op.Forward(ctx, x, equiv_scale_x, equiv_bias_x, &z, &equiv_scale_z,
&equiv_bias_z, &y, &bitmask);

TensorCopySync(mean_x, platform::CPUPlace(), cpu_mean_x);
TensorCopySync(var_x, platform::CPUPlace(), cpu_var_x);
Expand Down Expand Up @@ -690,7 +690,7 @@ class CudnnBNAddReluTester {
op::CudnnScaleBiasAddRelu<T> sbar_op(ctx, act_type, true, false, data_shape,
param_shape, bitmask_shape);
sbar_op.Backward(ctx, dy, x, bn_scale, bn_bias, saved_mean, saved_var,
bitmask, &dx, &dz, &dscale, &dbias, eps_);
&bitmask, &dx, &dz, &dscale, &dbias, eps_);

TensorCopySync(dx, platform::CPUPlace(), cpu_dx);
TensorCopySync(dz, platform::CPUPlace(), cpu_dz);
Expand Down
17 changes: 9 additions & 8 deletions paddle/fluid/operators/fused/cudnn_scale_bias_add_relu.cu.h
Original file line number Diff line number Diff line change
Expand Up @@ -108,8 +108,8 @@ class CudnnScaleBiasAddRelu {
~CudnnScaleBiasAddRelu() {}

void Forward(const platform::CUDADeviceContext &ctx, const Tensor &x,
const Tensor &x_scale, const Tensor &x_bias, const Tensor &z,
const Tensor &z_scale, const Tensor &z_bias, Tensor *out,
const Tensor &x_scale, const Tensor &x_bias, const Tensor *z,
const Tensor *z_scale, const Tensor *z_bias, Tensor *out,
Tensor *bitmask) {
ForwardInit(ctx);
auto handle = ctx.cudnn_handle();
Expand All @@ -125,15 +125,15 @@ class CudnnScaleBiasAddRelu {
fwd_op_.SetOpVariantParamAttrPtr(CUDNN_PTR_BN_EQSCALE, x_scale_ptr);
fwd_op_.SetOpVariantParamAttrPtr(CUDNN_PTR_BN_EQBIAS, x_bias_ptr);
if (has_shortcut_) {
T *z_ptr = const_cast<T *>(z.data<T>());
T *z_scale_ptr = const_cast<T *>(z_scale.data<T>());
T *z_bias_ptr = const_cast<T *>(z_bias.data<T>());
T *z_ptr = const_cast<T *>(z->data<T>());
T *z_scale_ptr = const_cast<T *>(z_scale->data<T>());
T *z_bias_ptr = const_cast<T *>(z_bias->data<T>());
fwd_op_.SetOpVariantParamAttrPtr(CUDNN_PTR_ZDATA, z_ptr);
fwd_op_.SetOpVariantParamAttrPtr(CUDNN_PTR_BN_Z_EQSCALE, z_scale_ptr);
fwd_op_.SetOpVariantParamAttrPtr(CUDNN_PTR_BN_Z_EQBIAS, z_bias_ptr);
} else {
if (fused_add_) {
T *z_ptr = const_cast<T *>(z.data<T>());
T *z_ptr = const_cast<T *>(z->data<T>());
fwd_op_.SetOpVariantParamAttrPtr(CUDNN_PTR_ZDATA, z_ptr);
}
}
Expand All @@ -160,7 +160,7 @@ class CudnnScaleBiasAddRelu {
void Backward(const platform::CUDADeviceContext &ctx, const Tensor &dy,
const Tensor &x, const Tensor &scale, const Tensor &bias,
const Tensor &saved_mean, const Tensor &saved_invstd,
const Tensor &bitmask, Tensor *dx, Tensor *dz, Tensor *dscale,
const Tensor *bitmask, Tensor *dx, Tensor *dz, Tensor *dscale,
Tensor *dbias, double eps) {
BackwardInit(ctx);
auto handle = ctx.cudnn_handle();
Expand All @@ -175,7 +175,8 @@ class CudnnScaleBiasAddRelu {
float *bias_ptr = const_cast<float *>(bias.data<float>());
float *saved_mean_ptr = const_cast<float *>(saved_mean.data<float>());
float *saved_invstd_ptr = const_cast<float *>(saved_invstd.data<float>());
int32_t *bitmask_ptr = const_cast<int32_t *>(bitmask.data<int32_t>());
int32_t *bitmask_ptr =
bitmask ? const_cast<int32_t *>(bitmask->data<int32_t>()) : nullptr;
T *dx_ptr = dx->mutable_data<T>(place);
T *dz_ptr = dz ? dz->mutable_data<T>(place) : nullptr;
float *dscale_ptr = dscale ? dscale->mutable_data<float>(place) : nullptr;
Expand Down
12 changes: 5 additions & 7 deletions paddle/fluid/operators/fused/resnet_unit_op.cu
Original file line number Diff line number Diff line change
Expand Up @@ -145,14 +145,12 @@ class ResNetUnitKernel : public framework::OpKernel<T> {
momentum, ele_count, is_train);
// 3.3 sbar
sbar_op.Forward(dev_ctx, *conv_out_x, equiv_scale_x, equiv_bias_x,
*conv_out_z, equiv_scale_z, equiv_bias_z, output,
conv_out_z, &equiv_scale_z, &equiv_bias_z, output,
bitmask);
} else {
const Tensor *input_z = fused_add ? ctx.Input<Tensor>("Z") : nullptr;
Tensor equiv_scale_z;
Tensor equiv_bias_z;
sbar_op.Forward(dev_ctx, *conv_out_x, equiv_scale_x, equiv_bias_x,
*input_z, equiv_scale_z, equiv_bias_z, output, bitmask);
input_z, nullptr, nullptr, output, bitmask);
}
}
};
Expand Down Expand Up @@ -245,7 +243,7 @@ class ResNetUnitGradKernel : public framework::OpKernel<T> {
Tensor z_grad_temp;
z_grad_temp.Resize(conv_out_z->dims());
sbar_x_op.Backward(dev_ctx, *y_grad, *conv_out_x, *scale_x, *bias_x,
*saved_mean_x, *saved_invstd_x, *bitmask,
*saved_mean_x, *saved_invstd_x, bitmask,
&conv_out_x_grad, &z_grad_temp, scale_x_grad,
bias_x_grad, eps);

Expand All @@ -255,7 +253,7 @@ class ResNetUnitGradKernel : public framework::OpKernel<T> {
CudnnScaleBiasAddRelu<T> sbar_z_op(
dev_ctx, "", false, false, output_shape, param_shape, bitmask_shape);
sbar_z_op.Backward(dev_ctx, z_grad_temp, *conv_out_z, *scale_z, *bias_z,
*saved_mean_z, *saved_invstd_z, *bitmask,
*saved_mean_z, *saved_invstd_z, nullptr,
&conv_out_z_grad, nullptr, scale_z_grad, bias_z_grad,
eps);

Expand All @@ -273,7 +271,7 @@ class ResNetUnitGradKernel : public framework::OpKernel<T> {
Tensor *z_grad =
fused_add ? ctx.Output<Tensor>(framework::GradVarName("Z")) : nullptr;
sbar_x_op.Backward(dev_ctx, *y_grad, *conv_out_x, *scale_x, *bias_x,
*saved_mean_x, *saved_invstd_x, *bitmask,
*saved_mean_x, *saved_invstd_x, bitmask,
&conv_out_x_grad, z_grad, scale_x_grad, bias_x_grad,
eps);
}
Expand Down

0 comments on commit ba751e4

Please sign in to comment.