Skip to content

Commit

Permalink
Add the complete code and related files of resnet_unit_op (#36366)
Browse files Browse the repository at this point in the history
  • Loading branch information
ZzSean authored Oct 14, 2021
1 parent bed4fb2 commit 12e6dbb
Show file tree
Hide file tree
Showing 8 changed files with 768 additions and 40 deletions.
2 changes: 1 addition & 1 deletion cmake/operators.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -216,7 +216,7 @@ function(op_library TARGET)
"fusion_transpose_flatten_concat_op" "fusion_conv_inception_op"
"sync_batch_norm_op" "sparse_attention_op" "dgc_op" "fused_fc_elementwise_layernorm_op"
"skip_layernorm_op" "multihead_matmul_op" "fusion_group_op" "fused_bn_activation_op" "fused_embedding_eltwise_layernorm_op" "fusion_gru_op" "fusion_lstm_op"
"fused_bn_add_activation_op")
"fused_bn_add_activation_op" "resnet_unit_op")
if ("${TARGET}" STREQUAL "${manual_pybind_op}")
set(pybind_flag 1)
endif()
Expand Down
6 changes: 5 additions & 1 deletion paddle/fluid/operators/fused/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,8 @@ register_operators(EXCLUDES
fusion_gru_op
fusion_lstm_op
fused_bn_add_activation_op
fused_transformer_op)
fused_transformer_op
resnet_unit_op)

# fusion_gru_op does not have CUDA kernel
op_library(fusion_gru_op)
Expand Down Expand Up @@ -78,7 +79,10 @@ if (WITH_GPU OR WITH_ROCM)
nv_test(test_fused_dropout_act_bias SRCS fused_dropout_act_bias_test.cu DEPS tensor op_registry dropout_op layer_norm_op device_context generator memory)
nv_test(test_fused_layernorm_residual_dropout_bias SRCS fused_layernorm_residual_dropout_bias_test.cu DEPS tensor op_registry dropout_op layer_norm_op device_context generator memory)
endif()
# resnet_unit needs cudnn 8.0 above
if ((NOT WITH_ROCM) AND (NOT ${CUDNN_VERSION} VERSION_LESS 8000))
op_library(resnet_unit_op)
file(APPEND ${pybind_file} "USE_CUDA_ONLY_OP(resnet_unit);\n")
cc_test(test_cudnn_norm_conv SRCS cudnn_norm_conv_test.cc DEPS conv_op blas im2col vol2col depthwise_conv eigen_function tensor op_registry device_context generator memory)
cc_test(test_cudnn_bn_add_relu SRCS cudnn_bn_add_relu_test.cc DEPS batch_norm_op fused_bn_add_activation_op tensor op_registry device_context generator memory)
endif()
Expand Down
6 changes: 3 additions & 3 deletions paddle/fluid/operators/fused/cudnn_bn_add_relu_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -631,8 +631,8 @@ class CudnnBNAddReluTester {
op::CudnnScaleBiasAddRelu<T> sbar_op(ctx, act_type_, fuse_add_,
has_shortcut_, data_shape, param_shape,
bitmask_shape);
sbar_op.Forward(ctx, x, equiv_scale_x, equiv_bias_x, z, equiv_scale_z,
equiv_bias_z, &y, &bitmask);
sbar_op.Forward(ctx, x, equiv_scale_x, equiv_bias_x, &z, &equiv_scale_z,
&equiv_bias_z, &y, &bitmask);

TensorCopySync(mean_x, platform::CPUPlace(), cpu_mean_x);
TensorCopySync(var_x, platform::CPUPlace(), cpu_var_x);
Expand Down Expand Up @@ -690,7 +690,7 @@ class CudnnBNAddReluTester {
op::CudnnScaleBiasAddRelu<T> sbar_op(ctx, act_type, true, false, data_shape,
param_shape, bitmask_shape);
sbar_op.Backward(ctx, dy, x, bn_scale, bn_bias, saved_mean, saved_var,
bitmask, &dx, &dz, &dscale, &dbias, eps_);
&bitmask, &dx, &dz, &dscale, &dbias, eps_);

TensorCopySync(dx, platform::CPUPlace(), cpu_dx);
TensorCopySync(dz, platform::CPUPlace(), cpu_dz);
Expand Down
10 changes: 6 additions & 4 deletions paddle/fluid/operators/fused/cudnn_fusion_helper.h
Original file line number Diff line number Diff line change
Expand Up @@ -38,10 +38,12 @@ class CudnnFusionOp {
&op_variant_params_, op_id));
}

~CudnnFusionOp() {
dynload::cudnnDestroyFusedOpsVariantParamPack(op_variant_params_);
dynload::cudnnDestroyFusedOpsConstParamPack(op_const_params_);
dynload::cudnnDestroyFusedOpsPlan(op_);
~CudnnFusionOp() PADDLE_MAY_THROW {
PADDLE_ENFORCE_CUDA_SUCCESS(
dynload::cudnnDestroyFusedOpsVariantParamPack(op_variant_params_));
PADDLE_ENFORCE_CUDA_SUCCESS(
dynload::cudnnDestroyFusedOpsConstParamPack(op_const_params_));
PADDLE_ENFORCE_CUDA_SUCCESS(dynload::cudnnDestroyFusedOpsPlan(op_));
}

// Execute fused op
Expand Down
35 changes: 18 additions & 17 deletions paddle/fluid/operators/fused/cudnn_scale_bias_add_relu.cu.h
Original file line number Diff line number Diff line change
Expand Up @@ -94,22 +94,22 @@ template <typename T>
class CudnnScaleBiasAddRelu {
public:
CudnnScaleBiasAddRelu(const platform::CUDADeviceContext &ctx,
const std::string &act_type, bool fused_add,
const std::string &act_type, bool fuse_add,
bool has_shortcut, const std::vector<int> &data_shape,
const std::vector<int> &param_shape,
const std::vector<int> &bitmask_shape)
: fwd_op_(CUDNN_FUSED_SCALE_BIAS_ADD_ACTIVATION_GEN_BITMASK),
bwd_op_(CUDNN_FUSED_DACTIVATION_FORK_DBATCHNORM) {
fused_add_ = fused_add;
fuse_add_ = fuse_add;
has_shortcut_ = has_shortcut;
args_.Set(act_type, data_shape, param_shape, bitmask_shape);
}

~CudnnScaleBiasAddRelu() {}

void Forward(const platform::CUDADeviceContext &ctx, const Tensor &x,
const Tensor &x_scale, const Tensor &x_bias, const Tensor &z,
const Tensor &z_scale, const Tensor &z_bias, Tensor *out,
const Tensor &x_scale, const Tensor &x_bias, const Tensor *z,
const Tensor *z_scale, const Tensor *z_bias, Tensor *out,
Tensor *bitmask) {
ForwardInit(ctx);
auto handle = ctx.cudnn_handle();
Expand All @@ -125,15 +125,15 @@ class CudnnScaleBiasAddRelu {
fwd_op_.SetOpVariantParamAttrPtr(CUDNN_PTR_BN_EQSCALE, x_scale_ptr);
fwd_op_.SetOpVariantParamAttrPtr(CUDNN_PTR_BN_EQBIAS, x_bias_ptr);
if (has_shortcut_) {
T *z_ptr = const_cast<T *>(z.data<T>());
T *z_scale_ptr = const_cast<T *>(z_scale.data<T>());
T *z_bias_ptr = const_cast<T *>(z_bias.data<T>());
T *z_ptr = const_cast<T *>(z->data<T>());
T *z_scale_ptr = const_cast<T *>(z_scale->data<T>());
T *z_bias_ptr = const_cast<T *>(z_bias->data<T>());
fwd_op_.SetOpVariantParamAttrPtr(CUDNN_PTR_ZDATA, z_ptr);
fwd_op_.SetOpVariantParamAttrPtr(CUDNN_PTR_BN_Z_EQSCALE, z_scale_ptr);
fwd_op_.SetOpVariantParamAttrPtr(CUDNN_PTR_BN_Z_EQBIAS, z_bias_ptr);
} else {
if (fused_add_) {
T *z_ptr = const_cast<T *>(z.data<T>());
if (fuse_add_) {
T *z_ptr = const_cast<T *>(z->data<T>());
fwd_op_.SetOpVariantParamAttrPtr(CUDNN_PTR_ZDATA, z_ptr);
}
}
Expand All @@ -160,7 +160,7 @@ class CudnnScaleBiasAddRelu {
void Backward(const platform::CUDADeviceContext &ctx, const Tensor &dy,
const Tensor &x, const Tensor &scale, const Tensor &bias,
const Tensor &saved_mean, const Tensor &saved_invstd,
const Tensor &bitmask, Tensor *dx, Tensor *dz, Tensor *dscale,
const Tensor *bitmask, Tensor *dx, Tensor *dz, Tensor *dscale,
Tensor *dbias, double eps) {
BackwardInit(ctx);
auto handle = ctx.cudnn_handle();
Expand All @@ -175,7 +175,8 @@ class CudnnScaleBiasAddRelu {
float *bias_ptr = const_cast<float *>(bias.data<float>());
float *saved_mean_ptr = const_cast<float *>(saved_mean.data<float>());
float *saved_invstd_ptr = const_cast<float *>(saved_invstd.data<float>());
int32_t *bitmask_ptr = const_cast<int32_t *>(bitmask.data<int32_t>());
int32_t *bitmask_ptr =
bitmask ? const_cast<int32_t *>(bitmask->data<int32_t>()) : nullptr;
T *dx_ptr = dx->mutable_data<T>(place);
T *dz_ptr = dz ? dz->mutable_data<T>(place) : nullptr;
float *dscale_ptr = dscale ? dscale->mutable_data<float>(place) : nullptr;
Expand All @@ -199,7 +200,7 @@ class CudnnScaleBiasAddRelu {
bwd_op_.SetOpVariantParamAttrPtr(CUDNN_PTR_BN_DBIAS, dbias_ptr);
bwd_op_.SetOpVariantParamAttrPtr<double>(CUDNN_SCALAR_DOUBLE_BN_EPSILON,
&eps);
if (has_shortcut_ || fused_add_) {
if (has_shortcut_ || fuse_add_) {
bwd_op_.SetOpVariantParamAttrPtr(CUDNN_PTR_DZDATA, dz_ptr);
}

Expand All @@ -226,14 +227,14 @@ class CudnnScaleBiasAddRelu {
{CUDNN_PARAM_ZDATA_PLACEHOLDER, CUDNN_PARAM_BN_Z_EQSCALE_PLACEHOLDER,
CUDNN_PARAM_BN_Z_EQBIAS_PLACEHOLDER},
CUDNN_PTR_16B_ALIGNED);
} else if (fused_add_) {
} else if (fuse_add_) {
fwd_op_.SetOpConstParamAttr(CUDNN_PARAM_ZDATA_PLACEHOLDER,
CUDNN_PTR_16B_ALIGNED);
}

// input desc
fwd_op_.SetOpConstParamDesc(CUDNN_PARAM_XDESC, args_.in_desc.desc());
if (has_shortcut_ || fused_add_) {
if (has_shortcut_ || fuse_add_) {
fwd_op_.SetOpConstParamDesc(CUDNN_PARAM_ZDESC, args_.in_desc.desc());
}

Expand Down Expand Up @@ -271,15 +272,15 @@ class CudnnScaleBiasAddRelu {
CUDNN_PARAM_BN_DSCALE_PLACEHOLDER, CUDNN_PARAM_BN_DBIAS_PLACEHOLDER,
CUDNN_PARAM_ACTIVATION_BITMASK_PLACEHOLDER},
CUDNN_PTR_16B_ALIGNED);
if (has_shortcut_ || fused_add_) {
if (has_shortcut_ || fuse_add_) {
bwd_op_.SetOpConstParamAttr(CUDNN_PARAM_DZDATA_PLACEHOLDER,
CUDNN_PTR_16B_ALIGNED);
}

// input desc
bwd_op_.SetOpConstParamDesc(CUDNN_PARAM_XDESC, args_.in_desc.desc());
bwd_op_.SetOpConstParamDesc(CUDNN_PARAM_DXDESC, args_.in_desc.desc());
if (has_shortcut_ || fused_add_) {
if (has_shortcut_ || fuse_add_) {
bwd_op_.SetOpConstParamDesc(CUDNN_PARAM_DZDESC, args_.in_desc.desc());
}

Expand All @@ -303,7 +304,7 @@ class CudnnScaleBiasAddRelu {
CUDNN_BATCHNORM_SPATIAL_PERSISTENT);
}

bool fused_add_ = false;
bool fuse_add_ = false;
bool has_shortcut_ = false;
size_t fwd_workspace_byte_;
size_t bwd_workspace_byte_;
Expand Down
Loading

0 comments on commit 12e6dbb

Please sign in to comment.